From e28a194dcf1e9a7a252abe48a1bf6e5d54fc0dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Fri, 16 Aug 2024 14:02:33 +0200 Subject: [PATCH] update script --- .../models/huggingface/codellama/codellama_instruct.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py b/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py index 697638d2..576fffc7 100644 --- a/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py +++ b/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py @@ -50,6 +50,7 @@ def __init__(self, model_name: str, **kwargs) -> None: model_name in self.__SUPPORTED_MODELS ), f"Model {model_name} not supported by {self.__class__.__name__}" self.model_name = model_name + self.max_prompt_length = 2048 # Generation settings assert ( kwargs.get("generation_strategy", "sampling") @@ -73,7 +74,6 @@ def __init__(self, model_name: str, **kwargs) -> None: def __load_model(self, **kwargs): # Setup environment self.device = "cuda" - self.context_size = self.generate_settings.max_length # Setup kwargs model_kwargs = dict( @@ -129,7 +129,7 @@ def _generate_batch(self, batch: List[str]) -> Any: return_tensors="pt", padding=True, truncation=True, - max_length=self.context_size, + max_length=self.max_prompt_length, ) inputs = inputs.to(self.device)