From 07f2b94dd045ac7b9d25cf9ae4ab3d288724b7c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Thu, 8 Aug 2024 20:07:03 +0200 Subject: [PATCH 1/7] remove deprecated features --- elleelleaime/core/benchmarks/bears/bears.py | 32 -- .../core/benchmarks/bears/bearsbug.py | 89 ---- elleelleaime/core/utils/benchmarks.py | 2 - .../sample/prompting/strategies/__init__.py | 0 .../strategies/fill_in_the_middle.py | 147 ------ .../strategies/function_to_function.py | 67 --- .../sample/{prompting => }/registry.py | 8 +- .../bears => sample/strategies}/__init__.py | 0 .../infilling.py} | 11 +- .../{prompting => }/strategies/instruct.py | 2 +- .../sample/{prompting => }/strategy.py | 0 generate_samples.py | 2 +- tests/core/benchmarks/bears/__init__.py | 0 tests/core/benchmarks/bears/test_bears.py | 163 ------- tests/sample/fill_in_the_middle/__init__.py | 0 .../fill_in_the_middle/test_starcoder.py | 415 ---------------- tests/sample/function_to_function/__init__.py | 0 .../test_function_to_function.py | 399 ---------------- .../sample/infilling}/__init__.py | 0 .../test_codellama.py | 164 +++---- tests/sample/zero_shot_cloze/__init__.py | 0 tests/sample/zero_shot_cloze/test_incoder.py | 446 ------------------ 22 files changed, 89 insertions(+), 1858 deletions(-) delete mode 100644 elleelleaime/core/benchmarks/bears/bears.py delete mode 100644 elleelleaime/core/benchmarks/bears/bearsbug.py delete mode 100644 elleelleaime/sample/prompting/strategies/__init__.py delete mode 100644 elleelleaime/sample/prompting/strategies/fill_in_the_middle.py delete mode 100644 elleelleaime/sample/prompting/strategies/function_to_function.py rename elleelleaime/sample/{prompting => }/registry.py (60%) rename elleelleaime/{core/benchmarks/bears => sample/strategies}/__init__.py (100%) rename elleelleaime/sample/{prompting/strategies/zero_shot_cloze.py => strategies/infilling.py} (96%) rename elleelleaime/sample/{prompting => }/strategies/instruct.py (97%) rename elleelleaime/sample/{prompting => }/strategy.py (100%) delete mode 100644 tests/core/benchmarks/bears/__init__.py delete mode 100644 tests/core/benchmarks/bears/test_bears.py delete mode 100644 tests/sample/fill_in_the_middle/__init__.py delete mode 100644 tests/sample/fill_in_the_middle/test_starcoder.py delete mode 100644 tests/sample/function_to_function/__init__.py delete mode 100644 tests/sample/function_to_function/test_function_to_function.py rename {elleelleaime/sample/prompting => tests/sample/infilling}/__init__.py (100%) rename tests/sample/{zero_shot_cloze => infilling}/test_codellama.py (79%) delete mode 100644 tests/sample/zero_shot_cloze/__init__.py delete mode 100644 tests/sample/zero_shot_cloze/test_incoder.py diff --git a/elleelleaime/core/benchmarks/bears/bears.py b/elleelleaime/core/benchmarks/bears/bears.py deleted file mode 100644 index ffa54c6c..00000000 --- a/elleelleaime/core/benchmarks/bears/bears.py +++ /dev/null @@ -1,32 +0,0 @@ -from pathlib import Path -from elleelleaime.core.benchmarks.benchmark import Benchmark -from elleelleaime.core.benchmarks.bears.bearsbug import BearsBug - -import logging -import json - - -class Bears(Benchmark): - """ - The class for representing the Bears benchmark. - """ - - def __init__(self, path: Path = Path("benchmarks/bears").absolute()) -> None: - super().__init__("bears", path) - - def initialize(self) -> None: - """ - Initializes the Bears benchmark object by collecting the list of all projects and bugs. - """ - logging.info("Initializing Bears benchmark...") - - # Get all bug ids - json_path = Path(self.path, "scripts", "data", "bug_id_and_branch.json") - with open(json_path, "r") as json_file: - bugs = json.load(json_file) - logging.info("Found %3d bugs" % len(bugs)) - - # Initialize dataset - # TODO: compute diffs, store them, load them from file - for bug in bugs: - self.add_bug(BearsBug(self, bug["bugId"], "")) diff --git a/elleelleaime/core/benchmarks/bears/bearsbug.py b/elleelleaime/core/benchmarks/bears/bearsbug.py deleted file mode 100644 index 6e382c4f..00000000 --- a/elleelleaime/core/benchmarks/bears/bearsbug.py +++ /dev/null @@ -1,89 +0,0 @@ -import subprocess -import tempfile -import backoff -import shutil -import getpass -import os - -from pathlib import Path -from uuid import uuid4 - -from elleelleaime.core.benchmarks.bug import Bug -from elleelleaime.core.benchmarks.test_result import TestResult -from elleelleaime.core.benchmarks.compile_result import CompileResult - - -class BearsBug(Bug): - """ - The class for representing Bears bugs - """ - - def checkout(self, path: str, fixed: bool = False) -> bool: - try: - # Remove the directory if it exists - shutil.rmtree(path, ignore_errors=True) - - # Create the directory - Path(path).mkdir(parents=True, exist_ok=True) - - # Copy the benchmark elsewhere to avoid conflicts - temp_benchmark_path = Path( - tempfile.gettempdir(), f"elleelleaime-{getpass.getuser()}", str(uuid4()) - ) - shutil.copytree( - self.benchmark.get_path(), temp_benchmark_path, dirs_exist_ok=True - ) - # We must copy also the .git directory - os.remove(Path(temp_benchmark_path, ".git/")) - shutil.copytree( - Path(self.benchmark.get_path(), "../../.git/modules/benchmarks/bears"), - Path(temp_benchmark_path, ".git/"), - dirs_exist_ok=True, - ) - # And remove the worktree from the config file - # Remove " worktree = ../../../../benchmarks/bears" from the config file using sed - config_path = Path(temp_benchmark_path, ".git", "config") - sed_run = subprocess.run( - f"sed -i '/worktree = \.\.\/\.\.\/\.\.\/\.\.\/benchmarks\/bears/d' {config_path}", - shell=True, - capture_output=True, - check=True, - ) - - # Run checkout script - checkout_run = subprocess.run( - f"cd {temp_benchmark_path} && python scripts/checkout_bug.py --bugId {self.identifier} --workspace {path}", - shell=True, - capture_output=True, - check=True, - ) - - # Checkout the fixed version if needed - if fixed: - subprocess.run( - f"cd {path} && git checkout -", - shell=True, - capture_output=True, - check=True, - ) - - return sed_run.returncode == 0 and checkout_run.returncode == 0 - finally: - # Remove the temporary directory - shutil.rmtree(temp_benchmark_path, ignore_errors=True) - - def compile(self, path: str) -> CompileResult: - run = subprocess.run( - f"cd {self.benchmark.get_path()} && timeout {5*60} python scripts/compile_bug.py --bugId {self.identifier} --workspace {path}", - shell=True, - capture_output=True, - ) - return CompileResult(run.returncode == 0) - - def test(self, path: str) -> TestResult: - run = subprocess.run( - f"cd {self.benchmark.get_path()} && timeout {30*60} python scripts/run_tests_bug.py --bugId {self.identifier} --workspace {path}", - shell=True, - capture_output=True, - ) - return TestResult(run.returncode == 0) diff --git a/elleelleaime/core/utils/benchmarks.py b/elleelleaime/core/utils/benchmarks.py index 9b2cf99e..2c421db6 100644 --- a/elleelleaime/core/utils/benchmarks.py +++ b/elleelleaime/core/utils/benchmarks.py @@ -2,7 +2,6 @@ from elleelleaime.core.benchmarks.defects4j.defects4j import Defects4J from elleelleaime.core.benchmarks.humanevaljava.humanevaljava import HumanEvalJava from elleelleaime.core.benchmarks.quixbugs.quixbugs import QuixBugs -from elleelleaime.core.benchmarks.bears.bears import Bears from elleelleaime.core.benchmarks.gitbugjava.gitbugjava import GitBugJava from typing import Optional @@ -11,7 +10,6 @@ "Defects4J": Defects4J, "HumanEvalJava": HumanEvalJava, "QuixBugs": QuixBugs, - "Bears": Bears, "GitBugJava": GitBugJava, } diff --git a/elleelleaime/sample/prompting/strategies/__init__.py b/elleelleaime/sample/prompting/strategies/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/elleelleaime/sample/prompting/strategies/fill_in_the_middle.py b/elleelleaime/sample/prompting/strategies/fill_in_the_middle.py deleted file mode 100644 index 4b2a3197..00000000 --- a/elleelleaime/sample/prompting/strategies/fill_in_the_middle.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Optional, Tuple, Union -from unidiff import PatchSet - -from elleelleaime.sample.prompting.strategy import PromptingStrategy -from elleelleaime.core.benchmarks.bug import Bug -from elleelleaime.core.utils.java.java import ( - extract_single_function, - compute_diff, - remove_java_comments, - remove_empty_lines, -) - - -class FillInTheMiddlePrompting(PromptingStrategy): - """ - Implements the fill-in-the-middle style prompt strategy for single diff file. - """ - - # MODEL_DICT is a dictionary of model names and their corresponding kwargs - MODEL_DICT = { - "starcoder": { - "prefix_token": "", - "middle_token": "", - "sufix_token": "", - }, - # Add the model you want to use here - } - - def __init__(self, **kwargs): - super().__init__("fill-in-the-middle") - - self.model_name: str = kwargs.get("model_name", "").strip().lower() - assert ( - self.model_name in self.MODEL_DICT.keys() - ), f"Unknown model name: {kwargs.get('model_name', None)}" - model_kwargs = self.MODEL_DICT.get(self.model_name, {}) - self.prefix_token: str = model_kwargs["prefix_token"] - self.middle_token: str = model_kwargs["middle_token"] - self.sufix_token: str = model_kwargs["sufix_token"] - self.keep_buggy_code: bool = kwargs.get("keep_buggy_code", False) - self.keep_comments: bool = kwargs.get("keep_comments", True) - - def build_fim_prompt(self, buggy_code: str, fixed_code: str) -> str: - fdiff = compute_diff(buggy_code, fixed_code) - - # Iterate over the diff to get the prefix, middle, and suffix parts - prefix = [True, ""] - middle = "" - suffix = [False, ""] - for line in fdiff: - if any(line.startswith(x) for x in ["---", "+++", "@@"]): - continue - elif any(line.startswith(x) for x in ["+", "-"]): - prefix[0] = False - suffix[0] = True - middle += suffix[1] - suffix[1] = "" - if line.startswith("-"): - middle += line[1:] - else: - if prefix[0]: - prefix[1] += line[1:] - elif suffix[0]: - suffix[1] += line[1:] - - if self.keep_buggy_code: - buggy_comment = "// buggy code\n" - if middle.strip() != "": - for line in middle.splitlines(keepends=True): - buggy_comment += "//" + line - prompt = ( - f"{self.prefix_token}" - + prefix[1] - + buggy_comment - + f"{self.sufix_token}" - + suffix[1] - + f"{self.middle_token}" - ) - else: - prompt = ( - f"{self.prefix_token}" - + prefix[1] - + f"{self.sufix_token}" - + suffix[1] - + f"{self.middle_token}" - ) - - return prompt - - def cloze_prompt( - self, bug: Bug - ) -> Tuple[Optional[str], Optional[str], Optional[str]]: - """ - Builds a cloze prompt for the given bug. - - Args: - bug: The bug to generate the prompt for. - Returns: - Tuple: A tuple of the form (buggy_code, fixed_code, prompt). - """ - result = extract_single_function(bug) - - if result is None: - return None, None, None - - buggy_code, fixed_code = result - - if not self.keep_comments: - buggy_code_prompt = remove_java_comments(buggy_code) - fixed_code_prompt = remove_java_comments(fixed_code) - else: - buggy_code_prompt = buggy_code - fixed_code_prompt = fixed_code - - buggy_code_prompt = remove_empty_lines(buggy_code_prompt) - fixed_code_prompt = remove_empty_lines(fixed_code_prompt) - - prompt = self.build_fim_prompt(buggy_code_prompt, fixed_code_prompt) - - return buggy_code, fixed_code, prompt - - def prompt(self, bug: Bug) -> dict[str, Optional[str]]: - """ - Returns the prompt for the given bug. - - :param bug: The bug to generate the prompt for. - """ - result = { - "identifier": bug.get_identifier(), - "buggy_code": None, - "fixed_code": None, - "prompt_strategy": self.strategy_name, - "prompt": None, - "ground_truth": bug.get_ground_truth(), - } - - diff = PatchSet(bug.get_ground_truth()) - # This strategy only supports single-file prompts - if len(diff) != 1: - return result - - ( - result["buggy_code"], - result["fixed_code"], - result["prompt"], - ) = self.cloze_prompt(bug) - return result diff --git a/elleelleaime/sample/prompting/strategies/function_to_function.py b/elleelleaime/sample/prompting/strategies/function_to_function.py deleted file mode 100644 index db6d5cf4..00000000 --- a/elleelleaime/sample/prompting/strategies/function_to_function.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Optional, Tuple, Union -from unidiff import PatchSet - -from elleelleaime.sample.prompting.strategy import PromptingStrategy -from elleelleaime.core.benchmarks.bug import Bug -from elleelleaime.core.utils.java.java import extract_single_function - - -class FunctionToFunctionPrompting(PromptingStrategy): - """ - Implements the function-to-function style prompt strategy. - """ - - def __init__(self, **kwargs): - super().__init__("function-to-function") - - self.model_name: str = kwargs.get("model_name", "").strip().lower() - # TODO: add_fault_localization - - def function_to_function( - self, bug: Bug - ) -> Tuple[Optional[str], Optional[str], Optional[str]]: - """ - Builds a function-to-function prompt for the given bug. - - Args: - bug: The bug to generate the prompt for. - Returns: - Tuple: A tuple of the form (buggy_code, fixed_code, prompt). - """ - result = extract_single_function(bug) - if result is None: - return None, None, None - - buggy_code, fixed_code = result - - # TODO: add fault localization option - prompt = buggy_code - - return buggy_code, fixed_code, prompt - - def prompt(self, bug: Bug) -> dict[str, Optional[str]]: - """ - Returns the prompt for the given bug. - - :param bug: The bug to generate the prompt for. - """ - result = { - "identifier": bug.get_identifier(), - "buggy_code": None, - "fixed_code": None, - "prompt_strategy": self.strategy_name, - "prompt": None, - "ground_truth": bug.get_ground_truth(), - } - - diff = PatchSet(bug.get_ground_truth()) - # This strategy only supports single-file prompts - if len(diff) != 1: - return result - - ( - result["buggy_code"], - result["fixed_code"], - result["prompt"], - ) = self.function_to_function(bug) - return result diff --git a/elleelleaime/sample/prompting/registry.py b/elleelleaime/sample/registry.py similarity index 60% rename from elleelleaime/sample/prompting/registry.py rename to elleelleaime/sample/registry.py index 250dcdd4..e1cb18d3 100644 --- a/elleelleaime/sample/prompting/registry.py +++ b/elleelleaime/sample/registry.py @@ -1,7 +1,5 @@ from .strategy import PromptingStrategy -from .strategies.zero_shot_cloze import ZeroShotClozePrompting -from .strategies.fill_in_the_middle import FillInTheMiddlePrompting -from .strategies.function_to_function import FunctionToFunctionPrompting +from .strategies.infilling import InfillingPrompting from .strategies.instruct import InstructPrompting @@ -11,9 +9,7 @@ class PromptStrategyRegistry: """ __STRATEGIES: dict[str, type] = { - "zero-shot-cloze": ZeroShotClozePrompting, - "fill-in-the-middle": FillInTheMiddlePrompting, - "function-to-function": FunctionToFunctionPrompting, + "infilling": InfillingPrompting, "instruct": InstructPrompting, } diff --git a/elleelleaime/core/benchmarks/bears/__init__.py b/elleelleaime/sample/strategies/__init__.py similarity index 100% rename from elleelleaime/core/benchmarks/bears/__init__.py rename to elleelleaime/sample/strategies/__init__.py diff --git a/elleelleaime/sample/prompting/strategies/zero_shot_cloze.py b/elleelleaime/sample/strategies/infilling.py similarity index 96% rename from elleelleaime/sample/prompting/strategies/zero_shot_cloze.py rename to elleelleaime/sample/strategies/infilling.py index 28d8cd9d..a9bd1065 100644 --- a/elleelleaime/sample/prompting/strategies/zero_shot_cloze.py +++ b/elleelleaime/sample/strategies/infilling.py @@ -1,8 +1,8 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple from unidiff import PatchSet import re -from elleelleaime.sample.prompting.strategy import PromptingStrategy +from elleelleaime.sample.strategy import PromptingStrategy from elleelleaime.core.benchmarks.bug import Bug from elleelleaime.core.utils.java.java import ( extract_single_function, @@ -12,18 +12,13 @@ ) -class ZeroShotClozePrompting(PromptingStrategy): +class InfillingPrompting(PromptingStrategy): """ Implements the zero-shot cloze style prompt strategy for single diff file. """ # MODEL_DICT is a dictionary of model names and their corresponding kwargs MODEL_DICT = { - "incoder": { - "mask_token": "<|mask:{}|>", - "extra_mask_token": True, - "single_chunk": False, - }, "codellama": { "mask_token": "", "extra_mask_token": False, diff --git a/elleelleaime/sample/prompting/strategies/instruct.py b/elleelleaime/sample/strategies/instruct.py similarity index 97% rename from elleelleaime/sample/prompting/strategies/instruct.py rename to elleelleaime/sample/strategies/instruct.py index c29a88f6..8c71d800 100644 --- a/elleelleaime/sample/prompting/strategies/instruct.py +++ b/elleelleaime/sample/strategies/instruct.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple from unidiff import PatchSet -from elleelleaime.sample.prompting.strategy import PromptingStrategy +from elleelleaime.sample.strategy import PromptingStrategy from elleelleaime.core.benchmarks.bug import RichBug from elleelleaime.core.utils.java.java import ( extract_single_function, diff --git a/elleelleaime/sample/prompting/strategy.py b/elleelleaime/sample/strategy.py similarity index 100% rename from elleelleaime/sample/prompting/strategy.py rename to elleelleaime/sample/strategy.py diff --git a/generate_samples.py b/generate_samples.py index d73d667a..1152476b 100644 --- a/generate_samples.py +++ b/generate_samples.py @@ -3,7 +3,7 @@ from elleelleaime.core.utils.jsonl import write_jsonl from elleelleaime.core.benchmarks.bug import Bug from typing import Optional, Union -from elleelleaime.sample.prompting.registry import PromptStrategyRegistry +from elleelleaime.sample.registry import PromptStrategyRegistry import fire import traceback diff --git a/tests/core/benchmarks/bears/__init__.py b/tests/core/benchmarks/bears/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/core/benchmarks/bears/test_bears.py b/tests/core/benchmarks/bears/test_bears.py deleted file mode 100644 index 75349ec9..00000000 --- a/tests/core/benchmarks/bears/test_bears.py +++ /dev/null @@ -1,163 +0,0 @@ -from elleelleaime.core.utils.benchmarks import get_benchmark -from elleelleaime.core.benchmarks.bug import Bug - -from pathlib import Path -import uuid -import shutil -import pytest -import tqdm -import getpass, tempfile -import concurrent.futures - - -@pytest.mark.skip(reason="Bears will be deprecated") -class TestBears: - def test_get_benchmark(self): - bears = get_benchmark("bears") - assert bears is not None - bears.initialize() - - bugs = bears.get_bugs() - - assert bugs is not None - assert len(bugs) == 77 - assert len(set([bug.get_identifier() for bug in bugs])) == 77 - - def checkout_bug(self, bug: Bug) -> bool: - buggy_path = f"{tempfile.gettempdir()}/elleelleaime-{getpass.getuser()}/{bug.get_identifier()}-buggy-{uuid.uuid4()}" - fixed_path = f"{tempfile.gettempdir()}/elleelleaime-{getpass.getuser()}/{bug.get_identifier()}-fixed-{uuid.uuid4()}" - try: - # Checkout buggy version - ret = bug.checkout(buggy_path, fixed=False) - if not ret: - return False - # Checkout fixed version - ret = bug.checkout(fixed_path, fixed=True) - if not ret: - return False - - # Assert that there are files in the directories - if len(list(Path(buggy_path).glob("**/*"))) == 0: - return False - if len(list(Path(fixed_path).glob("**/*"))) == 0: - return False - - # Assert that we can reach the java file - if not Path(buggy_path, "pom.xml").exists(): - return False - if not Path(fixed_path, "pom.xml").exists(): - return False - - return True - finally: - shutil.rmtree(buggy_path, ignore_errors=True) - shutil.rmtree(fixed_path, ignore_errors=True) - - def test_checkout_bugs(self): - bears = get_benchmark("bears") - assert bears is not None - bears.initialize() - - # This test takes a while, so we limit to 3 bugs. - bugs = list(bears.get_bugs())[:3] - assert bugs is not None - - for bug in bugs: - assert self.checkout_bug(bug), f"Failed checkout for {bug.get_identifier()}" - - @pytest.mark.skip(reason="This test is too slow to run on CI.") - def test_checkout_all_bugs(self): - bears = get_benchmark("bears") - assert bears is not None - bears.initialize() - - bugs = list(bears.get_bugs()) - assert bugs is not None - - for bug in bugs: - assert self.checkout_bug(bug), f"Failed checkout for {bug.get_identifier()}" - - def run_bug(self, bug: Bug) -> bool: - buggy_path = f"{tempfile.gettempdir()}/elleelleaime-{getpass.getuser()}/{bug.get_identifier()}-buggy-{uuid.uuid4()}" - fixed_path = f"{tempfile.gettempdir()}/elleelleaime-{getpass.getuser()}/{bug.get_identifier()}-fixed-{uuid.uuid4()}" - - try: - # Checkout buggy version - ret = bug.checkout(buggy_path, fixed=False) - assert ret, "Failed checkout for {bug.get_identifier()}-buggy" - # Checkout fixed version - ret = bug.checkout(fixed_path, fixed=True) - assert ret, f"Failed checkout for {bug.get_identifier()}-fixed" - - # Compile buggy version - compile_result = bug.compile(buggy_path) - if not compile_result.is_passing(): - return False - - # Test buggy version - test_result = bug.test(buggy_path) - if test_result.is_passing(): - return False - - # Compile fixed version - compile_result = bug.compile(fixed_path) - if not compile_result.is_passing(): - return False - - # Test fixed version - test_result = bug.test(fixed_path) - if not test_result.is_passing(): - return False - - return True - finally: - shutil.rmtree(buggy_path, ignore_errors=True) - shutil.rmtree(fixed_path, ignore_errors=True) - - def test_run_bugs(self): - bears = get_benchmark("bears") - assert bears is not None - bears.initialize() - - # This test takes a while, so we limit to 3 bugs. - bugs = list(bears.get_bugs())[:3] - assert bugs is not None - - with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [] - futures_to_bugs = {} - for bug in bugs: - # Submit the bug to be tested as a separate task - futures.append(executor.submit(self.run_bug, bug)) - futures_to_bugs[futures[-1]] = bug - # Wait for all tasks to complete - for future in tqdm.tqdm(concurrent.futures.as_completed(futures)): - result = future.result() - if not result: - assert ( - result - ), f"Failed for {futures_to_bugs[future].get_identifier()}" - - @pytest.mark.skip(reason="This test is too slow to run on CI.") - def test_run_all_bugs(self): - bears = get_benchmark("bears") - assert bears is not None - bears.initialize() - - bugs = list(bears.get_bugs()) - assert bugs is not None - - with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [] - futures_to_bugs = {} - for bug in bugs: - # Submit the bug to be tested as a separate task - futures.append(executor.submit(self.run_bug, bug)) - futures_to_bugs[futures[-1]] = bug - # Wait for all tasks to complete - for future in tqdm.tqdm(concurrent.futures.as_completed(futures)): - result = future.result() - if not result: - assert ( - result - ), f"Failed for {futures_to_bugs[future].get_identifier()}" diff --git a/tests/sample/fill_in_the_middle/__init__.py b/tests/sample/fill_in_the_middle/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/sample/fill_in_the_middle/test_starcoder.py b/tests/sample/fill_in_the_middle/test_starcoder.py deleted file mode 100644 index 5748538e..00000000 --- a/tests/sample/fill_in_the_middle/test_starcoder.py +++ /dev/null @@ -1,415 +0,0 @@ -from generate_samples import generate_sample -from elleelleaime.core.utils.benchmarks import get_benchmark -from elleelleaime.core.benchmarks.benchmark import Benchmark - - -class TestFillInTheMiddleSamplesStarCoder: - DEFECTS4J: Benchmark - PROMPT_STRATEGY: str = "fill-in-the-middle" - MODEL_NAME: str = "starcoder" - - @classmethod - def setup_class(cls): - TestFillInTheMiddleSamplesStarCoder.DEFECTS4J = get_benchmark("defects4j") - assert TestFillInTheMiddleSamplesStarCoder.DEFECTS4J is not None - TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.initialize() - - def test_closure_46(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Closure-46") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-46" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "public JSType getLeastSupertype(JSType that) {" in sample["buggy_code"] - assert sample["fixed_code"] == "" - - def test_closure_115(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Closure-115") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-115" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "boolean hasSideEffects = false;" in sample["buggy_code"] - assert "boolean hasSideEffects = false;" not in sample["fixed_code"] - assert ( - "if (hasSideEffects && NodeUtil.canBeSideEffected(cArg)) {" - in sample["buggy_code"] - ) - assert ( - "if (hasSideEffects && NodeUtil.canBeSideEffected(cArg)) {" - not in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " /**\n * Determines whether a function can be inlined at a particular call site." - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_closure_4(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Closure-4") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-4" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "if (detectImplicitPrototypeCycle()) {" in sample["buggy_code"] - assert "if (detectImplicitPrototypeCycle()) {" not in sample["fixed_code"] - assert "if (detectInheritanceCycle()) {" not in sample["buggy_code"] - assert "if (detectInheritanceCycle()) {" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " /**\n * Resolve the referenced type within the enclosing scope.\n */" - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_chart_4(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Chart-4") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-4" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert ( - """ if (r != null) { - Collection c = r.getAnnotations();""" - not in sample["buggy_code"] - ) - assert ( - """ if (r != null) { - Collection c = r.getAnnotations();""" - in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " /**\n * Returns the range for the specified axis." - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_chart_2(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Chart-2") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-2" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_math_99(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Math-99") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Math-99" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_chart_18(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Chart-18") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-18" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_closure_11(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Closure-11") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-11" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert ( - "} else if (n.getJSType() != null && parent.isAssign()) {" - in sample["buggy_code"] - ) - assert ( - not "} else if (n.getJSType() != null && parent.isAssign()) {" - in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert sample["prompt"].startswith( - " /**\n * Visits a GETPROP node." - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_closure_5(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Closure-5") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-5" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "if (gramps.isDelProp()) {" not in sample["buggy_code"] - assert "if (gramps.isDelProp()) {" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert sample["prompt"].startswith( - " /**\n * Counts the number of direct (full) references to an object." - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_chart_6(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Chart-6") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-6" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "return super.equals(obj);" in sample["buggy_code"] - assert "return super.equals(obj);" not in sample["fixed_code"] - assert "ShapeList that = (ShapeList) obj;" not in sample["buggy_code"] - assert "ShapeList that = (ShapeList) obj;" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert sample["prompt"].startswith( - " /**\n * Tests the list for equality with another object (typically also a list)." - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_lang_3(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Lang-3") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Lang-3" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "if(numDecimals <= 7){" not in sample["buggy_code"] - assert "if(numDecimals <= 7){" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " /**\n *

