-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert Nx.LinAlg.lu to optional callback #1388
Comments
We can start with the translation, I think that's the simplest, and then explore other routes if necessary! |
Agreed! |
Sorry if this is necroing a thread, but I just ran into this while trying to compute determinants:
Do you have any advice off-hand? If not no worries, it doesn't block me because I can try the Torchx backend or in the worst case try porting over the JAX implementation. |
I currently don't have the bandwidth to do a full-fledged implementation of LU on defn. defn vector_dot_slice(u, u_start, v, v_start) do
{n} = Nx.shape(u)
u = Nx.select(Nx.iota({n}) >= u_start, u, 0)
{n} = Nx.shape(v)
v = Nx.select(Nx.iota({n}) >= v_start, v, 0)
Nx.dot(u, v)
end I tried porting LU once, and it ended up requiring something like this due to it needing to do something like All of that being said, I believe we can add a custom call to Eigen like we have for Nx.LinAlg.qr and Nx.LinAlg.eigh that will at least let LU be available on CPU. Would you be open to send a PR on either? |
Thanks so much for the detailed reply, I really appreciate it! For now, I am trying to see if computing the determinant using QR is sufficient for my application, even though it will be slower. If that is insufficient I will certainly write back and try to do something based on your excellent notes. Thank you again! |
@jyc I ended up using this as an excuse to try the Cursor AI editor out 😅 The branch adds at least the CPU implementation as a paliative solution. |
XLA doesn't provide an implementation, and we use LU decomposition for
determinant
, which is used in some other LinAlg functions.Therefore, we should either look into JAX's implementation or just translate the BinaryBackend one to a
defn
, like we did with SVD.This will also enable LU on MLIR by default.
The text was updated successfully, but these errors were encountered: