Skip to content

Commit

Permalink
Fix E721 linting errors (#958)
Browse files Browse the repository at this point in the history
  • Loading branch information
jklaise authored Aug 1, 2023
1 parent 75cf298 commit 54d0c95
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions alibi/explainers/tests/test_shap_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,10 +742,10 @@ def test__summarise_background_kernel(caplog,
msg = "Received option to summarise the data but the background_data object was an " \
"instance of shap_utils.Data"
assert_message_in_logs(msg, caplog.records)
assert type(background_data) == type(summary_data)
assert type(background_data) == type(summary_data) # noqa: E721
else:
if use_groups or categorical_names:
assert type(background_data) == type(summary_data)
assert type(background_data) == type(summary_data) # noqa: E721
if data_type == 'series':
assert summary_data.shape == background_data.shape
else:
Expand Down Expand Up @@ -1181,14 +1181,14 @@ def test__summarise_background_tree(mock_tree_shap_explainer, data_dimension, da
assert explainer.summarise_background
if n_background_samples > n_instances:
if categorical_names:
assert type(background_data) == type(summary_data)
assert type(background_data) == type(summary_data) # noqa: E721
else:
assert isinstance(summary_data, shap_utils.Data)
assert summary_data.data.shape == background_data.shape
else:
if categorical_names:
assert summary_data.shape[0] == n_background_samples
assert type(background_data) == type(summary_data)
assert type(background_data) == type(summary_data) # noqa: E721
else:
assert summary_data.data.shape[0] == n_background_samples
assert isinstance(summary_data, shap_utils.Data)
Expand Down
2 changes: 1 addition & 1 deletion alibi/utils/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def batch_compute_kernel_matrix(x: Union[list, np.ndarray],
-------
Kernel matrix in the form of a `numpy` array.
"""
if type(x) != type(y):
if type(x) != type(y): # noqa: E721
raise ValueError("x and y should be of the same type")

n_x, n_y = len(x), len(y)
Expand Down
2 changes: 1 addition & 1 deletion alibi/utils/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def ord_to_num(data: np.ndarray, dist: dict) -> np.ndarray:
for k, v in dist.items():
cat_col = X[:, k].copy()
cat_col = np.array([v[int(cat_col[i])] for i in range(rng)])
if type(X) == np.matrix:
if isinstance(X, np.matrix):
X[:, k] = cat_col.reshape(-1, 1)
else:
X[:, k] = cat_col
Expand Down

0 comments on commit 54d0c95

Please sign in to comment.