MekkCyber commited on
Commit
679d8ed
·
1 Parent(s): 81d263b

fixing bindings

Browse files
activation/activation_kernels.cu CHANGED
@@ -226,16 +226,15 @@ void gelu_quick(torch::Tensor& out, // [..., d]
226
  }
227
 
228
  void gelu(torch::Tensor& out, // [..., d]
229
- torch::Tensor& input,
230
- std::string approximation) // [..., d]
231
  {
232
- if (approximation == "none") {
233
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_kernel);
234
- } else if (approximation == "tanh") {
235
- LAUNCH_ACTIVATION_KERNEL(vllm::gelu_tanh_kernel);
236
- } else {
237
- throw std::invalid_argument("Invalid approximation");
238
- }
239
  }
240
 
241
  void silu(torch::Tensor& out, // [..., d]
 
226
  }
227
 
228
  void gelu(torch::Tensor& out, // [..., d]
229
+ torch::Tensor& input) // [..., d]
 
230
  {
231
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_kernel);
232
+ }
233
+
234
+ void gelu_tanh(torch::Tensor& out, // [..., d]
235
+ torch::Tensor& input) // [..., d]
236
+ {
237
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_tanh_kernel);
238
  }
239
 
240
  void silu(torch::Tensor& out, // [..., d]
torch-ext/activation/__init__.py CHANGED
@@ -30,8 +30,8 @@ def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0)
30
  return out
31
 
32
 
33
- def gelu(out: torch.Tensor, x: torch.Tensor, approximation: str = "none") -> None:
34
- ops.gelu(out, x, approximation)
35
  return out
36
 
37
  def silu(out: torch.Tensor, x: torch.Tensor) -> None:
@@ -39,6 +39,11 @@ def silu(out: torch.Tensor, x: torch.Tensor) -> None:
39
  return out
40
 
41
 
 
 
 
 
 
42
  def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
43
  ops.gelu_fast(out, x)
44
  return out
@@ -56,11 +61,15 @@ def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
56
 
57
  __all__ = [
58
  "silu_and_mul",
 
59
  "gelu_and_mul",
60
  "gelu_tanh_and_mul",
61
  "fatrelu_and_mul",
62
  "gelu_fast",
63
  "gelu_new",
64
  "gelu_quick",
 
 
 
65
  "layers",
66
  ]
 
30
  return out
31
 
32
 
33
+ def gelu(out: torch.Tensor, x: torch.Tensor) -> None:
34
+ ops.gelu(out, x)
35
  return out
36
 
37
  def silu(out: torch.Tensor, x: torch.Tensor) -> None:
 
39
  return out
40
 
41
 
42
+ def gelu_tanh(out: torch.Tensor, x: torch.Tensor) -> None:
43
+ ops.gelu_tanh(out, x)
44
+ return out
45
+
46
+
47
  def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
48
  ops.gelu_fast(out, x)
49
  return out
 
61
 
62
  __all__ = [
63
  "silu_and_mul",
64
+ "mul_and_silu",
65
  "gelu_and_mul",
66
  "gelu_tanh_and_mul",
67
  "fatrelu_and_mul",
68
  "gelu_fast",
69
  "gelu_new",
70
  "gelu_quick",
71
+ "gelu_tanh",
72
+ "silu",
73
+ "gelu",
74
  "layers",
75
  ]
torch-ext/activation/layers.py CHANGED
@@ -52,11 +52,29 @@ class Gelu(nn.Module):
52
 
53
  can_torch_compile: bool = True
54
 
55
- def forward(self, x: torch.Tensor, approximation: str = "none"):
56
  out = torch.empty_like(x)
57
- ops.gelu(out, x, approximation)
58
  return out
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class MulAndSilu(nn.Module):
61
  """An activation function for SwiGLU.
62
 
 
52
 
53
  can_torch_compile: bool = True
54
 
55
+ def forward(self, x: torch.Tensor):
56
  out = torch.empty_like(x)
57
+ ops.gelu(out, x)
58
  return out
59
 
60
+ class GeluTanh(nn.Module):
61
+ """An activation function for GELU with `tanh` approximation.
62
+
63
+ The function computes x -> gelu_tanh(x).
64
+
65
+ Shapes:
66
+ x: (num_tokens, d) or (batch_size, seq_len, d)
67
+ return: (num_tokens, d) or (batch_size, seq_len, d)
68
+ """
69
+
70
+ can_torch_compile: bool = True
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ out = torch.empty_like(x)
74
+ ops.gelu_tanh(out, x)
75
+ return out
76
+
77
+
78
  class MulAndSilu(nn.Module):
79
  """An activation function for SwiGLU.
80
 
torch-ext/torch_binding.cpp CHANGED
@@ -35,6 +35,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
35
  // Quick GELU implementation.
36
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
37
  ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
 
 
 
 
 
 
 
 
 
 
 
 
38
  }
39
 
40
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
35
  // Quick GELU implementation.
36
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
37
  ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
38
+
39
+ // GELU with `tanh` approximation.
40
+ ops.def("gelu_tanh(Tensor! out, Tensor input) -> ()");
41
+ ops.impl("gelu_tanh", torch::kCUDA, &gelu_tanh);
42
+
43
+ // SiLU implementation.
44
+ ops.def("silu(Tensor! out, Tensor input) -> ()");
45
+ ops.impl("silu", torch::kCUDA, &silu);
46
+
47
+ // GELU with none approximation.
48
+ ops.def("gelu(Tensor! out, Tensor input) -> ()");
49
+ ops.impl("gelu", torch::kCUDA, &gelu);
50
  }
51
 
52
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -18,3 +18,9 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input);
18
  void gelu_fast(torch::Tensor &out, torch::Tensor &input);
19
 
20
  void gelu_quick(torch::Tensor &out, torch::Tensor &input);
 
 
 
 
 
 
 
18
  void gelu_fast(torch::Tensor &out, torch::Tensor &input);
19
 
20
  void gelu_quick(torch::Tensor &out, torch::Tensor &input);
21
+
22
+ void gelu_tanh(torch::Tensor &out, torch::Tensor &input);
23
+
24
+ void silu(torch::Tensor &out, torch::Tensor &input);
25
+
26
+ void gelu(torch::Tensor &out, torch::Tensor &input);