From bcc3069e91380ba6d48ef71407bd526716f90903 Mon Sep 17 00:00:00 2001 From: laurallu Date: Sun, 3 Nov 2019 21:54:54 -0500 Subject: [PATCH 1/4] added safe-level-smote method --- imblearn/over_sampling/__init__.py | 2 + imblearn/over_sampling/_smote.py | 328 ++++++++++++++++++++++++++++- 2 files changed, 326 insertions(+), 4 deletions(-) diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index bd20b76ea..e6625f098 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,6 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._smote import SLSMOTE __all__ = [ "ADASYN", @@ -19,4 +20,5 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", + "SLSMOTE", ] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index a81c492bf..bd30ea2f1 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -586,12 +586,14 @@ def _fit_resample(self, X, y): n_generated_samples = int(fractions * (n_samples + 1)) if np.count_nonzero(danger_bool) > 0: nns = self.nn_k_.kneighbors( - _safe_indexing(support_vector, np.flatnonzero(danger_bool)), + _safe_indexing( + support_vector, np.flatnonzero(danger_bool)), return_distance=False, )[:, 1:] X_new_1, y_new_1 = self._make_samples( - _safe_indexing(support_vector, np.flatnonzero(danger_bool)), + _safe_indexing( + support_vector, np.flatnonzero(danger_bool)), y.dtype, class_sample, X_class, @@ -602,12 +604,14 @@ def _fit_resample(self, X, y): if np.count_nonzero(safety_bool) > 0: nns = self.nn_k_.kneighbors( - _safe_indexing(support_vector, np.flatnonzero(safety_bool)), + _safe_indexing( + support_vector, np.flatnonzero(safety_bool)), return_distance=False, )[:, 1:] X_new_2, y_new_2 = self._make_samples( - _safe_indexing(support_vector, np.flatnonzero(safety_bool)), + _safe_indexing( + support_vector, np.flatnonzero(safety_bool)), y.dtype, class_sample, X_class, @@ -1308,3 +1312,319 @@ def _fit_resample(self, X, y): y_resampled = np.hstack((y_resampled, y_new)) return X_resampled, y_resampled + + +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring, +) +class SLSMOTE(BaseSMOTE): + """Class to perform over-sampling using safe-level SMOTE. + This is an implementation of the Safe-level-SMOTE described in [2]_. + + Parameters + ----------- + {sampling_strategy} + + {random_state} + + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + m_neighbors : int or object, optional (default=10) + If ``int``, number of nearest neighbours to use to determine the safe + level of an instance. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used + to find the m_neighbors. + + n_jobs : int or None, optional (default=None) + Number of CPU cores used during the cross-validation loop. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See + `Glossary `_ + for more details. + + + Notes + ----- + See the original papers: [2]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + See also + -------- + SMOTE : Over-sample using SMOTE. + + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE. + + ADASYN : Over-sample using ADASYN. + + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + .. [2] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap, "Safe-level- + SMOTE: Safe-level-synthetic minority over-sampling technique for + handling the class imbalanced problem," In: Theeramunkong T., + Kijsirikul B., Cercone N., Ho TB. (eds) Advances in Knowledge Discovery + and Data Mining. PAKDD 2009. Lecture Notes in Computer Science, + vol 5476. Springer, Berlin, Heidelberg, 475-482, 2009. + + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import \ +SLSMOTE # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> sm = SLSMOTE(random_state=42) + >>> X_res, y_res = sm.fit_resample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + + """ + + def __init__(self, + sampling_strategy='auto', + random_state=None, + k_neighbors=5, + m_neighbors=10, + n_jobs=None): + + super().__init__(sampling_strategy=sampling_strategy, + random_state=random_state, k_neighbors=k_neighbors, + n_jobs=n_jobs) + + self.m_neighbors = m_neighbors + + def _assign_sl(self, nn_estimator, samples, target_class, y): + ''' + Assign the safe levels to the instances in the target class. + + Parameters + ---------- + nn_estimator : estimator + An estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin`. It gets the + nearest neighbors that are used to determine the safe levels. + + samples : {array-like, sparse matrix}, shape (n_samples, n_features) + The samples to which the safe levels are assigned. + + target_class : int or str + The target corresponding class being over-sampled. + + y : array-like, shape (n_samples,) + The true label in order to calculate the safe levels. + + Returns + ------- + output : ndarray, shape (n_samples,) + A ndarray where the values refer to the safe level of the + instances in the target class. + ''' + + x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:] + nn_label = (y[x] == target_class).astype(int) + sl = np.sum(nn_label, axis=1) + return sl + + def _validate_estimator(self): + super()._validate_estimator() + self.nn_m_ = check_neighbors_object('m_neighbors', self.m_neighbors, + additional_neighbor=1) + self.nn_m_.set_params(**{"n_jobs": self.n_jobs}) + + def _fit_resample(self, X, y): + self._validate_estimator() + + X_resampled = X.copy() + y_resampled = y.copy() + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = _safe_indexing(X, target_class_indices) + + self.nn_m_.fit(X) + sl = self._assign_sl(self.nn_m_, X_class, class_sample, y) + + # filter the points in X_class that have safe level >0 + # If safe level = 0, the point is not used to + # generate synthetic instances + X_safe_indices = np.flatnonzero(sl != 0) + X_safe_class = _safe_indexing(X_class, X_safe_indices) + + self.nn_k_.fit(X_class) + nns = self.nn_k_.kneighbors(X_safe_class, + return_distance=False)[:, 1:] + + sl_safe_class = sl[X_safe_indices] + sl_nns = sl[nns] + sl_safe_t = np.array([sl_safe_class]).transpose() + with np.errstate(divide='ignore'): + sl_ratio = np.divide(sl_safe_t, sl_nns) + + X_new, y_new = self._make_samples_sl(X_safe_class, y.dtype, + class_sample, X_class, + nns, n_samples, sl_ratio, + 1.0) + + if sparse.issparse(X_new): + X_resampled = sparse.vstack([X_resampled, X_new]) + else: + X_resampled = np.vstack((X_resampled, X_new)) + y_resampled = np.hstack((y_resampled, y_new)) + + return X_resampled, y_resampled + + def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num, + n_samples, sl_ratio, step_size=1.): + """A support function that returns artificial samples using + safe-level SMOTE. It is similar to _make_samples method for SMOTE. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples_safe, n_features) + Points from which the points will be created. + + y_dtype : dtype + The data type of the targets. + + y_type : str or int + The minority target value, just so the function can return the + target values for the synthetic variables with correct length in + a clear format. + + nn_data : ndarray, shape (n_samples_all, n_features) + Data set carrying all the neighbours to be used + + nn_num : ndarray, shape (n_samples_safe, k_nearest_neighbours) + The nearest neighbours of each sample in `nn_data`. + + n_samples : int + The number of samples to generate. + + sl_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours) + + step_size : float, optional (default=1.) + The step size to create samples. + + + Returns + ------- + X_new : {ndarray, sparse matrix}, shape (n_samples_new, n_features) + Synthetically generated samples using the safe-level method. + + y_new : ndarray, shape (n_samples_new,) + Target values for synthetic samples. + + """ + + random_state = check_random_state(self.random_state) + samples_indices = random_state.randint(low=0, + high=len(nn_num.flatten()), + size=n_samples) + rows = np.floor_divide(samples_indices, nn_num.shape[1]) + cols = np.mod(samples_indices, nn_num.shape[1]) + gap_arr = step_size * self._vgenerate_gap(sl_ratio) + gaps = gap_arr.flatten()[samples_indices] + + y_new = np.array([y_type] * n_samples, dtype=y_dtype) + + if sparse.issparse(X): + row_indices, col_indices, samples = [], [], [] + for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)): + if X[row].nnz: + sample = self._generate_sample( + X, nn_data, nn_num, row, col, gap) + row_indices += [i] * len(sample.indices) + col_indices += sample.indices.tolist() + samples += sample.data.tolist() + return ( + sparse.csr_matrix( + (samples, (row_indices, col_indices)), + [len(samples_indices), X.shape[1]], + dtype=X.dtype, + ), + y_new, + ) + + else: + X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype) + for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)): + X_new[i] = self._generate_sample(X, nn_data, nn_num, + row, col, gap) + + return X_new, y_new + + def _generate_gap(self, a_ratio, rand_state=None): + """ generate gap according to sl_ratio, non-vectorized version. + + Parameters + ---------- + a_ratio: float + sl_ratio of a single data point + + rand_state: random state object or int + + + Returns + ------------ + gap: float + a number between 0 and 1 + + """ + + random_state = check_random_state(rand_state) + if np.isinf(a_ratio): + gap = 0 + elif a_ratio >= 1: + gap = random_state.uniform(0, 1/a_ratio) + elif 0 < a_ratio < 1: + gap = random_state.uniform(1-a_ratio, 1) + else: + raise ValueError('sl_ratio should be nonegative') + return gap + + def _vgenerate_gap(self, sl_ratio): + """ + generate gap according to sl_ratio, vectorized version of _generate_gap + + Parameters + ----------- + sl_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours) + sl_ratio of all instances with safe_level>0 in the specified + class + + Returns + ------------ + gap_arr: ndarray shape (n_samples_safe, k_nearest_neighbours) + the gap for all instances with safe_level>0 in the specified + class + + """ + prng = check_random_state(self.random_state) + rand_state = prng.randint(sl_ratio.size+1, size=sl_ratio.shape) + vgap = np.vectorize(self._generate_gap) + gap_arr = vgap(sl_ratio, rand_state) + return gap_arr From 394d686364725763de8ea2cc3f504d8c08fe111a Mon Sep 17 00:00:00 2001 From: laurallu Date: Tue, 5 Nov 2019 23:55:29 -0500 Subject: [PATCH 2/4] unit tests added for safe-level SMOTE --- imblearn/over_sampling/_smote.py | 22 +++---- imblearn/over_sampling/tests/test_sl_smote.py | 66 +++++++++++++++++++ 2 files changed, 77 insertions(+), 11 deletions(-) create mode 100644 imblearn/over_sampling/tests/test_sl_smote.py diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index bd30ea2f1..6acc9b498 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1324,18 +1324,18 @@ class SLSMOTE(BaseSMOTE): Parameters ----------- - {sampling_strategy} + {sampling_strategy} - {random_state} + {random_state} - k_neighbors : int or object, optional (default=5) + k_neighbors : int or object, optional (default=5) If ``int``, number of nearest neighbours to used to construct synthetic samples. If object, an estimator that inherits from :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to find the k_neighbors. m_neighbors : int or object, optional (default=10) - If ``int``, number of nearest neighbours to use to determine the safe + If ``int``, number of nearest neighbours used to determine the safe level of an instance. If object, an estimator that inherits from :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to find the m_neighbors. @@ -1582,16 +1582,16 @@ def _generate_gap(self, a_ratio, rand_state=None): Parameters ---------- - a_ratio: float - sl_ratio of a single data point + a_ratio: float + sl_ratio of a single data point - rand_state: random state object or int + rand_state: random state object or int - Returns - ------------ - gap: float - a number between 0 and 1 + Returns + ------------ + gap: float + a number between 0 and 1 """ diff --git a/imblearn/over_sampling/tests/test_sl_smote.py b/imblearn/over_sampling/tests/test_sl_smote.py new file mode 100644 index 000000000..a4e357ce4 --- /dev/null +++ b/imblearn/over_sampling/tests/test_sl_smote.py @@ -0,0 +1,66 @@ +import pytest +import numpy as np + +from sklearn.neighbors import NearestNeighbors +from scipy import sparse + +from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_array_equal + +from imblearn.over_sampling import SLSMOTE + + +def data_np(): + rng = np.random.RandomState(42) + X = rng.randn(20, 2) + y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + return X, y + + +def data_sparse(format): + X = sparse.random(20, 2, density=0.3, format=format, random_state=42) + y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + return X, y + + +@pytest.mark.parametrize( + "data", + [data_np(), data_sparse('csr'), data_sparse('csc')] +) +def test_slsmote(data): + y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]) + X, y = data + slsmote = SLSMOTE(random_state=42) + X_res, y_res = slsmote.fit_resample(X, y) + + assert X_res.shape == (24, 2) + assert_array_equal(y_res, y_gt) + + +def test_slsmote_nn(): + X, y = data_np() + slsmote = SLSMOTE(random_state=42) + slsmote_nn = SLSMOTE( + random_state=42, + k_neighbors=NearestNeighbors(n_neighbors=6), + m_neighbors=NearestNeighbors(n_neighbors=11), + ) + + X_res_1, y_res_1 = slsmote.fit_resample(X, y) + X_res_2, y_res_2 = slsmote_nn.fit_resample(X, y) + + assert_allclose(X_res_1, X_res_2) + assert_array_equal(y_res_1, y_res_2) + + +def test_slsmote_pd(): + pd = pytest.importorskip("pandas") + X, y = data_np() + X_pd = pd.DataFrame(X) + slsmote = SLSMOTE(random_state=42) + X_res, y_res = slsmote.fit_resample(X, y) + X_res_pd, y_res_pd = slsmote.fit_resample(X_pd, y) + + assert X_res_pd.tolist() == X_res.tolist() + assert_allclose(y_res_pd, y_res) From 609c4fc4fec5dc58ddff1f18cbaadefce0fe88a2 Mon Sep 17 00:00:00 2001 From: laurallu Date: Mon, 11 Nov 2019 13:53:09 -0500 Subject: [PATCH 3/4] fixed variable name, added doc and test --- doc/over_sampling.rst | 18 +++- imblearn/over_sampling/__init__.py | 4 +- imblearn/over_sampling/_smote.py | 85 ++++++++++++------- ...st_sl_smote.py => test_safelevel_smote.py} | 40 ++++++--- 4 files changed, 98 insertions(+), 49 deletions(-) rename imblearn/over_sampling/tests/{test_sl_smote.py => test_safelevel_smote.py} (55%) diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 6159e925b..bc6bf66d0 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -152,8 +152,9 @@ nearest neighbors class. Those variants are presented in the figure below. :align: center -The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, and -:class:`KMeansSMOTE` [LDB2017]_ offer some variant of the SMOTE algorithm:: +The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, +:class:`KMeansSMOTE` [LDB2017]_ and :class:`SafeLevelSMOTE` [BSL2009]_ +offer some variant of the SMOTE algorithm:: >>> from imblearn.over_sampling import BorderlineSMOTE >>> X_resampled, y_resampled = BorderlineSMOTE().fit_resample(X, y) @@ -213,6 +214,14 @@ other extra interpolation. Imbalanced Learning Based on K-Means and SMOTE" https://arxiv.org/abs/1711.00837 + [BSL2009] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap, + "Safe-level-SMOTE: Safe-level-synthetic minority over-sampling + technique for handling the class imbalanced problem," In: + Theeramunkong T., Kijsirikul B., Cercone N., Ho TB. (eds) + Advances in Knowledge Discovery and Data Mining. PAKDD 2009. + Lecture Notes in Computer Science, vol 5476. Springer, Berlin, + Heidelberg, 475-482, 2009. + Mathematical formulation ======================== @@ -274,6 +283,11 @@ parameter ``m_neighbors`` to decide if a sample is in danger, safe, or noise. method before to apply SMOTE. The clustering will group samples together and generate new samples depending of the cluster density. +**SafeLevel** SMOTE --- cf. to :class:`SafeLevelSMOTE` --- uses the safe level +(the number of positive instances in nearest neighbors) to generate a synthetic +instance. Compared to regular SMOTE, the new instance is positioned closer to +the positive instance with larger safe level. + ADASYN works similarly to the regular SMOTE. However, the number of samples generated for each :math:`x_i` is proportional to the number of samples which are not from the same class than :math:`x_i` in a given diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index e6625f098..8027b18a2 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,7 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC -from ._smote import SLSMOTE +from ._smote import SafeLevelSMOTE __all__ = [ "ADASYN", @@ -20,5 +20,5 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", - "SLSMOTE", + "SafeLevelSMOTE", ] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 6acc9b498..d5bc2a123 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -284,6 +284,11 @@ class BorderlineSMOTE(BaseSMOTE): SVMSMOTE : Over-sample using SVM-SMOTE variant. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + + ADASYN : Over-sample using ADASYN. References @@ -484,6 +489,10 @@ class SVMSMOTE(BaseSMOTE): BorderlineSMOTE : Over-sample using Borderline-SMOTE. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + ADASYN : Over-sample using ADASYN. References @@ -695,6 +704,10 @@ class SMOTE(BaseSMOTE): SVMSMOTE : Over-sample using the SVM-SMOTE variant. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + ADASYN : Over-sample using ADASYN. References @@ -864,6 +877,10 @@ class SMOTENC(SMOTE): BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + ADASYN : Over-sample using ADASYN. References @@ -1318,7 +1335,7 @@ def _fit_resample(self, X, y): sampling_strategy=BaseOverSampler._sampling_strategy_docstring, random_state=_random_state_docstring, ) -class SLSMOTE(BaseSMOTE): +class SafeLevelSMOTE(BaseSMOTE): """Class to perform over-sampling using safe-level SMOTE. This is an implementation of the Safe-level-SMOTE described in [2]_. @@ -1389,13 +1406,13 @@ class SLSMOTE(BaseSMOTE): >>> from collections import Counter >>> from sklearn.datasets import make_classification >>> from imblearn.over_sampling import \ -SLSMOTE # doctest: +NORMALIZE_WHITESPACE +SafeLevelSMOTE # doctest: +NORMALIZE_WHITESPACE >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) >>> print('Original dataset shape %s' % Counter(y)) Original dataset shape Counter({{1: 900, 0: 100}}) - >>> sm = SLSMOTE(random_state=42) + >>> sm = SafeLevelSMOTE(random_state=42) >>> X_res, y_res = sm.fit_resample(X, y) >>> print('Resampled dataset shape %s' % Counter(y_res)) Resampled dataset shape Counter({{0: 900, 1: 900}}) @@ -1415,7 +1432,7 @@ def __init__(self, self.m_neighbors = m_neighbors - def _assign_sl(self, nn_estimator, samples, target_class, y): + def _assign_safe_levels(self, nn_estimator, samples, target_class, y): ''' Assign the safe levels to the instances in the target class. @@ -1444,8 +1461,8 @@ def _assign_sl(self, nn_estimator, samples, target_class, y): x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:] nn_label = (y[x] == target_class).astype(int) - sl = np.sum(nn_label, axis=1) - return sl + safe_levels = np.sum(nn_label, axis=1) + return safe_levels def _validate_estimator(self): super()._validate_estimator() @@ -1466,28 +1483,30 @@ def _fit_resample(self, X, y): X_class = _safe_indexing(X, target_class_indices) self.nn_m_.fit(X) - sl = self._assign_sl(self.nn_m_, X_class, class_sample, y) + safe_levels = self._assign_safe_levels( + self.nn_m_, X_class, class_sample, y) # filter the points in X_class that have safe level >0 # If safe level = 0, the point is not used to # generate synthetic instances - X_safe_indices = np.flatnonzero(sl != 0) + X_safe_indices = np.flatnonzero(safe_levels != 0) X_safe_class = _safe_indexing(X_class, X_safe_indices) self.nn_k_.fit(X_class) nns = self.nn_k_.kneighbors(X_safe_class, return_distance=False)[:, 1:] - sl_safe_class = sl[X_safe_indices] - sl_nns = sl[nns] + sl_safe_class = safe_levels[X_safe_indices] + sl_nns = safe_levels[nns] sl_safe_t = np.array([sl_safe_class]).transpose() with np.errstate(divide='ignore'): - sl_ratio = np.divide(sl_safe_t, sl_nns) + safe_level_ratio = np.divide(sl_safe_t, sl_nns) - X_new, y_new = self._make_samples_sl(X_safe_class, y.dtype, - class_sample, X_class, - nns, n_samples, sl_ratio, - 1.0) + X_new, y_new = self._make_samples_safelevel(X_safe_class, y.dtype, + class_sample, X_class, + nns, n_samples, + safe_level_ratio, + 1.0) if sparse.issparse(X_new): X_resampled = sparse.vstack([X_resampled, X_new]) @@ -1497,8 +1516,8 @@ def _fit_resample(self, X, y): return X_resampled, y_resampled - def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num, - n_samples, sl_ratio, step_size=1.): + def _make_samples_safelevel(self, X, y_dtype, y_type, nn_data, nn_num, + n_samples, safe_level_ratio, step_size=1.): """A support function that returns artificial samples using safe-level SMOTE. It is similar to _make_samples method for SMOTE. @@ -1524,7 +1543,7 @@ def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num, n_samples : int The number of samples to generate. - sl_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours) + safe_level_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours) step_size : float, optional (default=1.) The step size to create samples. @@ -1546,8 +1565,8 @@ def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num, size=n_samples) rows = np.floor_divide(samples_indices, nn_num.shape[1]) cols = np.mod(samples_indices, nn_num.shape[1]) - gap_arr = step_size * self._vgenerate_gap(sl_ratio) - gaps = gap_arr.flatten()[samples_indices] + gap_array = step_size * self._vgenerate_gap(safe_level_ratio) + gaps = gap_array.flatten()[samples_indices] y_new = np.array([y_type] * n_samples, dtype=y_dtype) @@ -1578,12 +1597,12 @@ def _make_samples_sl(self, X, y_dtype, y_type, nn_data, nn_num, return X_new, y_new def _generate_gap(self, a_ratio, rand_state=None): - """ generate gap according to sl_ratio, non-vectorized version. + """ generate gap according to safe_level_ratio, non-vectorized version. Parameters ---------- a_ratio: float - sl_ratio of a single data point + safe_level_ratio of a single data point rand_state: random state object or int @@ -1603,28 +1622,30 @@ def _generate_gap(self, a_ratio, rand_state=None): elif 0 < a_ratio < 1: gap = random_state.uniform(1-a_ratio, 1) else: - raise ValueError('sl_ratio should be nonegative') + raise ValueError('safe_level_ratio should be nonegative') return gap - def _vgenerate_gap(self, sl_ratio): + def _vgenerate_gap(self, safe_level_ratio): """ - generate gap according to sl_ratio, vectorized version of _generate_gap + generate gap according to safe_level_ratio, vectorized version + of _generate_gap Parameters ----------- - sl_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours) - sl_ratio of all instances with safe_level>0 in the specified - class + safe_level_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours) + safe_level_ratio of all instances with safe_level>0 in the + specified class Returns ------------ - gap_arr: ndarray shape (n_samples_safe, k_nearest_neighbours) + gap_array: ndarray shape (n_samples_safe, k_nearest_neighbours) the gap for all instances with safe_level>0 in the specified class """ prng = check_random_state(self.random_state) - rand_state = prng.randint(sl_ratio.size+1, size=sl_ratio.shape) + rand_state = prng.randint( + safe_level_ratio.size+1, size=safe_level_ratio.shape) vgap = np.vectorize(self._generate_gap) - gap_arr = vgap(sl_ratio, rand_state) - return gap_arr + gap_array = vgap(safe_level_ratio, rand_state) + return gap_array diff --git a/imblearn/over_sampling/tests/test_sl_smote.py b/imblearn/over_sampling/tests/test_safelevel_smote.py similarity index 55% rename from imblearn/over_sampling/tests/test_sl_smote.py rename to imblearn/over_sampling/tests/test_safelevel_smote.py index a4e357ce4..ad4a3c8aa 100644 --- a/imblearn/over_sampling/tests/test_sl_smote.py +++ b/imblearn/over_sampling/tests/test_safelevel_smote.py @@ -1,5 +1,6 @@ import pytest import numpy as np +from collections import Counter from sklearn.neighbors import NearestNeighbors from scipy import sparse @@ -7,7 +8,7 @@ from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal -from imblearn.over_sampling import SLSMOTE +from imblearn.over_sampling import SafeLevelSMOTE def data_np(): @@ -27,40 +28,53 @@ def data_sparse(format): "data", [data_np(), data_sparse('csr'), data_sparse('csc')] ) -def test_slsmote(data): +def test_safelevel_smote(data): y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]) X, y = data - slsmote = SLSMOTE(random_state=42) - X_res, y_res = slsmote.fit_resample(X, y) + safelevel_smote = SafeLevelSMOTE(random_state=42) + X_res, y_res = safelevel_smote.fit_resample(X, y) assert X_res.shape == (24, 2) assert_array_equal(y_res, y_gt) -def test_slsmote_nn(): +def test_sl_smote_nn(): X, y = data_np() - slsmote = SLSMOTE(random_state=42) - slsmote_nn = SLSMOTE( + safelevel_smote = SafeLevelSMOTE(random_state=42) + safelevel_smote_nn = SafeLevelSMOTE( random_state=42, k_neighbors=NearestNeighbors(n_neighbors=6), m_neighbors=NearestNeighbors(n_neighbors=11), ) - X_res_1, y_res_1 = slsmote.fit_resample(X, y) - X_res_2, y_res_2 = slsmote_nn.fit_resample(X, y) + X_res_1, y_res_1 = safelevel_smote.fit_resample(X, y) + X_res_2, y_res_2 = safelevel_smote_nn.fit_resample(X, y) assert_allclose(X_res_1, X_res_2) assert_array_equal(y_res_1, y_res_2) -def test_slsmote_pd(): +def test_sl_smote_pd(): pd = pytest.importorskip("pandas") X, y = data_np() X_pd = pd.DataFrame(X) - slsmote = SLSMOTE(random_state=42) - X_res, y_res = slsmote.fit_resample(X, y) - X_res_pd, y_res_pd = slsmote.fit_resample(X_pd, y) + safelevel_smote = SafeLevelSMOTE(random_state=42) + X_res, y_res = safelevel_smote.fit_resample(X, y) + X_res_pd, y_res_pd = safelevel_smote.fit_resample(X_pd, y) assert X_res_pd.tolist() == X_res.tolist() assert_allclose(y_res_pd, y_res) + + +def test_sl_smote_multiclass(): + rng = np.random.RandomState(42) + X = rng.randn(50, 2) + y = np.array([0] * 10 + [1] * 15 + [2] * 25) + safelevel_smote = SafeLevelSMOTE(random_state=42) + X_res, y_res = safelevel_smote.fit_resample(X, y) + + count_y_res = Counter(y_res) + assert count_y_res[0] == 25 + assert count_y_res[1] == 25 + assert count_y_res[2] == 25 From 866a04f35acab4e226fdb5b9b930922c1ce0f742 Mon Sep 17 00:00:00 2001 From: laurallu Date: Mon, 11 Nov 2019 21:30:00 -0500 Subject: [PATCH 4/4] =removed redundant lines --- imblearn/over_sampling/_smote.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index ad187de9d..09d695885 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -1619,10 +1619,8 @@ def _generate_gap(self, a_ratio, rand_state=None): gap = 0 elif a_ratio >= 1: gap = random_state.uniform(0, 1/a_ratio) - elif 0 < a_ratio < 1: - gap = random_state.uniform(1-a_ratio, 1) else: - raise ValueError('safe_level_ratio should be nonegative') + gap = random_state.uniform(1-a_ratio, 1) return gap def _vgenerate_gap(self, safe_level_ratio):