Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FIX: make proposal for sdml formulation #162

Merged
merged 78 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
02cc937
FIX: make proposal for sdml formulation
Jan 24, 2019
aebd47f
MAINT clearer formulation to make the prior appear
Jan 29, 2019
40b2c88
MAINT call the prior prior
Jan 29, 2019
fb04cc9
Use skggm instead of graphical lasso
Feb 1, 2019
518d6e8
Be more severe for the class separation
Feb 1, 2019
c912e93
Merge branch 'master' into fix/proposal_for_sdml
Feb 1, 2019
8f0b113
Put back verbose param
Feb 1, 2019
c57a35a
MAINT: make more explicit the fact that to use identity (i.e. an SPD …
Feb 1, 2019
f0eb938
Add skggm as a requirement for SDML
Feb 1, 2019
821db0b
Add skggm to required packages for travis
Feb 1, 2019
bd2862d
Also add cython as a dependency
Feb 1, 2019
c6a2daa
FIX: install all except skggm and then skggm
Feb 1, 2019
93d790e
Remove cython dependency
Feb 1, 2019
cae6c28
Install skggm only if we have at least python 3.6
Feb 15, 2019
5d673ba
Should work if we want other versions superior to 3.6
Feb 15, 2019
e8a28d5
Fix bash >= which should be written -ge
Feb 15, 2019
e740702
Deal with tests when skggm is not installed and fix some PEP8 warnings
Feb 15, 2019
333675b
replace manual calls of algorithms with tuples_learners
Feb 15, 2019
1a6e97b
Remove another call of SDML if skggm is not installed
Feb 15, 2019
7cecf27
FIX fix the test_error_message_tuple_size
Feb 15, 2019
5303e1a
FIX fix test_sdml_supervised
Feb 15, 2019
377760a
FIX: fix another sdml test
Feb 15, 2019
0a46ad5
FIX quic call for python 2.7
Feb 15, 2019
391d773
Fix quic import
Feb 15, 2019
6654769
Add Sigma0 initalization (both sigma zero and theta zero should be sp…
Feb 15, 2019
ac4e18a
Deal with SDML making some tests fail
Feb 15, 2019
458d646
Remove epsilon that was unnecessary
Feb 15, 2019
fd7c9fb
FIX: use latest commit of skggm that fixes the non deterministic problem
Feb 19, 2019
e118cd8
MAINT: add message for SDML when not SPD
Mar 5, 2019
b0c4753
MAINT: add test for error message if skggm not installed
Mar 5, 2019
13146d8
Try other syntax for installing the right commit of skggm
Mar 5, 2019
db4a799
MAINT: make sklearn compat sdml test be run only if skggm is installed
Mar 5, 2019
1011391
Try another syntax for running travis
Mar 5, 2019
5ea7ba0
Better bash syntax
Mar 5, 2019
45d3b7b
Fix tests by removing duplicates
Mar 6, 2019
dbf5257
FIX: fix for sdml by reducing balance parameter
Mar 6, 2019
4b0bae9
FIX: update code to work with old version of numpy that does not have…
Mar 6, 2019
f3c690e
Remove the need for skggm
Mar 7, 2019
57b0567
Update travis not to use skggm
Mar 7, 2019
04316b2
Add a stable init for sklearn checks
Mar 7, 2019
b641641
FIX test_sdml_supervised
Mar 7, 2019
fedfb8e
Revert "Update travis not to use skggm"
Mar 8, 2019
f0bbf6d
Add fallback on skggm
Mar 8, 2019
520d7c2
FIX: fix versions comparison and tests
Mar 8, 2019
0437c62
MAINT: improve test of no warning
Mar 8, 2019
be1a5e6
FIX: fix wrap pairs that was returning column y (we need line y), and…
Mar 8, 2019
56efa09
FIX: force travis to do the right check
Mar 8, 2019
142eea9
TST: add non SPD test that works with skggm's quic but not sklearn's …
Mar 8, 2019
fcfd44c
Try again travis this time installing cython
Mar 8, 2019
019e28b
Try to make travis work with build_essential
Mar 8, 2019
04a5107
Try with installing liblapack
Mar 8, 2019
be3a2ad
TST: fix tests for when skggm is not installed
Mar 8, 2019
1ee8d1f
TST: use better pytest skipif syntax
Mar 8, 2019
03f4158
FIX: fix broken link in README.md
Mar 8, 2019
e621e27
use rst syntax for link
Mar 8, 2019
0086c98
use rst syntax for link
Mar 8, 2019
001600e
use rst syntax for link
Mar 8, 2019
8c50a0d
MAINT: remove test_sdml that was remaining from drafts tests
Mar 8, 2019
e4132d6
TST: remove skipping SDML in test_cross_validation_manual_vs_scikit
Mar 8, 2019
b3bf6a8
FIX link also in getting started
Mar 8, 2019
49f3b9e
Put back right indent
Mar 8, 2019
e1664c7
Remove unnecessary changes
Mar 8, 2019
187e22c
merging
Mar 18, 2019
60866cb
Nitpick for concatenation and refactor HAS_SKGGM
Mar 18, 2019
eb95719
ENH: Deal better with errors and skggm/scikit-learn
Mar 18, 2019
4d61dba
Better creation of prior
Mar 18, 2019
71a02e0
Simplification for init of sdml
Mar 18, 2019
1e6d440
Put skggm as optional
Mar 18, 2019
a7ed1bb
Specify skggm version
Mar 18, 2019
31072d3
TST: make test about 1 feature arrays more readable
Mar 18, 2019
000f29a
DOC: fix rst formatting
Mar 18, 2019
169dccf
DOC: reformulated skggm optional dependency
Mar 18, 2019
bfb0f8f
TST: give an example for sdml_supervised with skggm where it indeed f…
Mar 20, 2019
6f5666b
TST: fix test that fails weirdly when executing the whole test file a…
Mar 20, 2019
0973ef2
Revert "TST: fix test that fails weirdly when executing the whole tes…
Mar 20, 2019
1c28ecd
Merge branch 'master' into fix/proposal_for_sdml
wdevazelhes Mar 21, 2019
df2ae9c
Add coverage for all versions of python
Mar 21, 2019
9683934
Install pytest-cov for all versions
Mar 21, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ cache: pip
python:
- "2.7"
- "3.4"
- "3.6"
before_install:
- sudo apt-get install liblapack-dev
- pip install --upgrade pip pytest
- pip install wheel
- pip install codecov
- if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]];
then pip install pytest-cov;
- pip install wheel cython numpy scipy scikit-learn codecov pytest-cov
- if [[ ($TRAVIS_PYTHON_VERSION == "3.6") ||
($TRAVIS_PYTHON_VERSION == "2.7")]]; then
pip install git+https://github.com/skggm/skggm.git@a0ed406586c4364ea3297a658f415e13b5cbdaf8;
fi
- pip install numpy scipy scikit-learn
script:
- if [[ $TRAVIS_PYTHON_VERSION == "3.4" ]];
then pytest test --cov;
else pytest test;
fi
# we do coverage for all versions so that codecov will merge them: this
# way we will see that both paths (with or without skggm) are tested
- pytest test --cov;
after_success:
- bash <(curl -s https://codecov.io/bash)

7 changes: 6 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ Metric Learning algorithms in Python.

- Python 2.7+, 3.4+
- numpy, scipy, scikit-learn
- (for running the examples only: matplotlib)

**Optional dependencies**

- For SDML, using skggm will allow the algorithm to solve problematic cases
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
- For running the examples only: matplotlib

**Installation/Setup**

Expand Down
7 changes: 6 additions & 1 deletion doc/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ Alternately, download the source repository and run:

- Python 2.7+, 3.4+
- numpy, scipy, scikit-learn
- (for running the examples only: matplotlib)

**Optional dependencies**

- For SDML, using skggm will allow the algorithm to solve problematic cases
(install from commit `a0ed406 <https://github.com/skggm/skggm/commit/a0ed406586c4364ea3297a658f415e13b5cbdaf8>`_).
- For running the examples only: matplotlib

**Notes**

Expand Down
2 changes: 1 addition & 1 deletion metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ def wrap_pairs(X, constraints):
c = np.array(constraints[2])
d = np.array(constraints[3])
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
y = np.concatenate([np.ones_like(a), -np.ones_like(c)])
pairs = X[constraints]
return pairs, y
81 changes: 69 additions & 12 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
import warnings
import numpy as np
from sklearn.base import TransformerMixin
from sklearn.covariance import graph_lasso
from sklearn.utils.extmath import pinvh
from scipy.linalg import pinvh
from sklearn.covariance import graphical_lasso
from sklearn.exceptions import ConvergenceWarning

from .base_metric import MahalanobisMixin, _PairsClassifierMixin
from .constraints import Constraints, wrap_pairs
from ._util import transformer_from_metric
try:
from inverse_covariance import quic
except ImportError:
HAS_SKGGM = False
else:
HAS_SKGGM = True


class _BaseSDML(MahalanobisMixin):
Expand Down Expand Up @@ -52,24 +59,74 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
super(_BaseSDML, self).__init__(preprocessor)

def _fit(self, pairs, y):
if not HAS_SKGGM:
if self.verbose:
print("SDML will use scikit-learn's graphical lasso solver.")
else:
if self.verbose:
print("SDML will use skggm's graphical lasso solver.")
pairs, y = self._prepare_inputs(pairs, y,
type_of_inputs='tuples')

# set up prior M
# set up (the inverse of) the prior M
if self.use_cov:
X = np.vstack({tuple(row) for row in pairs.reshape(-1, pairs.shape[2])})
M = pinvh(np.atleast_2d(np.cov(X, rowvar = False)))
prior_inv = np.atleast_2d(np.cov(X, rowvar=False))
else:
M = np.identity(pairs.shape[2])
prior_inv = np.identity(pairs.shape[2])
diff = pairs[:, 0] - pairs[:, 1]
loss_matrix = (diff.T * y).dot(diff)
P = M + self.balance_param * loss_matrix
emp_cov = pinvh(P)
# hack: ensure positive semidefinite
emp_cov = emp_cov.T.dot(emp_cov)
_, M = graph_lasso(emp_cov, self.sparsity_param, verbose=self.verbose)

self.transformer_ = transformer_from_metric(M)
emp_cov = prior_inv + self.balance_param * loss_matrix

# our initialization will be the matrix with emp_cov's eigenvalues,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an init that we talked about with @bellet, that I found worked better (allowed tests to pass when with identity I had a lot of Linalg Error)

# with a constant added so that they are all positive (plus an epsilon
# to ensure definiteness). This is empirical.
w, V = np.linalg.eigh(emp_cov)
min_eigval = np.min(w)
if min_eigval < 0.:
warnings.warn("Warning, the input matrix of graphical lasso is not "
"positive semi-definite (PSD). The algorithm may diverge, "
"and lead to degenerate solutions. "
"To prevent that, try to decrease the balance parameter "
"`balance_param` and/or to set use_covariance=False.",
ConvergenceWarning)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ConvergenceWarning OK ? It's the one that seemed the more appropriate but here we raise it before even running the graphical lasso so maybe it's a bit weird... It would be better a PossibleConvergenceWarning kind of warning maybe ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it.

w -= min_eigval # we translate the eigenvalues to make them all positive
w += 1e-10 # we add a small offset to avoid definiteness problems
sigma0 = (V * w).dot(V.T)
try:
if HAS_SKGGM:
theta0 = pinvh(sigma0)
M, _, _, _, _, _ = quic(emp_cov, lam=self.sparsity_param,
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
raised_error = None
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
not_finite = not np.isfinite(M).all()
except Exception as e:
raised_error = e
not_spd = False # not_spd not applicable here so we set to False
not_finite = False # not_finite not applicable here so we set to False
if raised_error is not None or not_spd or not_finite:
msg = ("There was a problem in SDML when using {}'s graphical "
"lasso solver.").format("skggm" if HAS_SKGGM else "scikit-learn")
if not HAS_SKGGM:
skggm_advice = (" skggm's graphical lasso can sometimes converge "
"on non SPD cases where scikit-learn's graphical "
"lasso fails to converge. Try to install skggm and "
"rerun the algorithm (see the README.md for the "
"right version of skggm).")
msg += skggm_advice
if raised_error is not None:
msg += " The following error message was thrown: {}.".format(
raised_error)
raise RuntimeError(msg)

self.transformer_ = transformer_from_metric(np.atleast_2d(M))
return self


Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
extras_require=dict(
docs=['sphinx', 'shinx_rtd_theme', 'numpydoc'],
demo=['matplotlib'],
sdml=['skggm>=0.2.9']
),
test_suite='test',
keywords=[
Expand Down
Loading