Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Ledoit Wolf shrinkage covariance estimator #304

Merged
merged 14 commits into from
Nov 7, 2024

Conversation

norm4nn
Copy link
Contributor

@norm4nn norm4nn commented Oct 26, 2024

Added Ledoit Wolf shrinkage covariance matrix estimator.

Scikit-learn docs

Copy link
Contributor

@josevalim josevalim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We only need some tests :)

@norm4nn
Copy link
Contributor Author

norm4nn commented Oct 30, 2024

I added some tests ;)

lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
@@ -0,0 +1,352 @@
defmodule Scholar.Covariance.LedoitWolf do
@moduledoc """
Ledoit-Wolf is a particular form of shrinkage covariance estimator, where the shrinkage coefficient is computed using O.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"computed using 0" feels like an incomplete sentence.

And I'm a bit confused. As someone who's not familiar with the algorithm, this phrase seems to contradict the :shrinkage option provided below.

lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
Comment on lines 248 to 251
Nx.slice_along_axis(x_t, rows_from, block_size, axis: 0)
|> Nx.dot(Nx.slice_along_axis(x, cols_from, block_size, axis: 1))
|> Nx.pow(2)
|> Nx.sum()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Nx.slice_along_axis(x_t, rows_from, block_size, axis: 0)
|> Nx.dot(Nx.slice_along_axis(x, cols_from, block_size, axis: 1))
|> Nx.pow(2)
|> Nx.sum()
Nx.slice_along_axis(x, rows_from, block_size, axis: 1)
|> Nx.dot([0], Nx.slice_along_axis(x, cols_from, block_size, axis: 1), [0])
|> Nx.pow(2)
|> Nx.sum()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this Nx.pow(2) here seems odd, mathematically speaking. is it correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ll double-check that, but at first glance - it looks correct.

Copy link
Contributor Author

@norm4nn norm4nn Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente So, if we take the scikit-learn implementation as a reference, then it is correct. But I have to agree with you - this looks weird. I don't see a reason to implement it that way - I will simplify it in the next commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will simplify this whole section significantly in the next commit. After longer investigation of the scikit-learn code, I found it highly overcomplicated for Scholar. Maybe in Python it has some reason to calculate beta and delta in the loop, but I don't believe this applies here.

lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
lib/scholar/covariance/ledoit_wolf.ex Outdated Show resolved Hide resolved
@msluszniak
Copy link
Contributor

blocks like

      Nx.multiply(x, mask)
      |> Nx.dot([0], Nx.multiply(x, mask), [0])
      |> Nx.pow(2)
      |> Nx.sum()

occur multiple times. You might create a function simplify this

@norm4nn
Copy link
Contributor Author

norm4nn commented Nov 1, 2024

Sorry for changing a lot of code in the middle of review, but I believe that now it's much more readable

Copy link
Contributor

@msluszniak msluszniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm :))

@josevalim josevalim merged commit 473060f into elixir-nx:main Nov 7, 2024
2 checks passed
@josevalim
Copy link
Contributor

💚 💙 💜 💛 ❤️

Copy link
Member

@krstopro krstopro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two comments from my side.

UPDATE: Nevermind, got merged while I was reviewing. Looks good anyway.

end
end

defnp empirical_covariance(x) do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be replaced with Nx.covariance/3.

defstruct [:covariance, :shrinkage, :location]

opts_schema = [
assume_centered: [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename it to assume_centered? (add ?).

@msluszniak
Copy link
Contributor

@krstopro if Nx.covariance works, then feel free to merge these changes in seperate pr/commit

@krstopro
Copy link
Member

krstopro commented Nov 7, 2024

@krstopro if Nx.covariance works, then feel free to merge these changes in seperate pr/commit

Sure, but lemme fix #301 first. :)

@josevalim
Copy link
Contributor

@norm4nn can also submit a pull request with the Nx.covariance fix if they have the time :)

@norm4nn norm4nn mentioned this pull request Nov 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants