Skip to content

Commit

Permalink
fix: handle google exceptions and with backoff (#153)
Browse files Browse the repository at this point in the history
* fix: handle google exceptions and add backoff

* add backoff
  • Loading branch information
andre15silva authored Sep 6, 2024
1 parent 6602e3b commit cf065ea
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions elleelleaime/generate/strategies/models/google/google.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import google.api_core
import google.api_core.exceptions
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy

from dotenv import load_dotenv
from typing import Any, List

import os
import tqdm
import google.generativeai as genai
import google
import backoff

import google.api


class GoogleModels(PatchGenerationStrategy):
def __init__(self, model_name: str, **kwargs) -> None:
self.model_name = model_name
self.model = genai.GenerativeModel(self.model_name)
self.temperature = kwargs.get("temperature", 0.0)
self.n_samples = kwargs.get("n_samples", 1)

Expand All @@ -21,18 +29,20 @@ def __get_config(self):
temperature=self.temperature,
)

@backoff.on_exception(backoff.expo, google.api_core.exceptions.ResourceExhausted)
def __generate_with_backoff(self, prompt: str) -> dict:
completion = self.model.generate_content(
prompt, generation_config=self.__get_config()
)
return completion.to_dict()

def _generate_impl(self, chunk: List[str]) -> Any:
result = []

model = genai.GenerativeModel(self.model_name)

for prompt in chunk:
for prompt in tqdm.tqdm(chunk, "Generating patches for prompt..."):
p_results = []
for _ in range(self.n_samples):
completion = model.generate_content(
prompt, generation_config=self.__get_config()
)
p_results.append(completion.to_dict())
p_results.append(self.__generate_with_backoff(prompt))
result.append(p_results)

return result

0 comments on commit cf065ea

Please sign in to comment.