MekkCyber commited on
Commit
b9b616a
·
1 Parent(s): 74526fe
Files changed (1) hide show
  1. tests/kernels/test_activation.py +26 -0
tests/kernels/test_activation.py CHANGED
@@ -55,6 +55,14 @@ def gelu_and_mul(x: torch.Tensor, approximate: str) -> torch.Tensor:
55
  d = x.shape[-1] // 2
56
  return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
57
 
 
 
 
 
 
 
 
 
58
 
59
  @pytest.mark.parametrize(
60
  "activation_name", ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]
@@ -145,6 +153,24 @@ def test_act_and_mul(
145
  activation.ops.gelu_quick,
146
  activation.layers.QuickGELU,
147
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ],
149
  )
150
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 
55
  d = x.shape[-1] // 2
56
  return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
57
 
58
+ def gelu(x: torch.Tensor) -> torch.Tensor:
59
+ return F.gelu(x)
60
+
61
+ def gelu_tanh(x: torch.Tensor) -> torch.Tensor:
62
+ return F.gelu(x, approximate="tanh")
63
+
64
+ def silu(x: torch.Tensor) -> torch.Tensor:
65
+ return F.silu(x)
66
 
67
  @pytest.mark.parametrize(
68
  "activation_name", ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]
 
153
  activation.ops.gelu_quick,
154
  activation.layers.QuickGELU,
155
  ),
156
+ (
157
+ gelu_tanh,
158
+ activation.gelu_tanh,
159
+ activation.ops.gelu_tanh,
160
+ activation.layers.GeluTanh,
161
+ ),
162
+ (
163
+ silu,
164
+ activation.silu,
165
+ activation.ops.silu,
166
+ activation.layers.Silu,
167
+ ),
168
+ (
169
+ gelu,
170
+ activation.gelu,
171
+ activation.ops.gelu,
172
+ activation.layers.Gelu
173
+ ),
174
  ],
175
  )
176
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)