-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: google genai * fix patch extraction * add tests
- Loading branch information
1 parent
4a2dc4f
commit 6602e3b
Showing
9 changed files
with
700 additions
and
2 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from elleelleaime.evaluate.strategies.text.instruct import InstructEvaluationStrategy | ||
from elleelleaime.core.benchmarks.bug import Bug | ||
|
||
from typing import Optional, List | ||
|
||
|
||
class GoogleEvaluationStrategy(InstructEvaluationStrategy): | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(kwargs=kwargs) | ||
|
||
def _evaluate_impl(self, bug: Bug, sample: dict) -> Optional[List[dict]]: | ||
""" | ||
Returns the evaluation for the given bug and sample. | ||
:param bug: The bug to generate the prompt for. | ||
:param sample: The sample to evaluate. | ||
""" | ||
evaluation = [] | ||
|
||
if sample["generation"] is None: | ||
return evaluation | ||
|
||
for generation in sample["generation"]: | ||
for candidate in generation["candidates"]: | ||
candidate_patch = candidate["content"]["parts"][0]["text"] | ||
candidate_patch = self.extract_patch_from_message(candidate_patch) | ||
evaluation.append( | ||
self.evaluate_generation(bug, sample, candidate_patch) | ||
) | ||
|
||
return evaluation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy | ||
|
||
from dotenv import load_dotenv | ||
from typing import Any, List | ||
|
||
import os | ||
import google.generativeai as genai | ||
|
||
|
||
class GoogleModels(PatchGenerationStrategy): | ||
def __init__(self, model_name: str, **kwargs) -> None: | ||
self.model_name = model_name | ||
self.temperature = kwargs.get("temperature", 0.0) | ||
self.n_samples = kwargs.get("n_samples", 1) | ||
|
||
load_dotenv() | ||
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | ||
|
||
def __get_config(self): | ||
return genai.types.GenerationConfig( | ||
temperature=self.temperature, | ||
) | ||
|
||
def _generate_impl(self, chunk: List[str]) -> Any: | ||
result = [] | ||
|
||
model = genai.GenerativeModel(self.model_name) | ||
|
||
for prompt in chunk: | ||
p_results = [] | ||
for _ in range(self.n_samples): | ||
completion = model.generate_content( | ||
prompt, generation_config=self.__get_config() | ||
) | ||
p_results.append(completion.to_dict()) | ||
result.append(p_results) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
from evaluate_patches import evaluate_candidate | ||
from generate_samples import generate_sample | ||
from elleelleaime.core.utils.benchmarks import get_benchmark | ||
from elleelleaime.core.benchmarks.benchmark import Benchmark | ||
|
||
|
||
class TestEvaluatePatchesGoogleDefects4J: | ||
DEFECTS4J: Benchmark | ||
PROMPT_STRATEGY: str = "instruct" | ||
MODEL_NAME: str = "gemini-1.5-flash" | ||
EVALUATE_STRATEGY: str = "google" | ||
|
||
@classmethod | ||
def setup_class(cls): | ||
TestEvaluatePatchesGoogleDefects4J.DEFECTS4J = get_benchmark("defects4j") | ||
assert TestEvaluatePatchesGoogleDefects4J.DEFECTS4J is not None | ||
TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.initialize() | ||
|
||
@classmethod | ||
def get_exact_match_sample(cls): | ||
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1") | ||
assert bug is not None | ||
|
||
sample = generate_sample( | ||
bug=bug, | ||
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY, | ||
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME, | ||
) | ||
|
||
sample["generation"] = [ | ||
{ | ||
"candidates": [ | ||
{ | ||
"content": { | ||
"parts": [ | ||
{ | ||
"text": f"```java\n{sample['fixed_code']}" | ||
+ "\n// comment\n```" | ||
} | ||
], | ||
"role": "model", | ||
}, | ||
"finish_reason": 1, | ||
"index": 0, | ||
} | ||
] | ||
} | ||
] | ||
|
||
return bug, sample | ||
|
||
@classmethod | ||
def get_ast_match_sample(cls): | ||
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1") | ||
assert bug is not None | ||
|
||
sample = generate_sample( | ||
bug=bug, | ||
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY, | ||
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME, | ||
) | ||
|
||
code = """ public LegendItemCollection getLegendItems() { | ||
LegendItemCollection result = new LegendItemCollection(); | ||
if (this.plot == null) { | ||
return result; | ||
} | ||
int index = this.plot.getIndexOf(this); | ||
CategoryDataset dataset = this.plot.getDataset(index); | ||
if (dataset == null) | ||
{ | ||
return result; | ||
} | ||
int seriesCount = dataset.getRowCount(); | ||
if (plot.getRowRenderingOrder().equals(SortOrder.ASCENDING)) { | ||
for (int i = 0; i < seriesCount; i++) { | ||
if (isSeriesVisibleInLegend(i)) { | ||
LegendItem item = getLegendItem(index, i); | ||
if (item != null) { | ||
result.add(item); | ||
} | ||
} | ||
} | ||
} | ||
else { | ||
for (int i = seriesCount - 1; i >= 0; i--) { | ||
if (isSeriesVisibleInLegend(i)) { | ||
LegendItem item = getLegendItem(index, i); | ||
if (item != null) { | ||
result.add(item); | ||
} | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
""" | ||
|
||
sample["generation"] = [ | ||
{ | ||
"candidates": [ | ||
{ | ||
"content": { | ||
"parts": [{"text": f"```java\n{code}\n```"}], | ||
"role": "model", | ||
}, | ||
"finish_reason": 1, | ||
"index": 0, | ||
} | ||
] | ||
} | ||
] | ||
|
||
return bug, sample | ||
|
||
@classmethod | ||
def get_plausible_sample(cls): | ||
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1") | ||
assert bug is not None | ||
|
||
sample = generate_sample( | ||
bug=bug, | ||
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY, | ||
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME, | ||
) | ||
code = """ public LegendItemCollection getLegendItems() { | ||
LegendItemCollection result = new LegendItemCollection(); | ||
if (this.plot == null) { | ||
return result; | ||
} | ||
int index = this.plot.getIndexOf(this); | ||
CategoryDataset dataset = this.plot.getDataset(index); | ||
if (dataset == null) | ||
{ | ||
return result; | ||
} else { | ||
int a = 0; | ||
} | ||
int seriesCount = dataset.getRowCount(); | ||
if (plot.getRowRenderingOrder().equals(SortOrder.ASCENDING)) { | ||
for (int i = 0; i < seriesCount; i++) { | ||
if (isSeriesVisibleInLegend(i)) { | ||
LegendItem item = getLegendItem(index, i); | ||
if (item != null) { | ||
result.add(item); | ||
} | ||
} | ||
} | ||
} | ||
else { | ||
for (int i = seriesCount - 1; i >= 0; i--) { | ||
if (isSeriesVisibleInLegend(i)) { | ||
LegendItem item = getLegendItem(index, i); | ||
if (item != null) { | ||
result.add(item); | ||
} | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
""" | ||
|
||
sample["generation"] = [ | ||
{ | ||
"candidates": [ | ||
{ | ||
"content": { | ||
"parts": [{"text": f"```java\n{code}\n```"}], | ||
"role": "model", | ||
}, | ||
"finish_reason": 1, | ||
"index": 0, | ||
} | ||
] | ||
} | ||
] | ||
|
||
return bug, sample | ||
|
||
@classmethod | ||
def get_incorrect_sample(cls): | ||
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1") | ||
assert bug is not None | ||
|
||
sample = generate_sample( | ||
bug=bug, | ||
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY, | ||
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME, | ||
) | ||
|
||
sample["generation"] = [ | ||
{ | ||
"candidates": [ | ||
{ | ||
"content": { | ||
"parts": [ | ||
{"text": f"```java\n{sample['buggy_code']}\n```"} | ||
], | ||
"role": "model", | ||
}, | ||
"finish_reason": 1, | ||
"index": 0, | ||
} | ||
] | ||
} | ||
] | ||
|
||
return bug, sample | ||
|
||
def test_exact_match_patch(self): | ||
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_exact_match_sample() | ||
|
||
sample = evaluate_candidate( | ||
bug=bug, | ||
sample=sample, | ||
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY, | ||
) | ||
|
||
assert sample["evaluation"] is not None | ||
assert len(sample["evaluation"]) == 1 | ||
|
||
assert sample["evaluation"][0]["compile"] == True | ||
assert sample["evaluation"][0]["test"] == True | ||
assert sample["evaluation"][0]["exact_match"] == True | ||
assert sample["evaluation"][0]["ast_match"] == True | ||
|
||
def test_ast_match_patch(self): | ||
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_ast_match_sample() | ||
|
||
sample = evaluate_candidate( | ||
bug=bug, | ||
sample=sample, | ||
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY, | ||
) | ||
|
||
assert sample["evaluation"] is not None | ||
assert len(sample["evaluation"]) == 1 | ||
|
||
assert sample["evaluation"][0]["compile"] == True | ||
assert sample["evaluation"][0]["test"] == True | ||
assert sample["evaluation"][0]["ast_match"] == True | ||
assert sample["evaluation"][0]["exact_match"] == False | ||
|
||
def test_incorrect_patch(self): | ||
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_incorrect_sample() | ||
|
||
sample = evaluate_candidate( | ||
bug=bug, | ||
sample=sample, | ||
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY, | ||
) | ||
|
||
assert sample["evaluation"] is not None | ||
assert len(sample["evaluation"]) == 1 | ||
|
||
assert sample["evaluation"][0]["compile"] == True | ||
assert sample["evaluation"][0]["test"] == False | ||
assert sample["evaluation"][0]["exact_match"] == False | ||
assert sample["evaluation"][0]["ast_match"] == False | ||
|
||
def test_plausible_patch(self): | ||
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_plausible_sample() | ||
|
||
sample = evaluate_candidate( | ||
bug=bug, | ||
sample=sample, | ||
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY, | ||
) | ||
|
||
assert sample["evaluation"] is not None | ||
assert len(sample["evaluation"]) == 1 | ||
|
||
assert sample["evaluation"][0]["compile"] == True | ||
assert sample["evaluation"][0]["test"] == True | ||
assert sample["evaluation"][0]["exact_match"] == False | ||
assert sample["evaluation"][0]["ast_match"] == False |