Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex number wrapper #56

Open
nikopj opened this issue Jul 3, 2024 · 6 comments
Open

Complex number wrapper #56

nikopj opened this issue Jul 3, 2024 · 6 comments

Comments

@nikopj
Copy link

nikopj commented Jul 3, 2024

NCCL does not support complex numbers directly and does not plan to (see issue). Are we willing to add a wrapper to NCCL.jl to make using complex numbers more convienient? Alternatively, the wrapper could be put in a higher level package (ex. Lux.jl, see issue). I am happy to start working on this but would like some feedback if possible. My primary motivation is using neural networks with complex valued weights and this feature would greatly simplify things.

@maleadt
Copy link
Member

maleadt commented Jul 3, 2024

I'm not a user of NCCL.jl myself, so cc @avik-pal @simonbyrne.

@avik-pal
Copy link
Contributor

avik-pal commented Jul 3, 2024

I am honestly okay with either.

I suggested opening this issue because having it here makes it easier for downstream libraries (other than Lux when/if they want to use NCCL).

But if we want to keep this NCCL.jl wrapper simple and provide functionality nccl natively provides, we can implement this in Lux.

Let's wait for @simonbyrne's opinion. Since he did most of the work getting this package back to life.

@simonbyrne
Copy link
Contributor

That seems fine. Note that you don't want to use reim/complex though, instead just take advantage of the fact that complex arrays are packed the same as real arrays, but with twice as many elements.

I think the easiest solution:

  • Define ncclDataType_t(::Type{Complex{T}}) where {T} = ncclDataType_t(T)
  • Instead of using length(X) to deterine the count argument, define a custom function that can take the datatype into account, e.g.
    count(X::CuArray{T}) where {T} = length(X)
    count(X::CuArray{Complex{T}}) where {T} = 2*length(X)

@simonbyrne
Copy link
Contributor

Maybe don't tie it to CuArray either, as NCCL should also work with unified memory.

@nikopj
Copy link
Author

nikopj commented Jul 19, 2024

Maybe don't tie it to CuArray either, as NCCL should also work with unified memory.

@simonbyrne Does your above implementation not take care of unified memory automatically? My understanding of unified memory is that it's just a subtype of cuarray, i.e. something like CuArray{T, N, CUDA.UnifiedMemory}.

@maleadt
Copy link
Member

maleadt commented Jul 22, 2024

My understanding of unified memory is that it's just a subtype of cuarray, i.e. something like CuArray{T, N, CUDA.UnifiedMemory}

Unified memory can also be exposed as an Array (e.g. by doing unsafe_wrap), however the use case for that is mostly to be able to call into CPU functionality. For GPU-related uses, I would generally expect unified memory to be represented as a CuArray.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants