Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(No) safe way to wrap taichi kernels in pytorch #8339

Open
oliver-batchelor opened this issue Aug 31, 2023 · 10 comments
Open

(No) safe way to wrap taichi kernels in pytorch #8339

oliver-batchelor opened this issue Aug 31, 2023 · 10 comments
Assignees

Comments

@oliver-batchelor
Copy link
Contributor

See context in #8101. Essentially I don't think there's a completely safe way to wrap a taichi kernel in pytorch at the moment - below I implement the recommended solution from (#8101) but the problem is that the gradient does not propagate.

Using the other method mentioned in #8101 the gradient does propagate, but there are some slightly strange differences to "normal" functions using the taichi autograd depending on if retain_grad is enabled.

import taichi as ti
from taichi.types import ndarray
from taichi.math import vec3, length


import torch

@ti.kernel
def distances_kernel(x:ndarray(ti.math.vec3), y:ndarray(ti.math.vec3), distances:ndarray(ti.f32)):
  for i in range(x.shape[0]):
    distances[i] = length(x[i] - y[i])



def point_distance_func():
  class PointDistance(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x:torch.Tensor, y:torch.Tensor):

        distances = torch.zeros((x.shape[0],),  device=x.device, dtype=torch.float32)
        distances_kernel(x, y, distances)

        ctx.save_for_backward(x, y, distances)
        return distances

    @staticmethod
    def backward(ctx, grad_output):
        x, y, distances = ctx.saved_tensors

        distances.grad = grad_output.contiguous()
        distances_kernel.grad(x, y, distances)

        # x_grad = x.grad.clone()
        # y_grad = y.grad.clone()

        # x.grad.fill_(0)
        # y.grad.fill_(0)

        # return x_grad, y_grad
        
        return torch.zeros_like(x), torch.zeros_like(y)
        
    
  return PointDistance.apply



if __name__ == "__main__":
  ti.init(arch=ti.gpu, debug=True)

  torch.cuda.manual_seed(0)
  
  points1 = torch.randn((5, 3), device='cuda', dtype=torch.float32).requires_grad_(True)
  points2 = torch.randn((5, 3), device='cuda', dtype=torch.float32).requires_grad_(True)

  d = torch.zeros((5,), device='cuda', dtype=torch.float32).requires_grad_(True)
  
  dist = point_distance_func()

  points3 = points2  * 2 
  # points3.retain_grad()

  d = dist(points1, points3)
  # d = torch.norm(points1 - points3, dim=1)
  d.sum().backward()

# Notice how points2.grad is zero
  print(points1.grad, points2.grad, points3.grad)



@oliver-batchelor
Copy link
Contributor Author

oliver-batchelor commented Aug 31, 2023

A utility like this one seems the right way to restore any manipulation of the .grad attribute:

from contextlib import contextmanager

@contextmanager
def restore_grad(*tensors):
  try:
      grads = [tensor.grad for tensor in tensors]
      yield
  finally:
      for tensor, grad in zip(tensors, grads):
          tensor.grad = grad

However, this still doesn't match the torch autograd system because taichi kernels on ndarrays seem to actually allocate the .grad?! (See #8340).

Just setting all the the .grad attributes to None instead seems to have the same behavior as the torch autograd (at least in very simple examples tested).

@oliver-batchelor
Copy link
Contributor Author

Torch gives a lot of warnings about use of the .grad attribute too for non-leaf Tensors. Perhaps to avoid all of this nonsense taichi could use it's own attribute (e.g x.taichi_grad) or some such so that there's no danger of breaking the autograd and simplifying the code?

  if v.requires_grad and v.grad is None:
/home/oliver/test_gradients.py:85: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd. backwards (). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343995026/work/build/aten/src/ATen/core/TensorBody.h:486.)

@ailzhang ailzhang self-assigned this Sep 1, 2023
@ailzhang ailzhang moved this from Untriaged to Todo in Taichi Lang Sep 1, 2023
@chenzhekl
Copy link

chenzhekl commented Sep 26, 2023

I encountered the same problem. An additional side effect of this problem is that the custom autograd.Function wrapping a Taichi kernel would never work with gradcheck. The latter complaints that the backward path is not re-entrant.

