Skip to content

Commit

Permalink
feat: google genai (#150)
Browse files Browse the repository at this point in the history
* feat: google genai

* fix patch extraction

* add tests
  • Loading branch information
andre15silva authored Sep 4, 2024
1 parent 4a2dc4f commit 6602e3b
Show file tree
Hide file tree
Showing 9 changed files with 700 additions and 2 deletions.
Empty file.
32 changes: 32 additions & 0 deletions elleelleaime/evaluate/strategies/google/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from elleelleaime.evaluate.strategies.text.instruct import InstructEvaluationStrategy
from elleelleaime.core.benchmarks.bug import Bug

from typing import Optional, List


class GoogleEvaluationStrategy(InstructEvaluationStrategy):

def __init__(self, **kwargs):
super().__init__(kwargs=kwargs)

def _evaluate_impl(self, bug: Bug, sample: dict) -> Optional[List[dict]]:
"""
Returns the evaluation for the given bug and sample.
:param bug: The bug to generate the prompt for.
:param sample: The sample to evaluate.
"""
evaluation = []

if sample["generation"] is None:
return evaluation

for generation in sample["generation"]:
for candidate in generation["candidates"]:
candidate_patch = candidate["content"]["parts"][0]["text"]
candidate_patch = self.extract_patch_from_message(candidate_patch)
evaluation.append(
self.evaluate_generation(bug, sample, candidate_patch)
)

return evaluation
2 changes: 2 additions & 0 deletions elleelleaime/evaluate/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from elleelleaime.evaluate.strategies.text.replace import ReplaceEvaluationStrategy
from elleelleaime.evaluate.strategies.text.instruct import InstructEvaluationStrategy
from elleelleaime.evaluate.strategies.openai.openai import OpenAIEvaluationStrategy
from elleelleaime.evaluate.strategies.google.google import GoogleEvaluationStrategy


class PatchEvaluationStrategyRegistry:
Expand All @@ -14,6 +15,7 @@ def __init__(self, **kwargs):
"replace": ReplaceEvaluationStrategy(**kwargs),
"instruct": InstructEvaluationStrategy(**kwargs),
"openai": OpenAIEvaluationStrategy(**kwargs),
"google": GoogleEvaluationStrategy(**kwargs),
}

def get_evaluation(self, name: str) -> PatchEvaluationStrategy:
Expand Down
Empty file.
38 changes: 38 additions & 0 deletions elleelleaime/generate/strategies/models/google/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy

from dotenv import load_dotenv
from typing import Any, List

import os
import google.generativeai as genai


class GoogleModels(PatchGenerationStrategy):
def __init__(self, model_name: str, **kwargs) -> None:
self.model_name = model_name
self.temperature = kwargs.get("temperature", 0.0)
self.n_samples = kwargs.get("n_samples", 1)

load_dotenv()
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

def __get_config(self):
return genai.types.GenerationConfig(
temperature=self.temperature,
)

def _generate_impl(self, chunk: List[str]) -> Any:
result = []

model = genai.GenerativeModel(self.model_name)

for prompt in chunk:
p_results = []
for _ in range(self.n_samples):
completion = model.generate_content(
prompt, generation_config=self.__get_config()
)
p_results.append(completion.to_dict())
result.append(p_results)