Turns a string value into a java.lang.Number.

\n *" - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_closure_101(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Closure-101") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-101" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert ( - not "options.closurePass = flags.process_closure_primitives;" - in sample["buggy_code"] - ) - assert ( - "options.closurePass = flags.process_closure_primitives;" - in sample["fixed_code"] - ) - assert "if (flags.process_closure_primitives) {" in sample["buggy_code"] - assert "if (flags.process_closure_primitives) {" not in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " @Override\n protected CompilerOptions createOptions() {" - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_lang_10(self): - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Lang-10") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Lang-10" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the buggy code and fixed code are properly separated - assert "if(Character.isWhitespace(c)) {" in sample["buggy_code"] - assert "if(Character.isWhitespace(c)) {" not in sample["fixed_code"] - assert "boolean wasWhite= false;" in sample["buggy_code"] - assert "boolean wasWhite= false;" not in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " /**\n * Escape constant fields into regular expression" - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - - def test_chart_7(self): - # This is a special case that requires latin-1 encoding - bug = TestFillInTheMiddleSamplesStarCoder.DEFECTS4J.get_bug("Chart-7") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFillInTheMiddleSamplesStarCoder.PROMPT_STRATEGY, - model_name=TestFillInTheMiddleSamplesStarCoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-7" - assert sample["prompt_strategy"] == "fill-in-the-middle" - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - " /**\n * Update the index values for the maximum and minimum bounds." - ) - ) - assert sample["prompt"].endswith("") - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 - assert sample["prompt"].count("") == 1 diff --git a/tests/sample/function_to_function/__init__.py b/tests/sample/function_to_function/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/sample/function_to_function/test_function_to_function.py b/tests/sample/function_to_function/test_function_to_function.py deleted file mode 100644 index ccb974e0..00000000 --- a/tests/sample/function_to_function/test_function_to_function.py +++ /dev/null @@ -1,399 +0,0 @@ -from generate_samples import generate_sample -from elleelleaime.core.utils.benchmarks import get_benchmark -from elleelleaime.core.benchmarks.benchmark import Benchmark - - -class TestFunctionToFunctionSamples: - DEFECTS4J: Benchmark - PROMPT_STRATEGY: str = "function-to-function" - - @classmethod - def setup_class(cls): - TestFunctionToFunctionSamples.DEFECTS4J = get_benchmark("defects4j") - assert TestFunctionToFunctionSamples.DEFECTS4J is not None - TestFunctionToFunctionSamples.DEFECTS4J.initialize() - - def test_closure_46(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Closure-46") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-46" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "public JSType getLeastSupertype(JSType that) {" in sample["buggy_code"] - assert sample["fixed_code"] == "" - - def test_closure_115(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Closure-115") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-115" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "boolean hasSideEffects = false;" in sample["buggy_code"] - assert "boolean hasSideEffects = false;" not in sample["fixed_code"] - assert ( - "if (hasSideEffects && NodeUtil.canBeSideEffected(cArg)) {" - in sample["buggy_code"] - ) - assert ( - "if (hasSideEffects && NodeUtil.canBeSideEffected(cArg)) {" - not in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Determines whether a function can be inlined at a particular call site." - ) - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_closure_4(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Closure-4") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-4" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "if (detectImplicitPrototypeCycle()) {" in sample["buggy_code"] - assert "if (detectImplicitPrototypeCycle()) {" not in sample["fixed_code"] - assert "if (detectInheritanceCycle()) {" not in sample["buggy_code"] - assert "if (detectInheritanceCycle()) {" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Resolve the referenced type within the enclosing scope.\n */" - ) - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_chart_4(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Chart-4") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-4" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert ( - """ if (r != null) { - Collection c = r.getAnnotations();""" - not in sample["buggy_code"] - ) - assert ( - """ if (r != null) { - Collection c = r.getAnnotations();""" - in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith("/**\n * Returns the range for the specified axis.") - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_chart_2(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Chart-2") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-2" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_math_99(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Math-99") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Math-99" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_chart_18(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Chart-18") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-18" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_closure_11(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Closure-11") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-11" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert ( - "} else if (n.getJSType() != null && parent.isAssign()) {" - in sample["buggy_code"] - ) - assert ( - not "} else if (n.getJSType() != null && parent.isAssign()) {" - in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert sample["prompt"].strip().startswith("/**\n * Visits a GETPROP node.") - assert sample["prompt"] == sample["buggy_code"] - - def test_closure_5(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Closure-5") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-5" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "if (gramps.isDelProp()) {" not in sample["buggy_code"] - assert "if (gramps.isDelProp()) {" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Counts the number of direct (full) references to an object." - ) - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_chart_6(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Chart-6") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-6" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "return super.equals(obj);" in sample["buggy_code"] - assert "return super.equals(obj);" not in sample["fixed_code"] - assert "ShapeList that = (ShapeList) obj;" not in sample["buggy_code"] - assert "ShapeList that = (ShapeList) obj;" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Tests the list for equality with another object (typically also a list)." - ) - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_lang_3(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Lang-3") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Lang-3" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "if(numDecimals <= 7){" not in sample["buggy_code"] - assert "if(numDecimals <= 7){" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n *

Turns a string value into a java.lang.Number.

\n *" - ) - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_closure_101(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Closure-101") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-101" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert ( - not "options.closurePass = flags.process_closure_primitives;" - in sample["buggy_code"] - ) - assert ( - "options.closurePass = flags.process_closure_primitives;" - in sample["fixed_code"] - ) - assert "if (flags.process_closure_primitives) {" in sample["buggy_code"] - assert "if (flags.process_closure_primitives) {" not in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith("@Override\n protected CompilerOptions createOptions() {") - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_lang_10(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Lang-10") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Lang-10" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "if(Character.isWhitespace(c)) {" in sample["buggy_code"] - assert "if(Character.isWhitespace(c)) {" not in sample["fixed_code"] - assert "boolean wasWhite= false;" in sample["buggy_code"] - assert "boolean wasWhite= false;" not in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith("/**\n * Escape constant fields into regular expression") - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_chart_7(self): - # This is a special case that requires latin-1 encoding - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Chart-7") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-7" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Update the index values for the maximum and minimum bounds." - ) - ) - assert sample["prompt"] == sample["buggy_code"] - - def test_cli_29(self): - bug = TestFunctionToFunctionSamples.DEFECTS4J.get_bug("Cli-29") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestFunctionToFunctionSamples.PROMPT_STRATEGY, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Cli-29" - assert sample["prompt_strategy"] == "function-to-function" - - # Assert that the buggy code and fixed code are properly separated - assert "str = str.substring(1, str.length());" in sample["buggy_code"] - assert "str = str.substring(1, str.length());" not in sample["fixed_code"] - assert "str = str.substring(1, length - 1);" not in sample["buggy_code"] - assert "str = str.substring(1, length - 1);" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Remove the leading and trailing quotes from str." - ) - ) - assert sample["prompt"] == sample["buggy_code"] diff --git a/elleelleaime/sample/prompting/__init__.py b/tests/sample/infilling/__init__.py similarity index 100% rename from elleelleaime/sample/prompting/__init__.py rename to tests/sample/infilling/__init__.py diff --git a/tests/sample/zero_shot_cloze/test_codellama.py b/tests/sample/infilling/test_codellama.py similarity index 79% rename from tests/sample/zero_shot_cloze/test_codellama.py rename to tests/sample/infilling/test_codellama.py index aa7744f2..b4083ba3 100644 --- a/tests/sample/zero_shot_cloze/test_codellama.py +++ b/tests/sample/infilling/test_codellama.py @@ -6,7 +6,7 @@ import os -class TestClozeSamplesCodeLLaMA: +class TestInfillingCodeLLaMADefects4J: """ We test the generation of cloze prompts for several types of bug fixes. We only generate samples for bugs that are single-function and single-file. @@ -48,24 +48,24 @@ class TestClozeSamplesCodeLLaMA: @classmethod def setup_class(cls): - TestClozeSamplesCodeLLaMA.DEFECTS4J = get_benchmark("defects4j") - assert TestClozeSamplesCodeLLaMA.DEFECTS4J is not None - TestClozeSamplesCodeLLaMA.DEFECTS4J.initialize() - TestClozeSamplesCodeLLaMA.HUMANEVALJAVA = get_benchmark("humanevaljava") - assert TestClozeSamplesCodeLLaMA.HUMANEVALJAVA is not None - TestClozeSamplesCodeLLaMA.HUMANEVALJAVA.initialize() - TestClozeSamplesCodeLLaMA.GITBUGJAVA = get_benchmark("gitbugjava") - assert TestClozeSamplesCodeLLaMA.GITBUGJAVA is not None - TestClozeSamplesCodeLLaMA.GITBUGJAVA.initialize() + TestInfillingCodeLLaMADefects4J.DEFECTS4J = get_benchmark("defects4j") + assert TestInfillingCodeLLaMADefects4J.DEFECTS4J is not None + TestInfillingCodeLLaMADefects4J.DEFECTS4J.initialize() + TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA = get_benchmark("humanevaljava") + assert TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA is not None + TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.initialize() + TestInfillingCodeLLaMADefects4J.GITBUGJAVA = get_benchmark("gitbugjava") + assert TestInfillingCodeLLaMADefects4J.GITBUGJAVA is not None + TestInfillingCodeLLaMADefects4J.GITBUGJAVA.initialize() def test_closure_46(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-46") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-46") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -80,13 +80,13 @@ def test_closure_46(self): assert sample["prompt"].count("") == 1 def test_closure_115(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-115") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-115") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -116,13 +116,13 @@ def test_closure_115(self): assert sample["prompt"].count("") == 1 def test_closure_4(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-4") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-4") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -146,13 +146,13 @@ def test_closure_4(self): assert sample["prompt"].count("") == 1 def test_chart_4(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-4") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-4") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -180,13 +180,13 @@ def test_chart_4(self): assert sample["prompt"].count("") == 1 def test_chart_2(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-2") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-2") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -197,13 +197,13 @@ def test_chart_2(self): assert sample["prompt"] is None def test_math_99(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Math-99") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Math-99") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -214,13 +214,13 @@ def test_math_99(self): assert sample["prompt"] is None def test_chart_18(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-18") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-18") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -231,13 +231,13 @@ def test_chart_18(self): assert sample["prompt"] is None def test_closure_11(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-11") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-11") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -259,13 +259,13 @@ def test_closure_11(self): assert sample["prompt"].count("") == 1 def test_chart_1_keep_buggy_code(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-1") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-1") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) @@ -315,13 +315,13 @@ def test_chart_1_keep_buggy_code(self): ) def test_chart_5_keep_buggy_code(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-5") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-5") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) @@ -368,13 +368,13 @@ def test_chart_5_keep_buggy_code(self): ) def test_closure_11_keep_buggy_code(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-11") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-11") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) @@ -409,13 +409,13 @@ def test_closure_11_keep_buggy_code(self): ) def test_closure_2_keep_buggy_code(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-2") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-2") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) @@ -457,13 +457,13 @@ def test_closure_2_keep_buggy_code(self): ) def test_closure_5(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-5") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-5") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -485,13 +485,13 @@ def test_closure_5(self): assert sample["prompt"].count("") == 1 def test_chart_6(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-6") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-6") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -515,13 +515,13 @@ def test_chart_6(self): assert sample["prompt"].count("") == 1 def test_lang_3(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Lang-3") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Lang-3") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -543,13 +543,13 @@ def test_lang_3(self): assert sample["prompt"].count("") == 1 def test_closure_101(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Closure-101") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-101") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -577,13 +577,13 @@ def test_closure_101(self): assert sample["prompt"].count("") == 1 def test_lang_10(self): - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Lang-10") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Lang-10") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -606,13 +606,13 @@ def test_lang_10(self): def test_chart_7(self): # This is a special case that requires latin-1 encoding - bug = TestClozeSamplesCodeLLaMA.DEFECTS4J.get_bug("Chart-7") + bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-7") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -630,13 +630,13 @@ def test_chart_7(self): assert sample["prompt"].count("") == 1 def test_GET_ROW(self): - bug = TestClozeSamplesCodeLLaMA.HUMANEVALJAVA.get_bug("GET_ROW") + bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("GET_ROW") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -648,13 +648,13 @@ def test_GET_ROW(self): assert sample["prompt"].count("") == 1 def test_GET_ROW_keep_buggy_code(self): - bug = TestClozeSamplesCodeLLaMA.HUMANEVALJAVA.get_bug("GET_ROW") + bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("GET_ROW") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, ) @@ -671,13 +671,13 @@ def test_GET_ROW_keep_buggy_code(self): assert sample["prompt"].count("") == 1 def test_ADD(self): - bug = TestClozeSamplesCodeLLaMA.HUMANEVALJAVA.get_bug("ADD") + bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("ADD") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy @@ -689,13 +689,13 @@ def test_ADD(self): assert sample["prompt"].count("") == 1 def test_ADD_keep_buggy_code(self): - bug = TestClozeSamplesCodeLLaMA.HUMANEVALJAVA.get_bug("ADD") + bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("ADD") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, ) @@ -713,15 +713,15 @@ def test_ADD_keep_buggy_code(self): reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.", ) def test_traccar_traccar_37ed394724c0(self): - bug = TestClozeSamplesCodeLLaMA.GITBUGJAVA.get_bug( + bug = TestInfillingCodeLLaMADefects4J.GITBUGJAVA.get_bug( "traccar-traccar-37ed394724c0" ) assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, ) @@ -742,15 +742,15 @@ def test_traccar_traccar_37ed394724c0(self): reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.", ) def test_BrightSpots_rcv_688920f27706(self): - bug = TestClozeSamplesCodeLLaMA.GITBUGJAVA.get_bug( + bug = TestInfillingCodeLLaMADefects4J.GITBUGJAVA.get_bug( "BrightSpots-rcv-688920f27706" ) assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestClozeSamplesCodeLLaMA.PROMPT_STRATEGY, - model_name=TestClozeSamplesCodeLLaMA.MODEL_NAME, + prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, + model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, keep_buggy_code=True, ) diff --git a/tests/sample/zero_shot_cloze/__init__.py b/tests/sample/zero_shot_cloze/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/sample/zero_shot_cloze/test_incoder.py b/tests/sample/zero_shot_cloze/test_incoder.py deleted file mode 100644 index f706575e..00000000 --- a/tests/sample/zero_shot_cloze/test_incoder.py +++ /dev/null @@ -1,446 +0,0 @@ -from generate_samples import generate_sample -from elleelleaime.core.utils.benchmarks import get_benchmark -from elleelleaime.core.benchmarks.benchmark import Benchmark - - -class TestClozeSamplesIncoder: - """ - We test the generation of cloze prompts for several types of bug fixes. - We only generate samples for bugs that are single-function and single-file. - The bugs in parenthesis are the examples tested in this class. - - We test the following types of bug fixes: - - Addition only - - Single-Hunk - - N continuous lines (Closure-5) - - N non-continous lines (Lang-3) - - Whole function (Chart-23) - - Multi-Hunk - - N hunks of addition only (Chart-4) - - Removal only - - Single-Hunk - - N continuous lines (Closure-11) - - N non-continous lines (Lang-10) - - Whole function (no example found, other than Closure-46 which also changes the annotation) - - Multi-Hunk - - N hunks of removal only (Closure-115) - - - Addition and removal - - Single-Hunk - - N continuous lines (Chart-6) - - N non-continuous lines (Closure-101) - - Multi-Hunk - - N hunks of addition and removal (Closure-4) - - Unsupported bug types: - - non single-function, single-file (Chart-2, Math-99, Closure-46 (special case, due to annotation change!)) - - non single-function, non single-file (Chart-18) - """ - - DEFECTS4J: Benchmark - PROMPT_STRATEGY: str = "zero-shot-cloze" - MODEL_NAME: str = "incoder" - - @classmethod - def setup_class(cls): - TestClozeSamplesIncoder.DEFECTS4J = get_benchmark("defects4j") - assert TestClozeSamplesIncoder.DEFECTS4J is not None - TestClozeSamplesIncoder.DEFECTS4J.initialize() - - def test_closure_46(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Closure-46") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-46" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "public JSType getLeastSupertype(JSType that) {" in sample["buggy_code"] - assert sample["fixed_code"] == "" - - def test_closure_115(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Closure-115") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-115" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "boolean hasSideEffects = false;" in sample["buggy_code"] - assert "boolean hasSideEffects = false;" not in sample["fixed_code"] - assert ( - "if (hasSideEffects && NodeUtil.canBeSideEffected(cArg)) {" - in sample["buggy_code"] - ) - assert ( - "if (hasSideEffects && NodeUtil.canBeSideEffected(cArg)) {" - not in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Determines whether a function can be inlined at a particular call site." - ) - ) - assert sample["prompt"].count("<|mask:") == 3 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - assert sample["prompt"].count("<|mask:2|>") == 1 - - def test_closure_4(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Closure-4") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-4" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "if (detectImplicitPrototypeCycle()) {" in sample["buggy_code"] - assert "if (detectImplicitPrototypeCycle()) {" not in sample["fixed_code"] - assert "if (detectInheritanceCycle()) {" not in sample["buggy_code"] - assert "if (detectInheritanceCycle()) {" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Resolve the referenced type within the enclosing scope.\n */" - ) - ) - assert sample["prompt"].count("<|mask:") == 3 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - assert sample["prompt"].count("<|mask:2|>") == 1 - - def test_chart_4(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Chart-4") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-4" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert ( - """ if (r != null) { - Collection c = r.getAnnotations();""" - not in sample["buggy_code"] - ) - assert ( - """ if (r != null) { - Collection c = r.getAnnotations();""" - in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith("/**\n * Returns the range for the specified axis.") - ) - assert sample["prompt"].count("<|mask:") == 3 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - assert sample["prompt"].count("<|mask:2|>") == 1 - - def test_chart_2(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Chart-2") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-2" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_math_99(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Math-99") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Math-99" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_chart_18(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Chart-18") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-18" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the prompt was not generated - assert sample["prompt"] is None - - def test_closure_11(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Closure-11") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-11" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert ( - "} else if (n.getJSType() != null && parent.isAssign()) {" - in sample["buggy_code"] - ) - assert ( - not "} else if (n.getJSType() != null && parent.isAssign()) {" - in sample["fixed_code"] - ) - - # Assert that the prompt is properly constructed - assert sample["prompt"].strip().startswith("/**\n * Visits a GETPROP node.") - assert sample["prompt"].count("<|mask:") == 2 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - - def test_closure_5(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Closure-5") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-5" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "if (gramps.isDelProp()) {" not in sample["buggy_code"] - assert "if (gramps.isDelProp()) {" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Counts the number of direct (full) references to an object." - ) - ) - assert sample["prompt"].count("<|mask:") == 2 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - - def test_chart_6(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Chart-6") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-6" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "return super.equals(obj);" in sample["buggy_code"] - assert "return super.equals(obj);" not in sample["fixed_code"] - assert "ShapeList that = (ShapeList) obj;" not in sample["buggy_code"] - assert "ShapeList that = (ShapeList) obj;" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Tests the list for equality with another object (typically also a list)." - ) - ) - assert sample["prompt"].count("<|mask:") == 2 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - - def test_lang_3(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Lang-3") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Lang-3" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "if(numDecimals <= 7){" not in sample["buggy_code"] - assert "if(numDecimals <= 7){" in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n *