@JJBannister
Copy link

I am encountering a similar issue and was wondering if there is any current guidance or recommendations for how to wrap a Taichi kernel that uses Taichi autodiff in a torch autograd.Function.

If neither of the solutions proposed in #8101 integrate properly with the torch autodiff system, is the current best practice to copy data from tensors into fields before running the kernel and it's grad?

@oliver-batchelor
Copy link
Contributor Author

If neither of the solutions proposed in #8101 integrate properly with the torch autodiff system, is the current best practice to copy data from tensors into fields before running the kernel and it's grad?

That is safe - but also undesirable/non performant for many reasons.

@bobcao3
Copy link
Collaborator

bobcao3 commented Dec 24, 2023

If neither of the solutions proposed in #8101 integrate properly with the torch autodiff system, is the current best practice to copy data from tensors into fields before running the kernel and it's grad?

That is safe - but also undesirable/non performant for many reasons.

Currently one thing we use for our production applications is to use torch tensors directly in the autograd kernels, then we first backup original .grad attributes, create a new zero tensor and assign it to the .grad, and after running the backwards kernel the .grad is restored and the newly created grad tensors are returned. We are working on a way to do this without ugly marry go arounds like that (ability to pass a tuple of tensors for example...)

@oliver-batchelor
Copy link
Contributor Author

oliver-batchelor commented Dec 25, 2023

from contextlib import contextmanager

@contextmanager
def restore_grad(*tensors):
try:
grads = [tensor.grad for tensor in tensors]
yield
finally:
for tensor, grad in zip(tensors, grads):
tensor.grad = grad

If neither of the solutions proposed in #8101 integrate properly with the torch autodiff system, is the current best practice to copy data from tensors into fields before running the kernel and it's grad?

That is safe - but also undesirable/non performant for many reasons.

Currently one thing we use for our production applications is to use torch tensors directly in the autograd kernels, then we first backup original .grad attributes, create a new zero tensor and assign it to the .grad, and after running the backwards kernel the .grad is restored and the newly created grad tensors are returned. We are working on a way to do this without ugly marry go arounds like that (ability to pass a tuple of tensors for example...)

Thanks - this seems like the best solution for now.

@oliver-batchelor
Copy link
Contributor Author

oliver-batchelor commented Dec 25, 2023

Oh I see, to make it work it needs to actually copy the tensor, since the kernel grad function will mutate it. This version works with the gradcheck.

@contextmanager
def restore_grad(*tensors):
  try:
      grads = [tensor.grad.clone() if tensor.grad is not None else None
                for tensor in tensors]
      yield
  finally:
      
      for tensor, grad in zip(tensors, grads):
          tensor.grad = grad

An improvement would be to avoid touching the .grad attribute in the forward pass (currently it creates a grad attribute with torch.zeros), then this would be more efficient, as most of the time the .grad is None - but since taichi has initialized it with zeros we end up creating and restoring a zero vector unnecessarily.

Relevant code is here kernel_impl.py 762 - seems like the FIXME comment is quite relevant.

                    # FIXME: only allocate when launching grad kernel
                    if v.requires_grad and v.grad is None:
                       v.grad = torch.zeros_like(v)

@oliver-batchelor
Copy link
Contributor Author

Commenting out the lines in kernel_impl and changing the restore_grad to the below code seems to work nicely, too - that way it doesn't create the zero tensors until they're actually needed - probably better to do this in Taichi code though so it doesn't crash otherwise...

@contextmanager
def restore_grad(*tensors):
  try:
      grads = [tensor.grad if tensor.grad is not None else None
                for tensor in tensors]
      
      for tensor in tensors:    
          if tensor.requires_grad is True:
            tensor.grad = torch.zeros_like(tensor)
      yield
  finally:
      for tensor, grad in zip(tensors, grads):
          tensor.grad = grad

@Kiord
Copy link

Kiord commented Nov 14, 2024

Here is a small Torch module wrapper for Taichi kernels for annyone interested.

Pros:

  • Generalizes many Taichi kernels.

Cons:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Todo
Development

No branches or pull requests

6 participants