Skip to content

Commit

Permalink
Merge pull request #6 from azrael417/tkurth/precision-fix
Browse files Browse the repository at this point in the history
Fixing precision mismatch error in weight contractions
  • Loading branch information
azrael417 authored Aug 2, 2023
2 parents 855297a + 562dac1 commit a0d1fbc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion torch_harmonics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

__version__ = '0.6.1'
__version__ = '0.6.2'

from .sht import RealSHT, InverseRealSHT, RealVectorSHT, InverseRealVectorSHT
from . import quadrature
Expand Down
40 changes: 20 additions & 20 deletions torch_harmonics/sht.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def forward(self, x: torch.Tensor):
xout = torch.zeros(out_shape, dtype=x.dtype, device=x.device)

# contraction
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights )
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights )
xout[..., 0] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 0], self.weights.to(x.dtype) )
xout[..., 1] = torch.einsum('...km,mlk->...lm', x[..., :self.mmax, 1], self.weights.to(x.dtype) )
x = torch.view_as_complex(xout)

return x
Expand Down Expand Up @@ -185,8 +185,8 @@ def forward(self, x: torch.Tensor):
# Evaluate associated Legendre functions on the output nodes
x = torch.view_as_real(x)

rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct )
rl = torch.einsum('...lm, mlk->...km', x[..., 0], self.pct.to(x.dtype) )
im = torch.einsum('...lm, mlk->...km', x[..., 1], self.pct.to(x.dtype) )
xs = torch.stack((rl, im), -1)

# apply the inverse (real) FFT
Expand Down Expand Up @@ -282,20 +282,20 @@ def forward(self, x: torch.Tensor):

# contraction - spheroidal component
# real component
xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0]) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1])
xout[..., 0, :, :, 0] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[0].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[1].to(x.dtype))

# iamg component
xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0]) \
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1])
xout[..., 0, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[0].to(x.dtype)) \
+ torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[1].to(x.dtype))

# contraction - toroidal component
# real component
xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1]) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0])
xout[..., 1, :, :, 0] = - torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 1], self.weights[1].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 0], self.weights[0].to(x.dtype))
# imag component
xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1]) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0])
xout[..., 1, :, :, 1] = torch.einsum('...km,mlk->...lm', x[..., 0, :, :self.mmax, 0], self.weights[1].to(x.dtype)) \
- torch.einsum('...km,mlk->...lm', x[..., 1, :, :self.mmax, 1], self.weights[0].to(x.dtype))

return torch.view_as_complex(xout)

Expand Down Expand Up @@ -358,19 +358,19 @@ def forward(self, x: torch.Tensor):

# contraction - spheroidal component
# real component
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1])
srl = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[0].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[1].to(x.dtype))
# iamg component
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0]) \
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1])
sim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[0].to(x.dtype)) \
+ torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[1].to(x.dtype))

# contraction - toroidal component
# real component
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0])
trl = - torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 1], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 0], self.dpct[0].to(x.dtype))
# imag component
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1]) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0])
tim = torch.einsum('...lm,mlk->...km', x[..., 0, :, :, 0], self.dpct[1].to(x.dtype)) \
- torch.einsum('...lm,mlk->...km', x[..., 1, :, :, 1], self.dpct[0].to(x.dtype))

# reassemble
s = torch.stack((srl, sim), -1)
Expand Down

0 comments on commit a0d1fbc

Please sign in to comment.