Turns a string value into a java.lang.Number.

\n *" - ) - ) - assert sample["prompt"].count("<|mask:") == 5 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - assert sample["prompt"].count("<|mask:2|>") == 1 - assert sample["prompt"].count("<|mask:3|>") == 1 - assert sample["prompt"].count("<|mask:4|>") == 1 - - def test_closure_101(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Closure-101") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Closure-101" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert ( - not "options.closurePass = flags.process_closure_primitives;" - in sample["buggy_code"] - ) - assert ( - "options.closurePass = flags.process_closure_primitives;" - in sample["fixed_code"] - ) - assert "if (flags.process_closure_primitives) {" in sample["buggy_code"] - assert "if (flags.process_closure_primitives) {" not in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith("@Override\n protected CompilerOptions createOptions() {") - ) - assert sample["prompt"].count("<|mask:") == 2 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - - def test_lang_10(self): - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Lang-10") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Lang-10" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the buggy code and fixed code are properly separated - assert "if(Character.isWhitespace(c)) {" in sample["buggy_code"] - assert "if(Character.isWhitespace(c)) {" not in sample["fixed_code"] - assert "boolean wasWhite= false;" in sample["buggy_code"] - assert "boolean wasWhite= false;" not in sample["fixed_code"] - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith("/**\n * Escape constant fields into regular expression") - ) - assert sample["prompt"].count("<|mask:") == 3 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - assert sample["prompt"].count("<|mask:2|>") == 1 - - def test_chart_7(self): - # This is a special case that requires latin-1 encoding - bug = TestClozeSamplesIncoder.DEFECTS4J.get_bug("Chart-7") - assert bug is not None - - sample = generate_sample( - bug=bug, - prompt_strategy=TestClozeSamplesIncoder.PROMPT_STRATEGY, - model_name=TestClozeSamplesIncoder.MODEL_NAME, - ) - - # Assert we are dealing with the correct bug and strategy - assert sample["identifier"] == "Chart-7" - assert sample["prompt_strategy"] == "zero-shot-cloze" - - # Assert that the prompt is properly constructed - assert ( - sample["prompt"] - .strip() - .startswith( - "/**\n * Update the index values for the maximum and minimum bounds." - ) - ) - assert sample["prompt"].count("<|mask:") == 3 - assert sample["prompt"].count("<|mask:0|>") == 1 - assert sample["prompt"].count("<|mask:1|>") == 1 - assert sample["prompt"].count("<|mask:2|>") == 1 From 8b42c883e7c0ce65e35e59a1bceaca2489230324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Thu, 8 Aug 2024 20:27:10 +0200 Subject: [PATCH 2/7] remove deprecated models --- .../strategies/models/huggingface/incoder.py | 185 ------------------ .../models/huggingface/starcoder.py | 125 ------------ elleelleaime/generate/strategies/registry.py | 40 ---- 3 files changed, 350 deletions(-) delete mode 100644 elleelleaime/generate/strategies/models/huggingface/incoder.py delete mode 100644 elleelleaime/generate/strategies/models/huggingface/starcoder.py diff --git a/elleelleaime/generate/strategies/models/huggingface/incoder.py b/elleelleaime/generate/strategies/models/huggingface/incoder.py deleted file mode 100644 index 0966e8c2..00000000 --- a/elleelleaime/generate/strategies/models/huggingface/incoder.py +++ /dev/null @@ -1,185 +0,0 @@ -from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy -from typing import Any - -import torch -import re -import threading -from dataclasses import dataclass -from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import Optional - - -@dataclass -class GenerateSettings: - name: str - num_beams: int = 1 - do_sample: bool = False - temperature: float = 0.0 - max_new_tokens: int = 128 - num_return_sequences: int = 10 - - -class IncoderHFModels(PatchGenerationStrategy): - __SUPPORTED_MODELS = { - "facebook/incoder-6B", - "facebook/incoder-1B", - } - - __GENERATION_STRATEGIES = { - "beam_search": GenerateSettings( - name="beam_search", - ), - "sampling": GenerateSettings( - name="sampling", - do_sample=True, - ), - } - - __MODEL = None - __TOKENIZER = None - __MODELS_LOADED: bool = False - __MODELS_LOCK: threading.Lock = threading.Lock() - - def __init__(self, model_name: str, **kwargs) -> None: - assert ( - model_name in self.__SUPPORTED_MODELS - ), f"Model {model_name} not supported by IncoderModels" - self.model_name = model_name - self.__load_model() - # Generation settings - assert ( - kwargs.get("generation_strategy", "beam_search") - in self.__GENERATION_STRATEGIES - ), f"Generation strategy {kwargs.get('generation_strategy', 'beam_search')} not supported by IncoderHFModels" - self.generate_settings = self.__GENERATION_STRATEGIES[ - kwargs.get("generation_strategy", "beam_search") - ] - self.generate_settings.max_new_tokens = kwargs.get("max_new_tokens", 128) - self.generate_settings.num_return_sequences = kwargs.get( - "num_return_sequences", 10 - ) - self.generate_settings.num_beams = kwargs.get("num_beams", 1) - self.generate_settings.temperature = kwargs.get("temperature", 0.2) - - def __load_model(self): - # Setup environment - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.context_size = 2048 - - # Setup kwargs - if self.model_name == "facebook/incoder-6B": - kwargs = dict( - revision="float16", - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ) - else: - kwargs = dict() - - # Load the model and tokenizer - with self.__MODELS_LOCK: - if self.__MODELS_LOADED: - return - self.__TOKENIZER = AutoTokenizer.from_pretrained(self.model_name) - self.__MODEL = AutoModelForCausalLM.from_pretrained( - self.model_name, device_map="auto", **kwargs - ) - if self.device == "cuda": - self.__MODEL = self.__MODEL.half() - - self.__MODELS_LOADED = True - - def _generate_impl(self, prompt: str) -> Any: - # Setup generation settings - predicted_texts = [] - # signals the start of a document - BOS = "<|endoftext|>" - # signals the end of a generated infill - EOM = "<|endofmask|>" - - def generate( - input: str, generate_settings: GenerateSettings - ) -> Optional[list[str]]: - """ - Do standard left-to-right completion of the prefix `input` by sampling from the model - """ - input_ids = self.__TOKENIZER(input, return_tensors="pt").input_ids - input_ids = input_ids.to(self.device) - max_length = generate_settings.max_new_tokens + input_ids.flatten().size(0) - if max_length > self.context_size: - print( - "warning: max_length %s is greater than the context window %s" - % (max_length, self.context_size) - ) - return None - with torch.no_grad(): - outputs = self.__MODEL.generate( - input_ids, - max_length=max_length, - num_beams=generate_settings.num_beams, - num_return_sequences=generate_settings.num_return_sequences, - early_stopping=True, - do_sample=generate_settings.do_sample, - temperature=generate_settings.temperature, - ) - # pass clean_up_tokenization_spaces=False to avoid removing spaces before punctuation, e.g. "from ." -> "from." - detok_hypo_strs = self.__TOKENIZER.batch_decode( - outputs, clean_up_tokenization_spaces=False - ) - detok_hypo_strs = [ - ( - detok_hypo_str[len(BOS) :] - if detok_hypo_str.startswith(BOS) - else detok_hypo_str - ) - for detok_hypo_str in detok_hypo_strs - ] - for output in detok_hypo_strs: - print(output) - return detok_hypo_strs - - def infill( - prompt: str, generate_settings: GenerateSettings - ) -> Optional[list[str]]: - """ - Generate infills to complete a partial document, e.g. - [A C E] -> [A B C D E], where B and D are infills that have been generated. - """ - completions: list[list[str]] = [ - [] for _ in range(generate_settings.num_return_sequences) - ] - - # Split prompt into parts separated by sentinels - # We identify the sentinels with a regex pattern r"<\|mask:\d\|>" - # We do not include the last sentinel as it is not followed by any text - parts = re.split(r"<\|mask:\d\|>", prompt)[:-1] - - for sentinel_ix, part in enumerate(parts[:-1]): - completions = [ - [] + [part] for _ in range(generate_settings.num_return_sequences) - ] - prompt += "<|mask:%d|>" % sentinel_ix - # TODO: this is inefficient as it requires re-encoding prefixes repeatedly - generations = generate(prompt, generate_settings) - # TODO: save error value - if generations is None: - return None - - for i, generation in enumerate(generations): - completion = generation[len(prompt) :] - if EOM not in completion: - completion += EOM - completion = completion[: completion.index(EOM) + len(EOM)] - infilled = completion[: -len(EOM)] - completions[i].append(infilled) - - # TODO: maybe keep all 10 a generate beam starting on those 10? - if i == 1: - prompt += completion - - completions = [x + [parts[-1]] for x in completions] - return ["".join(completion) for completion in completions] - - predicted_texts = infill(prompt, self.generate_settings) - - return predicted_texts diff --git a/elleelleaime/generate/strategies/models/huggingface/starcoder.py b/elleelleaime/generate/strategies/models/huggingface/starcoder.py deleted file mode 100644 index c0747f1f..00000000 --- a/elleelleaime/generate/strategies/models/huggingface/starcoder.py +++ /dev/null @@ -1,125 +0,0 @@ -from elleelleaime.generate.strategies.strategy import PatchGenerationStrategy -from dataclasses import dataclass -from transformers import AutoModelForCausalLM, AutoTokenizer -from typing import Any - -import torch -import threading -import logging - - -@dataclass -class GenerateSettings: - name: str - num_beams: int = 1 - do_sample: bool = False - temperature: float = 0.0 - max_new_tokens: int = 128 - num_return_sequences: int = 10 - max_new_tokens: int = 1024 - - -class StarCoderHFModels(PatchGenerationStrategy): - __SUPPORTED_MODELS = { - "bigcode/starcoderbase", - "bigcode/starcoder", - "bigcode/starcoderplus", - "bigcode/starcoderbase-1b", - "bigcode/starcoderbase-3b", - "bigcode/starcoderbase-7b", - } - - __GENERATION_STRATEGIES = { - "beam_search": GenerateSettings( - name="beam_search", - ), - "sampling": GenerateSettings( - name="sampling", - do_sample=True, - ), - } - - __MODEL = None - __TOKENIZER = None - __MODELS_LOADED: bool = False - __MODELS_LOCK: threading.Lock = threading.Lock() - - def __init__(self, model_name: str, **kwargs) -> None: - assert ( - model_name in self.__SUPPORTED_MODELS - ), f"Model {model_name} not supported by StarCoderHFModels" - self.model_name = model_name - self.__load_model() - # Generation settings - assert ( - kwargs.get("generation_strategy", "beam_search") - in self.__GENERATION_STRATEGIES - ), f"Generation strategy {kwargs.get('generation_strategy', 'beam_search')} not supported by StarCoderHFModels" - self.generate_settings = self.__GENERATION_STRATEGIES[ - kwargs.get("generation_strategy", "beam_search") - ] - self.generate_settings.max_new_tokens = kwargs.get("max_new_tokens", 128) - self.generate_settings.num_return_sequences = kwargs.get( - "num_return_sequences", 10 - ) - self.generate_settings.num_beams = kwargs.get("num_beams", 1) - self.generate_settings.temperature = kwargs.get("temperature", 0.2) - - def __load_model(self): - # Setup environment - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.context_size = 2048 - - # Setup kwargs - kwargs = dict( - torch_dtype=torch.bfloat16, - ) - - # Load the model and tokenizer - with self.__MODELS_LOCK: - if self.__MODELS_LOADED: - return - self.__TOKENIZER = AutoTokenizer.from_pretrained(self.model_name) - self.__MODEL = AutoModelForCausalLM.from_pretrained( - self.model_name, device_map="auto", **kwargs - ) - self.__MODELS_LOADED = True - - def _generate_impl(self, prompt: str) -> Any: - inputs = self.__TOKENIZER.encode(prompt, return_tensors="pt").to(self.device) - - max_length = self.generate_settings.max_new_tokens + inputs.shape[1] - if max_length > self.context_size: - logging.warning( - "warning: max_length %s is greater than the context window %s" - % (max_length, self.context_size) - ) - return None - - with torch.no_grad(): - generated_ids = self.__MODEL.generate( - inputs, - max_new_tokens=self.generate_settings.max_new_tokens, - num_beams=self.generate_settings.num_beams, - num_return_sequences=self.generate_settings.num_return_sequences, - early_stopping=True, - do_sample=self.generate_settings.do_sample, - temperature=self.generate_settings.temperature, - ) - - input_len = inputs.shape[1] - fillings_ids = generated_ids[:, input_len:] - fillings = self.__TOKENIZER.batch_decode(fillings_ids, skip_special_tokens=True) - - # Reorganize the function with the fillings - # The prompt is organized as follows: - # - # We want to achieve - # - # where is the the filling - prefix = prompt.split("")[1].split("")[0] - suffix = prompt.split("")[1].split("")[0] - - fillings = [prefix + filling + suffix for filling in fillings] - - return fillings diff --git a/elleelleaime/generate/strategies/registry.py b/elleelleaime/generate/strategies/registry.py index b5dfe3f4..ec3edf5c 100644 --- a/elleelleaime/generate/strategies/registry.py +++ b/elleelleaime/generate/strategies/registry.py @@ -2,15 +2,9 @@ from elleelleaime.generate.strategies.models.openai.openai import ( OpenAIChatCompletionModels, ) -from elleelleaime.generate.strategies.models.huggingface.incoder import ( - IncoderHFModels, -) from elleelleaime.generate.strategies.models.huggingface.codellama import ( CodeLlamaHFModels, ) -from elleelleaime.generate.strategies.models.huggingface.starcoder import ( - StarCoderHFModels, -) from typing import Tuple @@ -27,42 +21,8 @@ class PatchGenerationStrategyRegistry: "gpt-3.5-turbo": (OpenAIChatCompletionModels, ("gpt-3.5-turbo",)), "gpt-4o-mini": (OpenAIChatCompletionModels, ("gpt-4o-mini",)), # HuggingFace models - "incoder-1b": (IncoderHFModels, ("facebook/incoder-1B",)), - "incoder-6b": (IncoderHFModels, ("facebook/incoder-6B",)), "codellama-7b": (CodeLlamaHFModels, ("codellama/CodeLlama-7b-hf",)), "codellama-13b": (CodeLlamaHFModels, ("codellama/CodeLlama-13b-hf",)), - "codellama-7b-instruct": ( - CodeLlamaHFModels, - ("codellama/CodeLlama-7b-Instruct-hf",), - ), - "codellama-13b-instruct": ( - CodeLlamaHFModels, - ("codellama/CodeLlama-13b-Instruct-hf",), - ), - "starcoderbase": ( - StarCoderHFModels, - ("bigcode/starcoderbase",), - ), - "starcoder": ( - StarCoderHFModels, - ("bigcode/starcoder",), - ), - "starcoderplus": ( - StarCoderHFModels, - ("bigcode/starcoderplus",), - ), - "starcoderbase-1b": ( - StarCoderHFModels, - ("bigcode/starcoderbase-1b",), - ), - "starcoderbase-3b": ( - StarCoderHFModels, - ("bigcode/starcoderbase-3b",), - ), - "starcoderbase-7b": ( - StarCoderHFModels, - ("bigcode/starcoderbase-7b",), - ), } @classmethod From 3e5de82ce1ec7130b00db23533dc83bfdad1bf37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Thu, 8 Aug 2024 20:34:16 +0200 Subject: [PATCH 3/7] remove bears submodule --- .gitmodules | 3 --- benchmarks/bears | 1 - 2 files changed, 4 deletions(-) delete mode 160000 benchmarks/bears diff --git a/.gitmodules b/.gitmodules index 2753ed4d..021434d7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,9 +7,6 @@ [submodule "benchmarks/quixbugs"] path = benchmarks/quixbugs url = https://github.com/andre15silva/QuixBugs.git -[submodule "benchmarks/bears"] - path = benchmarks/bears - url = https://github.com/andre15silva/bears-benchmark.git [submodule "benchmarks/gitbug-java"] path = benchmarks/gitbug-java url = https://github.com/gitbugactions/gitbug-java.git diff --git a/benchmarks/bears b/benchmarks/bears deleted file mode 160000 index fc60bdb1..00000000 --- a/benchmarks/bears +++ /dev/null @@ -1 +0,0 @@ -Subproject commit fc60bdb16877db2afff66bfd6adab9fbcddb59b0 From cbb5cc6f54a00b0552e2c3aa5e9f96cccd622ab7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Thu, 8 Aug 2024 21:09:03 +0200 Subject: [PATCH 4/7] fix --- .../evaluate/strategies/openai/openai.py | 3 - .../evaluate/strategies/text/replace.py | 3 - elleelleaime/sample/strategies/infilling.py | 5 +- tests/evaluate/test_evaluate_replace.py | 8 +- tests/sample/infilling/test_codellama.py | 218 +++++++++--------- 5 files changed, 112 insertions(+), 125 deletions(-) diff --git a/elleelleaime/evaluate/strategies/openai/openai.py b/elleelleaime/evaluate/strategies/openai/openai.py index 4f496c8e..01cc257c 100644 --- a/elleelleaime/evaluate/strategies/openai/openai.py +++ b/elleelleaime/evaluate/strategies/openai/openai.py @@ -6,9 +6,6 @@ class OpenAIEvaluationStrategy(ReplaceEvaluationStrategy): - """ - Implements the zero-shot cloze style prompt strategy for single diff file. - """ def __init__(self, **kwargs): super().__init__(kwargs=kwargs) diff --git a/elleelleaime/evaluate/strategies/text/replace.py b/elleelleaime/evaluate/strategies/text/replace.py index 890552bc..124b804f 100644 --- a/elleelleaime/evaluate/strategies/text/replace.py +++ b/elleelleaime/evaluate/strategies/text/replace.py @@ -8,9 +8,6 @@ class ReplaceEvaluationStrategy(PatchEvaluationStrategy): - """ - Implements the zero-shot cloze style prompt strategy for single diff file. - """ def __init__(self, **kwargs): super().__init__(kwargs=kwargs) diff --git a/elleelleaime/sample/strategies/infilling.py b/elleelleaime/sample/strategies/infilling.py index a9bd1065..27d61043 100644 --- a/elleelleaime/sample/strategies/infilling.py +++ b/elleelleaime/sample/strategies/infilling.py @@ -13,9 +13,6 @@ class InfillingPrompting(PromptingStrategy): - """ - Implements the zero-shot cloze style prompt strategy for single diff file. - """ # MODEL_DICT is a dictionary of model names and their corresponding kwargs MODEL_DICT = { @@ -28,7 +25,7 @@ class InfillingPrompting(PromptingStrategy): } def __init__(self, **kwargs): - super().__init__("zero-shot-cloze") + super().__init__("infilling") self.model_name: str = kwargs.get("model_name", "").strip().lower() assert ( diff --git a/tests/evaluate/test_evaluate_replace.py b/tests/evaluate/test_evaluate_replace.py index 6df3d4e6..ac7e8d56 100644 --- a/tests/evaluate/test_evaluate_replace.py +++ b/tests/evaluate/test_evaluate_replace.py @@ -9,8 +9,8 @@ class TestEvaluatePatchesReplaceDefects4J: DEFECTS4J: Benchmark - PROMPT_STRATEGY: str = "zero-shot-cloze" - MODEL_NAME: str = "incoder" + PROMPT_STRATEGY: str = "infilling" + MODEL_NAME: str = "codellama-7b" EVALUATE_STRATEGY: str = "replace" @classmethod @@ -220,8 +220,8 @@ def test_plausible_patch(self): ) class TestEvaluatePatchesReplaceGitBugJava: GITBUGJAVA: Benchmark - PROMPT_STRATEGY: str = "zero-shot-cloze" - MODEL_NAME: str = "incoder" + PROMPT_STRATEGY: str = "infilling" + MODEL_NAME: str = "codellama-7b" EVALUATE_STRATEGY: str = "replace" @classmethod diff --git a/tests/sample/infilling/test_codellama.py b/tests/sample/infilling/test_codellama.py index b4083ba3..107d7428 100644 --- a/tests/sample/infilling/test_codellama.py +++ b/tests/sample/infilling/test_codellama.py @@ -6,7 +6,7 @@ import os -class TestInfillingCodeLLaMADefects4J: +class TestInfillingCodellama: """ We test the generation of cloze prompts for several types of bug fixes. We only generate samples for bugs that are single-function and single-file. @@ -43,34 +43,34 @@ class TestInfillingCodeLLaMADefects4J: DEFECTS4J: Benchmark HUMANEVALJAVA: Benchmark GITBUGJAVA: Benchmark - PROMPT_STRATEGY: str = "zero-shot-cloze" + PROMPT_STRATEGY: str = "infilling" MODEL_NAME: str = "codellama" @classmethod def setup_class(cls): - TestInfillingCodeLLaMADefects4J.DEFECTS4J = get_benchmark("defects4j") - assert TestInfillingCodeLLaMADefects4J.DEFECTS4J is not None - TestInfillingCodeLLaMADefects4J.DEFECTS4J.initialize() - TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA = get_benchmark("humanevaljava") - assert TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA is not None - TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.initialize() - TestInfillingCodeLLaMADefects4J.GITBUGJAVA = get_benchmark("gitbugjava") - assert TestInfillingCodeLLaMADefects4J.GITBUGJAVA is not None - TestInfillingCodeLLaMADefects4J.GITBUGJAVA.initialize() + TestInfillingCodellama.DEFECTS4J = get_benchmark("defects4j") + assert TestInfillingCodellama.DEFECTS4J is not None + TestInfillingCodellama.DEFECTS4J.initialize() + TestInfillingCodellama.HUMANEVALJAVA = get_benchmark("humanevaljava") + assert TestInfillingCodellama.HUMANEVALJAVA is not None + TestInfillingCodellama.HUMANEVALJAVA.initialize() + TestInfillingCodellama.GITBUGJAVA = get_benchmark("gitbugjava") + assert TestInfillingCodellama.GITBUGJAVA is not None + TestInfillingCodellama.GITBUGJAVA.initialize() def test_closure_46(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-46") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-46") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-46" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "public JSType getLeastSupertype(JSType that) {" in sample["buggy_code"] @@ -80,18 +80,18 @@ def test_closure_46(self): assert sample["prompt"].count("") == 1 def test_closure_115(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-115") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-115") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-115" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "boolean hasSideEffects = false;" in sample["buggy_code"] @@ -116,18 +116,18 @@ def test_closure_115(self): assert sample["prompt"].count("") == 1 def test_closure_4(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-4") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-4") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-4" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "if (detectImplicitPrototypeCycle()) {" in sample["buggy_code"] @@ -146,18 +146,18 @@ def test_closure_4(self): assert sample["prompt"].count("") == 1 def test_chart_4(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-4") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-4") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-4" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert ( @@ -180,69 +180,69 @@ def test_chart_4(self): assert sample["prompt"].count("") == 1 def test_chart_2(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-2") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-2") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-2" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt was not generated assert sample["prompt"] is None def test_math_99(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Math-99") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Math-99") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Math-99" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt was not generated assert sample["prompt"] is None def test_chart_18(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-18") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-18") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-18" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt was not generated assert sample["prompt"] is None def test_closure_11(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-11") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-11") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-11" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert ( @@ -259,20 +259,20 @@ def test_closure_11(self): assert sample["prompt"].count("") == 1 def test_chart_1_keep_buggy_code(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-1") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-1") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-1" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" assert ( sample["prompt"] @@ -315,20 +315,20 @@ def test_chart_1_keep_buggy_code(self): ) def test_chart_5_keep_buggy_code(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-5") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-5") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-5" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" assert ( sample["prompt"] @@ -368,20 +368,20 @@ def test_chart_5_keep_buggy_code(self): ) def test_closure_11_keep_buggy_code(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-11") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-11") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-11" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert ( @@ -409,20 +409,20 @@ def test_closure_11_keep_buggy_code(self): ) def test_closure_2_keep_buggy_code(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-2") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-2") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, keep_comments=False, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-2" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" assert ( sample["prompt"] @@ -457,18 +457,18 @@ def test_closure_2_keep_buggy_code(self): ) def test_closure_5(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-5") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-5") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-5" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "if (gramps.isDelProp()) {" not in sample["buggy_code"] @@ -485,18 +485,18 @@ def test_closure_5(self): assert sample["prompt"].count("") == 1 def test_chart_6(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-6") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-6") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-6" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "return super.equals(obj);" in sample["buggy_code"] @@ -515,18 +515,18 @@ def test_chart_6(self): assert sample["prompt"].count("") == 1 def test_lang_3(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Lang-3") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Lang-3") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Lang-3" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "if(numDecimals <= 7){" not in sample["buggy_code"] @@ -543,18 +543,18 @@ def test_lang_3(self): assert sample["prompt"].count("") == 1 def test_closure_101(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Closure-101") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Closure-101") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Closure-101" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert ( @@ -577,18 +577,18 @@ def test_closure_101(self): assert sample["prompt"].count("") == 1 def test_lang_10(self): - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Lang-10") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Lang-10") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Lang-10" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the buggy code and fixed code are properly separated assert "if(Character.isWhitespace(c)) {" in sample["buggy_code"] @@ -606,18 +606,18 @@ def test_lang_10(self): def test_chart_7(self): # This is a special case that requires latin-1 encoding - bug = TestInfillingCodeLLaMADefects4J.DEFECTS4J.get_bug("Chart-7") + bug = TestInfillingCodellama.DEFECTS4J.get_bug("Chart-7") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "Chart-7" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert ( @@ -630,37 +630,37 @@ def test_chart_7(self): assert sample["prompt"].count("") == 1 def test_GET_ROW(self): - bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("GET_ROW") + bug = TestInfillingCodellama.HUMANEVALJAVA.get_bug("GET_ROW") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "GET_ROW" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert sample["prompt"] is not None assert sample["prompt"].count("") == 1 def test_GET_ROW_keep_buggy_code(self): - bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("GET_ROW") + bug = TestInfillingCodellama.HUMANEVALJAVA.get_bug("GET_ROW") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "GET_ROW" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert sample["prompt"] is not None @@ -671,37 +671,37 @@ def test_GET_ROW_keep_buggy_code(self): assert sample["prompt"].count("") == 1 def test_ADD(self): - bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("ADD") + bug = TestInfillingCodellama.HUMANEVALJAVA.get_bug("ADD") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "ADD" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert sample["prompt"] is not None assert sample["prompt"].count("") == 1 def test_ADD_keep_buggy_code(self): - bug = TestInfillingCodeLLaMADefects4J.HUMANEVALJAVA.get_bug("ADD") + bug = TestInfillingCodellama.HUMANEVALJAVA.get_bug("ADD") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "ADD" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert sample["prompt"] is not None @@ -713,21 +713,19 @@ def test_ADD_keep_buggy_code(self): reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.", ) def test_traccar_traccar_37ed394724c0(self): - bug = TestInfillingCodeLLaMADefects4J.GITBUGJAVA.get_bug( - "traccar-traccar-37ed394724c0" - ) + bug = TestInfillingCodellama.GITBUGJAVA.get_bug("traccar-traccar-37ed394724c0") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "traccar-traccar-37ed394724c0" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert sample["prompt"] is not None @@ -742,21 +740,19 @@ def test_traccar_traccar_37ed394724c0(self): reason="This test requires completing GitBug-Java's setup, which is too heavy for CI.", ) def test_BrightSpots_rcv_688920f27706(self): - bug = TestInfillingCodeLLaMADefects4J.GITBUGJAVA.get_bug( - "BrightSpots-rcv-688920f27706" - ) + bug = TestInfillingCodellama.GITBUGJAVA.get_bug("BrightSpots-rcv-688920f27706") assert bug is not None sample = generate_sample( bug=bug, - prompt_strategy=TestInfillingCodeLLaMADefects4J.PROMPT_STRATEGY, - model_name=TestInfillingCodeLLaMADefects4J.MODEL_NAME, + prompt_strategy=TestInfillingCodellama.PROMPT_STRATEGY, + model_name=TestInfillingCodellama.MODEL_NAME, keep_buggy_code=True, ) # Assert we are dealing with the correct bug and strategy assert sample["identifier"] == "BrightSpots-rcv-688920f27706" - assert sample["prompt_strategy"] == "zero-shot-cloze" + assert sample["prompt_strategy"] == "infilling" # Assert that the prompt is properly constructed assert sample["prompt"] is None From e91414fa172cc7f6d252ac1e0b53fd3407176417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Thu, 8 Aug 2024 21:56:53 +0200 Subject: [PATCH 5/7] debug prints --- elleelleaime/core/benchmarks/defects4j/defects4jbug.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/elleelleaime/core/benchmarks/defects4j/defects4jbug.py b/elleelleaime/core/benchmarks/defects4j/defects4jbug.py index aafda3c9..bba09454 100644 --- a/elleelleaime/core/benchmarks/defects4j/defects4jbug.py +++ b/elleelleaime/core/benchmarks/defects4j/defects4jbug.py @@ -75,6 +75,8 @@ def compile(self, path: str) -> CompileResult: path=path, check=False, ) + print(run.stdout.decode("utf-8")) + print(run.stderr.decode("utf-8")) return CompileResult(run.returncode == 0) def test(self, path: str) -> TestResult: @@ -84,6 +86,8 @@ def test(self, path: str) -> TestResult: path=path, check=False, ) + print(run.stdout.decode("utf-8")) + print(run.stderr.decode("utf-8")) m = re.search(r"Failing tests: ([0-9]+)", run.stdout.decode("utf-8")) if not (run.returncode == 0 and m != None and int(m.group(1)) == 0): return TestResult(False) @@ -94,6 +98,8 @@ def test(self, path: str) -> TestResult: path=path, check=False, ) + print(run.stdout.decode("utf-8")) + print(run.stderr.decode("utf-8")) m = re.search(r"Failing tests: ([0-9]+)", run.stdout.decode("utf-8")) return TestResult(run.returncode == 0 and m != None and int(m.group(1)) == 0) From 4cd82f3dc715c9e3f6ad5e0cf64ec6ebf1d7465c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Fri, 9 Aug 2024 10:43:45 +0200 Subject: [PATCH 6/7] update defects4j submodule --- benchmarks/defects4j | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/defects4j b/benchmarks/defects4j index 5562b5ee..6ce100b9 160000 --- a/benchmarks/defects4j +++ b/benchmarks/defects4j @@ -1 +1 @@ -Subproject commit 5562b5ee775af9bb673c2dd33946a1f4fd2690e3 +Subproject commit 6ce100b902c8ffc346a96f703bf61c8422e58f6b From 2c3fcac38afc16cc2c98ecc36dfa4b563d61ea69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Fri, 9 Aug 2024 10:56:28 +0200 Subject: [PATCH 7/7] remove prints --- elleelleaime/core/benchmarks/defects4j/defects4jbug.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/elleelleaime/core/benchmarks/defects4j/defects4jbug.py b/elleelleaime/core/benchmarks/defects4j/defects4jbug.py index bba09454..aafda3c9 100644 --- a/elleelleaime/core/benchmarks/defects4j/defects4jbug.py +++ b/elleelleaime/core/benchmarks/defects4j/defects4jbug.py @@ -75,8 +75,6 @@ def compile(self, path: str) -> CompileResult: path=path, check=False, ) - print(run.stdout.decode("utf-8")) - print(run.stderr.decode("utf-8")) return CompileResult(run.returncode == 0) def test(self, path: str) -> TestResult: @@ -86,8 +84,6 @@ def test(self, path: str) -> TestResult: path=path, check=False, ) - print(run.stdout.decode("utf-8")) - print(run.stderr.decode("utf-8")) m = re.search(r"Failing tests: ([0-9]+)", run.stdout.decode("utf-8")) if not (run.returncode == 0 and m != None and int(m.group(1)) == 0): return TestResult(False) @@ -98,8 +94,6 @@ def test(self, path: str) -> TestResult: path=path, check=False, ) - print(run.stdout.decode("utf-8")) - print(run.stderr.decode("utf-8")) m = re.search(r"Failing tests: ([0-9]+)", run.stdout.decode("utf-8")) return TestResult(run.returncode == 0 and m != None and int(m.group(1)) == 0)