Skip to content

Commit

Permalink
decrease tolerance on serialization check for macos test builds (#496)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Jan 24, 2022
1 parent c7af3d5 commit 9d91f8e
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions tests/test_serialize_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def test_old_load_explanation_backcompat(self, iris, tabular_explainer, iris_svm
explanation = explainer.explain_global(iris[DatasetConstants.X_TEST], include_local=False)
loaded_explanation = load_explanation(os.path.join('.', 'tests', 'backcompat_explanation'))
explanation._id = loaded_explanation._id
_assert_explanation_equivalence(explanation, loaded_explanation, rtol=1e-7)
_assert_numpy_explanation_types(explanation, loaded_explanation, rtol=1e-7)
_assert_explanation_equivalence(explanation, loaded_explanation, rtol=0.03, atol=0.002)
_assert_numpy_explanation_types(explanation, loaded_explanation, rtol=0.03, atol=0.002)


def _generate_old_explanation(iris, tabular_explainer, iris_svm_model):
Expand All @@ -104,7 +104,7 @@ def _generate_old_explanation(iris, tabular_explainer, iris_svm_model):
save_explanation(explanation, path, exist_ok=False)


def _assert_explanation_equivalence(actual, expected, rtol=None):
def _assert_explanation_equivalence(actual, expected, rtol=None, atol=None):
# get the non-null properties in the expected explanation
paramkeys = filter(lambda x, expected=expected: hasattr(expected, getattr(ExplainParams, x)),
list(ExplainParams.get_serializable()))
Expand All @@ -122,36 +122,36 @@ def _assert_explanation_equivalence(actual, expected, rtol=None):
else:
expected_dataset = expected_value.original_dataset
if issparse(actual_dataset) and issparse(expected_dataset):
_assert_sparse_data_equivalence(actual_dataset, expected_dataset, rtol=rtol)
_assert_sparse_data_equivalence(actual_dataset, expected_dataset, rtol=rtol, atol=atol)
else:
_assert_allclose_or_eq(actual_dataset, expected_dataset, rtol=rtol)
_assert_allclose_or_eq(actual_dataset, expected_dataset, rtol=rtol, atol=atol)
elif isinstance(actual_value, (np.ndarray, collections.abc.Sequence)):
_assert_allclose_or_eq(actual_value, expected_value, rtol=rtol)
_assert_allclose_or_eq(actual_value, expected_value, rtol=rtol, atol=atol)
elif isinstance(actual_value, pd.DataFrame) and isinstance(expected_value, pd.DataFrame):
_assert_allclose_or_eq(actual_value.values, expected_value.values, rtol=rtol)
_assert_allclose_or_eq(actual_value.values, expected_value.values, rtol=rtol, atol=atol)
elif issparse(actual_value) and issparse(expected_value):
_assert_sparse_data_equivalence(actual_value, expected_value, rtol=rtol)
_assert_sparse_data_equivalence(actual_value, expected_value, rtol=rtol, atol=atol)
else:
assert actual_value == expected_value


def _assert_allclose_or_eq(actual, expected, rtol=None):
def _assert_allclose_or_eq(actual, expected, rtol=None, atol=None):
if rtol is not None:
try:
return np.testing.assert_allclose(actual, expected, rtol=rtol)
return np.testing.assert_allclose(actual, expected, rtol=rtol, atol=atol)
except TypeError:
print("Caught type error, defaulting to regular compare")
np.testing.assert_array_equal(actual, expected)


def _assert_sparse_data_equivalence(actual, expected, rtol=None):
_assert_allclose_or_eq(actual.data, expected.data, rtol=rtol)
_assert_allclose_or_eq(actual.indices, expected.indices, rtol=rtol)
_assert_allclose_or_eq(actual.indptr, expected.indptr, rtol=rtol)
_assert_allclose_or_eq(actual.shape, expected.shape, rtol=rtol)
def _assert_sparse_data_equivalence(actual, expected, rtol=None, atol=None):
_assert_allclose_or_eq(actual.data, expected.data, rtol=rtol, atol=atol)
_assert_allclose_or_eq(actual.indices, expected.indices, rtol=rtol, atol=atol)
_assert_allclose_or_eq(actual.indptr, expected.indptr, rtol=rtol, atol=atol)
_assert_allclose_or_eq(actual.shape, expected.shape, rtol=rtol, atol=atol)


def _assert_numpy_explanation_types(actual, expected, rtol=None):
def _assert_numpy_explanation_types(actual, expected, rtol=None, atol=None):
# assert "_" variables equivalence
if hasattr(actual, ExplainParams.get_private(ExplainParams.LOCAL_IMPORTANCE_VALUES)):
assert(isinstance(actual._local_importance_values, np.ndarray))
Expand All @@ -161,14 +161,19 @@ def _assert_numpy_explanation_types(actual, expected, rtol=None):
expected._local_importance_values)
else:
np.testing.assert_allclose(actual._local_importance_values,
expected._local_importance_values, rtol=rtol)
expected._local_importance_values,
rtol=rtol,
atol=atol)
if hasattr(actual, ExplainParams.get_private(ExplainParams.EVAL_DATA)):
assert(isinstance(actual._eval_data, np.ndarray))
assert(isinstance(expected._eval_data, np.ndarray))
if rtol is None:
np.testing.assert_array_equal(actual._eval_data, expected._eval_data)
else:
np.testing.assert_allclose(actual._eval_data, expected._eval_data, rtol=rtol)
np.testing.assert_allclose(actual._eval_data,
expected._eval_data,
rtol=rtol,
atol=atol)


# performs serialization and de-serialization for any explanation
Expand Down

0 comments on commit 9d91f8e

Please sign in to comment.