File size: 38,638 Bytes
23edbfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
# The Custom Operators Manual

# Read Me First

This manual is a comprehensive reference for all things related to PyTorch Custom Operators. We recommend that you [first read one of the focused tutorials listed on our landing page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) and then refer to this document as a manual for edge cases or less-recommended approaches.

The landing page: [https://pytorch.org/tutorials/advanced/custom\_ops\_landing\_page.html](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) 

# What is an operator?

A **kernel** is a function that accepts Tensors and/or raw pointers to memory and performs a useful computation (for example, matrix multiplication, attention, etc).

An **operator** is glue code for the PyTorch runtime that tells it about the computation. A single operator can be associated with multiple kernels (for example, torch.add has a kernel for CPU and a kernel for CUDA). The glue code is necessary to get PyTorch subsystems (like torch.compile and torch.autograd) to compose with the computation.

Standalone kernels may work directly with PyTorch but will not compose with the majority of PyTorch subsystems. In order to get them to compose, please register an operator for them.

# How to make existing operators work with torch.compile.

## TL;DR

Call an operator **pt2\_compliant** if it works with the new PyTorch compilation APIs (torch.compile, torch.export, etc) introduced in PyTorch 2.x.

This is a two-step process:

Step 1: Test the custom op with torch.library.opcheck.  
Step 2: Fix all problems with the custom op (until the opcheck passes)  
Step 3: Mark the custom op as PT2 compliant.

## Step 1: How to use opcheck to test the custom op

You have two options: manually use opcheck to test the custom op, or (Meta-only) use \`generate\_opcheck\_tests\` to automatically test the custom op. If the custom op already is covered by one of these two mechanisms, skip ahead to step 2\.

### Step 1a: How to manually use opcheck to test the custom op

Please call opcheck multiple times with different representative sample inputs:

- If your operator works on CPU and CUDA, please pass a set of sample inputs on CPU and a set of sample inputs on CUDA  
- If your operator supports training, please pass some sample inputs with requires\_grad=True.

Using the operator torch.ops.aten.sin.default as an example:

```py
import torch
import unittest
from torch.library import opcheck

def sin_sample_inputs():
    sample_inputs = [
        (torch.randn(3, requires_grad=True, device='cpu'),),
        (torch.randn(3, requires_grad=True, device='cuda'),),
    ]

class TestOps(unittest.TestCase):
    def test_sin(self):
        sample_inputs = sin_sample_inputs()
        for i in range(len(sample_inputs)):
            opcheck(torch.ops.aten.sin.default, sample_inputs[i])
```

### Step 2b: How to automatically use opcheck to test the custom op

**Please only use this if you work at Meta. While this API is included with PyTorch, we do not guarantee BC for it.**

Use this approach (generate\_opcheck\_tests) only if:

- You have a large collection of existing tests that exercise multiple custom ops that you would like to be tested  
- You are willing to put up with the sharp edges in this API

Please see [Working with generate\_opcheck\_tests](https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit#heading=h.m2x0ozzf1awo) for more details, and [https://github.com/pytorch/FBGEMM/pull/2050/files\#diff-c1c25e22107028a66ff548c7042ba3f39bcc009db9348825e46b15f60754cbffR2452-R2467](https://github.com/pytorch/FBGEMM/pull/2050/files#diff-c1c25e22107028a66ff548c7042ba3f39bcc009db9348825e46b15f60754cbffR2452-R2467) for an example.

## Step 2: How to fix failing opcheck tests

If opcheck fails: please try to pass the tests in the following order.

### 1\. test\_schema fails

```py
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_schema")
```

This means that the schema of the operator is wrong and will lead to silent incorrectness issues. All operators have a schema string ([example](https://github.com/pytorch/pytorch/blob/050c56d0a5c0982e13255447f5fb7d2949c02407/aten/src/ATen/native/native_functions.yaml#L5886)) that specifies the types of the inputs and outputs as well as some “aliasing information”:

- if any outputs are views of the inputs  
- if any inputs are mutated in-place

The fix is usually to update the schema to include the aliasing information. See [https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md\#annotations](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations) for more details.

### 2\. test\_autograd\_registration fails

```py
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_autograd_registration")
```

This means that the autograd registration is incorrect and will lead to silent incorrectness issues. Some common symptoms are correct gradients without using torch.compile but incorrect gradients using torch.compile.

Please see [How to to add an autograd formula](#how-to-add-an-autograd-formula) for how the autograd registration should look like.

### 3\. test\_faketensor fails

```py
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_faketensor")
```

PT2 compilation APIs use “Fake Tensors”, Tensors without storage, to propagate metadata. Every operator needs a “Fake Tensor kernel” which is like a shape formula: given the shapes (and other metadata) of the inputs, it should specify the shapes (and other metadata) of the outputs.

This opcheck test could fail for two reasons:

- UnsupportedFakeTensorException: you didn’t write an abstract\_impl or a meta formula. Please see [How to add abstract impl / meta formula](#how-to-add-faketensor-support-\(abstract-impl;-meta-kernel\)) for how to write one.  
- Your abstract\_impl/meta formula is wrong. Please debug it.

### 4\. test\_aot\_autograd\_static fails

```py
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_aot_autograd_static")
```

If this test succeeds, then the operator works with torch.compile(dynamic=False).  
If the operator in question returns Tensors with data-dependent shapes, then this test is expected to fail.

Otherwise, if it fails, there are some common reasons, listed below. Please match your error message to one of them.

#### No DispatchKey::Functionalize (functionalization) kernel

torch.compile backends only support functional operators (that is, operators that do not mutate inputs and do not return views of the inputs). If your operator is not functional (indicated by an “(a)” in the schema string), then we need to teach torch.compile how to functionalize your operator.

For now, come find us for help if this is the case.

#### There is a .item call somewhere

When PT2 APIs like torch.compile see a C++ .item() call, they don’t know what to do.

Please either:

- rewrite your custom operator to not use the offending .item() call  
- hide the .item() call in a new custom operator

#### Backward formula isn’t traceable

PyTorch assumes that the backward pass of your operator only consists of invocations to the PyTorch dispatcher.

That is, the backward pass may call:

1. **Built-in PyTorch operators**.  
   1. In Python  
      1. torch.\* APIs  
   2. In C++  
      1. at::{API} operators where API is in [native\_functions.yaml](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)   
      2. Tensor metadata operations. E.g. Tensor::sizes() / Tensor::strides()  
2. **Custom Operators**. All calls to custom operators must be made through the PyTorch dispatcher.  
   1. If they are being invoked from Python, you don’t need to do anything.  
   2. If they are being invoked from C++, they must query the PyTorch Dispatcher for a TypedOperatorHandle and invoke TypedOperatorHandle::call:

```c
static auto custom_sin_op = torch::Dispatcher::singleton()
    .findSchemaOrThrow("custom::sin", "")
    .typed<decltype(custom_sin)>();
Tensor result = custom_sin_op.call(x);
```

Please see [How to to add an autograd formula](#how-to-add-an-autograd-formula) for more details

### 5\. test\_aot\_autograd\_dynamic fails

```py
opcheck(torch.ops.aten.sin.default, sample_inputs[i], test_utils="test_aot_autograd_dynamic")
```

This generally means that your operator doesn’t support Dynamic Shapes, especially if “test\_aot\_autograd\_static” succeeds.

Please see [The dynamic shapes manual](https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.dpddt0fbqk6u) for what to do.

### Detailed description of the opcheck tests

#### test\_schema

We test that the schema matches the implementation of the operator. For example: if the schema specifies a Tensor is mutated, then we check the implementation mutates the Tensor. If the schema specifies that we return a new Tensor, then we check that the implementation returns a new Tensor (instead of an existing one or a view of an existing one).

Note that the schema language might look simple, but it encodes a lot of information (about mutations, new tensors, and aliases) that is easy to get wrong.

#### test\_autograd\_registration

If the operator supports training (autograd): we check that its autograd formula is registered via torch.library.register\_autograd or a manual registration to one or more DispatchKey::Autograd keys. Any other DispatchKey-based registrations may lead to undefined behavior.

#### test\_faketensor

We check that a FakeTensor kernel (also sometimes known as a meta kernel) was registered for the operator and that it is correct. This test takes the result of running the operator on real tensors and the result of running the operator on FakeTensors and checks that they have the same Tensor metadata (sizes/strides/dtype/device/etc).

#### test\_aot\_dispatch\_dynamic

This test checks multiple things:

- We check that the operator supports functionalization. That is, it is functional or can automatically be functionalized by torch.compile.  
- If the operator supports training, we check that the backward pass supports FakeTensor and functionalization.

This test is effectively an e2e run with torch.compile(backend=”aot\_eager”) (through Dynamo+AOTDispatcher) that does something like:

```py
outs = op(*args)
outs_compiled = torch.compile(op, backend="aot_eager")(*args)
torch.testing.assert_close(outs_compiled, outs)

if supports_training(op):
    grad_args = torch.autograd.grad(outs, args, ...)
    grad_args_compiled = torch.autograd.grad(outs_compiled, args, ...)
    torch.testing.assert_close(grad_args_compiled, grad_args)
```

## Step 3: How to mark the op as PT2 compliant

### If the operator was defined in Python

#### Using torch.library.custom\_op

Custom ops created with torch.library.custom\_op automatically get this tag.

#### Using torch.library.define

```py
lib.define("sin(Tensor x) -> Tensor", tags=[torch.Tag.pt2_compliant_tag]);
```

### If the operator was defined in C++

Where the operator was defined, add the at::Tag::pt2\_compliant\_tag:

```c
m.def("sin(Tensor x) -> Tensor", {at::Tag::pt2_compliant_tag});
```

This documents the operator as PT2 compliant. Please only add this tag if the operator passes the opcheck tests.

# Writing a new Custom Operator

## What is a “kernel”? What is an “operator”?

A **kernel** is a function that accepts Tensors and/or raw pointers to memory and performs a useful computation (for example, matrix multiplication, attention, etc).

An **operator** is glue code for the PyTorch runtime that tells it about the computation. A single operator can be associated with multiple kernels (for example, torch.add has a kernel for CPU and a kernel for CUDA). The glue code is necessary to get PyTorch subsystems (like torch.compile and torch.autograd) to compose with the computation.

## When should I create a Custom Operator?

You may wish to create a Custom Operator for two reasons:

- You have some custom CPU/CUDA/other backend kernel that you’ like to integrate with PyTorch  
- You have some code that you want PyTorch to treat as an opaque callable (as a black-box).

For example, you may want to call out to some low-level third-party library like LAPACK or CUBLAS, or you may have written a bunch of CUDA kernels in .cu files.

Your custom operator kernels should include as few PyTorch built-in operators as possible. Including built-in PyTorch operators in a C++ custom operator hides them from PyTorch subsystems like torch.compile, which hides optimization opportunities.

**If what you are trying to do is expressible as a composition of built-in PyTorch operators** (and do not involve low-level C/C++/CUDA code or third-party Python libraries), then please write your routine as a Python function and call it instead of creating a custom operator.

## Python or C++?

You can define custom operators in both Python and C++. These registration APIs may be mixed: for example, one can define an operator’s CPU kernel from Python and CUDA kernel from C++.

Our general guidance is:

- If you care about AOTInductor (and being able to run in a Python-less environment), you should define the operator and add backend kernels in C++.  
- Otherwise, it is generally easier to use the Python custom operator registration APIs.

## How to define a custom operator

To define an operator, you must tell us:

- The name of the operator  
- Some metadata around the acceptable input/output types of the operator and if any inputs are being mutated

### \[From Python\] How to define a custom operator

#### \[PyTorch \>= 2.4\] Using torch.library.custom\_op

Use torch.library.custom\_op to decorate a function to turn it into a custom operator. The function must be decorated with type annotations, and you must correctly annotate inputs that are being mutated.

```py
@torch.library.custom_op("your_namespace::sin", mutates_args=())
def sin(x: torch.Tensor) -> torch.Tensor:
    return torch.from_numpy(np.sin(x.numpy(force=True))
```

#### \[PyTorch \< 2.4\] Using torch.library.define

To define an operator, you must tell us:

- The name of the operator  
- The [schema string](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func) of the operator. The spec for this schema is defined [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func), with multiple examples in [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml).

```py
torch.library.define("your_namespace::sin(Tensor x) -> Tensor")
```

### \[From C++\] How to define a custom operator

To define an operator, you must tell us:

- The name of the operator  
- The [schema string](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func) of the operator. The spec for this schema is defined [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func), with multiple examples in [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml).

Let’s go through an example of a custom sin operator.

```c
#include <torch/library.h>

// Define the operator
TORCH_LIBRARY(your_namespace, m) {
    m.def("sin(Tensor x) -> Tensor");
}
```

If you define operator schemas in multiple places, use TORCH\_LIBRARY\_FRAGMENT instead of TORCH\_LIBRARY.

## How to add CPU/CUDA/Backend implementations

### \[From Python\] How to add CPU/CUDA/Backend implementations

#### \[PyTorch \>= 2.4\]

Use torch.library.register\_kernel.

```py
@torch.library.register_kernel("your_namespace::sin", "cpu")
def _(x: torch.Tensor) -> torch.Tensor:
    # your CPU implementation
    ...

@torch.library.register_kernel("your_namespace::sin", "cuda")
def _(x: torch.Tensor) -> torch.Tensor:
    # your CUDA implementation
    ...
```

#### \[PyTorch \< 2.4\] 

Use torch.library.impl

```py
@torch.library.impl("your_namespace::sin", "cpu")
def _(x: torch.Tensor) -> torch.Tensor:
    # your CPU implementation
    ...

@torch.library.impl("your_namespace::sin", "cuda")
def _(x: torch.Tensor) -> torch.Tensor:
    # your CUDA implementation
    ...
```

### \[From C++\] How to add CPU/CUDA/Backend implementations

To provide backend-specific implementations for an operator, use TORCH\_LIBRARY\_IMPL.

```c

Tensor custom_sin_cpu(const Tensor& x) {
    // Replace this with at::sin if you want to test it out.
    return my_custom_sin_implementation_on_cpu(x);
}

// Register the CPU implementation for the operator
TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
    m.impl("sin", &custom_sin_cpu);
}

Tensor custom_sin_cuda(const Tensor& x) {
    // Replace this with at::sin if you want to test it out.
    return my_custom_sin_implementation_on_cuda(x);
}

// Register the CUDA implementation for the operator
TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
    m.impl("sin", &custom_sin_cuda);
}
```

## How to invoke a custom operator

### How to invoke a custom op defined in Python from Python

When you created a custom operator, you gave it a name. The custom operator is findable under torch.ops:

```py
x = torch.randn(3)
y = torch.ops.your_namespace.sin(x)
```

## 

### How to invoke a custom op defined in C++ from C++ {#how-to-invoke-a-custom-op-defined-in-c++-from-c++}

```c
static auto custom_sin_op = torch::Dispatcher::singleton()
    .findSchemaOrThrow("your_namespace::sin", "")
    .typed<decltype(custom_sin_cpu)>();
Tensor result = custom_sin_op.call(x)
```

In order to invoke the custom operator, we must first query it from the PyTorch dispatcher and then invoke it.

### How to invoke a custom op defined in C++ from Python

The C++ custom op gets compiled into a shared library. Use torch.ops.load\_library(path\_to\_shared\_library) to load the shared library.

Once the shared library has loaded, the custom op is available from the torch.ops namespace:

```py
x = torch.randn(3)
y = torch.ops.your_namespace.sin(x)
assert torch.allclose(y, torch.sin(x))
```

## How to add FakeTensor support (abstract impl; meta kernel) {#how-to-add-faketensor-support-(abstract-impl;-meta-kernel)}

In order for your operator to work with PT2, it must have FakeTensor support. There are around three blessed ways to do this:

- (Preferred) Write a FakeTensor kernel using the torch.library.register\_fake / torch.library.impl\_abstract API from Python  
- Write a C++ Meta kernel  
- If your operator is registered to CompositeImplicitAutograd, it will automatically decompose and we require its constituents to support FakeTensor.

### Context: meta kernel vs FakeTensor kernels

In order for your operator to work with PT2, it must have FakeTensor support. That is, we must know how to run the operator on “Fake” input Tensors that do not have storage, but have sizes/strides/device.

Adding a **meta kernel** (a function that describes how an operator works with device=’meta’) will automatically generate FakeTensor support. However, meta kernels don’t support the full range of things FakeTensors do. For example, operators with data-dependent output shape (think torch.nonzero) and operators with cross-device semantics (like Tensor.to(device=”cuda”)) are not describable with meta functions.

Instead of writing a meta function, **our recommendation is to write FakeTensor kernels**, which are a generalization of meta functions to support FakeTensors. Writing a FakeTensor kernel is very similar to writing a meta kernel: in most cases, the meta kernel can be re-used as the FakeTensor kernel impl.

NB: We also sometimes use “abstract impl” to refer to a “FakeTensor kernel”. These are the same thing.

### How to write a Python FakeTensor kernel {#how-to-write-a-python-faketensor-kernel}

There are three parts to this:

1. In Python, use torch.library.register\_fake (PyTorch 2.4+) or torch.library.impl\_abstract (PyTorch \<= 2.3) to provide the FakeTensor kernel for an operator  
2. In the Python program that uses the custom operator, import the module that includes the FakeTensor kernel registration from step 1\.

If you work at Meta, please see [https://fburl.com/python\_meta\_example](https://fburl.com/python_meta_example) for an example.

#### Step 1: Use torch.library.register\_fake

An "FakeTensor kernel" specifies the behavior of this operator on Tensors that carry no data. Given some input Tensors with certain properties (sizes/strides/storage\_offset/device), it specifies what the properties of the output Tensors are.

The FakeTensor kernel has the same signature as the operator. It is run for both FakeTensors and meta tensors. To write an FakeTensor kernel, assume that all Tensor inputs to the operator are regular CPU/CUDA/Meta tensors, but they do not have storage, and you are trying to return regular CPU/CUDA/Meta tensor(s) as output. The FakeTensor kernel must consist of only PyTorch operations (and may not directly access the storage or data of any input or intermediate Tensors).

```py
# your_module.py
import torch
torch.ops.load_library("path/to/shared/lib/that/has/your_cpp_file")

# Write the FakeTensor kernel
@torch.library.register_fake("your_namespace::sin")
def sin_abstract(x):
    # torch.empty_like(x) returns a Tensor on the same device as `x`.
    # If you instead want to hardcode the device of the output, no matter the
    # device of the input,
    # manually specify it like torch.empty_like(x, device="cuda")
    return torch.empty_like(x)
```

#### Step 2: Import the module that contains the torch.library.register\_fake call

```py
import torch
torch.ops.import_module("your_module")

@torch.compile(backend="eager")
def f(x):
    return torch.ops.your_namespace.sin(x)

x = torch.randn(3)
f(x)
```

#### (PyTorch \<=2.3 only) Add an abstract impl pystub (if one doesn’t already exist)

The operator will complain during testing if it needs an impl abstract pystub. In that case, add a \`m.impl\_abstract\_pystub\` call to the TORCH\_LIBRARY block that the operator was defined in (e.g. with the **m.def(** call).

```c
// your_cpp_file.cpp
TORCH_LIBRARY(your_namespace, m) {
  // Leave a stub that tells the C++ PyTorch Dispatcher that the
  // abstract impl exists in a given Python module.
  // This will prevent desync issues (where someone loads the operator
  // without actually loading the Python module).
  // 
  // impl_abstract_pystub(module_name, buck_target):
  // - module name: the name of the Python module the abstract impl resides in
  // - buck_target (optional): If you're using a buck-based build system,
  //   then you can include the name of the buck target that includes
  //   the module here, otherwise it is optional.
  // We use the module name and the buck target to give better error messages
  m.impl_abstract_pystub("your_module", "//your_module:custom_ops");
  m.def("sin(Tensor x) -> Tensor");
}
```

The pystub applies to all operators registered in the given TORCH\_LIBRARY block.

We removed the need for this in PyTorch 2.4+

### How to write FakeTensor kernel for operator with data-dependent output shape

Use torch.library.get\_ctx().new\_dynamic\_size() to allocate data-dependent output sizes. For example, a “nonzero” operator returns a Tensor with shape (number\_of\_nonzero\_elements, dim):

```py
@torch.library.register_fake("your_namespace::nonzero")
def nonzero_abstract(x):
    nnz = torch.library.get_ctx().new_dynamic_size()
    return x.new_empty(nnz, x.dim(), dtype=torch.long)
```

### How to write C++ meta kernel

You should seriously consider writing a [Python FakeTensor kernell](#how-to-write-a-python-faketensor-kernel) instead; the API is more generic (e.g. supports cross-device and data-dependent output shape), has fewer footguns (it is automatically symint-ified), you skip recompile cycles, and the resulting kernel is easier to debug. 

C++ meta kernel example:

```c
Tensor sin_meta(const Tensor& x) {
  return torch.empty_like(x)
}

TORCH_LIBRARY_IMPL(your_namespace, Meta, m) {
    m.impl("sin", &sin_meta);
}

```

### FakeTensor/meta kernels for CompositeImplicitAutograd operators

If your custom op consists only of calls to the PyTorch dispatcher (at:: operators in native\_functions.yaml and custom op calls via TypedOperatorHandle.call, as in [How to invoke the custom op from C++](#how-to-invoke-a-custom-op-defined-in-c++-from-c++)), then you can SymInt’ify your operator according to [The dynamic shapes manual](https://docs.google.com/document/u/0/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit) and register your op to the CompositeImplicitAutograd dispatch key.

The advantage is that you do not need to define an autograd function (like in [How to add an autograd formula](#how-to-add-an-autograd-formula)) or a meta implementation.

Before:

```c
TORCH_LIBRARY(your_namespace, m) {
    m.def("my_op(Tensor x, int[] shape) -> Tensor");
}

Tensor my_op_impl(const Tensor& x, IntArrayRef shape) {
    Tensor y = at::sin(x);
    // suppose custom op my_op2 has signature
    // (Tensor x, int[] shape) -> Tensor
    return my_op2(y, shape);
}

TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
    m.impl("my_op", &my_op_impl);
}

TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
    m.impl("my_op", &my_op_impl);
}
```

After:

```c
TORCH_LIBRARY(your_namespace, m) {
    m.def("my_op(Tensor x, SymInt[] shape) -> Tensor");
}

Tensor my_op_impl(const Tensor& x, SymIntArrayRef shape) {
    Tensor y = at::sin(x);
    // suppose custom op my_op2 has signature
    // (Tensor x, int[] shape) -> Tensor
    static auto my_op2_op = torch::Dispatcher::singleton()
        .findSchemaOrThrow("your_namespace::my_op2", "")
        .typed<decltype(my_op2)>();
    return my_op2_op.call(y, shape);
}

TORCH_LIBRARY_IMPL(your_namespace, CompositeImplicitAutograd, m) {
    m.impl("my_op", &my_op_impl);
}
```

## How to add an autograd formula {#how-to-add-an-autograd-formula}

In order for your custom operator to work with training, it must have an autograd formula. **You can register this from either Python or C++; we recommend doing this from Python.**

### (Recommended) \[In Python\] Adding an autograd formula (PyTorch 2.4+ only)

Use torch.library.register\_autograd to add an autograd formula for an operator. Please see the documentation for torch.library.register\_autograd for more information.

### \[In C++\] Adding an autograd formula

WARNING: this approach has a lot of footguns\! Use with caution; incorrect usage will result in silent incorrectness (in both eager-mode PyTorch and with torch.compile).

To add training support, please:

- Construct a C++ torch::autograd::Function with a forward() and a backward() pass:  
  - The forward pass must (1) save Tensors/data for backward and (2) re-dispatch to the operator (please see example)  
  - If your operator's backward pass is a custom kernel, then it should be invoked through a custom operator.  
- Register this torch::autograd::Function to DispatchKey::Autograd. **It is an error (and will be silently incorrect) if you register it to DispatchKey::CPU/CUDA/anything else**.

Below is an example of a custom sin operator: 

```c
#include <torch/library.h>

