diff --git a/elleelleaime/generate/strategies/models/google/google.py b/elleelleaime/generate/strategies/models/google/google.py index 3242c5d6..7e1ac2f2 100644 --- a/elleelleaime/generate/strategies/models/google/google.py +++ b/elleelleaime/generate/strategies/models/google/google.py @@ -1,3 +1,5 @@ +import google.api_core +import google.api_core.exceptions from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy from dotenv import load_dotenv @@ -6,6 +8,10 @@ import os import tqdm import google.generativeai as genai +import google +import backoff + +import google.api class GoogleModels(PatchGenerationStrategy): @@ -23,17 +29,20 @@ def __get_config(self): temperature=self.temperature, ) + @backoff.on_exception(backoff.expo, google.api_core.exceptions.ResourceExhausted) + def __generate_with_backoff(self, prompt: str) -> dict: + completion = self.model.generate_content( + prompt, generation_config=self.__get_config() + ) + return completion.to_dict() + def _generate_impl(self, chunk: List[str]) -> Any: result = [] - for prompt in tqdm.tqdm(chunk, "Generating patches for prompt..."): p_results = [] for _ in range(self.n_samples): - completion = self.model.generate_content( - prompt, generation_config=self.__get_config() - ) - p_results.append(completion.to_dict()) + p_results.append(self.__generate_with_backoff(prompt)) result.append(p_results) return result