From ccd031237848309670dfb0195c626f41e50f7bee Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Tue, 1 Aug 2023 09:29:05 +0100 Subject: [PATCH] Fix E721 linting errors --- alibi/explainers/tests/test_shap_wrappers.py | 8 ++++---- alibi/utils/distance.py | 2 +- alibi/utils/mapping.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/alibi/explainers/tests/test_shap_wrappers.py b/alibi/explainers/tests/test_shap_wrappers.py index 7a517be7a..95f2b484b 100644 --- a/alibi/explainers/tests/test_shap_wrappers.py +++ b/alibi/explainers/tests/test_shap_wrappers.py @@ -742,10 +742,10 @@ def test__summarise_background_kernel(caplog, msg = "Received option to summarise the data but the background_data object was an " \ "instance of shap_utils.Data" assert_message_in_logs(msg, caplog.records) - assert type(background_data) == type(summary_data) + assert type(background_data) == type(summary_data) # noqa: E721 else: if use_groups or categorical_names: - assert type(background_data) == type(summary_data) + assert type(background_data) == type(summary_data) # noqa: E721 if data_type == 'series': assert summary_data.shape == background_data.shape else: @@ -1181,14 +1181,14 @@ def test__summarise_background_tree(mock_tree_shap_explainer, data_dimension, da assert explainer.summarise_background if n_background_samples > n_instances: if categorical_names: - assert type(background_data) == type(summary_data) + assert type(background_data) == type(summary_data) # noqa: E721 else: assert isinstance(summary_data, shap_utils.Data) assert summary_data.data.shape == background_data.shape else: if categorical_names: assert summary_data.shape[0] == n_background_samples - assert type(background_data) == type(summary_data) + assert type(background_data) == type(summary_data) # noqa: E721 else: assert summary_data.data.shape[0] == n_background_samples assert isinstance(summary_data, shap_utils.Data) diff --git a/alibi/utils/distance.py b/alibi/utils/distance.py index 21f7d73dc..d050915c6 100644 --- a/alibi/utils/distance.py +++ b/alibi/utils/distance.py @@ -299,7 +299,7 @@ def batch_compute_kernel_matrix(x: Union[list, np.ndarray], ------- Kernel matrix in the form of a `numpy` array. """ - if type(x) != type(y): + if type(x) != type(y): # noqa: E721 raise ValueError("x and y should be of the same type") n_x, n_y = len(x), len(y) diff --git a/alibi/utils/mapping.py b/alibi/utils/mapping.py index 047a27024..d689b029b 100644 --- a/alibi/utils/mapping.py +++ b/alibi/utils/mapping.py @@ -52,7 +52,7 @@ def ord_to_num(data: np.ndarray, dist: dict) -> np.ndarray: for k, v in dist.items(): cat_col = X[:, k].copy() cat_col = np.array([v[int(cat_col[i])] for i in range(rng)]) - if type(X) == np.matrix: + if isinstance(X, np.matrix): X[:, k] = cat_col.reshape(-1, 1) else: X[:, k] = cat_col