// Declare the operator
TORCH_LIBRARY(your_namespace, m) {
    m.def("sin(Tensor x) -> Tensor");
}

// Add the CPU implementation for the operator
Tensor custom_sin_cpu(const Tensor& x) {
    // Replace this with at::sin if you want to test it out.
    return my_custom_sin_implementation_on_cpu(x);
}

TORCH_LIBRARY_IMPL(your_namespace, CPU, m) {
    m.impl("sin", &custom_sin_cpu);
}

// Add the CUDA implementation for the operator
Tensor custom_sin_cuda(const Tensor& x) {
    // Replace this with at::sin if you want to test it out.
    return my_custom_sin_implementation_on_cuda(x);
}

TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
    m.impl("sin", &custom_sin_cuda);
}
```

Now, let’s add a backward formula:

```c
// To register a backward formula for it, we need to construct a
// torch::autograd::Function.
class CustomSin : public torch::autograd::Function<CustomSin> {
 public:
  static variable_list forward(
      AutogradContext* ctx,
      const Tensor& x) {
    // It is important that the forward looks like the following.
    // If you do anything else, the operator may be silently incorrect!

    // (1) You must construct this guard and then invoke the Dispatcher
    // on this operator.
    // We refer to this sometimes as a "redispatch".
    // The following lines will Dispatch past Autograd and call a backend 
    // implementation for custom::sin, i.e. either custom_sin_cpu
    // or custom_sin_cuda.
    at::AutoDispatchBelowADInplaceOrView guard;
    static auto custom_sin_op = torch::Dispatcher::singleton()
       .findSchemaOrThrow("your_namespace::sin", "")
       .typed<decltype(custom_sin_cpu)>();
    Tensor result = custom_sin_op.call(x);

    // (2) You may save Tensors or other data (like the shape of the Tensor)
    // for backwards via one call to ctx->save_for_backward and
    // one or more calls to ctx->saved_data
    ctx->save_for_backward({x});

    // (3) Finally, return the result of the operator computed in step 1
    // as a flat list of Tensors. You may ONLY RETURN the results
    // computed in step 1, if you return anything else (like
    // a subsequent call to result.sum()), then the gradients may be
    // silently incorrect.
    return {result};

    // (4) Nothing else must be in the forward()! That is, there must
    // be NO CALLS to other PyTorch operators (e.g. at::sum)
  }

