File size: 1,315 Bytes
bec1e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-

from .asm import fp32_to_tf32_asm
from .cumsum import (
    chunk_global_cumsum,
    chunk_global_cumsum_scalar,
    chunk_global_cumsum_scalar_kernel,
    chunk_global_cumsum_vector,
    chunk_global_cumsum_vector_kernel,
    chunk_local_cumsum,
    chunk_local_cumsum_scalar,
    chunk_local_cumsum_scalar_kernel,
    chunk_local_cumsum_vector,
    chunk_local_cumsum_vector_kernel
)
from .logcumsumexp import logcumsumexp_fwd_kernel
from .logsumexp import logsumexp_fwd, logsumexp_fwd_kernel
from .matmul import addmm, matmul, matmul_kernel
from .pooling import mean_pooling
from .softmax import softmax_bwd, softmax_bwd_kernel, softmax_fwd, softmax_fwd_kernel

__all__ = [
    'chunk_global_cumsum',
    'chunk_global_cumsum_scalar',
    'chunk_global_cumsum_scalar_kernel',
    'chunk_global_cumsum_vector',
    'chunk_global_cumsum_vector_kernel',
    'chunk_local_cumsum',
    'chunk_local_cumsum_scalar',
    'chunk_local_cumsum_scalar_kernel',
    'chunk_local_cumsum_vector',
    'chunk_local_cumsum_vector_kernel',
    'logcumsumexp_fwd_kernel',
    'logsumexp_fwd',
    'logsumexp_fwd_kernel',
    'addmm',
    'matmul',
    'matmul_kernel',
    'mean_pooling',
    'softmax_bwd',
    'softmax_bwd_kernel',
    'softmax_fwd',
    'softmax_fwd_kernel',
    'fp32_to_tf32_asm',
]