Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative decoding: Test the target distribution (to prevent issues like #32867) #34553

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

keyboardAnt
Copy link

@keyboardAnt keyboardAnt commented Nov 1, 2024

What does this PR do?

This PR introduces a test for speculative decoding to ensure the target distribution is preserved, addressing potential issues similar to #32867. The added test (test_speculative_sampling_target_distribution) validates that tokens are generated according to their intended likelihood, as defined in the logits, ensuring that the speculative decoding process adheres to expected distributions. Additionally, this is a foundational step toward supporting advanced speculative decoding algorithms, such as token-tree-based rejection sampling, which will enhance flexibility and performance in future implementations.

Motivation and Context

The speculative decoding process has previously encountered issues where the target distribution was not preserved (e.g., in issues #32867 and #33534). This PR implements a test to safeguard against such inconsistencies by verifying that:

  • The most likely tokens are chosen more frequently than less probable ones.
  • Tokens are selected in alignment with the predefined candidate and new logits.

This enhancement not only improves the reliability of speculative sampling by enforcing distributional accuracy but also prepares the ground for implementing more advanced speculative decoding techniques, like token-trees-based sampling.

This PR is an initial step toward advancements in Universal Assisted Generation. In collaboration with @orenpereg, @danielkorat, @mosheber, @jmamou, and @MosheWasserb, we're preparing for a new speculative decoding function that this test will verify for losslessness in target distribution preservation.

Dependencies

No additional dependencies are required.

Linked Issues

#32867, #33534

Before Submitting Checklist

  • I have read the contributor guidelines.
  • Documentation updates are not needed as this is a test enhancement.
  • New test coverage has been added to verify the speculative sampling behavior.

Who can review?

@gante

@keyboardAnt keyboardAnt force-pushed the test-speculative-sampling-distribution branch from 1be059c to 5522333 Compare November 1, 2024 00:14
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on improving our tests 💛

A question: is this test somewhat fast to run (<5s)? If yes, amazing! If no, let's either a) reduce the number in range or b) tag the test as @slow [note: tests with @slow are usually run daily, so bad commits may squeeze in]

Comment on lines 2474 to 2486
[
-inf,
2.0,
-inf,
1.0,
-inf,
-inf,
-inf,
-0.01,
2.0,
-inf,
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value
Copy link
Member

@gante gante Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's make it in one line, so we can quickly compare indexes with other tensors.

(you'll have to remove the comma after the last -inf, otherwise the make fixup command will make it revert back to this format)

Copy link
Author

@keyboardAnt keyboardAnt Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the formatting as requested, but ruff's formatting check then failed the CI. (make fixup still reformats it into a column, even after removing the last comma you mentioned)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use # fmt: off and # fmt: on

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ArthurZucker. I changed all these inline comments to block comments, and it solved the issue while keeping the ruff checks on. 👍

@keyboardAnt
Copy link
Author

Thank you for working on improving our tests 💛

A question: is this test somewhat fast to run (<5s)? If yes, amazing! If no, let's either a) reduce the number in range or b) tag the test as @slow [note: tests with @slow are usually run daily, so bad commits may squeeze in]

The test itself takes 1.89 s, and when you run it with pytest (pytest tests/generation/test_utils.py::UtilsFunctionsTest::test_speculative_sampling_target_distribution), it's not more than 2.57 s. Although it was only tested locally on my laptop, I believe it’s safe to keep it with the rest of the <5s tests.

@keyboardAnt keyboardAnt force-pushed the test-speculative-sampling-distribution branch from 3a05527 to dc3be00 Compare November 5, 2024 02:19
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Comment on lines 2474 to 2486
[
-inf,
2.0,
-inf,
1.0,
-inf,
-inf,
-inf,
-0.01,
2.0,
-inf,
], # most likely to be 1 or 8, less likely to be 3, then 7, and should never be any other value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use # fmt: off and # fmt: on

@keyboardAnt
Copy link
Author

All checks have successfully passed (screenshot below). Are there any additional workflows to run before merging?

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants