From 816917a41111be79f7957f7cb1187d2cd11f7d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Silva?= Date: Fri, 16 Aug 2024 13:43:34 +0200 Subject: [PATCH] fix: fix round in pass@k metrics --- evaluate_patches.py | 46 ++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/evaluate_patches.py b/evaluate_patches.py index 63033769..f14ee467 100644 --- a/evaluate_patches.py +++ b/evaluate_patches.py @@ -161,26 +161,38 @@ def compute_statistics(samples: list) -> dict: # geometric progression over k for k in [1, 10, 100]: - if k < statistics["num_bugs_with_prompt"]: - statistics[f"exact_match@{k}"] = pass_at_k( - statistics["num_patches"], - statistics["num_exact_match_patches"], - k, + if k < (statistics["num_patches"] // statistics["num_bugs_with_candidates"]): + statistics[f"exact_match@{k}"] = round( + pass_at_k( + statistics["num_patches"], + statistics["num_exact_match_patches"], + k, + ), + 3, ) - statistics[f"ast_match@{k}"] = pass_at_k( - statistics["num_patches"], - statistics["num_ast_match_patches"], - k, + statistics[f"ast_match@{k}"] = round( + pass_at_k( + statistics["num_patches"], + statistics["num_ast_match_patches"], + k, + ), + 3, ) - statistics[f"plausible@{k}"] = pass_at_k( - statistics["num_patches"], - statistics["num_plausible_patches"], - k, + statistics[f"plausible@{k}"] = round( + pass_at_k( + statistics["num_patches"], + statistics["num_plausible_patches"], + k, + ), + 3, ) - statistics[f"compilable@{k}"] = pass_at_k( - statistics["num_patches"], - statistics["num_compilable_patches"], - k, + statistics[f"compilable@{k}"] = round( + pass_at_k( + statistics["num_patches"], + statistics["num_compilable_patches"], + k, + ), + 3, ) statistics["bugs_with_exact_match_candidates"].sort()