Skip to content

Commit

Permalink
feat: mistral (#181)
Browse files Browse the repository at this point in the history
* feat: mistral

* implement mistral
  • Loading branch information
andre15silva authored Nov 19, 2024
1 parent bafcd59 commit 2e0eedd
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 5 deletions.
Empty file.
42 changes: 42 additions & 0 deletions elleelleaime/evaluate/strategies/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ..text.instruct import InstructEvaluationStrategy
from elleelleaime.core.benchmarks.bug import Bug

from typing import Optional, List


class MistralEvaluationStrategy(InstructEvaluationStrategy):

def __init__(self, **kwargs):
super().__init__(**kwargs)

def __evaluate_generation(self, bug: Bug, sample: dict, generation) -> List[dict]:
"""
Evaluate the generation for the given bug.
:param bug: The bug to generate the prompt for.
:param generation: The generation to evaluate
"""
evaluation = []

for choice in generation["choices"]:
message = choice["message"]["content"]
candidate_patch = self.extract_patch_from_message(message)
evaluation.append(self.evaluate_generation(bug, sample, candidate_patch))

return evaluation

def _evaluate_impl(self, bug: Bug, sample: dict) -> Optional[List[dict]]:
"""
Returns the evaluation for the given bug and sample.
:param bug: The bug to generate the prompt for.
:param sample: The sample to evaluate.
"""
evaluation = []

if sample["generation"] is None:
return evaluation

evaluation.extend(self.__evaluate_generation(bug, sample, sample["generation"]))

return evaluation
2 changes: 2 additions & 0 deletions elleelleaime/evaluate/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from elleelleaime.evaluate.strategies.anthropic.anthropic import (
AnthropicEvaluationStrategy,
)
from elleelleaime.evaluate.strategies.mistral.mistral import MistralEvaluationStrategy


class PatchEvaluationStrategyRegistry:
Expand All @@ -24,6 +25,7 @@ def __init__(self, **kwargs):
"google": GoogleEvaluationStrategy(**kwargs),
"openrouter": OpenRouterEvaluationStrategy(**kwargs),
"anthropic": AnthropicEvaluationStrategy(**kwargs),
"mistral": MistralEvaluationStrategy(**kwargs),
}

def get_evaluation(self, name: str) -> PatchEvaluationStrategy:
Expand Down
2 changes: 2 additions & 0 deletions elleelleaime/export/cost/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .strategies.google import GoogleCostStrategy
from .strategies.openrouter import OpenRouterCostStrategy
from .strategies.anthropic import AnthropicCostStrategy
from .strategies.mistral import MistralCostStrategy
from typing import Optional


Expand All @@ -12,6 +13,7 @@ class CostCalculator:
"google": GoogleCostStrategy,
"openrouter": OpenRouterCostStrategy,
"anthropic": AnthropicCostStrategy,
"mistral": MistralCostStrategy,
}

@staticmethod
Expand Down
50 changes: 50 additions & 0 deletions elleelleaime/export/cost/strategies/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional
from .cost_strategy import CostStrategy

import tqdm


class MistralCostStrategy(CostStrategy):

__COST_PER_MILLION_TOKENS = {
"mistral-large-2411": {
"prompt": 2,
"completion": 6,
},
"codestral-2405": {
"prompt": 0.2,
"completion": 0.6,
},
}

@staticmethod
def compute_costs(samples: list, model_name: str) -> Optional[dict]:
if model_name not in MistralCostStrategy.__COST_PER_MILLION_TOKENS:
return None

costs = {
"prompt_cost": 0.0,
"completion_cost": 0.0,
"total_cost": 0.0,
}

for sample in tqdm.tqdm(samples, f"Computing costs for {model_name}..."):
if sample["generation"]:
g = sample["generation"]
prompt_token_count = g["usage"]["prompt_tokens"]
candidates_token_count = g["usage"]["completion_tokens"]

prompt_cost = MistralCostStrategy.__COST_PER_MILLION_TOKENS[model_name][
"prompt"
]
completion_cost = MistralCostStrategy.__COST_PER_MILLION_TOKENS[
model_name
]["completion"]

costs["prompt_cost"] += prompt_cost * prompt_token_count / 1000000
costs["completion_cost"] += (
completion_cost * candidates_token_count / 1000000
)

costs["total_cost"] = costs["prompt_cost"] + costs["completion_cost"]
return costs
Empty file.
45 changes: 45 additions & 0 deletions elleelleaime/generate/strategies/models/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy

from dotenv import load_dotenv
from typing import Any, List

import os
import mistralai
import backoff


class MistralModels(PatchGenerationStrategy):
def __init__(self, model_name: str, **kwargs) -> None:
self.model_name = model_name
self.temperature = kwargs.get("temperature", 0.0)
self.n_samples = kwargs.get("n_samples", 1)

load_dotenv()
self.client = mistralai.Mistral(os.getenv("MISTRAL_API_KEY", None))

@backoff.on_exception(
backoff.expo,
(
mistralai.models.SDKError,
mistralai.models.HTTPValidationError,
AssertionError,
),
)
def _completions_with_backoff(self, **kwargs):
response = self.client.chat.complete(**kwargs)
assert response is not None
return response

def _generate_impl(self, chunk: List[str]) -> Any:
result = []

for prompt in chunk:
completion = self._completions_with_backoff(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
temperature=self.temperature,
n=self.n_samples,
)
result.append(completion.model_dump())

return result
4 changes: 4 additions & 0 deletions elleelleaime/generate/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from elleelleaime.generate.strategies.models.anthropic.anthropic import (
AnthropicModels,
)
from elleelleaime.generate.strategies.models.mistral.mistral import (
MistralModels,
)

from typing import Tuple

Expand All @@ -35,6 +38,7 @@ class PatchGenerationStrategyRegistry:
"codellama-infilling": (CodeLLaMAInfilling, ("model_name",)),
"codellama-instruct": (CodeLLaMAIntruct, ("model_name",)),
"anthropic": (AnthropicModels, ("model_name", "max_tokens")),
"mistral": (MistralModels, ("model_name",)),
}

@classmethod
Expand Down
72 changes: 67 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ evaluate = "^0.4.2"
safetensors = "^0.4.3"
google-generativeai = "^0.7.2"
anthropic = "^0.34.2"
mistralai = "^1.2.3"

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
Expand Down
71 changes: 71 additions & 0 deletions tests/evaluate/test_evaluate_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from evaluate_patches import evaluate_candidate
from generate_samples import generate_sample
from elleelleaime.core.utils.benchmarks import get_benchmark
from elleelleaime.core.benchmarks.benchmark import Benchmark


class TestEvaluatePatchesMistralDefects4J:
DEFECTS4J: Benchmark
PROMPT_STRATEGY: str = "instruct"
MODEL_NAME: str = "codestral-2405"
EVALUATE_STRATEGY: str = "mistral"

@classmethod
def setup_class(cls):
TestEvaluatePatchesMistralDefects4J.DEFECTS4J = get_benchmark("defects4j")
assert TestEvaluatePatchesMistralDefects4J.DEFECTS4J is not None
TestEvaluatePatchesMistralDefects4J.DEFECTS4J.initialize()

@classmethod
def get_exact_match_sample(cls):
bug = TestEvaluatePatchesMistralDefects4J.DEFECTS4J.get_bug("Chart-1")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestEvaluatePatchesMistralDefects4J.PROMPT_STRATEGY,
model_name=TestEvaluatePatchesMistralDefects4J.MODEL_NAME,
)

sample["generation"] = {
"id": "5f26bfc6f38f46c2a399ef319293634a",
"object": "chat.completion",
"model": "codestral-2405",
"usage": {
"prompt_tokens": 934,
"completion_tokens": 604,
"total_tokens": 1538,
},
"created": 1732015902,
"choices": [
{
"index": 0,
"message": {
"content": f"```java\n{sample['fixed_code']}\n// comment\n```",
"tool_calls": None,
"prefix": False,
"role": "assistant",
},
"finish_reason": "stop",
}
],
}

return bug, sample

def test_exact_match_patch(self):
bug, sample = TestEvaluatePatchesMistralDefects4J.get_exact_match_sample()

sample = evaluate_candidate(
bug=bug,
sample=sample,
strategy=TestEvaluatePatchesMistralDefects4J.EVALUATE_STRATEGY,
)

assert sample["evaluation"] is not None
assert len(sample["evaluation"]) == 1

assert sample["evaluation"][0]["compile"] == True
assert sample["evaluation"][0]["test"] == True
assert sample["evaluation"][0]["exact_match"] == True
assert sample["evaluation"][0]["ast_match"] == True

0 comments on commit 2e0eedd

Please sign in to comment.