diff --git a/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py b/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py index 697638d2..576fffc7 100644 --- a/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py +++ b/elleelleaime/generate/strategies/models/huggingface/codellama/codellama_instruct.py @@ -50,6 +50,7 @@ def __init__(self, model_name: str, **kwargs) -> None: model_name in self.__SUPPORTED_MODELS ), f"Model {model_name} not supported by {self.__class__.__name__}" self.model_name = model_name + self.max_prompt_length = 2048 # Generation settings assert ( kwargs.get("generation_strategy", "sampling") @@ -73,7 +74,6 @@ def __init__(self, model_name: str, **kwargs) -> None: def __load_model(self, **kwargs): # Setup environment self.device = "cuda" - self.context_size = self.generate_settings.max_length # Setup kwargs model_kwargs = dict( @@ -129,7 +129,7 @@ def _generate_batch(self, batch: List[str]) -> Any: return_tensors="pt", padding=True, truncation=True, - max_length=self.context_size, + max_length=self.max_prompt_length, ) inputs = inputs.to(self.device)