  static variable_list backward(
      AutogradContext* ctx,
      variable_list grad_output) {
    const Tensor& grad = grad_output[0];
    auto saved_tensors = ctx->get_saved_variables();

    // The backward pass must only consist of invocations to
    // the PyTorch dispatcher.
    // That is, you may (1) call at::{API} operators where API
    // is in native_functions.yaml 
   //https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
    // and (2) you may call custom operators via the dispatcher
    // by invoking TypedOperatorHandle.call:
    // static auto custom_sin_op = torch::Dispatcher::singleton()
    //   .findSchemaOrThrow("your_namespace::sin", "")
    //   .typed<decltype(custom_sin)>();
    // Tensor result = custom_sin_op.call(x);
    //
    // Anything else may run into problems.
    return grad * at::cos(saved_tensors[0]);
  }
};

Tensor custom_sin_autograd(const Tensor& x) {
  return CustomSin::apply(x);
}

// It is important that you register a call to the autograd::Function::apply
// to DispatchKey::Autograd. PLEASE DO NOT REGISTER IT TO THE CPU/CUDA/any
// other key; that leads to silent incorrectness issues.
TORCH_LIBRARY_IMPL(your_namespace, Autograd, m) {
    m.impl("sin", &custom_sin_autograd);
}
```

### \[In C++\] How to add an autograd formula: some nitty gritty details

The above example is helpful to get a basic understanding for how to write an autograd function with PT2 support. However, more complicated custom autograd ops may require additional changes. Consider a custom autograd op that does the following:

* Has a complicated input (list of tensors, non-tensors, etc.)  
* Has a complicated return (multiple tensors, non-tensors, etc.)  
* Saves a computed tensor for backward  
* Saves computed integers for backward

This op is harder to support because there are a number of restrictions we need to be aware of:

* (Array of) integer input may be symbolic  
* Returns can only be (tuple of) tensor  
* All implementations (CPU, CUDA, Meta, Autograd) must return the same kind of output  
  * Tensors must have same metadata, same number of tensors returned (in the case of tuple return)

In order to work around these restrictions, we:

* Change integer inputs to be symbolic integers (symint), only if necessary  
* Create a “implementation” custom autograd op, as above:  
  * The CPU/CUDA implementation should pack non-tensor data into tensors  
  * The CPU/CUDA implementation must return both tensors intended for actual output, and tensors that are saved  
  * The Autograd implementation saves tensors required for backward and must return them as well  
* Create a backward op that will be called by the “implementation” op \- the backward op does not need to be an autograd op.  
* The “real” custom autograd op simply dispatches to the “implementation” op and returns only the tensors we want to outputThe “real” op should be registered to the CompositeImplicitAutograd dispatch key if both CPU and CUDA implementations of the “implementation” op go through the same autograd implementation of that op  
  * Otherwise, you will have to register to the AutogradCPU/AutogradCUDA keys and you may need to define a Meta function

```c
#include <torch/library.h>

// Declare the real and implementation operators
TORCH_LIBRARY(your_namespace, m) {
    m.def("custom_op(Tensor x, Tensor[] y, int a, SymInt[] b) -> Tensor[]");
    m.def("custom_op_impl(Tensor x, Tensor[] y, int a, SymInt[] b) -> Tensor[]");
    m.def("custom_op_backward_impl(Tensor[] grad_out, SymInt[] ints) -> Tensor[]");
}

// needs to have the same signature as forward, except for AutogradContext
variable_list custom_op_fwd_impl(
    const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
  ...
  ... b[0].as_int_unchecked() ...
  Tensor real_output;
  Tensor for_backward;
  int data;
  Tensor saved_data = ???
  vector<Tensor> output{real_output, for_backward, saved_data};
  return output;
}

// does NOT need to have the same signature as backward
variable_list custom_op_bwd_impl(
    TensorList grad_out, SymIntArrayRef ints) {
  vector<Tensor> output;
  ...
  // does NOT need to return the same as backward
  return output;
}

class CustomOpImpl : public torch::autograd::Function<CustomOp> {
 public:
  static variable_list forward(
      AutogradContext* ctx,
      const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
    at::AutoDispatchBelowADInplaceOrView guard;
    static auto op = torch::Dispatcher::singleton()
       .findSchemaOrThrow("your_namespace::custom_op_impl", "")
       .typed<decltype(custom_op_fwd_impl)>();
    auto result = op.call(x, y, a, b);
    ctx->save_for_backward({a, b});

    // could make additional computations here if they're "simple"
    // i.e. whatever is allowed according to the previous autograd example
    int simple_data;
    ...
    ctx->saved_data["simple"] = simple_data;
    // must return the same thing as custom_op_fwd_impl!
    return result;
  }

