Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
andre15silva committed Aug 16, 2024
1 parent 83c9272 commit 5b160e8
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, model_name: str, **kwargs) -> None:
model_name in self.__SUPPORTED_MODELS
), f"Model {model_name} not supported by {self.__class__.__name__}"
self.model_name = model_name
self.max_prompt_length = 2048
# Generation settings
assert (
kwargs.get("generation_strategy", "sampling")
Expand All @@ -73,7 +74,6 @@ def __init__(self, model_name: str, **kwargs) -> None:
def __load_model(self, **kwargs):
# Setup environment
self.device = "cuda"
self.context_size = self.generate_settings.max_length

# Setup kwargs
model_kwargs = dict(
Expand Down Expand Up @@ -129,7 +129,7 @@ def _generate_batch(self, batch: List[str]) -> Any:
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.context_size,
max_length=self.max_prompt_length,
)
inputs = inputs.to(self.device)

Expand Down

0 comments on commit 5b160e8

Please sign in to comment.