Spaces:
Running
on
L40S
A newer version of the Gradio SDK is available:
5.45.0
The Custom Operators Manual
Read Me First
This manual is a comprehensive reference for all things related to PyTorch Custom Operators. We recommend that you first read one of the focused tutorials listed on our landing page and then refer to this document as a manual for edge cases or less-recommended approaches.
The landing page: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
What is an operator?
A kernel is a function that accepts Tensors and/or raw pointers to memory and performs a useful computation (for example, matrix multiplication, attention, etc).
An operator is glue code for the PyTorch runtime that tells it about the computation. A single operator can be associated with multiple kernels (for example, torch.add has a kernel for CPU and a kernel for CUDA). The glue code is necessary to get PyTorch subsystems (like torch.compile and torch.autograd) to compose with the computation.
Standalone kernels may work directly with PyTorch but will not compose with the majority of PyTorch subsystems. In order to get them to compose, please register an operator for them.
How to make existing operators work with torch.compile.
TL;DR
Call an operator pt2_compliant if it works with the new PyTorch compilation APIs (torch.compile, torch.export, etc) introduced in PyTorch 2.x.
This is a two-step process:
Step 1: Test the custom op with torch.library.opcheck.
Step 2: Fix all problems with the custom op (until the opcheck passes)
Step 3: Mark the custom op as PT2 compliant.
Step 1: How to use opcheck to test the custom op
You have two options: manually use opcheck to test the custom op, or (Meta-only) use `generate_opcheck_tests` to automatically test the custom op. If the custom op already is covered by one of these two mechanisms, skip ahead to step 2.
Step 1a: How to manually use opcheck to test the custom op
Please call opcheck multiple times with different representative sample inputs:
- If your operator works on CPU and CUDA, please pass a set of sample inputs on CPU and a set of sample inputs on CUDA
- If your operator supports training, please pass some sample inputs with requires_grad=True.
Using the operator torch.ops.aten.sin.default as an example:
import torch
import unittest
from torch.library import opcheck
def sin_sample_inputs():
sample_inputs = [
(torch.randn(3, requires_grad=True, device='cpu'),),
(torch.randn(3, requires_grad=True, device='cuda'),),
]
class TestOps(unittest.TestCase):
def test_sin(self):
sample_inputs = sin_sample_inputs()
for i in range(len(sample_inputs)):
opcheck(torch.ops.aten.sin.default, sample_inputs[i])
Step 2b: How to automatically use opcheck to test the custom op
Please only use this if you work at Meta. While this API is included with PyTorch, we do not guarantee BC for it.
Use this approach (generate_opcheck_tests) only if:
- You have a large collection of existing tests that exercise multiple custom ops that you would like to be tested
- You are willing to put up with the sharp edges in this API
Please see Working with generate_opcheck_tests for more details, and https://github.com/pytorch/FBGEMM/pull/2050/files#diff-c1c25e22107028a66ff548c7042ba3f39bcc009db9348825e46b15f60754cbffR2452-R2467 for an example.
Step 2: How to fix failing opcheck tests
If opcheck fails: please try to pass the tests in the following order.
1. test_schema fails
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_schema")
This means that the schema of the operator is wrong and will lead to silent incorrectness issues. All operators have a schema string (example) that specifies the types of the inputs and outputs as well as some “aliasing information”:
- if any outputs are views of the inputs
- if any inputs are mutated in-place
The fix is usually to update the schema to include the aliasing information. See https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations for more details.
2. test_autograd_registration fails
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_autograd_registration")
This means that the autograd registration is incorrect and will lead to silent incorrectness issues. Some common symptoms are correct gradients without using torch.compile but incorrect gradients using torch.compile.
Please see How to to add an autograd formula for how the autograd registration should look like.
3. test_faketensor fails
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_faketensor")
PT2 compilation APIs use “Fake Tensors”, Tensors without storage, to propagate metadata. Every operator needs a “Fake Tensor kernel” which is like a shape formula: given the shapes (and other metadata) of the inputs, it should specify the shapes (and other metadata) of the outputs.
This opcheck test could fail for two reasons:
- UnsupportedFakeTensorException: you didn’t write an abstract_impl or a meta formula. Please see How to add abstract impl / meta formula for how to write one.
- Your abstract_impl/meta formula is wrong. Please debug it.
4. test_aot_autograd_static fails
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_aot_autograd_static")
If this test succeeds, then the operator works with torch.compile(dynamic=False).
If the operator in question returns Tensors with data-dependent shapes, then this test is expected to fail.
Otherwise, if it fails, there are some common reasons, listed below. Please match your error message to one of them.
No DispatchKey::Functionalize (functionalization) kernel
torch.compile backends only support functional operators (that is, operators that do not mutate inputs and do not return views of the inputs). If your operator is not functional (indicated by an “(a)” in the schema string), then we need to teach torch.compile how to functionalize your operator.
For now, come find us for help if this is the case.
There is a .item call somewhere
When PT2 APIs like torch.compile see a C++ .item() call, they don’t know what to do.
Please either:
- rewrite your custom operator to not use the offending .item() call
- hide the .item() call in a new custom operator
Backward formula isn’t traceable
PyTorch assumes that the backward pass of your operator only consists of invocations to the PyTorch dispatcher.
That is, the backward pass may call:
- Built-in PyTorch operators.
- In Python
- torch.* APIs
- In C++
- at::{API} operators where API is in native_functions.yaml
- Tensor metadata operations. E.g. Tensor::sizes() / Tensor::strides()
- In Python
- Custom Operators. All calls to custom operators must be made through the PyTorch dispatcher.
- If they are being invoked from Python, you don’t need to do anything.
- If they are being invoked from C++, they must query the PyTorch Dispatcher for a TypedOperatorHandle and invoke TypedOperatorHandle::call:
static auto custom_sin_op = torch::Dispatcher::singleton()
.findSchemaOrThrow("custom::sin", "")
.typed<decltype(custom_sin)>();
Tensor result = custom_sin_op.call(x);
Please see How to to add an autograd formula for more details
5. test_aot_autograd_dynamic fails
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_aot_autograd_dynamic")
This generally means that your operator doesn’t support Dynamic Shapes, especially if “test_aot_autograd_static” succeeds.
Please see The dynamic shapes manual for what to do.
Detailed description of the opcheck tests
test_schema
We test that the schema matches the implementation of the operator. For example: if the schema specifies a Tensor is mutated, then we check the implementation mutates the Tensor. If the schema specifies that we return a new Tensor, then we check that the implementation returns a new Tensor (instead of an existing one or a view of an existing one).
Note that the schema language might look simple, but it encodes a lot of information (about mutations, new tensors, and aliases) that is easy to get wrong.
test_autograd_registration
If the operator supports training (autograd): we check that its autograd formula is registered via torch.library.register_autograd or a manual registration to one or more DispatchKey::Autograd keys. Any other DispatchKey-based registrations may lead to undefined behavior.
test_faketensor
We check that a FakeTensor kernel (also sometimes known as a meta kernel) was registered for the operator and that it is correct. This test takes the result of running the operator on real tensors and the result of running the operator on FakeTensors and checks that they have the same Tensor metadata (sizes/strides/dtype/device/etc).
test_aot_dispatch_dynamic
This test checks multiple things:
- We check that the operator supports functionalization. That is, it is functional or can automatically be functionalized by torch.compile.
- If the operator supports training, we check that the backward pass supports FakeTensor and functionalization.
This test is effectively an e2e run with torch.compile(backend=”aot_eager”) (through Dynamo+AOTDispatcher) that does something like:
outs = op(*args)
outs_compiled = torch.compile(op, backend="aot_eager")(*args)
torch.testing.assert_close(outs_compiled, outs)
if supports_training(op):
grad_args = torch.autograd.grad(outs, args, ...)
grad_args_compiled = torch.autograd.grad(outs_compiled, args, ...)
torch.testing.assert_close(grad_args_compiled, grad_args)
Step 3: How to mark the op as PT2 compliant
If the operator was defined in Python
Using torch.library.custom_op
Custom ops created with torch.library.custom_op automatically get this tag.
Using torch.library.define
lib.define("sin(Tensor x) -> Tensor", tags=[torch.Tag.pt2_compliant_tag]);
If the operator was defined in C++
Where the operator was defined, add the at::Tag::pt2_compliant_tag:
m.def("sin(Tensor x) -> Tensor", {at::Tag::pt2_compliant_tag});
This documents the operator as PT2 compliant. Please only add this tag if the operator passes the opcheck tests.
Writing a new Custom Operator
What is a “kernel”? What is an “operator”?
A kernel is a function that accepts Tensors and/or raw pointers to memory and performs a useful computation (for example, matrix multiplication, attention, etc).
An operator is glue code for the PyTorch runtime that tells it about the computation. A single operator can be associated with multiple kernels (for example, torch.add has a kernel for CPU and a kernel for CUDA). The glue code is necessary to get PyTorch subsystems (like torch.compile and torch.autograd) to compose with the computation.
When should I create a Custom Operator?
You may wish to create a Custom Operator for two reasons:
- You have some custom CPU/CUDA/other backend kernel that you’ like to integrate with PyTorch
- You have some code that you want PyTorch to treat as an opaque callable (as a black-box).
For example, you may want to call out to some low-level third-party library like LAPACK or CUBLAS, or you may have written a bunch of CUDA kernels in .cu files.
Your custom operator kernels should include as few PyTorch built-in operators as possible. Including built-in PyTorch operators in a C++ custom operator hides them from PyTorch subsystems like torch.compile, which hides optimization opportunities.
If what you are trying to do is expressible as a composition of built-in PyTorch operators (and do not involve low-level C/C++/CUDA code or third-party Python libraries), then please write your routine as a Python function and call it instead of creating a custom operator.
Python or C++?
You can define custom operators in both Python and C++. These registration APIs may be mixed: for example, one can define an operator’s CPU kernel from Python and CUDA kernel from C++.
Our general guidance is:
- If you care about AOTInductor (and being able to run in a Python-less environment), you should define the operator and add backend kernels in C++.
- Otherwise, it is generally easier to use the Python custom operator registration APIs.
How to define a custom operator
To define an operator, you must tell us:
- The name of the operator
- Some metadata around the acceptable input/output types of the operator and if any inputs are being mutated
[From Python] How to define a custom operator
[PyTorch >= 2.4] Using torch.library.custom_op
Use torch.library.custom_op to decorate a function to turn it into a custom operator. The function must be decorated with type annotations, and you must correctly annotate inputs that are being mutated.
@torch.library.custom_op("your_namespace::sin", mutates_args=())
def sin(x: torch.Tensor) -> torch.Tensor:
return torch.from_numpy(np.sin(x.numpy(force=True))
[PyTorch < 2.4] Using torch.library.define
To define an operator, you must tell us:
- The name of the operator
- The schema string of the operator. The spec for this schema is defined here, with multiple examples in here.
torch.library.define("your_namespace::sin(Tensor x) -> Tensor")
[From C++] How to define a custom operator
To define an operator, you must tell us:
- The name of the operator
- The schema string of the operator. The spec for this schema is defined here, with multiple examples in here.
Let’s go through an example of a custom sin operator.
#include <torch/library.h>
// Define the operator
TORCH_LIBRARY(your_namespace, m) {
m.def("sin(Tensor x) -> Tensor");
}
If you define operator schemas in multiple places, use TORCH_LIBRARY_FRAGMENT instead of TORCH_LIBRARY.
How to add CPU/CUDA/Backend implementations
[From Python] How to add CPU/CUDA/Backend implementations
[PyTorch >= 2.4]
Use torch.library.register_kernel.
@torch.library.register_kernel("your_namespace::sin", "cpu")
def _(x: torch.Tensor) -> torch.Tensor:
# your CPU implementation
...
@torch.library.register_kernel("your_namespace::sin", "cuda")
def _(x: torch.Tensor) -> torch.Tensor:
# your CUDA implementation
...
[PyTorch < 2.4]
Use torch.library.impl
@torch.library.impl("your_namespace::sin", "cpu")
def _(x: torch.Tensor) -> torch.Tensor:
# your CPU implementation
...
@torch.library.impl("your_namespace::sin", "cuda")
def _(x: torch.Tensor) -> torch.Tensor:
# your CUDA implementation
...
[From C++] How to add CPU/CUDA/Backend implementations
To provide backend-specific implementations for an operator, use TORCH_LIBRARY_IMPL.
Tensor custom_sin_cpu(const Tensor& x) {
// Replace this with at::sin if you want to test it out.
return my_custom_sin_implementation_on_cpu(x);
}
// Register the CPU implementation for the operator
TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
m.impl("sin", &custom_sin_cpu);
}
Tensor custom_sin_cuda(const Tensor& x) {
// Replace this with at::sin if you want to test it out.
return my_custom_sin_implementation_on_cuda(x);
}
// Register the CUDA implementation for the operator
TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
m.impl("sin", &custom_sin_cuda);
}
How to invoke a custom operator
How to invoke a custom op defined in Python from Python
When you created a custom operator, you gave it a name. The custom operator is findable under torch.ops:
x = torch.randn(3)
y = torch.ops.your_namespace.sin(x)
How to invoke a custom op defined in C++ from C++ {#how-to-invoke-a-custom-op-defined-in-c++-from-c++}
static auto custom_sin_op = torch::Dispatcher::singleton()
.findSchemaOrThrow("your_namespace::sin", "")
.typed<decltype(custom_sin_cpu)>();
Tensor result = custom_sin_op.call(x)
In order to invoke the custom operator, we must first query it from the PyTorch dispatcher and then invoke it.
How to invoke a custom op defined in C++ from Python
The C++ custom op gets compiled into a shared library. Use torch.ops.load_library(path_to_shared_library) to load the shared library.
Once the shared library has loaded, the custom op is available from the torch.ops namespace:
x = torch.randn(3)
y = torch.ops.your_namespace.sin(x)
assert torch.allclose(y, torch.sin(x))
How to add FakeTensor support (abstract impl; meta kernel) {#how-to-add-faketensor-support-(abstract-impl;-meta-kernel)}
In order for your operator to work with PT2, it must have FakeTensor support. There are around three blessed ways to do this:
- (Preferred) Write a FakeTensor kernel using the torch.library.register_fake / torch.library.impl_abstract API from Python
- Write a C++ Meta kernel
- If your operator is registered to CompositeImplicitAutograd, it will automatically decompose and we require its constituents to support FakeTensor.
Context: meta kernel vs FakeTensor kernels
In order for your operator to work with PT2, it must have FakeTensor support. That is, we must know how to run the operator on “Fake” input Tensors that do not have storage, but have sizes/strides/device.
Adding a meta kernel (a function that describes how an operator works with device=’meta’) will automatically generate FakeTensor support. However, meta kernels don’t support the full range of things FakeTensors do. For example, operators with data-dependent output shape (think torch.nonzero) and operators with cross-device semantics (like Tensor.to(device=”cuda”)) are not describable with meta functions.
Instead of writing a meta function, our recommendation is to write FakeTensor kernels, which are a generalization of meta functions to support FakeTensors. Writing a FakeTensor kernel is very similar to writing a meta kernel: in most cases, the meta kernel can be re-used as the FakeTensor kernel impl.
NB: We also sometimes use “abstract impl” to refer to a “FakeTensor kernel”. These are the same thing.
How to write a Python FakeTensor kernel {#how-to-write-a-python-faketensor-kernel}
There are three parts to this:
- In Python, use torch.library.register_fake (PyTorch 2.4+) or torch.library.impl_abstract (PyTorch <= 2.3) to provide the FakeTensor kernel for an operator
- In the Python program that uses the custom operator, import the module that includes the FakeTensor kernel registration from step 1.
If you work at Meta, please see https://fburl.com/python_meta_example for an example.
Step 1: Use torch.library.register_fake
An "FakeTensor kernel" specifies the behavior of this operator on Tensors that carry no data. Given some input Tensors with certain properties (sizes/strides/storage_offset/device), it specifies what the properties of the output Tensors are.
The FakeTensor kernel has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write an FakeTensor kernel, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The FakeTensor kernel must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).
# your_module.py
import torch
torch.ops.load_library("path/to/shared/lib/that/has/your_cpp_file")
# Write the FakeTensor kernel
@torch.library.register_fake("your_namespace::sin")
def sin_abstract(x):
# torch.empty_like(x) returns a Tensor on the same device as `x`.
# If you instead want to hardcode the device of the output, no matter the
# device of the input,
# manually specify it like torch.empty_like(x, device="cuda")
return torch.empty_like(x)
Step 2: Import the module that contains the torch.library.register_fake call
import torch
torch.ops.import_module("your_module")
@torch.compile(backend="eager")
def f(x):
return torch.ops.your_namespace.sin(x)
x = torch.randn(3)
f(x)
(PyTorch <=2.3 only) Add an abstract impl pystub (if one doesn’t already exist)
The operator will complain during testing if it needs an impl abstract pystub. In that case, add a `m.impl_abstract_pystub` call to the TORCH_LIBRARY block that the operator was defined in (e.g. with the m.def( call).
// your_cpp_file.cpp
TORCH_LIBRARY(your_namespace, m) {
// Leave a stub that tells the C++ PyTorch Dispatcher that the
// abstract impl exists in a given Python module.
// This will prevent desync issues (where someone loads the operator
// without actually loading the Python module).
//
// impl_abstract_pystub(module_name, buck_target):
// - module name: the name of the Python module the abstract impl resides in
// - buck_target (optional): If you're using a buck-based build system,
// then you can include the name of the buck target that includes
// the module here, otherwise it is optional.
// We use the module name and the buck target to give better error messages
m.impl_abstract_pystub("your_module", "//your_module:custom_ops");
m.def("sin(Tensor x) -> Tensor");
}
The pystub applies to all operators registered in the given TORCH_LIBRARY block.
We removed the need for this in PyTorch 2.4+
How to write FakeTensor kernel for operator with data-dependent output shape
Use torch.library.get_ctx().new_dynamic_size() to allocate data-dependent output sizes. For example, a “nonzero” operator returns a Tensor with shape (number_of_nonzero_elements, dim):
@torch.library.register_fake("your_namespace::nonzero")
def nonzero_abstract(x):
nnz = torch.library.get_ctx().new_dynamic_size()
return x.new_empty(nnz, x.dim(), dtype=torch.long)
How to write C++ meta kernel
You should seriously consider writing a Python FakeTensor kernell instead; the API is more generic (e.g. supports cross-device and data-dependent output shape), has fewer footguns (it is automatically symint-ified), you skip recompile cycles, and the resulting kernel is easier to debug.
C++ meta kernel example:
Tensor sin_meta(const Tensor& x) {
return torch.empty_like(x)
}
TORCH_LIBRARY_IMPL(your_namespace, Meta, m) {
m.impl("sin", &sin_meta);
}
FakeTensor/meta kernels for CompositeImplicitAutograd operators
If your custom op consists only of calls to the PyTorch dispatcher (at:: operators in native_functions.yaml and custom op calls via TypedOperatorHandle.call, as in How to invoke the custom op from C++), then you can SymInt’ify your operator according to The dynamic shapes manual and register your op to the CompositeImplicitAutograd dispatch key.
The advantage is that you do not need to define an autograd function (like in How to add an autograd formula) or a meta implementation.
Before:
TORCH_LIBRARY(your_namespace, m) {
m.def("my_op(Tensor x, int[] shape) -> Tensor");
}
Tensor my_op_impl(const Tensor& x, IntArrayRef shape) {
Tensor y = at::sin(x);
// suppose custom op my_op2 has signature
// (Tensor x, int[] shape) -> Tensor
return my_op2(y, shape);
}
TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
m.impl("my_op", &my_op_impl);
}
TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
m.impl("my_op", &my_op_impl);
}
After:
TORCH_LIBRARY(your_namespace, m) {
m.def("my_op(Tensor x, SymInt[] shape) -> Tensor");
}
Tensor my_op_impl(const Tensor& x, SymIntArrayRef shape) {
Tensor y = at::sin(x);
// suppose custom op my_op2 has signature
// (Tensor x, int[] shape) -> Tensor
static auto my_op2_op = torch::Dispatcher::singleton()
.findSchemaOrThrow("your_namespace::my_op2", "")
.typed<decltype(my_op2)>();
return my_op2_op.call(y, shape);
}
TORCH_LIBRARY_IMPL(your_namespace, CompositeImplicitAutograd, m) {
m.impl("my_op", &my_op_impl);
}
How to add an autograd formula {#how-to-add-an-autograd-formula}
In order for your custom operator to work with training, it must have an autograd formula. You can register this from either Python or C++; we recommend doing this from Python.
(Recommended) [In Python] Adding an autograd formula (PyTorch 2.4+ only)
Use torch.library.register_autograd to add an autograd formula for an operator. Please see the documentation for torch.library.register_autograd for more information.
[In C++] Adding an autograd formula
WARNING: this approach has a lot of footguns! Use with caution; incorrect usage will result in silent incorrectness (in both eager-mode PyTorch and with torch.compile).
To add training support, please:
- Construct a C++ torch::autograd::Function with a forward() and a backward() pass:
- The forward pass must (1) save Tensors/data for backward and (2) re-dispatch to the operator (please see example)
- If your operator's backward pass is a custom kernel, then it should be invoked through a custom operator.
- Register this torch::autograd::Function to DispatchKey::Autograd. It is an error (and will be silently incorrect) if you register it to DispatchKey::CPU/CUDA/anything else.
Below is an example of a custom sin operator:
#include <torch/library.h>
// Declare the operator
TORCH_LIBRARY(your_namespace, m) {
m.def("sin(Tensor x) -> Tensor");
}
// Add the CPU implementation for the operator
Tensor custom_sin_cpu(const Tensor& x) {
// Replace this with at::sin if you want to test it out.
return my_custom_sin_implementation_on_cpu(x);
}
TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
m.impl("sin", &custom_sin_cpu);
}
// Add the CUDA implementation for the operator
Tensor custom_sin_cuda(const Tensor& x) {
// Replace this with at::sin if you want to test it out.
return my_custom_sin_implementation_on_cuda(x);
}
TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
m.impl("sin", &custom_sin_cuda);
}
Now, let’s add a backward formula:
// To register a backward formula for it, we need to construct a
// torch::autograd::Function.
class CustomSin : public torch::autograd::Function<CustomSin> {
public:
static variable_list forward(
AutogradContext* ctx,
const Tensor& x) {
// It is important that the forward looks like the following.
// If you do anything else, the operator may be silently incorrect!
// (1) You must construct this guard and then invoke the Dispatcher
// on this operator.
// We refer to this sometimes as a "redispatch".
// The following lines will Dispatch past Autograd and call a backend
// implementation for custom::sin, i.e. either custom_sin_cpu
// or custom_sin_cuda.
at::AutoDispatchBelowADInplaceOrView guard;
static auto custom_sin_op = torch::Dispatcher::singleton()
.findSchemaOrThrow("your_namespace::sin", "")
.typed<decltype(custom_sin_cpu)>();
Tensor result = custom_sin_op.call(x);
// (2) You may save Tensors or other data (like the shape of the Tensor)
// for backwards via one call to ctx->save_for_backward and
// one or more calls to ctx->saved_data
ctx->save_for_backward({x});
// (3) Finally, return the result of the operator computed in step 1
// as a flat list of Tensors. You may ONLY RETURN the results
// computed in step 1, if you return anything else (like
// a subsequent call to result.sum()), then the gradients may be
// silently incorrect.
return {result};
// (4) Nothing else must be in the forward()! That is, there must
// be NO CALLS to other PyTorch operators (e.g. at::sum)
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
const Tensor& grad = grad_output[0];
auto saved_tensors = ctx->get_saved_variables();
// The backward pass must only consist of invocations to
// the PyTorch dispatcher.
// That is, you may (1) call at::{API} operators where API
// is in native_functions.yaml
//https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
// and (2) you may call custom operators via the dispatcher
// by invoking TypedOperatorHandle.call:
// static auto custom_sin_op = torch::Dispatcher::singleton()
// .findSchemaOrThrow("your_namespace::sin", "")
// .typed<decltype(custom_sin)>();
// Tensor result = custom_sin_op.call(x);
//
// Anything else may run into problems.
return grad * at::cos(saved_tensors[0]);
}
};
Tensor custom_sin_autograd(const Tensor& x) {
return CustomSin::apply(x);
}
// It is important that you register a call to the autograd::Function::apply
// to DispatchKey::Autograd. PLEASE DO NOT REGISTER IT TO THE CPU/CUDA/any
// other key; that leads to silent incorrectness issues.
TORCH_LIBRARY_IMPL(your_namespace, Autograd, m) {
m.impl("sin", &custom_sin_autograd);
}
[In C++] How to add an autograd formula: some nitty gritty details
The above example is helpful to get a basic understanding for how to write an autograd function with PT2 support. However, more complicated custom autograd ops may require additional changes. Consider a custom autograd op that does the following:
- Has a complicated input (list of tensors, non-tensors, etc.)
- Has a complicated return (multiple tensors, non-tensors, etc.)
- Saves a computed tensor for backward
- Saves computed integers for backward
This op is harder to support because there are a number of restrictions we need to be aware of:
- (Array of) integer input may be symbolic
- Returns can only be (tuple of) tensor
- All implementations (CPU, CUDA, Meta, Autograd) must return the same kind of output
- Tensors must have same metadata, same number of tensors returned (in the case of tuple return)
In order to work around these restrictions, we:
- Change integer inputs to be symbolic integers (symint), only if necessary
- Create a “implementation” custom autograd op, as above:
- The CPU/CUDA implementation should pack non-tensor data into tensors
- The CPU/CUDA implementation must return both tensors intended for actual output, and tensors that are saved
- The Autograd implementation saves tensors required for backward and must return them as well
- Create a backward op that will be called by the “implementation” op - the backward op does not need to be an autograd op.
- The “real” custom autograd op simply dispatches to the “implementation” op and returns only the tensors we want to outputThe “real” op should be registered to the CompositeImplicitAutograd dispatch key if both CPU and CUDA implementations of the “implementation” op go through the same autograd implementation of that op
- Otherwise, you will have to register to the AutogradCPU/AutogradCUDA keys and you may need to define a Meta function
#include <torch/library.h>
// Declare the real and implementation operators
TORCH_LIBRARY(your_namespace, m) {
m.def("custom_op(Tensor x, Tensor[] y, int a, SymInt[] b) -> Tensor[]");
m.def("custom_op_impl(Tensor x, Tensor[] y, int a, SymInt[] b) -> Tensor[]");
m.def("custom_op_backward_impl(Tensor[] grad_out, SymInt[] ints) -> Tensor[]");
}
// needs to have the same signature as forward, except for AutogradContext
variable_list custom_op_fwd_impl(
const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
...
... b[0].as_int_unchecked() ...
Tensor real_output;
Tensor for_backward;
int data;
Tensor saved_data = ???
vector<Tensor> output{real_output, for_backward, saved_data};
return output;
}
// does NOT need to have the same signature as backward
variable_list custom_op_bwd_impl(
TensorList grad_out, SymIntArrayRef ints) {
vector<Tensor> output;
...
// does NOT need to return the same as backward
return output;
}
class CustomOpImpl : public torch::autograd::Function<CustomOp> {
public:
static variable_list forward(
AutogradContext* ctx,
const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
at::AutoDispatchBelowADInplaceOrView guard;
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("your_namespace::custom_op_impl", "")
.typed<decltype(custom_op_fwd_impl)>();
auto result = op.call(x, y, a, b);
ctx->save_for_backward({a, b});
// could make additional computations here if they're "simple"
// i.e. whatever is allowed according to the previous autograd example
int simple_data;
...
ctx->saved_data["simple"] = simple_data;
// must return the same thing as custom_op_fwd_impl!
return result;
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("your_namespace::custom_op_backward_impl", "")
.typed<decltype(custom_op_bwd_impl)>();
// can create custom_op_bwd_impl arguments here
auto grad_out = op.call(...);
// e.g. may have to add additional gradients
grad_out.push_back({});
return grad_out;
}
};
variable_list custom_op_impl_meta(
const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
vector<Tensor> output;
...
return output;
}
variable_list custom_op_bwd_impl_meta(
TensorList grad_out, SymIntArrayRef ints) {
vector<Tensor> output;
...
return output;
}
variable_list custom_op_impl_autograd(
const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
return CustomOp::apply(x, y, a, b);
}
// can have a different signature than custom_op_fwd_impl
variable_list custom_op(
const Tensor& x, TensorList y) {
...
int a = ...;
vector<SymInt> b = ...;
...
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("your_namespace::custom_op_impl", "")
.typed<decltype(custom_op_fwd_impl)>();
auto result = op.call(x, y, a, b);
// we can return only the tensors we want to return
return {result[0]};
}
// do also for CPU
TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
m.impl("custom_op_impl", &custom_op_fwd_impl);
m.impl("custom_op_backward_impl", &custom_op_bwd_impl);
}
TORCH_LIBRARY_IMPL(your_namespace, Meta, m) {
m.impl("custom_op_impl", &custom_op_impl_meta);
m.impl("custom_op_backward_impl", &custom_op_bwd_impl_meta);
}
TORCH_LIBRARY_IMPL(your_namespace, Autograd, m) {
m.impl("custom_op_impl", &custom_op_impl_autograd);
}
TORCH_LIBRARY_IMPL(your_namespace, CompositeImplicitAutograd, m) {
m.impl("custom_op", &custom_op);
}
See https://github.com/pytorch/FBGEMM/pull/2076/files for a real example.
Missing Meta function in PyTorch core
Occasionally, you may run into missing meta functions for operators that are defined in core PyTorch, not custom operator. In this case, add your meta registration to torch/_meta_registrations.py, following the pattern of other elements in the file.
Appendix
Bugs related to non-contiguous tensors
You may encounter a CUDA illegal memory access error. If this error only occurs in the AOT dispatch opcheck tests, then it may be caused by a backwards kernel handling non-contiguous tensors improperly. To fix this, you should coerce the inputs to the backwards kernel to be contiguous, i.e., using Tensor.contiguous() (example: https://github.com/pytorch/FBGEMM/pull/2093/files)
Interaction with torch.compile
Some tests use torch.compile/torch._dynamo.optimize. generate_opcheck_tests does not interact well with torch.compile, so dynamo will run in disabled mode (TORCHDYNAMO_DISABLE=1)
AutogradCPU and AutogradCUDA
It is possible to have a custom operator that has different autograd behavior per-device. We generally recommend this; a better paradigm is to have two different custom operators, one for each device.
If you must have this feature, then continue reading.
The AutogradCPU key is a combination of the Autograd and CPU keys - it registers a function as a CPU implementation and an autograd implementation, but only for CPU inputs. AutogradCUDA is analogous, but for CUDA. This is useful in the case where even the autograd implementation is different between CPU and CUDA. Note that the meta implementation must be separately registered.
See https://github.com/pytorch/FBGEMM/pull/2076/files for an example. In the example, the CPU implementation is a composition of autograd ops, while the CUDA implementation is a custom autograd op.