torch.func interaction with torch.compile ============================================== So you want to use a `torch.func` ("functorch") transform (like `vmap`, `grad`, `jacrev`, etc) with `torch.compile`. Here's a guide to what works today, what doesn't, and how to work around it. Applying a `torch.func` transform to a `torch.compile`'d function ----------------------------------------------------------------- This doesn't work and is being tracked by `https://github.com/pytorch/pytorch/issues/100320`. .. code:: python import torch @torch.compile def f(x): return torch.sin(x) def g(x): return torch.grad(f)(x) x = torch.randn(2, 3) g(x) As a workaround, please put the `torch.compile` outside of the `torch.func` transform: .. code:: python import torch def f(x): return torch.sin(x) @torch.compile def g(x): return torch.vmap(f)(x) x = torch.randn(2, 3) g(x) Doesn't work (PT 2.0): calling a `torch.func` transform inside of a `torch.compile`'ed function ------------------------------------------------------------------------------------------------ .. code:: python import torch @torch.compile def f(x): return torch.vmap(torch.sum)(x) x = torch.randn(2, 3) f(x) This doesn't work yet. Please see the workaround (the next section). Workaround: use `torch._dynamo.allow_in_graph` ---------------------------------------------- `allow_in_graph` is an escape hatch. If your code does not work with `torch.compile`, which introspects Python bytecode, but you believe it will work via a symbolic tracing approach (like `jax.jit`), then use `allow_in_graph`. By using `allow_in_graph` to annotate a function, you promise PyTorch a couple of things that we are unable to completely verify: - Your function is pure. That is, all outputs only depend on the inputs and do not depend on any captured Tensors. - Your function is functional. That is, it does not mutate any state. This may be relaxed; we actually support functions that appear to be functional from the outside: they may have in-place PyTorch operations, but may not mutate global state or inputs to the function. - Your function does not raise data-dependent errors. .. code:: python import torch @torch.compile def f(x): return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x) x = torch.randn(2, 3) f(x) A common pitfall is using `allow_in_graph` to annotate a function that invokes an `nn.Module`. This is because the outputs now depend on the parameters of the `nn.Module`. To actually get this to work, use `torch.func.functional_call` to extract the module state.