Spaces:
Running
Running
Upload 4 files
Browse files- models/common.py +32 -0
- models/layers.py +158 -0
- models/losses.py +101 -0
- models/sparse_embedding.py +132 -0
models/common.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
| 8 |
+
# NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
|
| 9 |
+
# This function is a PyTorch version of jax truncated normal init (default init method in flax)
|
| 10 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
|
| 11 |
+
# https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
|
| 12 |
+
|
| 13 |
+
with torch.no_grad():
|
| 14 |
+
if std == 0:
|
| 15 |
+
tensor.zero_()
|
| 16 |
+
else:
|
| 17 |
+
sqrt2 = math.sqrt(2)
|
| 18 |
+
a = math.erf(lower / sqrt2)
|
| 19 |
+
b = math.erf(upper / sqrt2)
|
| 20 |
+
z = (b - a) / 2
|
| 21 |
+
|
| 22 |
+
c = (2 * math.pi) ** -0.5
|
| 23 |
+
pdf_u = c * math.exp(-0.5 * lower ** 2)
|
| 24 |
+
pdf_l = c * math.exp(-0.5 * upper ** 2)
|
| 25 |
+
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
| 26 |
+
|
| 27 |
+
tensor.uniform_(a, b)
|
| 28 |
+
tensor.erfinv_()
|
| 29 |
+
tensor.mul_(sqrt2 * comp_std)
|
| 30 |
+
tensor.clip_(lower * comp_std, upper * comp_std)
|
| 31 |
+
|
| 32 |
+
return tensor
|
models/layers.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from flash_attn_interface import flash_attn_func # type: ignore[import]
|
| 9 |
+
except ImportError:
|
| 10 |
+
# Fallback to FlashAttention 2
|
| 11 |
+
from flash_attn import flash_attn_func # type: ignore[import]
|
| 12 |
+
|
| 13 |
+
from models.common import trunc_normal_init_
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CosSin = Tuple[torch.Tensor, torch.Tensor]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _find_multiple(a, b):
|
| 20 |
+
return (-(a // -b)) * b
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def rotate_half(x: torch.Tensor):
|
| 24 |
+
"""Rotates half the hidden dims of the input."""
|
| 25 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 26 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 27 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
| 31 |
+
# q, k: [bs, seq_len, num_heads, head_dim]
|
| 32 |
+
# cos, sin: [seq_len, head_dim]
|
| 33 |
+
orig_dtype = q.dtype
|
| 34 |
+
q = q.to(cos.dtype)
|
| 35 |
+
k = k.to(cos.dtype)
|
| 36 |
+
|
| 37 |
+
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
| 38 |
+
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
| 39 |
+
|
| 40 |
+
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CastedLinear(nn.Module):
|
| 44 |
+
def __init__(self,
|
| 45 |
+
in_features: int,
|
| 46 |
+
out_features: int,
|
| 47 |
+
bias: bool):
|
| 48 |
+
super().__init__()
|
| 49 |
+
# Truncated LeCun normal init
|
| 50 |
+
self.weight = nn.Parameter(
|
| 51 |
+
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
| 52 |
+
)
|
| 53 |
+
self.bias = None
|
| 54 |
+
if bias:
|
| 55 |
+
# Zero init bias
|
| 56 |
+
self.bias = nn.Parameter(torch.zeros((out_features, )))
|
| 57 |
+
|
| 58 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CastedEmbedding(nn.Module):
|
| 63 |
+
def __init__(self,
|
| 64 |
+
num_embeddings: int,
|
| 65 |
+
embedding_dim: int,
|
| 66 |
+
init_std: float,
|
| 67 |
+
cast_to: torch.dtype):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.cast_to = cast_to
|
| 70 |
+
|
| 71 |
+
# Truncated LeCun normal init
|
| 72 |
+
self.embedding_weight = nn.Parameter(
|
| 73 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
return F.embedding(input, self.embedding_weight.to(self.cast_to))
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class RotaryEmbedding(nn.Module):
|
| 81 |
+
def __init__(self, dim, max_position_embeddings, base, device=None):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
# RoPE
|
| 85 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| 86 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
| 87 |
+
freqs = torch.outer(t, inv_freq)
|
| 88 |
+
|
| 89 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 90 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 91 |
+
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
|
| 92 |
+
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
|
| 93 |
+
|
| 94 |
+
def forward(self):
|
| 95 |
+
return self.cos_cached, self.sin_cached
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Attention(nn.Module):
|
| 99 |
+
def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.hidden_size = hidden_size
|
| 103 |
+
self.head_dim = head_dim
|
| 104 |
+
self.output_size = head_dim * num_heads
|
| 105 |
+
self.num_heads = num_heads
|
| 106 |
+
self.num_key_value_heads = num_key_value_heads
|
| 107 |
+
self.causal = causal
|
| 108 |
+
|
| 109 |
+
self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
|
| 110 |
+
self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
|
| 111 |
+
|
| 112 |
+
def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 114 |
+
|
| 115 |
+
# hidden_states: [bs, seq_len, num_heads, head_dim]
|
| 116 |
+
qkv = self.qkv_proj(hidden_states)
|
| 117 |
+
|
| 118 |
+
# Split head
|
| 119 |
+
qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
| 120 |
+
query = qkv[:, :, :self.num_heads]
|
| 121 |
+
key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
|
| 122 |
+
value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
|
| 123 |
+
|
| 124 |
+
# RoPE
|
| 125 |
+
if cos_sin is not None:
|
| 126 |
+
cos, sin = cos_sin
|
| 127 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 128 |
+
|
| 129 |
+
# flash attn
|
| 130 |
+
attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
|
| 131 |
+
if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
|
| 132 |
+
attn_output = attn_output[0]
|
| 133 |
+
|
| 134 |
+
# attn_output: [batch_size, num_heads, seq_len, head_dim]
|
| 135 |
+
attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
|
| 136 |
+
return self.o_proj(attn_output)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class SwiGLU(nn.Module):
|
| 140 |
+
def __init__(self, hidden_size: int, expansion: float):
|
| 141 |
+
super().__init__()
|
| 142 |
+
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
|
| 143 |
+
|
| 144 |
+
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
| 145 |
+
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
| 149 |
+
return self.down_proj(F.silu(gate) * up)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
|
| 153 |
+
input_dtype = hidden_states.dtype
|
| 154 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 155 |
+
|
| 156 |
+
variance = hidden_states.square().mean(-1, keepdim=True)
|
| 157 |
+
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| 158 |
+
return hidden_states.to(input_dtype)
|
models/losses.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Tuple, Dict, Sequence, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
IGNORE_LABEL_ID = -100
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def s(x, epsilon=1e-30):
|
| 12 |
+
return torch.where(
|
| 13 |
+
x<0,
|
| 14 |
+
1/(1-x+ epsilon),
|
| 15 |
+
x + 1
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def log_stablemax(x, dim=-1):
|
| 20 |
+
s_x = s(x)
|
| 21 |
+
return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 25 |
+
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
|
| 26 |
+
|
| 27 |
+
valid_mask = labels != ignore_index
|
| 28 |
+
transformed_labels = torch.where(valid_mask, labels, 0)
|
| 29 |
+
prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
|
| 30 |
+
|
| 31 |
+
return -torch.where(valid_mask, prediction_logprobs, 0)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
|
| 35 |
+
# Cast logits to f32
|
| 36 |
+
# Flatten logits
|
| 37 |
+
return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ACTLossHead(nn.Module):
|
| 41 |
+
def __init__(self, model: nn.Module, loss_type: str):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.model = model
|
| 44 |
+
self.loss_fn = globals()[loss_type]
|
| 45 |
+
|
| 46 |
+
def initial_carry(self, *args, **kwargs):
|
| 47 |
+
return self.model.initial_carry(*args, **kwargs) # type: ignore
|
| 48 |
+
|
| 49 |
+
def forward(
|
| 50 |
+
self,
|
| 51 |
+
return_keys: Sequence[str],
|
| 52 |
+
# Model args
|
| 53 |
+
**model_kwargs,
|
| 54 |
+
) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
|
| 55 |
+
# Model logits
|
| 56 |
+
# B x SeqLen x D
|
| 57 |
+
new_carry, outputs = self.model(**model_kwargs)
|
| 58 |
+
labels = new_carry.current_data["labels"]
|
| 59 |
+
|
| 60 |
+
# Correctness
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
mask = labels != IGNORE_LABEL_ID
|
| 63 |
+
loss_counts = mask.sum(-1)
|
| 64 |
+
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
|
| 65 |
+
|
| 66 |
+
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
|
| 67 |
+
seq_is_correct = is_correct.sum(-1) == loss_counts
|
| 68 |
+
|
| 69 |
+
# Metrics (halted)
|
| 70 |
+
valid_metrics = new_carry.halted & (loss_counts > 0)
|
| 71 |
+
metrics = {
|
| 72 |
+
"count": valid_metrics.sum(),
|
| 73 |
+
|
| 74 |
+
"accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
|
| 75 |
+
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
|
| 76 |
+
|
| 77 |
+
"q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
|
| 78 |
+
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Losses
|
| 82 |
+
# FIXME: Assuming the batch is always full
|
| 83 |
+
lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
|
| 84 |
+
q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
|
| 85 |
+
|
| 86 |
+
metrics.update({
|
| 87 |
+
"lm_loss": lm_loss.detach(),
|
| 88 |
+
"q_halt_loss": q_halt_loss.detach(),
|
| 89 |
+
})
|
| 90 |
+
|
| 91 |
+
# Q continue (bootstrapping target loss)
|
| 92 |
+
q_continue_loss = 0
|
| 93 |
+
if "target_q_continue" in outputs:
|
| 94 |
+
q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
|
| 95 |
+
|
| 96 |
+
metrics["q_continue_loss"] = q_continue_loss.detach()
|
| 97 |
+
|
| 98 |
+
# Filter outputs for return
|
| 99 |
+
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
|
| 100 |
+
|
| 101 |
+
return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
|
models/sparse_embedding.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
from torch.optim.optimizer import Optimizer, ParamsT
|
| 7 |
+
|
| 8 |
+
from models.common import trunc_normal_init_
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class CastedSparseEmbedding(nn.Module):
|
| 12 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.cast_to = cast_to
|
| 15 |
+
|
| 16 |
+
# Real Weights
|
| 17 |
+
# Truncated LeCun normal init
|
| 18 |
+
self.weights = nn.Buffer(
|
| 19 |
+
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Local weights and IDs
|
| 23 |
+
# Local embeddings, with gradient, not persistent
|
| 24 |
+
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
|
| 25 |
+
# Local embedding IDs, not persistent
|
| 26 |
+
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
|
| 27 |
+
|
| 28 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
if not self.training:
|
| 30 |
+
# Test mode, no gradient
|
| 31 |
+
return self.weights[inputs].to(self.cast_to)
|
| 32 |
+
|
| 33 |
+
# Training mode, fill puzzle embedding from weights
|
| 34 |
+
with torch.no_grad():
|
| 35 |
+
self.local_weights.copy_(self.weights[inputs])
|
| 36 |
+
self.local_ids.copy_(inputs)
|
| 37 |
+
|
| 38 |
+
return self.local_weights.to(self.cast_to)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
params: ParamsT,
|
| 45 |
+
|
| 46 |
+
world_size: int,
|
| 47 |
+
lr: Union[float, torch.Tensor] = 1e-3,
|
| 48 |
+
weight_decay: float = 1e-2,
|
| 49 |
+
):
|
| 50 |
+
if not 0.0 <= lr:
|
| 51 |
+
raise ValueError(f"Invalid learning rate: {lr}")
|
| 52 |
+
if not 0.0 <= weight_decay:
|
| 53 |
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
| 54 |
+
|
| 55 |
+
defaults = dict(
|
| 56 |
+
lr=lr,
|
| 57 |
+
weight_decay=weight_decay,
|
| 58 |
+
world_size=world_size
|
| 59 |
+
)
|
| 60 |
+
super().__init__(params, defaults)
|
| 61 |
+
|
| 62 |
+
@torch.no_grad
|
| 63 |
+
def step(self, closure=None): # type: ignore
|
| 64 |
+
for group in self.param_groups:
|
| 65 |
+
# Find the sparse embedding weights
|
| 66 |
+
local_weights_grad = None
|
| 67 |
+
local_ids = None
|
| 68 |
+
weights = None
|
| 69 |
+
|
| 70 |
+
assert len(group["params"]) == 3
|
| 71 |
+
for p in group["params"]:
|
| 72 |
+
if p.requires_grad:
|
| 73 |
+
local_weights_grad = p.grad
|
| 74 |
+
elif p.ndim == 1:
|
| 75 |
+
local_ids = p
|
| 76 |
+
elif p.ndim == 2:
|
| 77 |
+
weights = p
|
| 78 |
+
else:
|
| 79 |
+
assert False
|
| 80 |
+
|
| 81 |
+
assert local_weights_grad is not None
|
| 82 |
+
assert local_ids is not None
|
| 83 |
+
assert weights is not None
|
| 84 |
+
|
| 85 |
+
# Apply SignSGD
|
| 86 |
+
# Adam ≈ SignSGD if gradient is very sparse
|
| 87 |
+
_sparse_emb_signsgd_dist(
|
| 88 |
+
local_weights_grad,
|
| 89 |
+
local_ids,
|
| 90 |
+
weights,
|
| 91 |
+
|
| 92 |
+
lr=group["lr"],
|
| 93 |
+
weight_decay=group["weight_decay"],
|
| 94 |
+
world_size=group["world_size"]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _sparse_emb_signsgd_dist(
|
| 99 |
+
local_weights_grad: torch.Tensor,
|
| 100 |
+
local_ids: torch.Tensor,
|
| 101 |
+
weights: torch.Tensor,
|
| 102 |
+
|
| 103 |
+
lr: float,
|
| 104 |
+
weight_decay: float,
|
| 105 |
+
world_size: int
|
| 106 |
+
) -> None:
|
| 107 |
+
N, D = local_weights_grad.shape
|
| 108 |
+
|
| 109 |
+
# All-gather
|
| 110 |
+
all_weights_grad = local_weights_grad
|
| 111 |
+
all_ids = local_ids
|
| 112 |
+
|
| 113 |
+
if world_size > 1:
|
| 114 |
+
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
|
| 115 |
+
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
|
| 116 |
+
|
| 117 |
+
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
|
| 118 |
+
dist.all_gather_into_tensor(all_ids, local_ids)
|
| 119 |
+
|
| 120 |
+
# Unique
|
| 121 |
+
grad_ids, inv = all_ids.unique(return_inverse=True)
|
| 122 |
+
|
| 123 |
+
grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
|
| 124 |
+
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
|
| 125 |
+
|
| 126 |
+
# SignSGD with decoupled weight decay
|
| 127 |
+
p = weights[grad_ids]
|
| 128 |
+
|
| 129 |
+
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
|
| 130 |
+
|
| 131 |
+
# Write updated slices back
|
| 132 |
+
weights[grad_ids] = p
|