return result
6 changes: 5 additions & 1 deletion elleelleaime/generate/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from elleelleaime.generate.strategies.models.openai.openai import (
OpenAIChatCompletionModels,
)
from elleelleaime.generate.strategies.models.google.google import (
GoogleModels,
)
from elleelleaime.generate.strategies.models.huggingface.codellama.codellama_infilling import (
CodeLLaMAInfilling,
)
Expand All @@ -17,10 +20,11 @@ class PatchGenerationStrategyRegistry:
Class for storing and retrieving models based on their name.
"""

# The registry is a dictionary of strategy names to a tuple of the class and the mandatory arguments to pass to the class
# The registry is a dict of strategy names to a tuple of class and mandatory arguments to init the class
# NOTE: Do not instantiate the model here, as we should only instanciate the class to be used
__MODELS: dict[str, Tuple[type, Tuple]] = {
"openai-chatcompletion": (OpenAIChatCompletionModels, ("model_name",)),
"google": (GoogleModels, ("model_name",)),
"codellama-infilling": (CodeLLaMAInfilling, ("model_name",)),
"codellama-instruct": (CodeLLaMAIntruct, ("model_name",)),
}
Expand Down
346 changes: 345 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ peft = "^0.11.1"
bitsandbytes = "^0.43.1"
evaluate = "^0.4.2"
safetensors = "^0.4.3"
google-generativeai = "^0.7.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
Expand Down
277 changes: 277 additions & 0 deletions tests/evaluate/test_evaluate_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
from evaluate_patches import evaluate_candidate
from generate_samples import generate_sample
from elleelleaime.core.utils.benchmarks import get_benchmark
from elleelleaime.core.benchmarks.benchmark import Benchmark


class TestEvaluatePatchesGoogleDefects4J:
DEFECTS4J: Benchmark
PROMPT_STRATEGY: str = "instruct"
MODEL_NAME: str = "gemini-1.5-flash"
EVALUATE_STRATEGY: str = "google"

@classmethod
def setup_class(cls):
TestEvaluatePatchesGoogleDefects4J.DEFECTS4J = get_benchmark("defects4j")
assert TestEvaluatePatchesGoogleDefects4J.DEFECTS4J is not None
TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.initialize()

@classmethod
def get_exact_match_sample(cls):
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY,
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME,
)

sample["generation"] = [
{
"candidates": [
{
"content": {
"parts": [
{
"text": f"```java\n{sample['fixed_code']}"
+ "\n// comment\n```"
}
],
"role": "model",
},
"finish_reason": 1,
"index": 0,
}
]
}
]

return bug, sample

@classmethod
def get_ast_match_sample(cls):
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY,
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME,
)

code = """ public LegendItemCollection getLegendItems() {
LegendItemCollection result = new LegendItemCollection();
if (this.plot == null) {
return result;
}
int index = this.plot.getIndexOf(this);
CategoryDataset dataset = this.plot.getDataset(index);
if (dataset == null)
{
return result;
}
int seriesCount = dataset.getRowCount();
if (plot.getRowRenderingOrder().equals(SortOrder.ASCENDING)) {
for (int i = 0; i < seriesCount; i++) {
if (isSeriesVisibleInLegend(i)) {
LegendItem item = getLegendItem(index, i);
if (item != null) {
result.add(item);
}
}
}
}
else {
for (int i = seriesCount - 1; i >= 0; i--) {
if (isSeriesVisibleInLegend(i)) {
LegendItem item = getLegendItem(index, i);
if (item != null) {
result.add(item);
}
}
}
}
return result;
}
"""

sample["generation"] = [
{
"candidates": [
{
"content": {
"parts": [{"text": f"```java\n{code}\n```"}],
"role": "model",
},
"finish_reason": 1,
"index": 0,
}
]
}
]

return bug, sample

@classmethod
def get_plausible_sample(cls):
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY,
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME,
)
code = """ public LegendItemCollection getLegendItems() {
LegendItemCollection result = new LegendItemCollection();
if (this.plot == null) {
return result;
}
int index = this.plot.getIndexOf(this);
CategoryDataset dataset = this.plot.getDataset(index);
if (dataset == null)
{
return result;
} else {
int a = 0;
}
int seriesCount = dataset.getRowCount();
if (plot.getRowRenderingOrder().equals(SortOrder.ASCENDING)) {
for (int i = 0; i < seriesCount; i++) {
if (isSeriesVisibleInLegend(i)) {
LegendItem item = getLegendItem(index, i);
if (item != null) {
result.add(item);
}
}
}
}
else {
for (int i = seriesCount - 1; i >= 0; i--) {
if (isSeriesVisibleInLegend(i)) {
LegendItem item = getLegendItem(index, i);
if (item != null) {
result.add(item);
}
}
}
}
return result;
}
"""

sample["generation"] = [
{
"candidates": [
{
"content": {
"parts": [{"text": f"```java\n{code}\n```"}],
"role": "model",
},
"finish_reason": 1,
"index": 0,
}
]
}
]

return bug, sample

@classmethod
def get_incorrect_sample(cls):
bug = TestEvaluatePatchesGoogleDefects4J.DEFECTS4J.get_bug("Chart-1")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestEvaluatePatchesGoogleDefects4J.PROMPT_STRATEGY,
model_name=TestEvaluatePatchesGoogleDefects4J.MODEL_NAME,
)

sample["generation"] = [
{
"candidates": [
{
"content": {
"parts": [
{"text": f"```java\n{sample['buggy_code']}\n```"}
],
"role": "model",
},
"finish_reason": 1,
"index": 0,
}
]
}
]

return bug, sample

def test_exact_match_patch(self):
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_exact_match_sample()

sample = evaluate_candidate(
bug=bug,
sample=sample,
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY,
)

assert sample["evaluation"] is not None
assert len(sample["evaluation"]) == 1

assert sample["evaluation"][0]["compile"] == True
assert sample["evaluation"][0]["test"] == True
assert sample["evaluation"][0]["exact_match"] == True
assert sample["evaluation"][0]["ast_match"] == True

def test_ast_match_patch(self):
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_ast_match_sample()

sample = evaluate_candidate(
bug=bug,
sample=sample,
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY,
)

assert sample["evaluation"] is not None
assert len(sample["evaluation"]) == 1

assert sample["evaluation"][0]["compile"] == True
assert sample["evaluation"][0]["test"] == True
assert sample["evaluation"][0]["ast_match"] == True
assert sample["evaluation"][0]["exact_match"] == False

def test_incorrect_patch(self):
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_incorrect_sample()

sample = evaluate_candidate(
bug=bug,
sample=sample,
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY,
)

assert sample["evaluation"] is not None
assert len(sample["evaluation"]) == 1

assert sample["evaluation"][0]["compile"] == True
assert sample["evaluation"][0]["test"] == False
assert sample["evaluation"][0]["exact_match"] == False
assert sample["evaluation"][0]["ast_match"] == False

def test_plausible_patch(self):
bug, sample = TestEvaluatePatchesGoogleDefects4J.get_plausible_sample()

sample = evaluate_candidate(
bug=bug,
sample=sample,
strategy=TestEvaluatePatchesGoogleDefects4J.EVALUATE_STRATEGY,
)

assert sample["evaluation"] is not None
assert len(sample["evaluation"]) == 1

assert sample["evaluation"][0]["compile"] == True
assert sample["evaluation"][0]["test"] == True
assert sample["evaluation"][0]["exact_match"] == False
assert sample["evaluation"][0]["ast_match"] == False

0 comments on commit 6602e3b

Please sign in to comment.