Spaces:
Sleeping
Sleeping
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# 1/sqrt(2*pi)-> 0.3989423 | |
# 1/sqrt(2) -> 0.70710678 | |
# sqrt(2/pi) -> 0.79788456 | |
# this function is tanh approximation of gelu | |
# actual gelu is: | |
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) | |
def bias_gelu(y, bias): | |
x = bias + y | |
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) | |
# gradient of tanh approximation of gelu | |
# gradient of actual gelu is: | |
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) | |
def bias_gelu_back(g, y, bias): | |
"""Assume that y has shape (B, D) and bias has shape (D)""" | |
x = bias + y | |
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) | |
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 | |
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( | |
1 + tanh_out | |
) | |
grad_y = ff * g | |
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) | |
class GeLUFunction(torch.autograd.Function): | |
# bias is an optional argument | |
def forward(ctx, input, bias): | |
ctx.save_for_backward(input, bias) | |
return bias_gelu(input, bias) | |
def backward(ctx, grad_output): | |
input, bias = ctx.saved_tensors | |
tmp = bias_gelu_back(grad_output, input, bias) | |
return tmp, tmp | |
bias_gelu_impl = GeLUFunction.apply | |
# this function is tanh approximation of gelu | |
# actual gelu is: | |
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) | |
def gelu_fwd(x): | |
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) | |
# gradient of tanh approximation of gelu | |
# gradient of actual gelu is: | |
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) | |
def gelu_bwd(g, x): | |
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) | |
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 | |
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( | |
1 + tanh_out | |
) | |
return (ff * g).to(dtype=x.dtype) | |
class FastGeLUFunction(torch.autograd.Function): | |
# bias is an optional argument | |
def forward(ctx, input): | |
ctx.save_for_backward(input) | |
return gelu_fwd(input) | |
def backward(ctx, grad_output): | |
(input,) = ctx.saved_tensors | |
tmp = gelu_bwd(grad_output, input) | |
return tmp | |
fast_gelu_impl = FastGeLUFunction.apply | |
def relu_bwd(g, x): | |
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) | |
def sqrelu_fwd(x): | |
r = F.relu(x) | |
return (r * r).to(dtype=x.dtype) | |
def sqrelu_bwd(g, x): | |
return (2.0 * g * F.relu(x)).to(dtype=x.dtype) | |
swiglu_fwd_codestring = """ | |
template <typename T> T swiglu_fwd(T x, T y) { | |
return float(x) * float(y) / (1.0f + ::exp(-float(x))); | |
} | |
""" | |
swiglu_bwd_codestring = """ | |
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { | |
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); | |
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); | |
dy = float(x) * x_sigmoid * float(g); | |
} | |
""" | |
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) | |
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) | |
class SwiGLUFunction(torch.autograd.Function): | |
def forward(ctx, x, y): | |
ctx.save_for_backward(x, y) | |
return swiglu_fwd(x, y) | |
def backward(ctx, dout): | |
x, y = ctx.saved_tensors | |
return swiglu_bwd(x, y, dout) | |
swiglu = SwiGLUFunction.apply | |