jbilcke-hf HF Staff commited on
Commit
23edbfb
·
1 Parent(s): bcafc05

let's simplify the demo

Browse files
app.py CHANGED
@@ -45,8 +45,11 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
45
 
46
  with gr.Row():
47
  with gr.Tabs():
48
- animation(gen, chunk_size, device)
 
 
49
  img_edit(gen, device)
50
- vid_edit(gen, chunk_size, device)
 
51
 
52
  demo.launch(allowed_paths=["./data/source","./data/driving"])
 
45
 
46
  with gr.Row():
47
  with gr.Tabs():
48
+ #animation(gen, chunk_size, device)
49
+
50
+ # for this demo, let's only showcase img_edit
51
  img_edit(gen, device)
52
+
53
+ #vid_edit(gen, chunk_size, device)
54
 
55
  demo.launch(allowed_paths=["./data/source","./data/driving"])
docs/torch/Pytorch Custom Operators.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PyTorch Custom Operators[#](#pytorch-custom-operators "Link to this heading")
2
+ =============================================================================
3
+
4
+ Created On: Jun 18, 2024 | Last Updated: Jul 31, 2025 | Last Verified: Nov 05, 2024
5
+
6
+ PyTorch offers a large library of operators that work on Tensors (e.g. `torch.add`, `torch.sum`, etc). However, you may wish to bring a new custom operation to PyTorch and get it to work with subsystems like `torch.compile`, autograd, and `torch.vmap`. In order to do so, you must register the custom operation with PyTorch via the Python [torch.library docs](https://pytorch.org/docs/stable/library.html) or C++ `TORCH_LIBRARY` APIs.
7
+
8
+ Authoring a custom operator from Python[#](#authoring-a-custom-operator-from-python "Link to this heading")
9
+ -----------------------------------------------------------------------------------------------------------
10
+
11
+ Please see [Custom Python Operators](python_custom_ops.html#python-custom-ops-tutorial).
12
+
13
+ You may wish to author a custom operator from Python (as opposed to C++) if:
14
+
15
+ * you have a Python function you want PyTorch to treat as an opaque callable, especially with respect to `torch.compile` and `torch.export`.
16
+
17
+ * you have some Python bindings to C++/CUDA kernels and want those to compose with PyTorch subsystems (like `torch.compile` or `torch.autograd`)
18
+
19
+ * you are using Python (and not a C++-only environment like AOTInductor).
20
+
21
+
22
+ Integrating custom C++ and/or CUDA code with PyTorch[#](#integrating-custom-c-and-or-cuda-code-with-pytorch "Link to this heading")
23
+ -----------------------------------------------------------------------------------------------------------------------------------
24
+
25
+ Please see [Custom C++ and CUDA Operators](cpp_custom_ops.html#cpp-custom-ops-tutorial).
26
+
27
+ Note
28
+
29
+ `SYCL` serves as the backend programming language for Intel GPUs. Integrate custom Sycl code refer to [Custom SYCL Operators](cpp_custom_ops_sycl.html#cpp-custom-ops-tutorial-sycl).
30
+
31
+ You may wish to author a custom operator from C++ (as opposed to Python) if:
32
+
33
+ * you have custom C++ and/or CUDA code.
34
+
35
+ * you plan to use this code with `AOTInductor` to do Python-less inference.
36
+
37
+
38
+ The Custom Operators Manual[#](#the-custom-operators-manual "Link to this heading")
39
+ -----------------------------------------------------------------------------------
40
+
41
+ For information not covered in the tutorials and this page, please see [The Custom Operators Manual](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU) (we’re working on moving the information to our docs site). We recommend that you first read one of the tutorials above and then use the Custom Operators Manual as a reference; it is not meant to be read head to toe.
42
+
43
+ ### When should I create a Custom Operator?[#](#when-should-i-create-a-custom-operator "Link to this heading")
44
+
45
+ If your operation is expressible as a composition of built-in PyTorch operators then please write it as a Python function and call it instead of creating a custom operator. Use the operator registration APIs to create a custom operator if you are calling into some library that PyTorch doesn’t understand (e.g. custom C/C++ code, a custom CUDA kernel, or Python bindings to C/C++/CUDA extensions).
46
+
47
+ ### Why should I create a Custom Operator?[#](#why-should-i-create-a-custom-operator "Link to this heading")
48
+
49
+ It is possible to use a C/C++/CUDA kernel by grabbing a Tensor’s data pointer and passing it to a pybind’ed kernel. However, this approach doesn’t compose with PyTorch subsystems like autograd, torch.compile, vmap, and more. In order for an operation to compose with PyTorch subsystems, it must be registered via the operator registration APIs.
docs/torch/The Custom Operators Manual.md ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The Custom Operators Manual
2
+
3
+ # Read Me First
4
+
5
+ This manual is a comprehensive reference for all things related to PyTorch Custom Operators. We recommend that you [first read one of the focused tutorials listed on our landing page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) and then refer to this document as a manual for edge cases or less-recommended approaches.
6
+
7
+ The landing page: [https://pytorch.org/tutorials/advanced/custom\_ops\_landing\_page.html](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
8
+
9
+ # What is an operator?
10
+
11
+ A **kernel** is a function that accepts Tensors and/or raw pointers to memory and performs a useful computation (for example, matrix multiplication, attention, etc).
12
+
13
+ An **operator** is glue code for the PyTorch runtime that tells it about the computation. A single operator can be associated with multiple kernels (for example, torch.add has a kernel for CPU and a kernel for CUDA). The glue code is necessary to get PyTorch subsystems (like torch.compile and torch.autograd) to compose with the computation.
14
+
15
+ Standalone kernels may work directly with PyTorch but will not compose with the majority of PyTorch subsystems. In order to get them to compose, please register an operator for them.
16
+
17
+ # How to make existing operators work with torch.compile.
18
+
19
+ ## TL;DR
20
+
21
+ Call an operator **pt2\_compliant** if it works with the new PyTorch compilation APIs (torch.compile, torch.export, etc) introduced in PyTorch 2.x.
22
+
23
+ This is a two-step process:
24
+
25
+ Step 1: Test the custom op with torch.library.opcheck.
26
+ Step 2: Fix all problems with the custom op (until the opcheck passes)
27
+ Step 3: Mark the custom op as PT2 compliant.
28
+
29
+ ## Step 1: How to use opcheck to test the custom op
30
+
31
+ You have two options: manually use opcheck to test the custom op, or (Meta-only) use \`generate\_opcheck\_tests\` to automatically test the custom op. If the custom op already is covered by one of these two mechanisms, skip ahead to step 2\.
32
+
33
+ ### Step 1a: How to manually use opcheck to test the custom op
34
+
35
+ Please call opcheck multiple times with different representative sample inputs:
36
+
37
+ - If your operator works on CPU and CUDA, please pass a set of sample inputs on CPU and a set of sample inputs on CUDA
38
+ - If your operator supports training, please pass some sample inputs with requires\_grad=True.
39
+
40
+ Using the operator torch.ops.aten.sin.default as an example:
41
+
42
+ ```py
43
+ import torch
44
+ import unittest
45
+ from torch.library import opcheck
46
+
47
+ def sin_sample_inputs():
48
+ sample_inputs = [
49
+ (torch.randn(3, requires_grad=True, device='cpu'),),
50
+ (torch.randn(3, requires_grad=True, device='cuda'),),
51
+ ]
52
+
53
+ class TestOps(unittest.TestCase):
54
+ def test_sin(self):
55
+ sample_inputs = sin_sample_inputs()
56
+ for i in range(len(sample_inputs)):
57
+ opcheck(torch.ops.aten.sin.default, sample_inputs[i])
58
+ ```
59
+
60
+ ### Step 2b: How to automatically use opcheck to test the custom op
61
+
62
+ **Please only use this if you work at Meta. While this API is included with PyTorch, we do not guarantee BC for it.**
63
+
64
+ Use this approach (generate\_opcheck\_tests) only if:
65
+
66
+ - You have a large collection of existing tests that exercise multiple custom ops that you would like to be tested
67
+ - You are willing to put up with the sharp edges in this API
68
+
69
+ Please see [Working with generate\_opcheck\_tests](https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit#heading=h.m2x0ozzf1awo) for more details, and [https://github.com/pytorch/FBGEMM/pull/2050/files\#diff-c1c25e22107028a66ff548c7042ba3f39bcc009db9348825e46b15f60754cbffR2452-R2467](https://github.com/pytorch/FBGEMM/pull/2050/files#diff-c1c25e22107028a66ff548c7042ba3f39bcc009db9348825e46b15f60754cbffR2452-R2467) for an example.
70
+
71
+ ## Step 2: How to fix failing opcheck tests
72
+
73
+ If opcheck fails: please try to pass the tests in the following order.
74
+
75
+ ### 1\. test\_schema fails
76
+
77
+ ```py
78
+ opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_schema")
79
+ ```
80
+
81
+ This means that the schema of the operator is wrong and will lead to silent incorrectness issues. All operators have a schema string ([example](https://github.com/pytorch/pytorch/blob/050c56d0a5c0982e13255447f5fb7d2949c02407/aten/src/ATen/native/native_functions.yaml#L5886)) that specifies the types of the inputs and outputs as well as some “aliasing information”:
82
+
83
+ - if any outputs are views of the inputs
84
+ - if any inputs are mutated in-place
85
+
86
+ The fix is usually to update the schema to include the aliasing information. See [https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md\#annotations](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations) for more details.
87
+
88
+ ### 2\. test\_autograd\_registration fails
89
+
90
+ ```py
91
+ opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_autograd_registration")
92
+ ```
93
+
94
+ This means that the autograd registration is incorrect and will lead to silent incorrectness issues. Some common symptoms are correct gradients without using torch.compile but incorrect gradients using torch.compile.
95
+
96
+ Please see [How to to add an autograd formula](#how-to-add-an-autograd-formula) for how the autograd registration should look like.
97
+
98
+ ### 3\. test\_faketensor fails
99
+
100
+ ```py
101
+ opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_faketensor")
102
+ ```
103
+
104
+ PT2 compilation APIs use “Fake Tensors”, Tensors without storage, to propagate metadata. Every operator needs a “Fake Tensor kernel” which is like a shape formula: given the shapes (and other metadata) of the inputs, it should specify the shapes (and other metadata) of the outputs.
105
+
106
+ This opcheck test could fail for two reasons:
107
+
108
+ - UnsupportedFakeTensorException: you didn’t write an abstract\_impl or a meta formula. Please see [How to add abstract impl / meta formula](#how-to-add-faketensor-support-\(abstract-impl;-meta-kernel\)) for how to write one.
109
+ - Your abstract\_impl/meta formula is wrong. Please debug it.
110
+
111
+ ### 4\. test\_aot\_autograd\_static fails
112
+
113
+ ```py
114
+ opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_aot_autograd_static")
115
+ ```
116
+
117
+ If this test succeeds, then the operator works with torch.compile(dynamic=False).
118
+ If the operator in question returns Tensors with data-dependent shapes, then this test is expected to fail.
119
+
120
+ Otherwise, if it fails, there are some common reasons, listed below. Please match your error message to one of them.
121
+
122
+ #### No DispatchKey::Functionalize (functionalization) kernel
123
+
124
+ torch.compile backends only support functional operators (that is, operators that do not mutate inputs and do not return views of the inputs). If your operator is not functional (indicated by an “(a)” in the schema string), then we need to teach torch.compile how to functionalize your operator.
125
+
126
+ For now, come find us for help if this is the case.
127
+
128
+ #### There is a .item call somewhere
129
+
130
+ When PT2 APIs like torch.compile see a C++ .item() call, they don’t know what to do.
131
+
132
+ Please either:
133
+
134
+ - rewrite your custom operator to not use the offending .item() call
135
+ - hide the .item() call in a new custom operator
136
+
137
+ #### Backward formula isn’t traceable
138
+
139
+ PyTorch assumes that the backward pass of your operator only consists of invocations to the PyTorch dispatcher.
140
+
141
+ That is, the backward pass may call:
142
+
143
+ 1. **Built-in PyTorch operators**.
144
+ 1. In Python
145
+ 1. torch.\* APIs
146
+ 2. In C++
147
+ 1. at::{API} operators where API is in [native\_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)
148
+ 2. Tensor metadata operations. E.g. Tensor::sizes() / Tensor::strides()
149
+ 2. **Custom Operators**. All calls to custom operators must be made through the PyTorch dispatcher.
150
+ 1. If they are being invoked from Python, you don’t need to do anything.
151
+ 2. If they are being invoked from C++, they must query the PyTorch Dispatcher for a TypedOperatorHandle and invoke TypedOperatorHandle::call:
152
+
153
+ ```c
154
+ static auto custom_sin_op = torch::Dispatcher::singleton()
155
+ .findSchemaOrThrow("custom::sin", "")
156
+ .typed<decltype(custom_sin)>();
157
+ Tensor result = custom_sin_op.call(x);
158
+ ```
159
+
160
+ Please see [How to to add an autograd formula](#how-to-add-an-autograd-formula) for more details
161
+
162
+ ### 5\. test\_aot\_autograd\_dynamic fails
163
+
164
+ ```py
165
+ opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_aot_autograd_dynamic")
166
+ ```
167
+
168
+ This generally means that your operator doesn’t support Dynamic Shapes, especially if “test\_aot\_autograd\_static” succeeds.
169
+
170
+ Please see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.dpddt0fbqk6u) for what to do.
171
+
172
+ ### Detailed description of the opcheck tests
173
+
174
+ #### test\_schema
175
+
176
+ We test that the schema matches the implementation of the operator. For example: if the schema specifies a Tensor is mutated, then we check the implementation mutates the Tensor. If the schema specifies that we return a new Tensor, then we check that the implementation returns a new Tensor (instead of an existing one or a view of an existing one).
177
+
178
+ Note that the schema language might look simple, but it encodes a lot of information (about mutations, new tensors, and aliases) that is easy to get wrong.
179
+
180
+ #### test\_autograd\_registration
181
+
182
+ If the operator supports training (autograd): we check that its autograd formula is registered via torch.library.register\_autograd or a manual registration to one or more DispatchKey::Autograd keys. Any other DispatchKey-based registrations may lead to undefined behavior.
183
+
184
+ #### test\_faketensor
185
+
186
+ We check that a FakeTensor kernel (also sometimes known as a meta kernel) was registered for the operator and that it is correct. This test takes the result of running the operator on real tensors and the result of running the operator on FakeTensors and checks that they have the same Tensor metadata (sizes/strides/dtype/device/etc).
187
+
188
+ #### test\_aot\_dispatch\_dynamic
189
+
190
+ This test checks multiple things:
191
+
192
+ - We check that the operator supports functionalization. That is, it is functional or can automatically be functionalized by torch.compile.
193
+ - If the operator supports training, we check that the backward pass supports FakeTensor and functionalization.
194
+
195
+ This test is effectively an e2e run with torch.compile(backend=”aot\_eager”) (through Dynamo+AOTDispatcher) that does something like:
196
+
197
+ ```py
198
+ outs = op(*args)
199
+ outs_compiled = torch.compile(op, backend="aot_eager")(*args)
200
+ torch.testing.assert_close(outs_compiled, outs)
201
+
202
+ if supports_training(op):
203
+ grad_args = torch.autograd.grad(outs, args, ...)
204
+ grad_args_compiled = torch.autograd.grad(outs_compiled, args, ...)
205
+ torch.testing.assert_close(grad_args_compiled, grad_args)
206
+ ```
207
+
208
+ ## Step 3: How to mark the op as PT2 compliant
209
+
210
+ ### If the operator was defined in Python
211
+
212
+ #### Using torch.library.custom\_op
213
+
214
+ Custom ops created with torch.library.custom\_op automatically get this tag.
215
+
216
+ #### Using torch.library.define
217
+
218
+ ```py
219
+ lib.define("sin(Tensor x) -> Tensor", tags=[torch.Tag.pt2_compliant_tag]);
220
+ ```
221
+
222
+ ### If the operator was defined in C++
223
+
224
+ Where the operator was defined, add the at::Tag::pt2\_compliant\_tag:
225
+
226
+ ```c
227
+ m.def("sin(Tensor x) -> Tensor", {at::Tag::pt2_compliant_tag});
228
+ ```
229
+
230
+ This documents the operator as PT2 compliant. Please only add this tag if the operator passes the opcheck tests.
231
+
232
+ # Writing a new Custom Operator
233
+
234
+ ## What is a “kernel”? What is an “operator”?
235
+
236
+ A **kernel** is a function that accepts Tensors and/or raw pointers to memory and performs a useful computation (for example, matrix multiplication, attention, etc).
237
+
238
+ An **operator** is glue code for the PyTorch runtime that tells it about the computation. A single operator can be associated with multiple kernels (for example, torch.add has a kernel for CPU and a kernel for CUDA). The glue code is necessary to get PyTorch subsystems (like torch.compile and torch.autograd) to compose with the computation.
239
+
240
+ ## When should I create a Custom Operator?
241
+
242
+ You may wish to create a Custom Operator for two reasons:
243
+
244
+ - You have some custom CPU/CUDA/other backend kernel that you’ like to integrate with PyTorch
245
+ - You have some code that you want PyTorch to treat as an opaque callable (as a black-box).
246
+
247
+ For example, you may want to call out to some low-level third-party library like LAPACK or CUBLAS, or you may have written a bunch of CUDA kernels in .cu files.
248
+
249
+ Your custom operator kernels should include as few PyTorch built-in operators as possible. Including built-in PyTorch operators in a C++ custom operator hides them from PyTorch subsystems like torch.compile, which hides optimization opportunities.
250
+
251
+ **If what you are trying to do is expressible as a composition of built-in PyTorch operators** (and do not involve low-level C/C++/CUDA code or third-party Python libraries), then please write your routine as a Python function and call it instead of creating a custom operator.
252
+
253
+ ## Python or C++?
254
+
255
+ You can define custom operators in both Python and C++. These registration APIs may be mixed: for example, one can define an operator’s CPU kernel from Python and CUDA kernel from C++.
256
+
257
+ Our general guidance is:
258
+
259
+ - If you care about AOTInductor (and being able to run in a Python-less environment), you should define the operator and add backend kernels in C++.
260
+ - Otherwise, it is generally easier to use the Python custom operator registration APIs.
261
+
262
+ ## How to define a custom operator
263
+
264
+ To define an operator, you must tell us:
265
+
266
+ - The name of the operator
267
+ - Some metadata around the acceptable input/output types of the operator and if any inputs are being mutated
268
+
269
+ ### \[From Python\] How to define a custom operator
270
+
271
+ #### \[PyTorch \>= 2.4\] Using torch.library.custom\_op
272
+
273
+ Use torch.library.custom\_op to decorate a function to turn it into a custom operator. The function must be decorated with type annotations, and you must correctly annotate inputs that are being mutated.
274
+
275
+ ```py
276
+ @torch.library.custom_op("your_namespace::sin", mutates_args=())
277
+ def sin(x: torch.Tensor) -> torch.Tensor:
278
+ return torch.from_numpy(np.sin(x.numpy(force=True))
279
+ ```
280
+
281
+ #### \[PyTorch \< 2.4\] Using torch.library.define
282
+
283
+ To define an operator, you must tell us:
284
+
285
+ - The name of the operator
286
+ - The [schema string](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func) of the operator. The spec for this schema is defined [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func), with multiple examples in [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml).
287
+
288
+ ```py
289
+ torch.library.define("your_namespace::sin(Tensor x) -> Tensor")
290
+ ```
291
+
292
+ ### \[From C++\] How to define a custom operator
293
+
294
+ To define an operator, you must tell us:
295
+
296
+ - The name of the operator
297
+ - The [schema string](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func) of the operator. The spec for this schema is defined [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func), with multiple examples in [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml).
298
+
299
+ Let’s go through an example of a custom sin operator.
300
+
301
+ ```c
302
+ #include <torch/library.h>
303
+
304
+ // Define the operator
305
+ TORCH_LIBRARY(your_namespace, m) {
306
+ m.def("sin(Tensor x) -> Tensor");
307
+ }
308
+ ```
309
+
310
+ If you define operator schemas in multiple places, use TORCH\_LIBRARY\_FRAGMENT instead of TORCH\_LIBRARY.
311
+
312
+ ## How to add CPU/CUDA/Backend implementations
313
+
314
+ ### \[From Python\] How to add CPU/CUDA/Backend implementations
315
+
316
+ #### \[PyTorch \>= 2.4\]
317
+
318
+ Use torch.library.register\_kernel.
319
+
320
+ ```py
321
+ @torch.library.register_kernel("your_namespace::sin", "cpu")
322
+ def _(x: torch.Tensor) -> torch.Tensor:
323
+ # your CPU implementation
324
+ ...
325
+
326
+ @torch.library.register_kernel("your_namespace::sin", "cuda")
327
+ def _(x: torch.Tensor) -> torch.Tensor:
328
+ # your CUDA implementation
329
+ ...
330
+ ```
331
+
332
+ #### \[PyTorch \< 2.4\]
333
+
334
+ Use torch.library.impl
335
+
336
+ ```py
337
+ @torch.library.impl("your_namespace::sin", "cpu")
338
+ def _(x: torch.Tensor) -> torch.Tensor:
339
+ # your CPU implementation
340
+ ...
341
+
342
+ @torch.library.impl("your_namespace::sin", "cuda")
343
+ def _(x: torch.Tensor) -> torch.Tensor:
344
+ # your CUDA implementation
345
+ ...
346
+ ```
347
+
348
+ ### \[From C++\] How to add CPU/CUDA/Backend implementations
349
+
350
+ To provide backend-specific implementations for an operator, use TORCH\_LIBRARY\_IMPL.
351
+
352
+ ```c
353
+
354
+ Tensor custom_sin_cpu(const Tensor& x) {
355
+ // Replace this with at::sin if you want to test it out.
356
+ return my_custom_sin_implementation_on_cpu(x);
357
+ }
358
+
359
+ // Register the CPU implementation for the operator
360
+ TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
361
+ m.impl("sin", &custom_sin_cpu);
362
+ }
363
+
364
+ Tensor custom_sin_cuda(const Tensor& x) {
365
+ // Replace this with at::sin if you want to test it out.
366
+ return my_custom_sin_implementation_on_cuda(x);
367
+ }
368
+
369
+ // Register the CUDA implementation for the operator
370
+ TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
371
+ m.impl("sin", &custom_sin_cuda);
372
+ }
373
+ ```
374
+
375
+ ## How to invoke a custom operator
376
+
377
+ ### How to invoke a custom op defined in Python from Python
378
+
379
+ When you created a custom operator, you gave it a name. The custom operator is findable under torch.ops:
380
+
381
+ ```py
382
+ x = torch.randn(3)
383
+ y = torch.ops.your_namespace.sin(x)
384
+ ```
385
+
386
+ ##
387
+
388
+ ### How to invoke a custom op defined in C++ from C++ {#how-to-invoke-a-custom-op-defined-in-c++-from-c++}
389
+
390
+ ```c
391
+ static auto custom_sin_op = torch::Dispatcher::singleton()
392
+ .findSchemaOrThrow("your_namespace::sin", "")
393
+ .typed<decltype(custom_sin_cpu)>();
394
+ Tensor result = custom_sin_op.call(x)
395
+ ```
396
+
397
+ In order to invoke the custom operator, we must first query it from the PyTorch dispatcher and then invoke it.
398
+
399
+ ### How to invoke a custom op defined in C++ from Python
400
+
401
+ The C++ custom op gets compiled into a shared library. Use torch.ops.load\_library(path\_to\_shared\_library) to load the shared library.
402
+
403
+ Once the shared library has loaded, the custom op is available from the torch.ops namespace:
404
+
405
+ ```py
406
+ x = torch.randn(3)
407
+ y = torch.ops.your_namespace.sin(x)
408
+ assert torch.allclose(y, torch.sin(x))
409
+ ```
410
+
411
+ ## How to add FakeTensor support (abstract impl; meta kernel) {#how-to-add-faketensor-support-(abstract-impl;-meta-kernel)}
412
+
413
+ In order for your operator to work with PT2, it must have FakeTensor support. There are around three blessed ways to do this:
414
+
415
+ - (Preferred) Write a FakeTensor kernel using the torch.library.register\_fake / torch.library.impl\_abstract API from Python
416
+ - Write a C++ Meta kernel
417
+ - If your operator is registered to CompositeImplicitAutograd, it will automatically decompose and we require its constituents to support FakeTensor.
418
+
419
+ ### Context: meta kernel vs FakeTensor kernels
420
+
421
+ In order for your operator to work with PT2, it must have FakeTensor support. That is, we must know how to run the operator on “Fake” input Tensors that do not have storage, but have sizes/strides/device.
422
+
423
+ Adding a **meta kernel** (a function that describes how an operator works with device=’meta’) will automatically generate FakeTensor support. However, meta kernels don’t support the full range of things FakeTensors do. For example, operators with data-dependent output shape (think torch.nonzero) and operators with cross-device semantics (like Tensor.to(device=”cuda”)) are not describable with meta functions.
424
+
425
+ Instead of writing a meta function, **our recommendation is to write FakeTensor kernels**, which are a generalization of meta functions to support FakeTensors. Writing a FakeTensor kernel is very similar to writing a meta kernel: in most cases, the meta kernel can be re-used as the FakeTensor kernel impl.
426
+
427
+ NB: We also sometimes use “abstract impl” to refer to a “FakeTensor kernel”. These are the same thing.
428
+
429
+ ### How to write a Python FakeTensor kernel {#how-to-write-a-python-faketensor-kernel}
430
+
431
+ There are three parts to this:
432
+
433
+ 1. In Python, use torch.library.register\_fake (PyTorch 2.4+) or torch.library.impl\_abstract (PyTorch \<= 2.3) to provide the FakeTensor kernel for an operator
434
+ 2. In the Python program that uses the custom operator, import the module that includes the FakeTensor kernel registration from step 1\.
435
+
436
+ If you work at Meta, please see [https://fburl.com/python\_meta\_example](https://fburl.com/python_meta_example) for an example.
437
+
438
+ #### Step 1: Use torch.library.register\_fake
439
+
440
+ An "FakeTensor kernel" specifies the behavior of this operator on Tensors that carry no data. Given some input Tensors with certain properties (sizes/strides/storage\_offset/device), it specifies what the properties of the output Tensors are.
441
+
442
+ The FakeTensor kernel has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write an FakeTensor kernel, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The FakeTensor kernel must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).
443
+
444
+ ```py
445
+ # your_module.py
446
+ import torch
447
+ torch.ops.load_library("path/to/shared/lib/that/has/your_cpp_file")
448
+
449
+ # Write the FakeTensor kernel
450
+ @torch.library.register_fake("your_namespace::sin")
451
+ def sin_abstract(x):
452
+ # torch.empty_like(x) returns a Tensor on the same device as `x`.
453
+ # If you instead want to hardcode the device of the output, no matter the
454
+ # device of the input,
455
+ # manually specify it like torch.empty_like(x, device="cuda")
456
+ return torch.empty_like(x)
457
+ ```
458
+
459
+ #### Step 2: Import the module that contains the torch.library.register\_fake call
460
+
461
+ ```py
462
+ import torch
463
+ torch.ops.import_module("your_module")
464
+
465
+ @torch.compile(backend="eager")
466
+ def f(x):
467
+ return torch.ops.your_namespace.sin(x)
468
+
469
+ x = torch.randn(3)
470
+ f(x)
471
+ ```
472
+
473
+ #### (PyTorch \<=2.3 only) Add an abstract impl pystub (if one doesn’t already exist)
474
+
475
+ The operator will complain during testing if it needs an impl abstract pystub. In that case, add a \`m.impl\_abstract\_pystub\` call to the TORCH\_LIBRARY block that the operator was defined in (e.g. with the **m.def(** call).
476
+
477
+ ```c
478
+ // your_cpp_file.cpp
479
+ TORCH_LIBRARY(your_namespace, m) {
480
+ // Leave a stub that tells the C++ PyTorch Dispatcher that the
481
+ // abstract impl exists in a given Python module.
482
+ // This will prevent desync issues (where someone loads the operator
483
+ // without actually loading the Python module).
484
+ //
485
+ // impl_abstract_pystub(module_name, buck_target):
486
+ // - module name: the name of the Python module the abstract impl resides in
487
+ // - buck_target (optional): If you're using a buck-based build system,
488
+ // then you can include the name of the buck target that includes
489
+ // the module here, otherwise it is optional.
490
+ // We use the module name and the buck target to give better error messages
491
+ m.impl_abstract_pystub("your_module", "//your_module:custom_ops");
492
+ m.def("sin(Tensor x) -> Tensor");
493
+ }
494
+ ```
495
+
496
+ The pystub applies to all operators registered in the given TORCH\_LIBRARY block.
497
+
498
+ We removed the need for this in PyTorch 2.4+
499
+
500
+ ### How to write FakeTensor kernel for operator with data-dependent output shape
501
+
502
+ Use torch.library.get\_ctx().new\_dynamic\_size() to allocate data-dependent output sizes. For example, a “nonzero” operator returns a Tensor with shape (number\_of\_nonzero\_elements, dim):
503
+
504
+ ```py
505
+ @torch.library.register_fake("your_namespace::nonzero")
506
+ def nonzero_abstract(x):
507
+ nnz = torch.library.get_ctx().new_dynamic_size()
508
+ return x.new_empty(nnz, x.dim(), dtype=torch.long)
509
+ ```
510
+
511
+ ### How to write C++ meta kernel
512
+
513
+ You should seriously consider writing a [Python FakeTensor kernell](#how-to-write-a-python-faketensor-kernel) instead; the API is more generic (e.g. supports cross-device and data-dependent output shape), has fewer footguns (it is automatically symint-ified), you skip recompile cycles, and the resulting kernel is easier to debug.
514
+
515
+ C++ meta kernel example:
516
+
517
+ ```c
518
+ Tensor sin_meta(const Tensor& x) {
519
+ return torch.empty_like(x)
520
+ }
521
+
522
+ TORCH_LIBRARY_IMPL(your_namespace, Meta, m) {
523
+ m.impl("sin", &sin_meta);
524
+ }
525
+
526
+ ```
527
+
528
+ ### FakeTensor/meta kernels for CompositeImplicitAutograd operators
529
+
530
+ If your custom op consists only of calls to the PyTorch dispatcher (at:: operators in native\_functions.yaml and custom op calls via TypedOperatorHandle.call, as in [How to invoke the custom op from C++](#how-to-invoke-a-custom-op-defined-in-c++-from-c++)), then you can SymInt’ify your operator according to [The dynamic shapes manual](https://docs.google.com/document/u/0/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit) and register your op to the CompositeImplicitAutograd dispatch key.
531
+
532
+ The advantage is that you do not need to define an autograd function (like in [How to add an autograd formula](#how-to-add-an-autograd-formula)) or a meta implementation.
533
+
534
+ Before:
535
+
536
+ ```c
537
+ TORCH_LIBRARY(your_namespace, m) {
538
+ m.def("my_op(Tensor x, int[] shape) -> Tensor");
539
+ }
540
+
541
+ Tensor my_op_impl(const Tensor& x, IntArrayRef shape) {
542
+ Tensor y = at::sin(x);
543
+ // suppose custom op my_op2 has signature
544
+ // (Tensor x, int[] shape) -> Tensor
545
+ return my_op2(y, shape);
546
+ }
547
+
548
+ TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
549
+ m.impl("my_op", &my_op_impl);
550
+ }
551
+
552
+ TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
553
+ m.impl("my_op", &my_op_impl);
554
+ }
555
+ ```
556
+
557
+ After:
558
+
559
+ ```c
560
+ TORCH_LIBRARY(your_namespace, m) {
561
+ m.def("my_op(Tensor x, SymInt[] shape) -> Tensor");
562
+ }
563
+
564
+ Tensor my_op_impl(const Tensor& x, SymIntArrayRef shape) {
565
+ Tensor y = at::sin(x);
566
+ // suppose custom op my_op2 has signature
567
+ // (Tensor x, int[] shape) -> Tensor
568
+ static auto my_op2_op = torch::Dispatcher::singleton()
569
+ .findSchemaOrThrow("your_namespace::my_op2", "")
570
+ .typed<decltype(my_op2)>();
571
+ return my_op2_op.call(y, shape);
572
+ }
573
+
574
+ TORCH_LIBRARY_IMPL(your_namespace, CompositeImplicitAutograd, m) {
575
+ m.impl("my_op", &my_op_impl);
576
+ }
577
+ ```
578
+
579
+ ## How to add an autograd formula {#how-to-add-an-autograd-formula}
580
+
581
+ In order for your custom operator to work with training, it must have an autograd formula. **You can register this from either Python or C++; we recommend doing this from Python.**
582
+
583
+ ### (Recommended) \[In Python\] Adding an autograd formula (PyTorch 2.4+ only)
584
+
585
+ Use torch.library.register\_autograd to add an autograd formula for an operator. Please see the documentation for torch.library.register\_autograd for more information.
586
+
587
+ ### \[In C++\] Adding an autograd formula
588
+
589
+ WARNING: this approach has a lot of footguns\! Use with caution; incorrect usage will result in silent incorrectness (in both eager-mode PyTorch and with torch.compile).
590
+
591
+ To add training support, please:
592
+
593
+ - Construct a C++ torch::autograd::Function with a forward() and a backward() pass:
594
+ - The forward pass must (1) save Tensors/data for backward and (2) re-dispatch to the operator (please see example)
595
+ - If your operator's backward pass is a custom kernel, then it should be invoked through a custom operator.
596
+ - Register this torch::autograd::Function to DispatchKey::Autograd. **It is an error (and will be silently incorrect) if you register it to DispatchKey::CPU/CUDA/anything else**.
597
+
598
+ Below is an example of a custom sin operator:
599
+
600
+ ```c
601
+ #include <torch/library.h>
602
+
603
+ // Declare the operator
604
+ TORCH_LIBRARY(your_namespace, m) {
605
+ m.def("sin(Tensor x) -> Tensor");
606
+ }
607
+
608
+ // Add the CPU implementation for the operator
609
+ Tensor custom_sin_cpu(const Tensor& x) {
610
+ // Replace this with at::sin if you want to test it out.
611
+ return my_custom_sin_implementation_on_cpu(x);
612
+ }
613
+
614
+ TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
615
+ m.impl("sin", &custom_sin_cpu);
616
+ }
617
+
618
+ // Add the CUDA implementation for the operator
619
+ Tensor custom_sin_cuda(const Tensor& x) {
620
+ // Replace this with at::sin if you want to test it out.
621
+ return my_custom_sin_implementation_on_cuda(x);
622
+ }
623
+
624
+ TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
625
+ m.impl("sin", &custom_sin_cuda);
626
+ }
627
+ ```
628
+
629
+ Now, let’s add a backward formula:
630
+
631
+ ```c
632
+ // To register a backward formula for it, we need to construct a
633
+ // torch::autograd::Function.
634
+ class CustomSin : public torch::autograd::Function<CustomSin> {
635
+ public:
636
+ static variable_list forward(
637
+ AutogradContext* ctx,
638
+ const Tensor& x) {
639
+ // It is important that the forward looks like the following.
640
+ // If you do anything else, the operator may be silently incorrect!
641
+
642
+ // (1) You must construct this guard and then invoke the Dispatcher
643
+ // on this operator.
644
+ // We refer to this sometimes as a "redispatch".
645
+ // The following lines will Dispatch past Autograd and call a backend
646
+ // implementation for custom::sin, i.e. either custom_sin_cpu
647
+ // or custom_sin_cuda.
648
+ at::AutoDispatchBelowADInplaceOrView guard;
649
+ static auto custom_sin_op = torch::Dispatcher::singleton()
650
+ .findSchemaOrThrow("your_namespace::sin", "")
651
+ .typed<decltype(custom_sin_cpu)>();
652
+ Tensor result = custom_sin_op.call(x);
653
+
654
+ // (2) You may save Tensors or other data (like the shape of the Tensor)
655
+ // for backwards via one call to ctx->save_for_backward and
656
+ // one or more calls to ctx->saved_data
657
+ ctx->save_for_backward({x});
658
+
659
+ // (3) Finally, return the result of the operator computed in step 1
660
+ // as a flat list of Tensors. You may ONLY RETURN the results
661
+ // computed in step 1, if you return anything else (like
662
+ // a subsequent call to result.sum()), then the gradients may be
663
+ // silently incorrect.
664
+ return {result};
665
+
666
+ // (4) Nothing else must be in the forward()! That is, there must
667
+ // be NO CALLS to other PyTorch operators (e.g. at::sum)
668
+ }
669
+
670
+ static variable_list backward(
671
+ AutogradContext* ctx,
672
+ variable_list grad_output) {
673
+ const Tensor& grad = grad_output[0];
674
+ auto saved_tensors = ctx->get_saved_variables();
675
+
676
+ // The backward pass must only consist of invocations to
677
+ // the PyTorch dispatcher.
678
+ // That is, you may (1) call at::{API} operators where API
679
+ // is in native_functions.yaml
680
+ //https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
681
+ // and (2) you may call custom operators via the dispatcher
682
+ // by invoking TypedOperatorHandle.call:
683
+ // static auto custom_sin_op = torch::Dispatcher::singleton()
684
+ // .findSchemaOrThrow("your_namespace::sin", "")
685
+ // .typed<decltype(custom_sin)>();
686
+ // Tensor result = custom_sin_op.call(x);
687
+ //
688
+ // Anything else may run into problems.
689
+ return grad * at::cos(saved_tensors[0]);
690
+ }
691
+ };
692
+
693
+ Tensor custom_sin_autograd(const Tensor& x) {
694
+ return CustomSin::apply(x);
695
+ }
696
+
697
+ // It is important that you register a call to the autograd::Function::apply
698
+ // to DispatchKey::Autograd. PLEASE DO NOT REGISTER IT TO THE CPU/CUDA/any
699
+ // other key; that leads to silent incorrectness issues.
700
+ TORCH_LIBRARY_IMPL(your_namespace, Autograd, m) {
701
+ m.impl("sin", &custom_sin_autograd);
702
+ }
703
+ ```
704
+
705
+ ### \[In C++\] How to add an autograd formula: some nitty gritty details
706
+
707
+ The above example is helpful to get a basic understanding for how to write an autograd function with PT2 support. However, more complicated custom autograd ops may require additional changes. Consider a custom autograd op that does the following:
708
+
709
+ * Has a complicated input (list of tensors, non-tensors, etc.)
710
+ * Has a complicated return (multiple tensors, non-tensors, etc.)
711
+ * Saves a computed tensor for backward
712
+ * Saves computed integers for backward
713
+
714
+ This op is harder to support because there are a number of restrictions we need to be aware of:
715
+
716
+ * (Array of) integer input may be symbolic
717
+ * Returns can only be (tuple of) tensor
718
+ * All implementations (CPU, CUDA, Meta, Autograd) must return the same kind of output
719
+ * Tensors must have same metadata, same number of tensors returned (in the case of tuple return)
720
+
721
+ In order to work around these restrictions, we:
722
+
723
+ * Change integer inputs to be symbolic integers (symint), only if necessary
724
+ * Create a “implementation” custom autograd op, as above:
725
+ * The CPU/CUDA implementation should pack non-tensor data into tensors
726
+ * The CPU/CUDA implementation must return both tensors intended for actual output, and tensors that are saved
727
+ * The Autograd implementation saves tensors required for backward and must return them as well
728
+ * Create a backward op that will be called by the “implementation” op \- the backward op does not need to be an autograd op.
729
+ * The “real” custom autograd op simply dispatches to the “implementation” op and returns only the tensors we want to outputThe “real” op should be registered to the CompositeImplicitAutograd dispatch key if both CPU and CUDA implementations of the “implementation” op go through the same autograd implementation of that op
730
+ * Otherwise, you will have to register to the AutogradCPU/AutogradCUDA keys and you may need to define a Meta function
731
+
732
+ ```c
733
+ #include <torch/library.h>
734
+
735
+ // Declare the real and implementation operators
736
+ TORCH_LIBRARY(your_namespace, m) {
737
+ m.def("custom_op(Tensor x, Tensor[] y, int a, SymInt[] b) -> Tensor[]");
738
+ m.def("custom_op_impl(Tensor x, Tensor[] y, int a, SymInt[] b) -> Tensor[]");
739
+ m.def("custom_op_backward_impl(Tensor[] grad_out, SymInt[] ints) -> Tensor[]");
740
+ }
741
+
742
+ // needs to have the same signature as forward, except for AutogradContext
743
+ variable_list custom_op_fwd_impl(
744
+ const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
745
+ ...
746
+ ... b[0].as_int_unchecked() ...
747
+ Tensor real_output;
748
+ Tensor for_backward;
749
+ int data;
750
+ Tensor saved_data = ???
751
+ vector<Tensor> output{real_output, for_backward, saved_data};
752
+ return output;
753
+ }
754
+
755
+ // does NOT need to have the same signature as backward
756
+ variable_list custom_op_bwd_impl(
757
+ TensorList grad_out, SymIntArrayRef ints) {
758
+ vector<Tensor> output;
759
+ ...
760
+ // does NOT need to return the same as backward
761
+ return output;
762
+ }
763
+
764
+ class CustomOpImpl : public torch::autograd::Function<CustomOp> {
765
+ public:
766
+ static variable_list forward(
767
+ AutogradContext* ctx,
768
+ const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
769
+ at::AutoDispatchBelowADInplaceOrView guard;
770
+ static auto op = torch::Dispatcher::singleton()
771
+ .findSchemaOrThrow("your_namespace::custom_op_impl", "")
772
+ .typed<decltype(custom_op_fwd_impl)>();
773
+ auto result = op.call(x, y, a, b);
774
+ ctx->save_for_backward({a, b});
775
+
776
+ // could make additional computations here if they're "simple"
777
+ // i.e. whatever is allowed according to the previous autograd example
778
+ int simple_data;
779
+ ...
780
+ ctx->saved_data["simple"] = simple_data;
781
+ // must return the same thing as custom_op_fwd_impl!
782
+ return result;
783
+ }
784
+
785
+ static variable_list backward(
786
+ AutogradContext* ctx,
787
+ variable_list grad_output) {
788
+ static auto op = torch::Dispatcher::singleton()
789
+ .findSchemaOrThrow("your_namespace::custom_op_backward_impl", "")
790
+ .typed<decltype(custom_op_bwd_impl)>();
791
+ // can create custom_op_bwd_impl arguments here
792
+ auto grad_out = op.call(...);
793
+ // e.g. may have to add additional gradients
794
+ grad_out.push_back({});
795
+ return grad_out;
796
+ }
797
+ };
798
+
799
+ variable_list custom_op_impl_meta(
800
+ const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
801
+ vector<Tensor> output;
802
+ ...
803
+ return output;
804
+ }
805
+
806
+ variable_list custom_op_bwd_impl_meta(
807
+ TensorList grad_out, SymIntArrayRef ints) {
808
+ vector<Tensor> output;
809
+ ...
810
+ return output;
811
+ }
812
+
813
+ variable_list custom_op_impl_autograd(
814
+ const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
815
+ return CustomOp::apply(x, y, a, b);
816
+ }
817
+
818
+ // can have a different signature than custom_op_fwd_impl
819
+ variable_list custom_op(
820
+ const Tensor& x, TensorList y) {
821
+ ...
822
+ int a = ...;
823
+ vector<SymInt> b = ...;
824
+ ...
825
+ static auto op = torch::Dispatcher::singleton()
826
+ .findSchemaOrThrow("your_namespace::custom_op_impl", "")
827
+ .typed<decltype(custom_op_fwd_impl)>();
828
+ auto result = op.call(x, y, a, b);
829
+ // we can return only the tensors we want to return
830
+ return {result[0]};
831
+ }
832
+
833
+ // do also for CPU
834
+ TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
835
+ m.impl("custom_op_impl", &custom_op_fwd_impl);
836
+ m.impl("custom_op_backward_impl", &custom_op_bwd_impl);
837
+ }
838
+
839
+ TORCH_LIBRARY_IMPL(your_namespace, Meta, m) {
840
+ m.impl("custom_op_impl", &custom_op_impl_meta);
841
+ m.impl("custom_op_backward_impl", &custom_op_bwd_impl_meta);
842
+ }
843
+
844
+ TORCH_LIBRARY_IMPL(your_namespace, Autograd, m) {
845
+ m.impl("custom_op_impl", &custom_op_impl_autograd);
846
+ }
847
+
848
+ TORCH_LIBRARY_IMPL(your_namespace, CompositeImplicitAutograd, m) {
849
+ m.impl("custom_op", &custom_op);
850
+ }
851
+ ```
852
+
853
+ See [https://github.com/pytorch/FBGEMM/pull/2076/file](https://github.com/pytorch/FBGEMM/pull/2076/files)s for a real example.
854
+
855
+ # Missing Meta function in PyTorch core
856
+
857
+ Occasionally, you may run into missing meta functions for operators that are defined in core PyTorch, not custom operator. In this case, add your meta registration to torch/\_meta\_registrations.py, following the pattern of other elements in the file.
858
+
859
+ # Appendix
860
+
861
+ ## Bugs related to non-contiguous tensors
862
+
863
+ You may encounter a CUDA illegal memory access error. If this error only occurs in the AOT dispatch opcheck tests, then it may be caused by a backwards kernel handling non-contiguous tensors improperly. To fix this, you should coerce the inputs to the backwards kernel to be contiguous, i.e., using Tensor.contiguous() (example: [https://github.com/pytorch/FBGEMM/pull/2093/files](https://github.com/pytorch/FBGEMM/pull/2093/files))
864
+
865
+ ## Interaction with torch.compile
866
+
867
+ Some tests use torch.compile/torch.\_dynamo.optimize. generate\_opcheck\_tests does not interact well with torch.compile, so dynamo will run in disabled mode (TORCHDYNAMO\_DISABLE=1)
868
+
869
+ ## AutogradCPU and AutogradCUDA
870
+
871
+ It is possible to have a custom operator that has different autograd behavior per-device. We generally recommend this; a better paradigm is to have two different custom operators, one for each device.
872
+
873
+ If you must have this feature, then continue reading.
874
+
875
+ The AutogradCPU key is a combination of the Autograd and CPU keys \- it registers a function as a CPU implementation and an autograd implementation, but only for CPU inputs. AutogradCUDA is analogous, but for CUDA. This is useful in the case where even the autograd implementation is different between CPU and CUDA. Note that the meta implementation must be separately registered.
876
+
877
+ See [https://github.com/pytorch/FBGEMM/pull/2076/files](https://github.com/pytorch/FBGEMM/pull/2076/files) for an example. In the example, the CPU implementation is a composition of autograd ops, while the CUDA implementation is a custom autograd op.