Skip to content

Commit

Permalink
fix error when creating a raw explanation from an engineered explanat…
Browse files Browse the repository at this point in the history
…ion that does not have a DatasetsMixin (#427)
  • Loading branch information
imatiach-msft authored Jul 22, 2021
1 parent 348cb2f commit 57641e3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
10 changes: 6 additions & 4 deletions python/interpret_community/explanation/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ def get_raw_explanation(self, feature_maps, raw_feature_names=None, eval_data=No
raw_kwargs[ExplainParams.FEATURES] = raw_feature_names
raw_kwargs[ExplainParams.IS_RAW] = True
raw_kwargs[ExplainParams.EVAL_DATA] = eval_data
raw_kwargs[ExplainParams.EVAL_Y_PRED] = self.eval_y_predicted
raw_kwargs[ExplainParams.EVAL_Y_PRED_PROBA] = self.eval_y_predicted_proba
if _DatasetsMixin._does_quack(self):
raw_kwargs[ExplainParams.EVAL_Y_PRED] = self.eval_y_predicted
raw_kwargs[ExplainParams.EVAL_Y_PRED_PROBA] = self.eval_y_predicted_proba
self._is_eng = True
return _create_raw_feats_local_explanation(self, feature_maps=feature_maps, **raw_kwargs)

Expand Down Expand Up @@ -775,8 +776,9 @@ def get_raw_explanation(self, feature_maps, raw_feature_names=None, eval_data=No
raw_kwargs[ExplainParams.FEATURES] = raw_feature_names
raw_kwargs[ExplainParams.IS_RAW] = True
raw_kwargs[ExplainParams.EVAL_DATA] = eval_data
raw_kwargs[ExplainParams.EVAL_Y_PRED] = self.eval_y_predicted
raw_kwargs[ExplainParams.EVAL_Y_PRED_PROBA] = self.eval_y_predicted_proba
if _DatasetsMixin._does_quack(self):
raw_kwargs[ExplainParams.EVAL_Y_PRED] = self.eval_y_predicted
raw_kwargs[ExplainParams.EVAL_Y_PRED_PROBA] = self.eval_y_predicted_proba
self._is_eng = True
return _create_raw_feats_global_explanation(self, feature_maps=feature_maps, **raw_kwargs)

Expand Down
66 changes: 56 additions & 10 deletions test/raw_explain/test_raw_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@

import numpy as np

from common_utils import create_sklearn_svm_classifier, create_sklearn_random_forest_regressor, \
create_sklearn_linear_regressor, create_multiclass_sparse_newsgroups_data, \
create_sklearn_logistic_regressor, create_binary_sparse_newsgroups_data, LINEAR_METHOD
from common_utils import (
create_sklearn_svm_classifier, create_sklearn_random_forest_regressor,
create_sklearn_linear_regressor, create_multiclass_sparse_newsgroups_data,
create_sklearn_logistic_regressor, create_binary_sparse_newsgroups_data,
LINEAR_METHOD, LIGHTGBM_METHOD)
from constants import DatasetConstants, owner_email_tools_and_ux
from datasets import retrieve_dataset
from sklearn.model_selection import train_test_split
from interpret_community.mimic.models.linear_model import LinearExplainableModel
from interpret_community.mimic.models.lightgbm_model import LGBMExplainableModel
from interpret_community.explanation.explanation import _DatasetsMixin, _create_local_explanation
from interpret_community.common.constants import ExplainParams, ExplainType


@pytest.mark.owner(email=owner_email_tools_and_ux)
Expand Down Expand Up @@ -206,8 +211,49 @@ def test_get_global_raw_explanations_regression_eval_data(self, boston, tabular_
self.validate_global_explanation_regression(global_explanation, global_raw_explanation, feature_map,
has_raw_eval_data=True)

def test_get_raw_explanation_no_datasets_mixin(self, boston, mimic_explainer):
model = create_sklearn_random_forest_regressor(boston[DatasetConstants.X_TRAIN],
boston[DatasetConstants.Y_TRAIN])

explainer = mimic_explainer(model, boston[DatasetConstants.X_TRAIN], LGBMExplainableModel)
global_explanation = explainer.explain_global(boston[DatasetConstants.X_TEST])
assert global_explanation.method == LIGHTGBM_METHOD

kwargs = {ExplainParams.METHOD: global_explanation.method}
kwargs[ExplainParams.FEATURES] = global_explanation.features
kwargs[ExplainParams.MODEL_TASK] = ExplainType.REGRESSION
kwargs[ExplainParams.LOCAL_IMPORTANCE_VALUES] = global_explanation._local_importance_values
kwargs[ExplainParams.EXPECTED_VALUES] = 0
kwargs[ExplainParams.CLASSIFICATION] = False
kwargs[ExplainParams.IS_ENG] = True
synthetic_explanation = _create_local_explanation(**kwargs)

num_engineered_feats = boston[DatasetConstants.X_TRAIN].shape[1]
feature_map = np.eye(5, num_engineered_feats)
feature_names = [str(i) for i in range(feature_map.shape[0])]
raw_names = feature_names[:feature_map.shape[0]]
assert not _DatasetsMixin._does_quack(synthetic_explanation)
global_raw_explanation = synthetic_explanation.get_raw_explanation([feature_map],
raw_feature_names=raw_names)
self.validate_local_explanation_regression(synthetic_explanation,
global_raw_explanation,
feature_map,
has_eng_eval_data=False,
has_raw_eval_data=False,
has_dataset_data=False)

def validate_global_explanation_regression(self, eng_explanation, raw_explanation, feature_map,
has_eng_eval_data=True, has_raw_eval_data=False):
self.validate_local_explanation_regression(eng_explanation,
raw_explanation,
feature_map,
has_eng_eval_data,
has_raw_eval_data)
assert np.array(raw_explanation.global_importance_values).shape[-1] == feature_map.shape[0]

def validate_local_explanation_regression(self, eng_explanation, raw_explanation, feature_map,
has_eng_eval_data=True, has_raw_eval_data=False,
has_dataset_data=True):
assert not eng_explanation.is_raw
assert hasattr(eng_explanation, 'eval_data') == has_eng_eval_data
assert eng_explanation.is_engineered
Expand All @@ -216,15 +262,15 @@ def validate_global_explanation_regression(self, eng_explanation, raw_explanatio

assert raw_explanation.is_raw
assert not raw_explanation.is_engineered
assert np.array(raw_explanation.global_importance_values).shape[-1] == feature_map.shape[0]

# Test the y_pred and y_pred_proba on the raw explanations
assert raw_explanation.eval_y_predicted is not None
assert raw_explanation.eval_y_predicted_proba is None
if has_dataset_data:
# Test the y_pred and y_pred_proba on the raw explanations
assert raw_explanation.eval_y_predicted is not None
assert raw_explanation.eval_y_predicted_proba is None

# Test the raw data on the raw explanations
assert hasattr(raw_explanation, 'eval_data')
assert (raw_explanation.eval_data is not None) == has_raw_eval_data
# Test the raw data on the raw explanations
assert hasattr(raw_explanation, 'eval_data')
assert (raw_explanation.eval_data is not None) == has_raw_eval_data

def validate_global_explanation_classification(self, eng_explanation, raw_explanation,
feature_map, classes, feature_names, is_sparse=False,
Expand Down

0 comments on commit 57641e3

Please sign in to comment.