Skip to content

Commit

Permalink
Merge branch 'codellama-instruct' of github.com:ASSERT-KTH/elle-elle-…
Browse files Browse the repository at this point in the history
…aime into codellama-instruct
  • Loading branch information
andre15silva committed Aug 22, 2024
2 parents 264c960 + 0d27b0f commit dc58515
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 32 deletions.
49 changes: 39 additions & 10 deletions elleelleaime/core/benchmarks/gitbugjava/gitbugjava.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess
import logging
import tqdm
import json
import re
import os


Expand Down Expand Up @@ -53,12 +53,41 @@ def initialize(self) -> None:
logging.info("Found %3d bugs" % len(bids))

for bid in tqdm.tqdm(bids, "Loading GitBug-Java"):
pid = bid.rsplit("-", 1)[0]
diff = ""
with open(f"{self.path}/data/bugs/{pid}.json", "r") as f:
for line in f:
bug_info = json.loads(line)
if bug_info["commit_hash"][:12] in bid:
diff = bug_info["bug_patch"]
break
self.add_bug(GitBugJavaBug(self, bid, diff))
# Run info command
run = self.run_command(
f"info {bid}",
check=True,
)
stdout = run.stdout.decode("utf-8")

# Get diff (after "### Bug Patch", between triple ticks)
diff = stdout.split("### Bug Patch")[1].split("```diff")[1].split("```")[0]

# Get failing tests
# The info command prints out the failing tests in the following format
# - failing test
# - type of failure
# - failure message
failing_tests = {}
stdout = stdout.split("### Failing Tests")[1]
for test in re.split(r"(^-)", stdout):
# Split the three lines
info = test.strip().split("\n")

# Extract failing test class and method
failing_test_case = info[0].replace("-", "", 1).strip()
failing_test_case = (
failing_test_case.replace(":", "::")
.replace("#", "::")
.replace("()", "")
)
# Remove value between '$' and '::' if it exists (happens for jitterted tests)
failing_test_case = re.sub(r"\$.*?::", "::", failing_test_case)

# Extract cause
cause = info[2].replace("-", "", 1).strip()
if cause == "None":
cause = info[1].replace("-", "", 1).strip()
failing_tests[failing_test_case] = cause

self.add_bug(GitBugJavaBug(self, bid, diff, failing_tests))
17 changes: 13 additions & 4 deletions elleelleaime/core/benchmarks/gitbugjava/gitbugjavabug.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,24 @@
import os
from elleelleaime.core.benchmarks.benchmark import Benchmark

from elleelleaime.core.benchmarks.bug import Bug
from elleelleaime.core.benchmarks.bug import RichBug
from elleelleaime.core.benchmarks.test_result import TestResult
from elleelleaime.core.benchmarks.compile_result import CompileResult


class GitBugJavaBug(Bug):
class GitBugJavaBug(RichBug):
"""
The class for representing GitBug-Java bugs
"""

def __init__(self, benchmark: Benchmark, bid: str, ground_truth: str) -> None:
super().__init__(benchmark, bid, ground_truth, False)
def __init__(
self,
benchmark: Benchmark,
bid: str,
ground_truth: str,
failing_tests: dict[str, str],
) -> None:
super().__init__(benchmark, bid, ground_truth, failing_tests, False)

def checkout(self, path: str, fixed: bool = False) -> bool:
# Remove the directory if it exists
Expand Down Expand Up @@ -44,3 +50,6 @@ def test(self, path: str) -> TestResult:
)
except subprocess.TimeoutExpired:
return TestResult(False)

def get_src_test_dir(self, path: str) -> str:
return path
35 changes: 30 additions & 5 deletions elleelleaime/core/utils/java/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unidiff import PatchSet
from uuid import uuid4
from pathlib import Path
import logging
import getpass, tempfile, difflib, shutil
import subprocess
import re
Expand Down Expand Up @@ -239,6 +240,32 @@ def extract_single_function(bug: Bug) -> Optional[Tuple[str, str]]:
shutil.rmtree(fixed_path, ignore_errors=True)


