From 3204c5b35d90902990d31b258f3bd62f4e205453 Mon Sep 17 00:00:00 2001 From: Timo Rothenpieler Date: Thu, 26 May 2022 17:58:00 +0200 Subject: [PATCH] Add support for comparing multiple baseline images The failing compare with the lowest rms value will be used for the summary. A shape-mismatch has infinitely low "rms" and will be preferred over any comparison mismatch. --- pytest_mpl/plugin.py | 149 ++++++++++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 51 deletions(-) diff --git a/pytest_mpl/plugin.py b/pytest_mpl/plugin.py index fd8fe03c..a95d04e1 100644 --- a/pytest_mpl/plugin.py +++ b/pytest_mpl/plugin.py @@ -30,6 +30,7 @@ import io import os +import glob import json import shutil import hashlib @@ -370,25 +371,43 @@ def _download_file(self, baseline, filename): tmpfile.write(content) return Path(filename) - def obtain_baseline_image(self, item, target_dir): + def obtain_baseline_images(self, item, target_dir): """ - Copy the baseline image to our working directory. + Copy the baseline image(s) to our working directory. If the image is remote it is downloaded, if it is local it is copied to ensure it is kept in the event of a test failure. """ + compare = self.get_compare(item) + multi = compare.kwargs.get('multi', False) filename = self.generate_filename(item) baseline_dir = self.get_baseline_directory(item) baseline_remote = (isinstance(baseline_dir, str) and # noqa baseline_dir.startswith(('http://', 'https://'))) if baseline_remote: + if multi: + pytest.fail('Multi-baseline testing only works with local baselines.', + pytrace=False) # baseline_dir can be a list of URLs when remote, so we have to # pass base and filename to download - baseline_image = self._download_file(baseline_dir, filename) + baseline_images = [self._download_file(baseline_dir, filename)] + elif not multi: + baseline_images = [(baseline_dir / filename).absolute()] else: - baseline_image = (baseline_dir / filename).absolute() + dirname, ext = os.path.splitext(filename) + baseline_images = glob.glob( + os.path.join(baseline_dir.absolute(), dirname, '**', '*' + ext), + recursive=True) + + return baseline_images + + def obtain_baseline_image(self, item, target_dir): + """ + Backwards-Compatible wrapper for obtain_baseline_images. - return baseline_image + Always returns the first found baseline image. + """ + return self.obtain_baseline_images(item, target_dir)[0] def generate_baseline_image(self, item, fig): """ @@ -396,14 +415,20 @@ def generate_baseline_image(self, item, fig): """ compare = self.get_compare(item) savefig_kwargs = compare.kwargs.get('savefig_kwargs', {}) + multi = compare.kwargs.get('multi', False) if not os.path.exists(self.generate_dir): os.makedirs(self.generate_dir) baseline_filename = self.generate_filename(item) baseline_path = (self.generate_dir / baseline_filename).absolute() - fig.savefig(str(baseline_path), **savefig_kwargs) + if multi: + raw_name, ext = os.path.splitext(str(baseline_path)) + if not os.path.exists(raw_name): + os.makedirs(raw_name) + baseline_path = os.path.join(raw_name, "generated" + ext) + fig.savefig(str(baseline_path), **savefig_kwargs) close_mpl_figure(fig) return baseline_path @@ -440,13 +465,14 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None): tolerance = compare.kwargs.get('tolerance', 2) savefig_kwargs = compare.kwargs.get('savefig_kwargs', {}) - baseline_image_ref = self.obtain_baseline_image(item, result_dir) + baseline_image_refs = self.obtain_baseline_images(item, result_dir) + baseline_image_refs = [p for p in baseline_image_refs if os.path.exists(p)] test_image = (result_dir / "result.png").absolute() fig.savefig(str(test_image), **savefig_kwargs) summary['result_image'] = test_image.relative_to(self.results_dir).as_posix() - if not os.path.exists(baseline_image_ref): + if len(baseline_image_refs) == 0: summary['status'] = 'failed' summary['image_status'] = 'missing' error_message = ("Image file not found for comparison test in: \n\t" @@ -457,49 +483,70 @@ def compare_image_to_baseline(self, item, fig, result_dir, summary=None): summary['status_msg'] = error_message return error_message - # setuptools may put the baseline images in non-accessible places, - # copy to our tmpdir to be sure to keep them in case of failure - baseline_image = (result_dir / "baseline.png").absolute() - shutil.copyfile(baseline_image_ref, baseline_image) - summary['baseline_image'] = baseline_image.relative_to(self.results_dir).as_posix() - - # Compare image size ourselves since the Matplotlib - # exception is a bit cryptic in this case and doesn't show - # the filenames - expected_shape = imread(str(baseline_image)).shape[:2] - actual_shape = imread(str(test_image)).shape[:2] - if expected_shape != actual_shape: - summary['status'] = 'failed' - summary['image_status'] = 'diff' - error_message = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image, - expected_shape=expected_shape, - actual_path=test_image, - actual_shape=actual_shape) - summary['status_msg'] = error_message - return error_message - - results = compare_images(str(baseline_image), str(test_image), tol=tolerance, in_decorator=True) - summary['tolerance'] = tolerance - if results is None: - summary['status'] = 'passed' - summary['image_status'] = 'match' - summary['status_msg'] = 'Image comparison passed.' - return None - else: - summary['status'] = 'failed' - summary['image_status'] = 'diff' - summary['rms'] = results['rms'] - diff_image = (result_dir / 'result-failed-diff.png').absolute() - summary['diff_image'] = diff_image.relative_to(self.results_dir).as_posix() - template = ['Error: Image files did not match.', - 'RMS Value: {rms}', - 'Expected: \n {expected}', - 'Actual: \n {actual}', - 'Difference:\n {diff}', - 'Tolerance: \n {tol}', ] - error_message = '\n '.join([line.format(**results) for line in template]) - summary['status_msg'] = error_message - return error_message + cur_summ = {} + best_rms = float('inf') + all_msgs = '' + i = -1 + + for baseline_image_ref in baseline_image_refs: + # setuptools may put the baseline images in non-accessible places, + # copy to our tmpdir to be sure to keep them in case of failure + i += 1 + baseline_file = f"baseline-{i}.png" if i else "baseline.png" + baseline_image = (result_dir / baseline_file).absolute() + rel_baseline_image = baseline_image.relative_to(self.results_dir).as_posix() + shutil.copyfile(baseline_image_ref, baseline_image) + + # Compare image size ourselves since the Matplotlib + # exception is a bit cryptic in this case and doesn't show + # the filenames + expected_shape = imread(str(baseline_image)).shape[:2] + actual_shape = imread(str(test_image)).shape[:2] + if expected_shape != actual_shape: + best_rms = float('-inf') + cur_summ = {} + cur_summ['baseline_image'] = rel_baseline_image + cur_summ['status'] = 'failed' + cur_summ['image_status'] = 'diff' + error_message = SHAPE_MISMATCH_ERROR.format(expected_path=baseline_image, + expected_shape=expected_shape, + actual_path=test_image, + actual_shape=actual_shape) + cur_summ['status_msg'] = error_message + all_msgs += error_message + '\n\n' + continue + + results = compare_images(str(baseline_image), str(test_image), tol=tolerance, in_decorator=True) + if results is None: + summary['baseline_image'] = rel_baseline_image + summary['tolerance'] = tolerance + summary['status'] = 'passed' + summary['image_status'] = 'match' + summary['status_msg'] = 'Image comparison passed.' + return None + else: + template = ['Error: Image files did not match.', + 'RMS Value: {rms}', + 'Expected: \n {expected}', + 'Actual: \n {actual}', + 'Difference:\n {diff}', + 'Tolerance: \n {tol}', ] + error_message = '\n '.join([line.format(**results) for line in template]) + all_msgs += error_message + '\n\n' + if results['rms'] < best_rms: + best_rms = results['rms'] + cur_summ = {} + cur_summ['baseline_image'] = rel_baseline_image + cur_summ['tolerance'] = tolerance + cur_summ['status'] = 'failed' + cur_summ['image_status'] = 'diff' + cur_summ['rms'] = results['rms'] + diff_image = (result_dir / 'result-failed-diff.png').absolute() + cur_summ['diff_image'] = diff_image.relative_to(self.results_dir).as_posix() + cur_summ['status_msg'] = error_message + + summary.update(cur_summ) + return all_msgs.strip() def load_hash_library(self, library_path): with open(str(library_path)) as fp: