-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speculative decoding: Test the target distribution (to prevent issues like #32867) #34553
base: main
Are you sure you want to change the base?
Speculative decoding: Test the target distribution (to prevent issues like #32867) #34553
Conversation
1be059c
to
5522333
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on improving our tests 💛
A question: is this test somewhat fast to run (<5s)? If yes, amazing! If no, let's either a) reduce the number in range
or b) tag the test as @slow
[note: tests with @slow
are usually run daily, so bad commits may squeeze in]
tests/generation/test_utils.py
Outdated
[ | ||
-inf, | ||
2.0, | ||
-inf, | ||
1.0, | ||
-inf, | ||
-inf, | ||
-inf, | ||
-0.01, | ||
2.0, | ||
-inf, | ||
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's make it in one line, so we can quickly compare indexes with other tensors.
(you'll have to remove the comma after the last -inf
, otherwise the make fixup
command will make it revert back to this format)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the formatting as requested, but ruff
's formatting check then failed the CI. (make fixup
still reformats it into a column, even after removing the last comma you mentioned)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use # fmt: off
and # fmt: on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ArthurZucker. I changed all these inline comments to block comments, and it solved the issue while keeping the ruff checks on. 👍
…ub.com/keyboardAnt/transformers into test-speculative-sampling-distribution
The test itself takes 1.89 s, and when you run it with pytest ( |
3a05527
to
dc3be00
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
tests/generation/test_utils.py
Outdated
[ | ||
-inf, | ||
2.0, | ||
-inf, | ||
1.0, | ||
-inf, | ||
-inf, | ||
-inf, | ||
-0.01, | ||
2.0, | ||
-inf, | ||
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use # fmt: off
and # fmt: on
What does this PR do?
This PR introduces a test for speculative decoding to ensure the target distribution is preserved, addressing potential issues similar to #32867. The added test (
test_speculative_sampling_target_distribution
) validates that tokens are generated according to their intended likelihood, as defined in the logits, ensuring that the speculative decoding process adheres to expected distributions. Additionally, this is a foundational step toward supporting advanced speculative decoding algorithms, such as token-tree-based rejection sampling, which will enhance flexibility and performance in future implementations.Motivation and Context
The speculative decoding process has previously encountered issues where the target distribution was not preserved (e.g., in issues #32867 and #33534). This PR implements a test to safeguard against such inconsistencies by verifying that:
This enhancement not only improves the reliability of speculative sampling by enforcing distributional accuracy but also prepares the ground for implementing more advanced speculative decoding techniques, like token-trees-based sampling.
This PR is an initial step toward advancements in Universal Assisted Generation. In collaboration with @orenpereg, @danielkorat, @mosheber, @jmamou, and @MosheWasserb, we're preparing for a new speculative decoding function that this test will verify for losslessness in target distribution preservation.
Dependencies
No additional dependencies are required.
Linked Issues
#32867, #33534
Before Submitting Checklist
Who can review?
@gante