Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradient when using taichi autodiff.grad and pytorch autodiff.function together. #8534

Open
zjcs opened this issue May 30, 2024 · 3 comments

Comments

@zjcs
Copy link

zjcs commented May 30, 2024

Describe the bug
A clear and concise description of what the bug is, ideally within 20 words.

Wrong gradient when using taichi autodiff.grad and pytorch autodiff.function together.

To Reproduce
Please post a minimal sample code to reproduce the bug.
The developer team will put a higher priority on bugs that can be reproduced within 20 lines of code. If you want a prompt reply, please keep the sample code short and representative.

# sample code here

import taichi as ti
import torch

ti.init(arch=ti.cpu)

@ti.kernel
def func_x2(x:ti.types.ndarray(ndim=1),
           y :ti.types.ndarray(ndim=1),
):
    for i in ti.ndrange(x.shape[0]):
        y[0] += x[i]**2

class TaichiKernel(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.zeros((1,), dtype=torch.float32, requires_grad=True)
        print("taichi kernel forward:", x.grad, y.grad)
        func_x2(x, y)
        ctx.save_for_backward(x)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        # x, y = ctx.saved_tensors
        x, = ctx.saved_tensors
        # print("taichi kernel backward:", x.grad, y.grad, grad_y)
        print("taichi kernel backward+1:", x.grad, grad_y)
        y = torch.zeros((1,), dtype=torch.float32, requires_grad=True)
        y.grad = grad_y
        print("taichi kernel backward+2:", x.grad, y.grad, grad_y)
        func_x2.grad(x, y)
        print("taichi kernel backward+3:", x.grad, y.grad, grad_y)
        return x.grad

class TaichiModule(torch.nn.Module):
    def forward(self, x):
        return TaichiKernel().apply(x)

print("===================>")
x = torch.arange(4, dtype=torch.float32, requires_grad=True)
y = TaichiModule()(x)
loss = y.sum()
loss.backward()

print("y.grad final", y.grad)
print("x.grad final", x.grad)

Log/Screenshots
Please post the full log of the program (instead of just a few lines around the error message, unless the log is > 1000 lines). This will help us diagnose what's happening. For example:

$ python my_sample_code.py
[Taichi] version 1.7.1, llvm 15.0.4, commit 0f143b2f, linux, python 3.8.8
[Taichi] Starting on arch=x64
===================>
taichi kernel forward: None None
taichi kernel backward+1: tensor([0., 0., 0., 0.]) tensor([1.])
taichi kernel backward+2: tensor([0., 0., 0., 0.]) tensor([1.]) tensor([1.])
taichi kernel backward+3: tensor([0., 2., 4., 6.]) tensor([1.]) tensor([1.])
y.grad final tensor([0.])
x.grad final tensor([ 0.,  4.,  8., 12.])

Additional comments
If possible, please also consider attaching the output of command ti diagnose. This produces the detailed environment information and hopefully helps us diagnose faster.

If you have local commits (e.g. compile fixes before you reproduce the bug), please make sure you first make a PR to fix the build errors and then report the bug.

@zjcs
Copy link
Author

zjcs commented May 30, 2024

the gradient of x and y is wrong in the log, while the right result should be:
y.grad: tensor([0.]) -> tensor([1.])
x.grad: tensor([ 0., 4., 8., 12.]) -> tensor([0., 2., 4., 6.])

@bobcao3
Copy link
Collaborator

bobcao3 commented Jun 22, 2024

Taichi is accumulating directly into the gradient tensor. For correct interop behavior with pytorch you need to declare new zeroed gradient tensor and pass them into taichi, and then return those

@oliver-batchelor
Copy link
Contributor

Relates to #8339 - IMO ideally Taichi should not touch the .grad attribute at all and use someother attribute or method to pass around gradients.

If you are careful you can replace the .grad vector with zeros before the taichi grad kernel call then afterwards restore whatever was in the .grad vector and it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Untriaged
Development

No branches or pull requests

3 participants