  static variable_list backward(
      AutogradContext* ctx,
      variable_list grad_output) {
    static auto op = torch::Dispatcher::singleton()
       .findSchemaOrThrow("your_namespace::custom_op_backward_impl", "")
       .typed<decltype(custom_op_bwd_impl)>();
    // can create custom_op_bwd_impl arguments here
    auto grad_out = op.call(...);
    // e.g. may have to add additional gradients
    grad_out.push_back({});
    return grad_out;
  }
};

variable_list custom_op_impl_meta(
    const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
  vector<Tensor> output;
  ...
  return output;
}

variable_list custom_op_bwd_impl_meta(
    TensorList grad_out, SymIntArrayRef ints) {
  vector<Tensor> output;
  ...
  return output;
}

variable_list custom_op_impl_autograd(
    const Tensor& x, TensorList y, int a, SymIntArrayRef b) {
  return CustomOp::apply(x, y, a, b);
}

// can have a different signature than custom_op_fwd_impl
variable_list custom_op(
    const Tensor& x, TensorList y) {
  ...
  int a = ...;
  vector<SymInt> b = ...;
  ...
  static auto op = torch::Dispatcher::singleton()
       .findSchemaOrThrow("your_namespace::custom_op_impl", "")
       .typed<decltype(custom_op_fwd_impl)>();
  auto result = op.call(x, y, a, b);
  // we can return only the tensors we want to return
  return {result[0]};
}

// do also for CPU
TORCH_LIBRARY_IMPL(your_namespace, CUDA, m) {
    m.impl("custom_op_impl", &custom_op_fwd_impl);
    m.impl("custom_op_backward_impl", &custom_op_bwd_impl);
}

TORCH_LIBRARY_IMPL(your_namespace, Meta, m) {
    m.impl("custom_op_impl", &custom_op_impl_meta);
    m.impl("custom_op_backward_impl", &custom_op_bwd_impl_meta);
}

TORCH_LIBRARY_IMPL(your_namespace, Autograd, m) {
    m.impl("custom_op_impl", &custom_op_impl_autograd);
}

TORCH_LIBRARY_IMPL(your_namespace, CompositeImplicitAutograd, m) {
    m.impl("custom_op", &custom_op);
}
```

See [https://github.com/pytorch/FBGEMM/pull/2076/file](https://github.com/pytorch/FBGEMM/pull/2076/files)s for a real example.

# Missing Meta function in PyTorch core

Occasionally, you may run into missing meta functions for operators that are defined in core PyTorch, not custom operator. In this case, add your meta registration to torch/\_meta\_registrations.py, following the pattern of other elements in the file.

# Appendix

## Bugs related to non-contiguous tensors

You may encounter a CUDA illegal memory access error. If this error only occurs in the AOT dispatch opcheck tests, then it may be caused by a backwards kernel handling non-contiguous tensors improperly. To fix this, you should coerce the inputs to the backwards kernel to be contiguous, i.e., using Tensor.contiguous() (example: [https://github.com/pytorch/FBGEMM/pull/2093/files](https://github.com/pytorch/FBGEMM/pull/2093/files)) 

## Interaction with torch.compile

Some tests use torch.compile/torch.\_dynamo.optimize. generate\_opcheck\_tests does not interact well with torch.compile, so dynamo will run in disabled mode (TORCHDYNAMO\_DISABLE=1)

## AutogradCPU and AutogradCUDA

It is possible to have a custom operator that has different autograd behavior per-device. We generally recommend this; a better paradigm is to have two different custom operators, one for each device.

If you must have this feature, then continue reading.

The AutogradCPU key is a combination of the Autograd and CPU keys \- it registers a function as a CPU implementation and an autograd implementation, but only for CPU inputs. AutogradCUDA is analogous, but for CUDA. This is useful in the case where even the autograd implementation is different between CPU and CUDA. Note that the meta implementation must be separately registered.

See [https://github.com/pytorch/FBGEMM/pull/2076/files](https://github.com/pytorch/FBGEMM/pull/2076/files) for an example. In the example, the CPU implementation is a composition of autograd ops, while the CUDA implementation is a custom autograd op.