From b3639a995490c0e398b92fccc6d9ddcaf98b71e4 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 10:54:21 +0200 Subject: [PATCH 01/14] add robust metric maker --- sklearn_extra/robust/__init__.py | 3 +++ sklearn_extra/robust/mean_estimators.py | 24 +++++++++++++++++++ .../robust/tests/test_mean_estimators.py | 16 ++++++++++++- 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/sklearn_extra/robust/__init__.py b/sklearn_extra/robust/__init__.py index 640ad475..97861ee6 100644 --- a/sklearn_extra/robust/__init__.py +++ b/sklearn_extra/robust/__init__.py @@ -3,9 +3,12 @@ RobustWeightedKMeans, RobustWeightedRegressor, ) +from sklearn_extra.robust.mean_estimators import huber, make_huber_metric __all__ = [ "RobustWeightedClassifier", "RobustWeightedKMeans", "RobustWeightedRegressor", + "huber", + "make_huber_metric", ] diff --git a/sklearn_extra/robust/mean_estimators.py b/sklearn_extra/robust/mean_estimators.py index d2817a14..c0e44aa3 100644 --- a/sklearn_extra/robust/mean_estimators.py +++ b/sklearn_extra/robust/mean_estimators.py @@ -4,6 +4,7 @@ # License: BSD 3 clause import numpy as np +from scipy.stats import iqr def block_mom(X, k, random_state): @@ -136,3 +137,26 @@ def psisx(x, c): # new weights. mu = np.sum(np.array(w[ind_pos]) * x[ind_pos]) / np.sum(w[ind_pos]) return mu + + +def make_huber_metric(metric_func, c=None, T=20): + """ + Make a robust metric using Huber estimator. + """ + + def metric(y_true, y_pred): + # change size in order to use the raw multisample + # to have individual values + y1 = [y_true] + y2 = [y_pred] + values = metric_func(y1, y2, multioutput="raw_values") + if c is None: + c_ = iqr(values) + else: + c_ = c + if c_ == 0: + return np.median(values) + else: + return huber(values, c_, T) + + return metric diff --git a/sklearn_extra/robust/tests/test_mean_estimators.py b/sklearn_extra/robust/tests/test_mean_estimators.py index be5c78f9..d35cfe15 100644 --- a/sklearn_extra/robust/tests/test_mean_estimators.py +++ b/sklearn_extra/robust/tests/test_mean_estimators.py @@ -1,7 +1,12 @@ import numpy as np import pytest -from sklearn_extra.robust.mean_estimators import median_of_means, huber +from sklearn_extra.robust.mean_estimators import ( + median_of_means, + huber, + make_huber_metric, +) +from sklearn.metrics import mean_squared_error rng = np.random.RandomState(42) @@ -29,3 +34,12 @@ def test_huber(): with pytest.warns(None) as record: huber(X) assert len(record) == 0 + + +def test_robust_metric(): + robust_mse = make_huber_metric(mean_squared_error, c=5) + y_true = np.hstack([np.zeros(95), 20 * np.ones(5)]) + np.random.shuffle(y_true) + y_pred = np.zeros(100) + + assert robust_mse(y_true, y_pred) < 1 From ae055b0b38daf765733da3d1636ec2b988a8b886 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 11:09:08 +0200 Subject: [PATCH 02/14] add docstring --- sklearn_extra/robust/mean_estimators.py | 41 +++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/sklearn_extra/robust/mean_estimators.py b/sklearn_extra/robust/mean_estimators.py index c0e44aa3..f3a75bdd 100644 --- a/sklearn_extra/robust/mean_estimators.py +++ b/sklearn_extra/robust/mean_estimators.py @@ -5,6 +5,7 @@ import numpy as np from scipy.stats import iqr +from sklearn.metrics import mean_squared_error def block_mom(X, k, random_state): @@ -139,9 +140,45 @@ def psisx(x, c): return mu -def make_huber_metric(metric_func, c=None, T=20): +def make_huber_metric(score_func=mean_squared_error, c=None, T=20): """ Make a robust metric using Huber estimator. + + Parameters + ---------- + + score_func : callable + Score function (or loss function) with signature + ``score_func(y, y_pred, **kwargs)``. + + c : float >0, default = 1.35 + parameter that control the robustness of the estimator. + c going to zero gives a behavior close to the median. + c going to infinity gives a behavior close to sample mean. + if c is None, the iqr is used as heuristic. + + T : int, default = 20 + Number of iterations of the algorithm. + + Return + ------ + + Robust metric function, a callable with signature + ``score_func(y, y_pred, **kwargs). + + Examples + -------- + + >>> import numpy as np + >>> from sklearn.metrics import mean_squared_error + >>> from sklearn_extra.robust import make_huber_metric + >>> robust_mse = make_huber_metric(mean_squared_error, c=5) + >>> y_true = np.hstack([np.zeros(98), 20*np.ones(2)]) # corrupted test values + >>> np.random.shuffle(y_true) # shuffle them + >>> _ = clf.fit(X, y) + >>> y_pred = np.zeros(100) # predicted values + >>> robust_mse(y_true, y_pred) + 0.26315789473684204 """ def metric(y_true, y_pred): @@ -149,7 +186,7 @@ def metric(y_true, y_pred): # to have individual values y1 = [y_true] y2 = [y_pred] - values = metric_func(y1, y2, multioutput="raw_values") + values = score_func(y1, y2, multioutput="raw_values") if c is None: c_ = iqr(values) else: From c43de902f972313e3f41f3c1b950b95e51f0dd58 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 11:20:24 +0200 Subject: [PATCH 03/14] fix doctring example --- sklearn_extra/robust/mean_estimators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn_extra/robust/mean_estimators.py b/sklearn_extra/robust/mean_estimators.py index f3a75bdd..4a3fd7ce 100644 --- a/sklearn_extra/robust/mean_estimators.py +++ b/sklearn_extra/robust/mean_estimators.py @@ -175,10 +175,9 @@ def make_huber_metric(score_func=mean_squared_error, c=None, T=20): >>> robust_mse = make_huber_metric(mean_squared_error, c=5) >>> y_true = np.hstack([np.zeros(98), 20*np.ones(2)]) # corrupted test values >>> np.random.shuffle(y_true) # shuffle them - >>> _ = clf.fit(X, y) >>> y_pred = np.zeros(100) # predicted values >>> robust_mse(y_true, y_pred) - 0.26315789473684204 + 0.1020408163265306 """ def metric(y_true, y_pred): From c266b21950459e09357f1ce5daea3e09a322cd58 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 14:30:34 +0200 Subject: [PATCH 04/14] add doc --- doc/modules/robust.rst | 51 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index cdd47308..deda431c 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -142,6 +142,57 @@ This algorithm has been studied in the context of "mom" weights in the article [1]_, the context of "huber" weights has been mentioned in [2]_. Both weighting schemes can be seen as special cases of the algorithm in [3]_. + +Robust model selection +---------------------- + +one of the big challenge of robust machine learning is that the usual scoring +scheme (cross_validation with mean squared error for instance) is not robust. +Indeed, if the dataset has some outliers, then the test sets in cross-validation +may have outliers and then the cross_validation MSE would give us a huge error +for our robust algorithm on any corrupted data. + +To solve this problem, one can use robust score methods when doing +cross-validation using `make_huber_metric`. The following example show how +`make_huber_metric` can be used and it shows that `HuberRegressor` is robust +to outliers in the variable y. + +Example : + +Import the libraries + + >>>import numpy as np + >>>from sklearn.metrics import mean_squared_error, make_scorer + >>>from sklearn.model_selection import cross_val_score + >>>from sklearn_extra.robust import make_huber_metric + >>>from sklearn.linear_model import LinearRegression, HuberRegressor + +Define the robust metric + + >>>robust_mse = make_huber_metric(mean_squared_error, c=9) + +Define a corrupted dataset + + >>>rng = np.random.RandomState(42) + >>>X = rng.uniform(size=100)[:,np.newaxis] + >>>y = 3*X.ravel() + >>>y[[42//2,42, 42*2]] = 200 # outliers + +Get the non robust errors: + + >>>for reg in [LinearRegression(), HuberRegressor()]: + >>> print(reg, " mse : %.2F" %(np.mean(cross_val_score(reg, X, y, scoring = make_scorer(mean_squared_error))))) + LinearRegression() mse : 1154.63 + HuberRegressor() mse : 1194.19 + +Get the robust errors: + + >>>for reg in [LinearRegression(), HuberRegressor()]: + >>> print(reg, " mse : %.2F" %(np.mean(cross_val_score(reg, X, y, scoring = make_scorer(robust_mse))))) + LinearRegression() mse : 51.93 + HuberRegressor() mse : 0.28 + + Comparison with other robust estimators --------------------------------------- From 81b0333fae2a9564c515b048699fef3c3f189940 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 15:21:27 +0200 Subject: [PATCH 05/14] Add example and doc --- doc/modules/robust.rst | 44 ++++------------------- examples/robust/robust_cv_example.py | 54 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 37 deletions(-) create mode 100644 examples/robust/robust_cv_example.py diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index deda431c..573b2f45 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -153,45 +153,11 @@ may have outliers and then the cross_validation MSE would give us a huge error for our robust algorithm on any corrupted data. To solve this problem, one can use robust score methods when doing -cross-validation using `make_huber_metric`. The following example show how -`make_huber_metric` can be used and it shows that `HuberRegressor` is robust -to outliers in the variable y. +cross-validation using `make_huber_metric`. See the following example: -Example : - -Import the libraries - - >>>import numpy as np - >>>from sklearn.metrics import mean_squared_error, make_scorer - >>>from sklearn.model_selection import cross_val_score - >>>from sklearn_extra.robust import make_huber_metric - >>>from sklearn.linear_model import LinearRegression, HuberRegressor - -Define the robust metric - - >>>robust_mse = make_huber_metric(mean_squared_error, c=9) - -Define a corrupted dataset - - >>>rng = np.random.RandomState(42) - >>>X = rng.uniform(size=100)[:,np.newaxis] - >>>y = 3*X.ravel() - >>>y[[42//2,42, 42*2]] = 200 # outliers - -Get the non robust errors: - - >>>for reg in [LinearRegression(), HuberRegressor()]: - >>> print(reg, " mse : %.2F" %(np.mean(cross_val_score(reg, X, y, scoring = make_scorer(mean_squared_error))))) - LinearRegression() mse : 1154.63 - HuberRegressor() mse : 1194.19 - -Get the robust errors: - - >>>for reg in [LinearRegression(), HuberRegressor()]: - >>> print(reg, " mse : %.2F" %(np.mean(cross_val_score(reg, X, y, scoring = make_scorer(robust_mse))))) - LinearRegression() mse : 51.93 - HuberRegressor() mse : 0.28 +:ref:`../auto_examples/robust/robust_cv_example.html#sphx-glr-download-auto-examples-robust-robust-cv-example-py` +This type of robust cross-validation was mentioned for instance in [4]_. Comparison with other robust estimators --------------------------------------- @@ -254,3 +220,7 @@ the example with California housing real dataset, for further discussion. .. [3] Stanislav Minsker and Timothée Mathieu. `"Excess risk bounds in robust empirical risk minimization" `_ arXiv preprint (2019). arXiv:1910.07485. + + .. [4] Elvezio Ronchetti , Christopher Field & Wade Blanchard + `" Robust Linear Model Selection by Cross-Validation" _ + Journal of the American Statistical Association (1995). diff --git a/examples/robust/robust_cv_example.py b/examples/robust/robust_cv_example.py new file mode 100644 index 00000000..0c65bf38 --- /dev/null +++ b/examples/robust/robust_cv_example.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +""" +================================================================ +An example of a robust cross-validation evaluation in regression +================================================================ +In this example we compare `LinearRegression` (OLS) with `HuberRegressor` from +scikit-learn using cross-validation. + +We show that a robust cross-validation scheme is better to have a better +evaluation of the generalisation error on the majority of the data. +""" +print(__doc__) + +import numpy as np +from sklearn.metrics import mean_squared_error, make_scorer +from sklearn.model_selection import cross_val_score +from sklearn_extra.robust import make_huber_metric +from sklearn.linear_model import LinearRegression, HuberRegressor + +robust_mse = make_huber_metric(mean_squared_error, c=9) +rng = np.random.RandomState(42) + +X = rng.uniform(size=100)[:, np.newaxis] +y = 3 * X.ravel() +# Remark y <= 3 + +y[[42 // 2, 42, 42 * 2]] = 200 # outliers + +print("Non robust error:") +for reg in [LinearRegression(), HuberRegressor()]: + print( + reg, + " mse : %.2F" + % ( + np.mean( + cross_val_score( + reg, X, y, scoring=make_scorer(mean_squared_error) + ) + ) + ), + ) + +print("\n") +print("Robust error:") +for reg in [LinearRegression(), HuberRegressor()]: + print( + reg, + " mse : %.2F" + % ( + np.mean( + cross_val_score(reg, X, y, scoring=make_scorer(robust_mse)) + ) + ), + ) From 8f829e9c14646f5701dcd9accd0c9b6c149be8a3 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 15:38:22 +0200 Subject: [PATCH 06/14] make example executable and try another link --- doc/modules/robust.rst | 2 +- .../robust/{robust_cv_example.py => plot_robust_cv_example.py} | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) rename examples/robust/{robust_cv_example.py => plot_robust_cv_example.py} (99%) diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index 573b2f45..8ebd3a4e 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -155,7 +155,7 @@ for our robust algorithm on any corrupted data. To solve this problem, one can use robust score methods when doing cross-validation using `make_huber_metric`. See the following example: -:ref:`../auto_examples/robust/robust_cv_example.html#sphx-glr-download-auto-examples-robust-robust-cv-example-py` +:ref:`../auto_examples/robust/plot_robust_cv_example.html` This type of robust cross-validation was mentioned for instance in [4]_. diff --git a/examples/robust/robust_cv_example.py b/examples/robust/plot_robust_cv_example.py similarity index 99% rename from examples/robust/robust_cv_example.py rename to examples/robust/plot_robust_cv_example.py index 0c65bf38..4f1f59db 100644 --- a/examples/robust/robust_cv_example.py +++ b/examples/robust/plot_robust_cv_example.py @@ -40,6 +40,7 @@ ), ) + print("\n") print("Robust error:") for reg in [LinearRegression(), HuberRegressor()]: From 56cd4e7db770dc18d06cac80b07403021b784fc6 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 15:43:17 +0200 Subject: [PATCH 07/14] add make_huber_metric to api doc --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 57b36246..e5a8ba86 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -43,3 +43,4 @@ Robust robust.RobustWeightedClassifier robust.RobustWeightedRegressor robust.RobustWeightedKMeans + robust.make_huber_metric From e141e8a7c500ce8c9560232f6b3ff993a1debbf3 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 15:46:18 +0200 Subject: [PATCH 08/14] fix doc --- doc/modules/robust.rst | 2 +- examples/robust/plot_robust_cv_example.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index 8ebd3a4e..c7faa87b 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -146,7 +146,7 @@ Both weighting schemes can be seen as special cases of the algorithm in [3]_. Robust model selection ---------------------- -one of the big challenge of robust machine learning is that the usual scoring +One of the big challenge of robust machine learning is that the usual scoring scheme (cross_validation with mean squared error for instance) is not robust. Indeed, if the dataset has some outliers, then the test sets in cross-validation may have outliers and then the cross_validation MSE would give us a huge error diff --git a/examples/robust/plot_robust_cv_example.py b/examples/robust/plot_robust_cv_example.py index 4f1f59db..680dd7ac 100644 --- a/examples/robust/plot_robust_cv_example.py +++ b/examples/robust/plot_robust_cv_example.py @@ -6,8 +6,8 @@ In this example we compare `LinearRegression` (OLS) with `HuberRegressor` from scikit-learn using cross-validation. -We show that a robust cross-validation scheme is better to have a better -evaluation of the generalisation error on the majority of the data. +We show that a robust cross-validation scheme gives a better +evaluation of the generalisation error in a corrupted dataset. """ print(__doc__) From 834a00beab458b6074d3940f7fc8ad2592eb8cf7 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Fri, 25 Jun 2021 15:52:29 +0200 Subject: [PATCH 09/14] fix doc api --- doc/api.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index e5a8ba86..57175274 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -43,4 +43,9 @@ Robust robust.RobustWeightedClassifier robust.RobustWeightedRegressor robust.RobustWeightedKMeans + +.. autosummary:: + :toctree: generated/ + :template: function.rst + robust.make_huber_metric From e7a006f20e4db70e0898ad4a69381de2438ce2cb Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Sat, 26 Jun 2021 09:03:03 +0200 Subject: [PATCH 10/14] add more explanation and change names of variables --- doc/modules/robust.rst | 19 +++++++++++++++++ sklearn_extra/robust/mean_estimators.py | 28 +++++++++++++++++-------- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index c7faa87b..87c9136d 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -145,6 +145,7 @@ Both weighting schemes can be seen as special cases of the algorithm in [3]_. Robust model selection ---------------------- +.. _make_huber_metric: One of the big challenge of robust machine learning is that the usual scoring scheme (cross_validation with mean squared error for instance) is not robust. @@ -159,6 +160,24 @@ cross-validation using `make_huber_metric`. See the following example: This type of robust cross-validation was mentioned for instance in [4]_. + +Here is what `make_huber_metric` computes: suppose that we compute a +loss function as such: + +.. math:: + + \widehat L = \frac{1}{n}\sum_{i=1}^n \ell(Y_i, f(X_i)) + +`make_huber_metric` propose to change this computation for + +.. math:: + \widehat L_{rob}=\widehat{\mathrm{Hub}}\left(\ell(Y_i, f(X_i))\right) + +where :math:`\widehat{\mathrm{Hub}}` is the Huber estimator of location. It is a +robust estimator of the mean, and :math:`\widehat{L}_{rob}` is robust in the sense +that an especially large value of :math:`\ell(Y_i, f(X_i))` would not change the +value of the result by a lot. + Comparison with other robust estimators --------------------------------------- diff --git a/sklearn_extra/robust/mean_estimators.py b/sklearn_extra/robust/mean_estimators.py index 4a3fd7ce..190c1c0c 100644 --- a/sklearn_extra/robust/mean_estimators.py +++ b/sklearn_extra/robust/mean_estimators.py @@ -90,7 +90,7 @@ def median_of_means(X, k, random_state=np.random.RandomState(42)): return median_of_means_blocked(x, blocks)[0] -def huber(X, c=1.35, T=20): +def huber(X, c=1.35, n_iter=20): """Compute the Huber estimator of location of X with parameter c Parameters @@ -104,7 +104,7 @@ def huber(X, c=1.35, T=20): c going to zero gives a behavior close to the median. c going to infinity gives a behavior close to sample mean. - T : int, default = 20 + n_iter : int, default = 20 Number of iterations of the algorithm. Return @@ -127,7 +127,7 @@ def psisx(x, c): return res # Run the iterative reweighting algorithm to compute M-estimator. - for t in range(T): + for t in range(n_iter): # Compute the weights w = psisx(x - mu, c) @@ -140,10 +140,14 @@ def psisx(x, c): return mu -def make_huber_metric(score_func=mean_squared_error, c=None, T=20): +def make_huber_metric( + score_func=mean_squared_error, sample_weight=None, c=None, n_iter=20 +): """ Make a robust metric using Huber estimator. + Read more in the :ref:`User Guide `. + Parameters ---------- @@ -151,13 +155,17 @@ def make_huber_metric(score_func=mean_squared_error, c=None, T=20): Score function (or loss function) with signature ``score_func(y, y_pred, **kwargs)``. - c : float >0, default = 1.35 + sample_weight: array-like of shape (n_samples,), default=None + Sample weights. + + + c : float >0, default = None parameter that control the robustness of the estimator. c going to zero gives a behavior close to the median. c going to infinity gives a behavior close to sample mean. - if c is None, the iqr is used as heuristic. + if c is None, the iqr (inter quartile range) is used as heuristic. - T : int, default = 20 + n_iter : int, default = 20 Number of iterations of the algorithm. Return @@ -185,7 +193,9 @@ def metric(y_true, y_pred): # to have individual values y1 = [y_true] y2 = [y_pred] - values = score_func(y1, y2, multioutput="raw_values") + values = score_func( + y1, y2, sample_weight=sample_weight, multioutput="raw_values" + ) if c is None: c_ = iqr(values) else: @@ -193,6 +203,6 @@ def metric(y_true, y_pred): if c_ == 0: return np.median(values) else: - return huber(values, c_, T) + return huber(values, c_, n_iter) return metric From a48bc1b559b36fec1f533820989f6cd6e5cc90d4 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Sat, 26 Jun 2021 09:53:48 +0200 Subject: [PATCH 11/14] add test robust cv --- doc/modules/robust.rst | 4 +++- .../robust/tests/test_mean_estimators.py | 22 +++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index 87c9136d..6ef3a42b 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -176,7 +176,9 @@ loss function as such: where :math:`\widehat{\mathrm{Hub}}` is the Huber estimator of location. It is a robust estimator of the mean, and :math:`\widehat{L}_{rob}` is robust in the sense that an especially large value of :math:`\ell(Y_i, f(X_i))` would not change the -value of the result by a lot. +value of the result by a lot. The constant `c` used when tuning +:math:`\widehat{\mathrm{Hub}}` has the same role of tuning the robustness as in +the case of regression and classification using Huber weights. Comparison with other robust estimators --------------------------------------- diff --git a/sklearn_extra/robust/tests/test_mean_estimators.py b/sklearn_extra/robust/tests/test_mean_estimators.py index d35cfe15..59cbc10a 100644 --- a/sklearn_extra/robust/tests/test_mean_estimators.py +++ b/sklearn_extra/robust/tests/test_mean_estimators.py @@ -6,8 +6,9 @@ huber, make_huber_metric, ) -from sklearn.metrics import mean_squared_error - +from sklearn.metrics import mean_squared_error, make_scorer +from sklearn.model_selection import cross_val_score +from sklearn.linear_model import HuberRegressor rng = np.random.RandomState(42) @@ -43,3 +44,20 @@ def test_robust_metric(): y_pred = np.zeros(100) assert robust_mse(y_true, y_pred) < 1 + + +def test_check_robust_cv(): + + robust_mse = make_huber_metric(mean_squared_error, c=9) + rng = np.random.RandomState(42) + + X = rng.uniform(size=100)[:, np.newaxis] + y = 3 * X.ravel() + + y[[42 // 2, 42, 42 * 2]] = 200 # outliers + + huber_reg = HuberRegressor() + error_Hub_reg = error_ols = np.mean( + cross_val_score(huber_reg, X, y, scoring=make_scorer(robust_mse)) + ) + assert error_Hub_reg < 1 From 1e16ab268b69a34584daa3651f8bc9265ad9f364 Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Sun, 27 Jun 2021 10:03:13 +0200 Subject: [PATCH 12/14] add to changelog, add trimmed mean example to example --- doc/changelog.rst | 3 +++ doc/modules/robust.rst | 3 ++- examples/robust/plot_robust_cv_example.py | 7 +++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 053b9197..82103d50 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,6 +4,9 @@ Changelog Unreleased ---------- +- Add `make_huber_metric` which transform a non-robust to a robust metric using + Huber estimator. + - Add `CLARA` (Clustering for Large Applications) which extends k-medoids to be more scalable using a sampling approach. [`#83 `_]. diff --git a/doc/modules/robust.rst b/doc/modules/robust.rst index 6ef3a42b..00845c41 100644 --- a/doc/modules/robust.rst +++ b/doc/modules/robust.rst @@ -174,7 +174,8 @@ loss function as such: \widehat L_{rob}=\widehat{\mathrm{Hub}}\left(\ell(Y_i, f(X_i))\right) where :math:`\widehat{\mathrm{Hub}}` is the Huber estimator of location. It is a -robust estimator of the mean, and :math:`\widehat{L}_{rob}` is robust in the sense +robust estimator of the mean (similar result can also be attained using the +trimmed mean), and :math:`\widehat{L}_{rob}` is robust in the sense that an especially large value of :math:`\ell(Y_i, f(X_i))` would not change the value of the result by a lot. The constant `c` used when tuning :math:`\widehat{\mathrm{Hub}}` has the same role of tuning the robustness as in diff --git a/examples/robust/plot_robust_cv_example.py b/examples/robust/plot_robust_cv_example.py index 680dd7ac..19d66fd5 100644 --- a/examples/robust/plot_robust_cv_example.py +++ b/examples/robust/plot_robust_cv_example.py @@ -8,6 +8,13 @@ We show that a robust cross-validation scheme gives a better evaluation of the generalisation error in a corrupted dataset. + +In this example, we do robust cross-validation by using an alternative to the +empirical mean to aggregate the errors. This alternative is a robust estimator +of the mean (the trimmed mean is an example of such a robust estimator, but here +we use Huber's estimator). This robust estimator of the mean is used on each +fold of the cross-validation and then, we return the empirical mean of the +obtained robust scores to get the final score. """ print(__doc__) From 495b852ff7e84e24f56b71f1bf327f76a032aecf Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Tue, 7 Sep 2021 10:29:42 +0200 Subject: [PATCH 13/14] black reformat --- sklearn_extra/robust/tests/test_mean_estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn_extra/robust/tests/test_mean_estimators.py b/sklearn_extra/robust/tests/test_mean_estimators.py index ab4750df..f93f0f6d 100644 --- a/sklearn_extra/robust/tests/test_mean_estimators.py +++ b/sklearn_extra/robust/tests/test_mean_estimators.py @@ -37,6 +37,7 @@ def test_huber(): assert len(record) == 0 assert np.abs(mu) < 0.1 + def test_robust_metric(): robust_mse = make_huber_metric(mean_squared_error, c=5) y_true = np.hstack([np.zeros(95), 20 * np.ones(5)]) @@ -61,4 +62,3 @@ def test_check_robust_cv(): cross_val_score(huber_reg, X, y, scoring=make_scorer(robust_mse)) ) assert error_Hub_reg < 1 - From 0a3843a5a9d90ac75b2ad909da1757bce167b08d Mon Sep 17 00:00:00 2001 From: TimotheeMathieu Date: Tue, 7 Sep 2021 10:58:47 +0200 Subject: [PATCH 14/14] fix docstring --- sklearn_extra/robust/mean_estimators.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn_extra/robust/mean_estimators.py b/sklearn_extra/robust/mean_estimators.py index cf3619bd..e6040352 100644 --- a/sklearn_extra/robust/mean_estimators.py +++ b/sklearn_extra/robust/mean_estimators.py @@ -204,8 +204,7 @@ def make_huber_metric( >>> y_true = np.hstack([np.zeros(98), 20*np.ones(2)]) # corrupted test values >>> np.random.shuffle(y_true) # shuffle them >>> y_pred = np.zeros(100) # predicted values - >>> robust_mse(y_true, y_pred) - 0.1020408163265306 + >>> result = robust_mse(y_true, y_pred) """ def metric(y_true, y_pred):