-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
284 additions
and
5 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from ..text.instruct import InstructEvaluationStrategy | ||
from elleelleaime.core.benchmarks.bug import Bug | ||
|
||
from typing import Optional, List | ||
|
||
|
||
class MistralEvaluationStrategy(InstructEvaluationStrategy): | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def __evaluate_generation(self, bug: Bug, sample: dict, generation) -> List[dict]: | ||
""" | ||
Evaluate the generation for the given bug. | ||
:param bug: The bug to generate the prompt for. | ||
:param generation: The generation to evaluate | ||
""" | ||
evaluation = [] | ||
|
||
for choice in generation["choices"]: | ||
message = choice["message"]["content"] | ||
candidate_patch = self.extract_patch_from_message(message) | ||
evaluation.append(self.evaluate_generation(bug, sample, candidate_patch)) | ||
|
||
return evaluation | ||
|
||
def _evaluate_impl(self, bug: Bug, sample: dict) -> Optional[List[dict]]: | ||
""" | ||
Returns the evaluation for the given bug and sample. | ||
:param bug: The bug to generate the prompt for. | ||
:param sample: The sample to evaluate. | ||
""" | ||
evaluation = [] | ||
|
||
if sample["generation"] is None: | ||
return evaluation | ||
|
||
evaluation.extend(self.__evaluate_generation(bug, sample, sample["generation"])) | ||
|
||
return evaluation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Optional | ||
from .cost_strategy import CostStrategy | ||
|
||
import tqdm | ||
|
||
|
||
class MistralCostStrategy(CostStrategy): | ||
|
||
__COST_PER_MILLION_TOKENS = { | ||
"mistral-large-2411": { | ||
"prompt": 2, | ||
"completion": 6, | ||
}, | ||
"codestral-2405": { | ||
"prompt": 0.2, | ||
"completion": 0.6, | ||
}, | ||
} | ||
|
||
@staticmethod | ||
def compute_costs(samples: list, model_name: str) -> Optional[dict]: | ||
if model_name not in MistralCostStrategy.__COST_PER_MILLION_TOKENS: | ||
return None | ||
|
||
costs = { | ||
"prompt_cost": 0.0, | ||
"completion_cost": 0.0, | ||
"total_cost": 0.0, | ||
} | ||
|
||
for sample in tqdm.tqdm(samples, f"Computing costs for {model_name}..."): | ||
if sample["generation"]: | ||
g = sample["generation"] | ||
prompt_token_count = g["usage"]["prompt_tokens"] | ||
candidates_token_count = g["usage"]["completion_tokens"] | ||
|
||
prompt_cost = MistralCostStrategy.__COST_PER_MILLION_TOKENS[model_name][ | ||
"prompt" | ||
] | ||
completion_cost = MistralCostStrategy.__COST_PER_MILLION_TOKENS[ | ||
model_name | ||
]["completion"] | ||
|
||
costs["prompt_cost"] += prompt_cost * prompt_token_count / 1000000 | ||
costs["completion_cost"] += ( | ||
completion_cost * candidates_token_count / 1000000 | ||
) | ||
|
||
costs["total_cost"] = costs["prompt_cost"] + costs["completion_cost"] | ||
return costs |
Empty file.
45 changes: 45 additions & 0 deletions
45
elleelleaime/generate/strategies/models/mistral/mistral.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy | ||
|
||
from dotenv import load_dotenv | ||
from typing import Any, List | ||
|
||
import os | ||
import mistralai | ||
import backoff | ||
|
||
|
||
class MistralModels(PatchGenerationStrategy): | ||
def __init__(self, model_name: str, **kwargs) -> None: | ||
self.model_name = model_name | ||
self.temperature = kwargs.get("temperature", 0.0) | ||
self.n_samples = kwargs.get("n_samples", 1) | ||
|
||
load_dotenv() | ||
self.client = mistralai.Mistral(os.getenv("MISTRAL_API_KEY", None)) | ||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
( | ||
mistralai.models.SDKError, | ||
mistralai.models.HTTPValidationError, | ||
AssertionError, | ||
), | ||
) | ||
def _completions_with_backoff(self, **kwargs): | ||
response = self.client.chat.complete(**kwargs) | ||
assert response is not None | ||
return response | ||
|
||
def _generate_impl(self, chunk: List[str]) -> Any: | ||
result = [] | ||
|
||
for prompt in chunk: | ||
completion = self._completions_with_backoff( | ||
model=self.model_name, | ||
messages=[{"role": "user", "content": prompt}], | ||
temperature=self.temperature, | ||
n=self.n_samples, | ||
) | ||
result.append(completion.model_dump()) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from evaluate_patches import evaluate_candidate | ||
from generate_samples import generate_sample | ||
from elleelleaime.core.utils.benchmarks import get_benchmark | ||
from elleelleaime.core.benchmarks.benchmark import Benchmark | ||
|
||
|
||
class TestEvaluatePatchesMistralDefects4J: | ||
DEFECTS4J: Benchmark | ||
PROMPT_STRATEGY: str = "instruct" | ||
MODEL_NAME: str = "codestral-2405" | ||
EVALUATE_STRATEGY: str = "mistral" | ||
|
||
@classmethod | ||
def setup_class(cls): | ||
TestEvaluatePatchesMistralDefects4J.DEFECTS4J = get_benchmark("defects4j") | ||
assert TestEvaluatePatchesMistralDefects4J.DEFECTS4J is not None | ||
TestEvaluatePatchesMistralDefects4J.DEFECTS4J.initialize() | ||
|
||
@classmethod | ||
def get_exact_match_sample(cls): | ||
bug = TestEvaluatePatchesMistralDefects4J.DEFECTS4J.get_bug("Chart-1") | ||
assert bug is not None | ||
|
||
sample = generate_sample( | ||
bug=bug, | ||
prompt_strategy=TestEvaluatePatchesMistralDefects4J.PROMPT_STRATEGY, | ||
model_name=TestEvaluatePatchesMistralDefects4J.MODEL_NAME, | ||
) | ||
|
||
sample["generation"] = { | ||
"id": "5f26bfc6f38f46c2a399ef319293634a", | ||
"object": "chat.completion", | ||
"model": "codestral-2405", | ||
"usage": { | ||
"prompt_tokens": 934, | ||
"completion_tokens": 604, | ||
"total_tokens": 1538, | ||
}, | ||
"created": 1732015902, | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"content": f"```java\n{sample['fixed_code']}\n// comment\n```", | ||
"tool_calls": None, | ||
"prefix": False, | ||
"role": "assistant", | ||
}, | ||
"finish_reason": "stop", | ||
} | ||
], | ||
} | ||
|
||
return bug, sample | ||
|
||
def test_exact_match_patch(self): | ||
bug, sample = TestEvaluatePatchesMistralDefects4J.get_exact_match_sample() | ||
|
||
sample = evaluate_candidate( | ||
bug=bug, | ||
sample=sample, | ||
strategy=TestEvaluatePatchesMistralDefects4J.EVALUATE_STRATEGY, | ||
) | ||
|
||
assert sample["evaluation"] is not None | ||
assert len(sample["evaluation"]) == 1 | ||
|
||
assert sample["evaluation"][0]["compile"] == True | ||
assert sample["evaluation"][0]["test"] == True | ||
assert sample["evaluation"][0]["exact_match"] == True | ||
assert sample["evaluation"][0]["ast_match"] == True |