diff --git a/src/python/interpret/glassbox/ebm/ebm.py b/src/python/interpret/glassbox/ebm/ebm.py index 2e9cbe939..21629ba4b 100644 --- a/src/python/interpret/glassbox/ebm/ebm.py +++ b/src/python/interpret/glassbox/ebm/ebm.py @@ -602,7 +602,8 @@ def __init__( def predict_proba(self, X): check_is_fitted(self, "has_fitted_") - return EBMUtils.classifier_predict_proba(X, self) + prob = EBMUtils.classifier_predict_proba(X, self) + return prob def predict(self, X): check_is_fitted(self, "has_fitted_") @@ -1165,7 +1166,8 @@ def predict_proba(self, X): check_is_fitted(self, "has_fitted_") X, _, _, _ = unify_data(X, None, self.feature_names, self.feature_types) X = self.preprocessor_.transform(X) - return EBMUtils.classifier_predict_proba(X, self) + prob = EBMUtils.classifier_predict_proba(X, self) + return prob def predict(self, X): check_is_fitted(self, "has_fitted_") diff --git a/src/python/interpret/glassbox/ebm/utils.py b/src/python/interpret/glassbox/ebm/utils.py index 87026324b..928ecbe8e 100644 --- a/src/python/interpret/glassbox/ebm/utils.py +++ b/src/python/interpret/glassbox/ebm/utils.py @@ -1,4 +1,5 @@ # Copyright (c) 2019 Microsoft Corporation + # Distributed under the MIT software license # TODO: Test EBMUtils @@ -89,9 +90,8 @@ def classifier_predict_proba(X, estimator, skip_attr_set_idxs=[]): ) # NOTE: Generalize predict when multiclass is supported. - log_odds_trans = np.c_[-log_odds_vector, log_odds_vector] - scores = expit(log_odds_trans) - + prob = expit(log_odds_vector) + scores = np.vstack([1 - prob, prob]).T return scores @staticmethod