Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
andre15silva committed Aug 18, 2024
1 parent e8c1e20 commit 9788b4e
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(self, model_name: str, **kwargs) -> None:
self.generate_settings.temperature = kwargs.get(
"temperature", GenerateSettings.temperature
)
self.generate_settings.max_length = kwargs.get(
"max_length", GenerateSettings.max_length
)

def __format_prompt(self, prompt: str) -> str:
return f"<s>[INST] {prompt} [\\INST]"
Expand All @@ -85,7 +88,10 @@ def _generate_impl(self, chunk: List[str]) -> Any:
tok = AutoTokenizer.from_pretrained(self.model_name)
tok.pad_token = tok.eos_token

logging.info(f"Model successfully loaded: {m}")

# Generate patches
logging.info(f"Starting generation: {self.generate_settings}")
result = []
for prompt in tqdm.tqdm(chunk, "Generating patches...", total=len(chunk)):
with torch.no_grad():
Expand All @@ -99,6 +105,7 @@ def _generate_impl(self, chunk: List[str]) -> Any:
logging.warning(
f"Skipping prompt due to length: {input_length} is larger than {self.generate_settings.max_length}"
)
continue

# Generate patch
inputs = inputs.to("cuda")
Expand Down

0 comments on commit 9788b4e

Please sign in to comment.