Skip to content

Commit

Permalink
Data Imputation Fix in erroranalysis (#2436)
Browse files Browse the repository at this point in the history
* data imputation fix

* data imputation fix

* python lint fix

* refactor

* nan test case

* import fix

* python lint fix

* isort fix

* test fix

* test fix
  • Loading branch information
Advitya17 authored Dec 13, 2023
1 parent afcc4ba commit 095b2ab
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
2 changes: 1 addition & 1 deletion erroranalysis/erroranalysis/analyzer/error_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def compute_importances(self, error_correlation_method=MUTUAL_INFO):
except ValueError:
# Impute input_data if it contains NaNs, infinity or a value too
# large for dtype('float64')
input_data = np.nan_to_num(input_data)
input_data = np.nan_to_num(input_data.astype(float))
importances = self._compute_error_correlation(
input_data, diff, error_correlation_method)
return importances
Expand Down
37 changes: 37 additions & 0 deletions erroranalysis/tests/test_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from common_utils import replicate_dataset
from sklearn.base import BaseEstimator

from erroranalysis._internal.constants import (ErrorCorrelationMethods,
ModelTask)
Expand Down Expand Up @@ -139,6 +140,42 @@ def test_small_data_importances(self, num_rows, error_correlation_method):
scores = model_analyzer.compute_importances(error_correlation_method)
assert len(scores) == DEFAULT_SAMPLE_COLS

@pytest.mark.parametrize('num_rows', [3, 4])
@pytest.mark.parametrize('nan_correlation_method',
[MUTUAL_INFO, EBM, GBM_SHAP])
def test_nan_data_importances(self, num_rows, nan_correlation_method):
# validate we can run on very few rows
X_train, _, X_test, y_test, _ = \
create_binary_classification_dataset(NUM_SAMPLE_ROWS)
feature_names = list(X_train.columns)

class DummyModel(BaseEstimator):
def fit(self, X, y=None):
return self

def predict(self, X):
return np.zeros((len(X),), dtype=bool)

def predict_proba(self, X):
return np.zeros((len(X),), dtype=bool)

# Use the dummy model
model = DummyModel()

X_test = X_test[:num_rows]
y_test = y_test[:num_rows]

# Randomly replace some values in X_test with NaN
nan_mask = np.random.choice([True, False], size=X_test.shape)
X_test = X_test.astype(float)
X_test[nan_mask] = np.nan

categorical_features = []
model_analyzer = ModelAnalyzer(model, X_test, y_test,
feature_names,
categorical_features)
model_analyzer.compute_importances(nan_correlation_method)

@pytest.mark.parametrize('error_correlation_method',
[MUTUAL_INFO, EBM, GBM_SHAP])
def test_importances_missings(self, error_correlation_method):
Expand Down

0 comments on commit 095b2ab

Please sign in to comment.