def find_test_class(path: Path, bug, class_name: str) -> Optional[Path]:
# Get the base test directory
base_test_dir = Path(path, bug.get_src_test_dir(str(path)))

# Convert class name to the relative path format
class_relative_path = f"{class_name.replace('.', '/')}.java"

# Iterate through all the subdirectories under the base test directory
candidates = []
for java_file in base_test_dir.rglob("*.java"):
# Check if the file ends with the class relative path
if java_file.as_posix().endswith(class_relative_path):
candidates.append(
java_file
) # Return the full path to the matched Java file

if len(candidates) == 0:
logging.error(f"No test class found for {class_name}")
return None
elif len(candidates) == 1:
return candidates[0]
else:
logging.error(f"Multiple test classes found for {class_name}")
return None


def extract_failing_test_cases(bug: RichBug) -> dict[str, str]:
"""
Extracts the code of the failing test cases of a bug.
Expand All @@ -263,11 +290,9 @@ def extract_failing_test_cases(bug: RichBug) -> dict[str, str]:
)
try:
bug.checkout(str(path), fixed=False)
test_class_path = Path(
path,
bug.get_src_test_dir(str(path)),
f"{class_name.replace('.', '/')}.java",
)
test_class_path = find_test_class(path, bug, class_name)
if test_class_path is None:
return {}

# Run code extractor for the failing test case
run = subprocess.run(
Expand Down
14 changes: 9 additions & 5 deletions evaluate_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,13 @@ def compute_statistics(samples: list) -> dict:
if sample["prompt"]:
statistics["num_bugs_with_prompt"] += 1
if sample["generation"] and any(
candidate["generation"] for candidate in sample["evaluation"]
candidate["generation"] if candidate is not None else None
for candidate in sample["evaluation"]
):
statistics["num_bugs_with_candidates"] += 1
statistics["num_patches"] += sum(
bool(candidate["generation"]) for candidate in sample["evaluation"]
bool(candidate["generation"]) if candidate is not None else False
for candidate in sample["evaluation"]
)
statistics["num_compilable_patches"] += sum(
compilable(candidate) for candidate in sample["evaluation"]
Expand Down Expand Up @@ -214,7 +216,8 @@ def export_patches(samples: list, dir_path: str) -> None:

for sample in tqdm.tqdm(samples):
if not sample["generation"] or all(
candidate["generation"] is None for candidate in sample["evaluation"]
candidate["generation"] is None if candidate is not None else None
for candidate in sample["evaluation"]
):
continue

Expand All @@ -238,7 +241,7 @@ def export_patches(samples: list, dir_path: str) -> None:
f.write(sample["prompt"])

for i, candidate in enumerate(sample["evaluation"]):
if not candidate["generation"]:
if candidate is None or not candidate["generation"]:
continue

# Compute diff between generated code and buggy code
Expand Down Expand Up @@ -284,7 +287,8 @@ def export_bugs(samples, dir_path):
if sample["generation"] is not None
and len(sample["generation"]) > 0
and not all(
candidate["generation"] is None for candidate in sample["evaluation"]
candidate["generation"] is None if candidate is not None else None
for candidate in sample["evaluation"]
)
]
)
Expand Down
96 changes: 88 additions & 8 deletions tests/sample/instruct/test_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,27 @@
from elleelleaime.core.utils.benchmarks import get_benchmark
from elleelleaime.core.benchmarks.benchmark import Benchmark

import pytest
import os

class TestInstructPrompting:

class TestInstructPromptingDefects4J:
DEFECTS4J: Benchmark
PROMPT_STRATEGY: str = "instruct"

@classmethod
def setup_class(cls):
TestInstructPrompting.DEFECTS4J = get_benchmark("defects4j")
assert TestInstructPrompting.DEFECTS4J is not None
TestInstructPrompting.DEFECTS4J.initialize()
TestInstructPromptingDefects4J.DEFECTS4J = get_benchmark("defects4j")
assert TestInstructPromptingDefects4J.DEFECTS4J is not None
TestInstructPromptingDefects4J.DEFECTS4J.initialize()

def test_closure_115(self):
bug = TestInstructPrompting.DEFECTS4J.get_bug("Closure-115")
bug = TestInstructPromptingDefects4J.DEFECTS4J.get_bug("Closure-115")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestInstructPrompting.PROMPT_STRATEGY,
prompt_strategy=TestInstructPromptingDefects4J.PROMPT_STRATEGY,
)

# Assert we are dealing with the correct bug and strategy
Expand All @@ -45,12 +48,12 @@ def test_closure_115(self):
)

def test_closure_4(self):
bug = TestInstructPrompting.DEFECTS4J.get_bug("Closure-4")
bug = TestInstructPromptingDefects4J.DEFECTS4J.get_bug("Closure-4")
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestInstructPrompting.PROMPT_STRATEGY,
prompt_strategy=TestInstructPromptingDefects4J.PROMPT_STRATEGY,
)

# Assert we are dealing with the correct bug and strategy
Expand All @@ -68,3 +71,80 @@ def test_closure_4(self):
"/**\n * Resolve the referenced type within the enclosing scope.\n */"
in sample["prompt"]
)


class TestInstructPromptingGitBugJava:
GITBUGJAVA: Benchmark
PROMPT_STRATEGY: str = "instruct"

@classmethod
def setup_class(cls):
TestInstructPromptingGitBugJava.GITBUGJAVA = get_benchmark("gitbugjava")
assert TestInstructPromptingGitBugJava.GITBUGJAVA is not None
TestInstructPromptingGitBugJava.GITBUGJAVA.initialize()

@pytest.mark.skipif(
os.environ.get("CI") is not None,
reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.",
)
def test_traccar_traccar_37ed394724c0(self):
bug = TestInstructPromptingGitBugJava.GITBUGJAVA.get_bug(
"traccar-traccar-37ed394724c0"
)
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestInstructPromptingGitBugJava.PROMPT_STRATEGY,
)

# Assert we are dealing with the correct bug and strategy
assert sample["identifier"] == "traccar-traccar-37ed394724c0"
assert sample["prompt_strategy"] == "instruct"

# Assert that the prompt is properly constructed
assert sample["prompt"] is not None

@pytest.mark.skipif(
os.environ.get("CI") is not None,
reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.",
)
def test_TheAlgorithms_Java_e5c7a08874a6(self):
bug = TestInstructPromptingGitBugJava.GITBUGJAVA.get_bug(
"TheAlgorithms-Java-e5c7a08874a6"
)
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestInstructPromptingGitBugJava.PROMPT_STRATEGY,
)

# Assert we are dealing with the correct bug and strategy
assert sample["identifier"] == "TheAlgorithms-Java-e5c7a08874a6"
assert sample["prompt_strategy"] == "instruct"

# Assert that the prompt is properly constructed
assert sample["prompt"] is not None

@pytest.mark.skipif(
os.environ.get("CI") is not None,
reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.",
)
def test_BrightSpots_rcv_688920f27706(self):
bug = TestInstructPromptingGitBugJava.GITBUGJAVA.get_bug(
"BrightSpots-rcv-688920f27706"
)
assert bug is not None

sample = generate_sample(
bug=bug,
prompt_strategy=TestInstructPromptingGitBugJava.PROMPT_STRATEGY,
)

# Assert we are dealing with the correct bug and strategy
assert sample["identifier"] == "BrightSpots-rcv-688920f27706"
assert sample["prompt_strategy"] == "instruct"

# Assert that the prompt is properly constructed
assert sample["prompt"] is None

0 comments on commit dc58515

Please sign in to comment.