-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] FIX: make proposal for sdml formulation #162
Changes from all commits
02cc937
aebd47f
40b2c88
fb04cc9
518d6e8
c912e93
8f0b113
c57a35a
f0eb938
821db0b
bd2862d
c6a2daa
93d790e
cae6c28
5d673ba
e8a28d5
e740702
333675b
1a6e97b
7cecf27
5303e1a
377760a
0a46ad5
391d773
6654769
ac4e18a
458d646
fd7c9fb
e118cd8
b0c4753
13146d8
db4a799
1011391
5ea7ba0
45d3b7b
dbf5257
4b0bae9
f3c690e
57b0567
04316b2
b641641
fedfb8e
f0bbf6d
520d7c2
0437c62
be1a5e6
56efa09
142eea9
fcfd44c
019e28b
04a5107
be3a2ad
1ee8d1f
03f4158
e621e27
0086c98
001600e
8c50a0d
e4132d6
b3bf6a8
49f3b9e
e1664c7
187e22c
60866cb
eb95719
4d61dba
71a02e0
1e6d440
a7ed1bb
31072d3
000f29a
169dccf
bfb0f8f
6f5666b
0973ef2
1c28ecd
df2ae9c
9683934
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,19 @@ | |
import warnings | ||
import numpy as np | ||
from sklearn.base import TransformerMixin | ||
from sklearn.covariance import graph_lasso | ||
from sklearn.utils.extmath import pinvh | ||
from scipy.linalg import pinvh | ||
from sklearn.covariance import graphical_lasso | ||
from sklearn.exceptions import ConvergenceWarning | ||
|
||
from .base_metric import MahalanobisMixin, _PairsClassifierMixin | ||
from .constraints import Constraints, wrap_pairs | ||
from ._util import transformer_from_metric | ||
try: | ||
from inverse_covariance import quic | ||
except ImportError: | ||
HAS_SKGGM = False | ||
else: | ||
HAS_SKGGM = True | ||
|
||
|
||
class _BaseSDML(MahalanobisMixin): | ||
|
@@ -52,24 +59,74 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True, | |
super(_BaseSDML, self).__init__(preprocessor) | ||
|
||
def _fit(self, pairs, y): | ||
if not HAS_SKGGM: | ||
if self.verbose: | ||
print("SDML will use scikit-learn's graphical lasso solver.") | ||
else: | ||
if self.verbose: | ||
print("SDML will use skggm's graphical lasso solver.") | ||
pairs, y = self._prepare_inputs(pairs, y, | ||
type_of_inputs='tuples') | ||
|
||
# set up prior M | ||
# set up (the inverse of) the prior M | ||
if self.use_cov: | ||
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])}) | ||
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False))) | ||
prior_inv = np.atleast_2d(np.cov(X, rowvar=False)) | ||
else: | ||
M = np.identity(pairs.shape[2]) | ||
prior_inv = np.identity(pairs.shape[2]) | ||
diff = pairs[:, 0] - pairs[:, 1] | ||
loss_matrix = (diff.T * y).dot(diff) | ||
P = M + self.balance_param * loss_matrix | ||
emp_cov = pinvh(P) | ||
# hack: ensure positive semidefinite | ||
emp_cov = emp_cov.T.dot(emp_cov) | ||
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose) | ||
|
||
self.transformer_ = transformer_from_metric(M) | ||
emp_cov = prior_inv + self.balance_param * loss_matrix | ||
|
||
# our initialization will be the matrix with emp_cov's eigenvalues, | ||
# with a constant added so that they are all positive (plus an epsilon | ||
# to ensure definiteness). This is empirical. | ||
w, V = np.linalg.eigh(emp_cov) | ||
min_eigval = np.min(w) | ||
if min_eigval < 0.: | ||
warnings.warn("Warning, the input matrix of graphical lasso is not " | ||
"positive semi-definite (PSD). The algorithm may diverge, " | ||
"and lead to degenerate solutions. " | ||
"To prevent that, try to decrease the balance parameter " | ||
"`balance_param` and/or to set use_covariance=False.", | ||
ConvergenceWarning) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with it. |
||
w -= min_eigval # we translate the eigenvalues to make them all positive | ||
w += 1e-10 # we add a small offset to avoid definiteness problems | ||
sigma0 = (V * w).dot(V.T) | ||
try: | ||
if HAS_SKGGM: | ||
theta0 = pinvh(sigma0) | ||
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param, | ||
msg=self.verbose, | ||
Theta0=theta0, Sigma0=sigma0) | ||
else: | ||
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, | ||
verbose=self.verbose, | ||
cov_init=sigma0) | ||
raised_error = None | ||
w_mahalanobis, _ = np.linalg.eigh(M) | ||
not_spd = any(w_mahalanobis < 0.) | ||
not_finite = not np.isfinite(M).all() | ||
except Exception as e: | ||
raised_error = e | ||
not_spd = False # not_spd not applicable here so we set to False | ||
not_finite = False # not_finite not applicable here so we set to False | ||
if raised_error is not None or not_spd or not_finite: | ||
msg = ("There was a problem in SDML when using {}'s graphical " | ||
"lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn") | ||
if not HAS_SKGGM: | ||
skggm_advice = (" skggm's graphical lasso can sometimes converge " | ||
"on non SPD cases where scikit-learn's graphical " | ||
"lasso fails to converge. Try to install skggm and " | ||
"rerun the algorithm (see the README.md for the " | ||
"right version of skggm).") | ||
msg += skggm_advice | ||
if raised_error is not None: | ||
msg += " The following error message was thrown: {}.".format( | ||
raised_error) | ||
raise RuntimeError(msg) | ||
|
||
self.transformer_ = transformer_from_metric(np.atleast_2d(M)) | ||
return self | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an init that we talked about with @bellet, that I found worked better (allowed tests to pass when with identity I had a lot of Linalg Error)