Skip to content

Commit

Permalink
fix another flaky
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Oct 29, 2024
1 parent e50df32 commit 81443f5
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,6 +2048,9 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
self.skipTest(reason="Stateful models don't support assisted generation")

config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
config.use_cache = True
config.is_decoder = True

Expand All @@ -2064,7 +2067,6 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
"output_scores": True,
}

assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate(**generation_kwargs, **inputs_dict, num_logits_to_keep=0)
# By default, num_logits_to_keep is automatically set to 1 if not provided (new behavior)
Expand Down

0 comments on commit 81443f5

Please sign in to comment.