Tarive commited on
Commit
67e8c5e
·
verified ·
1 Parent(s): 26d8a81

Upload 4 files

Browse files
Files changed (4) hide show
  1. models/common.py +32 -0
  2. models/layers.py +158 -0
  3. models/losses.py +101 -0
  4. models/sparse_embedding.py +132 -0
models/common.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
8
+ # NOTE: PyTorch nn.init.trunc_normal_ is not mathematically correct, the std dev is not actually the std dev of initialized tensor
9
+ # This function is a PyTorch version of jax truncated normal init (default init method in flax)
10
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/random.py#L807-L848
11
+ # https://github.com/jax-ml/jax/blob/main/jax/_src/nn/initializers.py#L162-L199
12
+
13
+ with torch.no_grad():
14
+ if std == 0:
15
+ tensor.zero_()
16
+ else:
17
+ sqrt2 = math.sqrt(2)
18
+ a = math.erf(lower / sqrt2)
19
+ b = math.erf(upper / sqrt2)
20
+ z = (b - a) / 2
21
+
22
+ c = (2 * math.pi) ** -0.5
23
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
24
+ pdf_l = c * math.exp(-0.5 * upper ** 2)
25
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
26
+
27
+ tensor.uniform_(a, b)
28
+ tensor.erfinv_()
29
+ tensor.mul_(sqrt2 * comp_std)
30
+ tensor.clip_(lower * comp_std, upper * comp_std)
31
+
32
+ return tensor
models/layers.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+
7
+ try:
8
+ from flash_attn_interface import flash_attn_func # type: ignore[import]
9
+ except ImportError:
10
+ # Fallback to FlashAttention 2
11
+ from flash_attn import flash_attn_func # type: ignore[import]
12
+
13
+ from models.common import trunc_normal_init_
14
+
15
+
16
+ CosSin = Tuple[torch.Tensor, torch.Tensor]
17
+
18
+
19
+ def _find_multiple(a, b):
20
+ return (-(a // -b)) * b
21
+
22
+
23
+ def rotate_half(x: torch.Tensor):
24
+ """Rotates half the hidden dims of the input."""
25
+ x1 = x[..., : x.shape[-1] // 2]
26
+ x2 = x[..., x.shape[-1] // 2 :]
27
+ return torch.cat((-x2, x1), dim=-1)
28
+
29
+
30
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
31
+ # q, k: [bs, seq_len, num_heads, head_dim]
32
+ # cos, sin: [seq_len, head_dim]
33
+ orig_dtype = q.dtype
34
+ q = q.to(cos.dtype)
35
+ k = k.to(cos.dtype)
36
+
37
+ q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
38
+ k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
39
+
40
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
41
+
42
+
43
+ class CastedLinear(nn.Module):
44
+ def __init__(self,
45
+ in_features: int,
46
+ out_features: int,
47
+ bias: bool):
48
+ super().__init__()
49
+ # Truncated LeCun normal init
50
+ self.weight = nn.Parameter(
51
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
52
+ )
53
+ self.bias = None
54
+ if bias:
55
+ # Zero init bias
56
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
57
+
58
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
59
+ return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None)
60
+
61
+
62
+ class CastedEmbedding(nn.Module):
63
+ def __init__(self,
64
+ num_embeddings: int,
65
+ embedding_dim: int,
66
+ init_std: float,
67
+ cast_to: torch.dtype):
68
+ super().__init__()
69
+ self.cast_to = cast_to
70
+
71
+ # Truncated LeCun normal init
72
+ self.embedding_weight = nn.Parameter(
73
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
74
+ )
75
+
76
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
77
+ return F.embedding(input, self.embedding_weight.to(self.cast_to))
78
+
79
+
80
+ class RotaryEmbedding(nn.Module):
81
+ def __init__(self, dim, max_position_embeddings, base, device=None):
82
+ super().__init__()
83
+
84
+ # RoPE
85
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
86
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
87
+ freqs = torch.outer(t, inv_freq)
88
+
89
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
90
+ emb = torch.cat((freqs, freqs), dim=-1)
91
+ self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
92
+ self.sin_cached = nn.Buffer(emb.sin(), persistent=False)
93
+
94
+ def forward(self):
95
+ return self.cos_cached, self.sin_cached
96
+
97
+
98
+ class Attention(nn.Module):
99
+ def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):
100
+ super().__init__()
101
+
102
+ self.hidden_size = hidden_size
103
+ self.head_dim = head_dim
104
+ self.output_size = head_dim * num_heads
105
+ self.num_heads = num_heads
106
+ self.num_key_value_heads = num_key_value_heads
107
+ self.causal = causal
108
+
109
+ self.qkv_proj = CastedLinear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, bias=False)
110
+ self.o_proj = CastedLinear(self.output_size, self.hidden_size, bias=False)
111
+
112
+ def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:
113
+ batch_size, seq_len, _ = hidden_states.shape
114
+
115
+ # hidden_states: [bs, seq_len, num_heads, head_dim]
116
+ qkv = self.qkv_proj(hidden_states)
117
+
118
+ # Split head
119
+ qkv = qkv.view(batch_size, seq_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
120
+ query = qkv[:, :, :self.num_heads]
121
+ key = qkv[:, :, self.num_heads: self.num_heads + self.num_key_value_heads]
122
+ value = qkv[:, :, self.num_heads + self.num_key_value_heads:]
123
+
124
+ # RoPE
125
+ if cos_sin is not None:
126
+ cos, sin = cos_sin
127
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
128
+
129
+ # flash attn
130
+ attn_output = flash_attn_func(q=query, k=key, v=value, causal=self.causal)
131
+ if isinstance(attn_output, tuple): # fa2 and fa3 compatibility
132
+ attn_output = attn_output[0]
133
+
134
+ # attn_output: [batch_size, num_heads, seq_len, head_dim]
135
+ attn_output = attn_output.view(batch_size, seq_len, self.output_size) # type: ignore
136
+ return self.o_proj(attn_output)
137
+
138
+
139
+ class SwiGLU(nn.Module):
140
+ def __init__(self, hidden_size: int, expansion: float):
141
+ super().__init__()
142
+ inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)
143
+
144
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
145
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
146
+
147
+ def forward(self, x):
148
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
149
+ return self.down_proj(F.silu(gate) * up)
150
+
151
+
152
+ def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:
153
+ input_dtype = hidden_states.dtype
154
+ hidden_states = hidden_states.to(torch.float32)
155
+
156
+ variance = hidden_states.square().mean(-1, keepdim=True)
157
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
158
+ return hidden_states.to(input_dtype)
models/losses.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Tuple, Dict, Sequence, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ IGNORE_LABEL_ID = -100
9
+
10
+
11
+ def s(x, epsilon=1e-30):
12
+ return torch.where(
13
+ x<0,
14
+ 1/(1-x+ epsilon),
15
+ x + 1
16
+ )
17
+
18
+
19
+ def log_stablemax(x, dim=-1):
20
+ s_x = s(x)
21
+ return torch.log(s_x/torch.sum(s_x, dim=dim, keepdim=True))
22
+
23
+
24
+ def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):
25
+ logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
26
+
27
+ valid_mask = labels != ignore_index
28
+ transformed_labels = torch.where(valid_mask, labels, 0)
29
+ prediction_logprobs = torch.gather(logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1).squeeze(-1)
30
+
31
+ return -torch.where(valid_mask, prediction_logprobs, 0)
32
+
33
+
34
+ def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
35
+ # Cast logits to f32
36
+ # Flatten logits
37
+ return F.cross_entropy(logits.to(torch.float32).view(-1, logits.shape[-1]), labels.to(torch.long).view(-1), ignore_index=ignore_index, reduction="none").view(labels.shape)
38
+
39
+
40
+ class ACTLossHead(nn.Module):
41
+ def __init__(self, model: nn.Module, loss_type: str):
42
+ super().__init__()
43
+ self.model = model
44
+ self.loss_fn = globals()[loss_type]
45
+
46
+ def initial_carry(self, *args, **kwargs):
47
+ return self.model.initial_carry(*args, **kwargs) # type: ignore
48
+
49
+ def forward(
50
+ self,
51
+ return_keys: Sequence[str],
52
+ # Model args
53
+ **model_kwargs,
54
+ ) -> Tuple[Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor]:
55
+ # Model logits
56
+ # B x SeqLen x D
57
+ new_carry, outputs = self.model(**model_kwargs)
58
+ labels = new_carry.current_data["labels"]
59
+
60
+ # Correctness
61
+ with torch.no_grad():
62
+ mask = labels != IGNORE_LABEL_ID
63
+ loss_counts = mask.sum(-1)
64
+ loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
65
+
66
+ is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
67
+ seq_is_correct = is_correct.sum(-1) == loss_counts
68
+
69
+ # Metrics (halted)
70
+ valid_metrics = new_carry.halted & (loss_counts > 0)
71
+ metrics = {
72
+ "count": valid_metrics.sum(),
73
+
74
+ "accuracy": torch.where(valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0).sum(),
75
+ "exact_accuracy": (valid_metrics & seq_is_correct).sum(),
76
+
77
+ "q_halt_accuracy": (valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)).sum(),
78
+ "steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
79
+ }
80
+
81
+ # Losses
82
+ # FIXME: Assuming the batch is always full
83
+ lm_loss = (self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID) / loss_divisor).sum()
84
+ q_halt_loss = F.binary_cross_entropy_with_logits(outputs["q_halt_logits"], seq_is_correct.to(outputs["q_halt_logits"].dtype), reduction="sum")
85
+
86
+ metrics.update({
87
+ "lm_loss": lm_loss.detach(),
88
+ "q_halt_loss": q_halt_loss.detach(),
89
+ })
90
+
91
+ # Q continue (bootstrapping target loss)
92
+ q_continue_loss = 0
93
+ if "target_q_continue" in outputs:
94
+ q_continue_loss = F.binary_cross_entropy_with_logits(outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum")
95
+
96
+ metrics["q_continue_loss"] = q_continue_loss.detach()
97
+
98
+ # Filter outputs for return
99
+ detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
100
+
101
+ return new_carry, lm_loss + 0.5 * (q_halt_loss + q_continue_loss), metrics, detached_outputs, new_carry.halted.all()
models/sparse_embedding.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+ from torch.optim.optimizer import Optimizer, ParamsT
7
+
8
+ from models.common import trunc_normal_init_
9
+
10
+
11
+ class CastedSparseEmbedding(nn.Module):
12
+ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):
13
+ super().__init__()
14
+ self.cast_to = cast_to
15
+
16
+ # Real Weights
17
+ # Truncated LeCun normal init
18
+ self.weights = nn.Buffer(
19
+ trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
20
+ )
21
+
22
+ # Local weights and IDs
23
+ # Local embeddings, with gradient, not persistent
24
+ self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
25
+ # Local embedding IDs, not persistent
26
+ self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
27
+
28
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
29
+ if not self.training:
30
+ # Test mode, no gradient
31
+ return self.weights[inputs].to(self.cast_to)
32
+
33
+ # Training mode, fill puzzle embedding from weights
34
+ with torch.no_grad():
35
+ self.local_weights.copy_(self.weights[inputs])
36
+ self.local_ids.copy_(inputs)
37
+
38
+ return self.local_weights.to(self.cast_to)
39
+
40
+
41
+ class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
42
+ def __init__(
43
+ self,
44
+ params: ParamsT,
45
+
46
+ world_size: int,
47
+ lr: Union[float, torch.Tensor] = 1e-3,
48
+ weight_decay: float = 1e-2,
49
+ ):
50
+ if not 0.0 <= lr:
51
+ raise ValueError(f"Invalid learning rate: {lr}")
52
+ if not 0.0 <= weight_decay:
53
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
54
+
55
+ defaults = dict(
56
+ lr=lr,
57
+ weight_decay=weight_decay,
58
+ world_size=world_size
59
+ )
60
+ super().__init__(params, defaults)
61
+
62
+ @torch.no_grad
63
+ def step(self, closure=None): # type: ignore
64
+ for group in self.param_groups:
65
+ # Find the sparse embedding weights
66
+ local_weights_grad = None
67
+ local_ids = None
68
+ weights = None
69
+
70
+ assert len(group["params"]) == 3
71
+ for p in group["params"]:
72
+ if p.requires_grad:
73
+ local_weights_grad = p.grad
74
+ elif p.ndim == 1:
75
+ local_ids = p
76
+ elif p.ndim == 2:
77
+ weights = p
78
+ else:
79
+ assert False
80
+
81
+ assert local_weights_grad is not None
82
+ assert local_ids is not None
83
+ assert weights is not None
84
+
85
+ # Apply SignSGD
86
+ # Adam ≈ SignSGD if gradient is very sparse
87
+ _sparse_emb_signsgd_dist(
88
+ local_weights_grad,
89
+ local_ids,
90
+ weights,
91
+
92
+ lr=group["lr"],
93
+ weight_decay=group["weight_decay"],
94
+ world_size=group["world_size"]
95
+ )
96
+
97
+
98
+ def _sparse_emb_signsgd_dist(
99
+ local_weights_grad: torch.Tensor,
100
+ local_ids: torch.Tensor,
101
+ weights: torch.Tensor,
102
+
103
+ lr: float,
104
+ weight_decay: float,
105
+ world_size: int
106
+ ) -> None:
107
+ N, D = local_weights_grad.shape
108
+
109
+ # All-gather
110
+ all_weights_grad = local_weights_grad
111
+ all_ids = local_ids
112
+
113
+ if world_size > 1:
114
+ all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
115
+ all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
116
+
117
+ dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
118
+ dist.all_gather_into_tensor(all_ids, local_ids)
119
+
120
+ # Unique
121
+ grad_ids, inv = all_ids.unique(return_inverse=True)
122
+
123
+ grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device)
124
+ grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
125
+
126
+ # SignSGD with decoupled weight decay
127
+ p = weights[grad_ids]
128
+
129
+ p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
130
+
131
+ # Write updated slices back
132
+ weights[grad_ids] = p