Skip to content

Commit

Permalink
fix pass@k
Browse files Browse the repository at this point in the history
  • Loading branch information
andre15silva committed Aug 22, 2024
1 parent dc58515 commit d0c057b
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions evaluate_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,28 @@ def exact_match(evaluation: dict) -> bool:
"""
Returns True if the evaluation is an exact match.
"""
return bool(evaluation["exact_match"])
return evaluation is not None and bool(evaluation["exact_match"])


def ast_match(evaluation: dict) -> bool:
"""
Returns True if the evaluation is an AST match.
"""
return bool(evaluation["ast_match"])
return evaluation is not None and bool(evaluation["ast_match"])


def plausible(evaluation: dict) -> bool:
"""
Returns True if the evaluation is plausible.
"""
return bool(evaluation["test"])
return evaluation is not None and bool(evaluation["test"])


def compilable(evaluation: dict) -> bool:
"""
Returns True if the evaluation is compilable.
"""
return bool(evaluation["compile"])
return evaluation is not None and bool(evaluation["compile"])


def compute_diff(buggy_code: str, fixed_code: str, context_len: int = 3) -> str:
Expand Down Expand Up @@ -99,7 +99,7 @@ def compute_statistics(samples: list) -> dict:
statistics = {
"num_bugs": 0,
"num_bugs_with_prompt": 0,
"num_bugs_with_candidates": 0,
"num_bugs_with_patches": 0,
"num_bugs_with_exact_match_candidates": 0,
"num_bugs_with_ast_match_candidates": 0,
"num_bugs_with_plausible_candidates": 0,
Expand All @@ -115,19 +115,16 @@ def compute_statistics(samples: list) -> dict:
"bugs_with_compilable_candidates": [],
}

for sample in tqdm.tqdm(samples):
for sample in tqdm.tqdm(samples, "Computing statistics..."):
statistics["num_bugs"] += 1

if sample["prompt"]:
statistics["num_bugs_with_prompt"] += 1
if sample["generation"] and any(
candidate["generation"] if candidate is not None else None
for candidate in sample["evaluation"]
):
statistics["num_bugs_with_candidates"] += 1
statistics["num_patches"] += sum(
bool(candidate["generation"]) if candidate is not None else False
for candidate in sample["evaluation"]
)

if sample["generation"]:
statistics["num_bugs_with_patches"] += 1
statistics["num_patches"] += len(sample["evaluation"])

statistics["num_compilable_patches"] += sum(
compilable(candidate) for candidate in sample["evaluation"]
)
Expand All @@ -140,6 +137,7 @@ def compute_statistics(samples: list) -> dict:
statistics["num_exact_match_patches"] += sum(
exact_match(candidate) for candidate in sample["evaluation"]
)

if any(exact_match(candidate) for candidate in sample["evaluation"]):
statistics["num_bugs_with_exact_match_candidates"] += 1
statistics["bugs_with_exact_match_candidates"].append(
Expand All @@ -163,7 +161,7 @@ def compute_statistics(samples: list) -> dict:

# geometric progression over k
for k in [1, 10, 100]:
if k < (statistics["num_patches"] // statistics["num_bugs_with_candidates"]):
if k < (statistics["num_patches"] // statistics["num_bugs_with_patches"]):
statistics[f"exact_match@{k}"] = round(
pass_at_k(
statistics["num_patches"],
Expand Down Expand Up @@ -214,7 +212,7 @@ def export_patches(samples: list, dir_path: str) -> None:
if os.path.exists(patches_dir):
shutil.rmtree(patches_dir)

for sample in tqdm.tqdm(samples):
for sample in tqdm.tqdm(samples, "Exporting patches..."):
if not sample["generation"] or all(
candidate["generation"] is None if candidate is not None else None
for candidate in sample["evaluation"]
Expand Down Expand Up @@ -322,20 +320,21 @@ def entry_point(
dir_path = os.path.dirname(samples_path)
prompt_strategy = samples_file_name.split("_")[2].split(".")[0]
model_name = samples_file_name.split("_")[3].split(".")[0]
benchmark_obj = get_benchmark(benchmark)
if benchmark_obj is None:
raise ValueError(f"Unknown benchmark {benchmark}")
benchmark_obj.initialize()

# Read the samples
logging.info("Reading samples...")
samples = list(stream_jsonl(samples_path))

# Correctness evaluation
if correctness:
benchmark_obj = get_benchmark(benchmark)
if benchmark_obj is None:
raise ValueError(f"Unknown benchmark {benchmark}")
benchmark_obj.initialize()

with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = []
for sample in tqdm.tqdm(samples):
for sample in tqdm.tqdm(samples, "Lauching candidate evaluation..."):
bug = benchmark_obj.get_bug(sample["identifier"])
if bug is None:
raise ValueError(f"Unknown bug {sample['identifier']}")
Expand Down

0 comments on commit d0c057b

Please sign in to comment.