Skip to content

Commit

Permalink
FIX: set new JIT scopes in PyTorch wrapper methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Speierers authored and njroussel committed Jan 14, 2025
1 parent 81f73b4 commit 3a8364e
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion drjit/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 373,21 @@ def wrapper(args, kwargs):
# See https://github.com/pytorch/pytorch/issues/117491
torch_desc_o = None

def new_drjit_scope():
if dr.has_backend(dr.JitBackend.LLVM):
dr.detail.new_scope(dr.JitBackend.LLVM)
if dr.has_backend(dr.JitBackend.CUDA):
dr.detail.new_scope(dr.JitBackend.CUDA)

def create_torch_wrapper():
from torch import set_grad_enabled as torch_set_grad_enabled
from torch.autograd import Function, function

class TorchWrapper(Function):
@staticmethod
def forward(ctx, func, desc, *inputs):
new_drjit_scope()

# Convert and unflatten the input PyTrees
inputs = to_drjit(inputs, 'torch', enable_grad=True)
args, kwargs = unflatten(desc, *inputs)
Expand Down Expand Up @@ -409,11 417,15 @@ def fn(h):
# Convert the output and return
output_conv = from_drjit(output, 'torch')[0]

new_drjit_scope()

return tuple(output_conv)

@staticmethod
@function.once_differentiable
def backward(ctx, *grad_outputs):
new_drjit_scope()

grad_outputs = to_drjit(grad_outputs, 'torch')
dr.set_grad(ctx.output, grad_outputs)

Expand All @@ -426,11 438,16 @@ def backward(ctx, *grad_outputs):

# Convert
grad_inputs = from_drjit(grad_inputs, 'torch')[0]

new_drjit_scope()

return None, None, *grad_inputs

@staticmethod
@function.once_differentiable
def jvp(ctx, func, desc, *grad_inputs):
new_drjit_scope()

grad_inputs = to_drjit(grad_inputs, 'torch')
dr.set_grad(ctx.inputs, grad_inputs)

Expand All @@ -443,6 460,9 @@ def jvp(ctx, func, desc, *grad_inputs):

# Convert
grad_output = from_drjit(grad_output, 'torch')[0]

new_drjit_scope()

return grad_output


Expand Down Expand Up @@ -480,7 500,7 @@ def wrap(source: typing.Union[str, types.ModuleType],
The following table lists the currently supported conversions:
.. |nbsp| unicode:: 0xA0
.. |nbsp| unicode:: 0xA0
:trim:
.. list-table::
Expand Down

0 comments on commit 3a8364e

Please sign in to comment.