import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # --------------------- # Utility Layers # --------------------- class RMSNorm(nn.Module): def __init__(self, d: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d)) def forward(self, x): norm = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(norm + self.eps) return self.weight * x class FeedForward(nn.Module): def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0): super().__init__() inner = d_model * mult self.net = nn.Sequential( nn.Linear(d_model, inner * 2), # GEGLU nn.GLU(dim=-1), nn.Linear(inner, d_model), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) # --------------------- # SSW Components # --------------------- class LocalTextureConv(nn.Module): """Depthwise 1D conv + GLU gate. Causal padding. O(n * d * k) with small k.""" def __init__(self, d_model: int, kernel_size: int = 7): super().__init__() assert kernel_size % 2 == 1, "kernel_size should be odd for simple causal pad" self.dw = nn.Conv1d(d_model, d_model, kernel_size, groups=d_model, padding=kernel_size-1) self.pw = nn.Conv1d(d_model, 2 * d_model, 1) def forward(self, x): # x: (B, T, C) x_c = x.transpose(1, 2) # (B, C, T) y = self.dw(x_c) T = x.size(1) y = y[..., :T] # causal crop y = self.pw(y).transpose(1, 2) # (B, T, 2C) y = F.glu(y, dim=-1) # (B, T, C) return y class GlobalStatePropagation(nn.Module): """Simplified selective SSM-like recurrence (toy, readable).""" def __init__(self, d_model: int, state_size: int = 128): super().__init__() self.state_size = state_size self.inp = nn.Linear(d_model, state_size * 3) self.out = nn.Linear(state_size, d_model) def forward(self, x): B, T, _ = x.size() u, f, r = self.inp(x).chunk(3, dim=-1) f = torch.sigmoid(f) r = torch.sigmoid(r) u = torch.tanh(u) h = torch.zeros(B, self.state_size, device=x.device, dtype=x.dtype) outs = [] for t in range(T): h = f[:, t] * h + (1 - f[:, t]) * u[:, t] outs.append(r[:, t] * h) y = torch.stack(outs, dim=1) # (B, T, S) return self.out(y) # (B, T, C) class ContentBasedSummarizer(nn.Module): """Top-k sparse attention over history (causal).""" def __init__(self, d_model: int, top_k: int = 8): super().__init__() self.k = top_k self.q = nn.Linear(d_model, d_model, bias=False) self.kv = nn.Linear(d_model, 2 * d_model, bias=False) self.scale = 1.0 / math.sqrt(d_model) self.scorer = nn.Linear(d_model, 1, bias=False) def forward(self, x): B, T, C = x.size() q = self.q(x) k, v = self.kv(x).chunk(2, dim=-1) imp = self.scorer(x).squeeze(-1) # (B, T) out = torch.zeros_like(x) for t in range(T): topk = min(self.k, t + 1) vals, idx = torch.topk(imp[:, :t+1], k=topk, dim=-1) k_sel = torch.gather(k[:, :t+1, :], 1, idx.unsqueeze(-1).expand(-1, -1, C)) v_sel = torch.gather(v[:, :t+1, :], 1, idx.unsqueeze(-1).expand(-1, -1, C)) q_t = q[:, t:t+1, :] att = torch.matmul(q_t, k_sel.transpose(1, 2)) * self.scale att = F.softmax(att, dim=-1) out[:, t:t+1, :] = torch.matmul(att, v_sel) return out class WeaverBlock(nn.Module): def __init__(self, d_model: int, ltc_kernel: int, gsp_state: int, cbs_topk: int, dropout: float): super().__init__() self.norm1 = RMSNorm(d_model) self.ltc = LocalTextureConv(d_model, kernel_size=ltc_kernel) self.gsp = GlobalStatePropagation(d_model, state_size=gsp_state) self.cbs = ContentBasedSummarizer(d_model, top_k=cbs_topk) self.mix = nn.Linear(d_model * 3, d_model) self.dropout = nn.Dropout(dropout) self.norm2 = RMSNorm(d_model) self.ff = FeedForward(d_model, mult=4, dropout=dropout) def forward(self, x): h = self.norm1(x) a = self.ltc(h) b = self.gsp(h) c = self.cbs(h) h = self.mix(torch.cat([a, b, c], dim=-1)) x = x + self.dropout(h) x = x + self.ff(self.norm2(x)) return x class SSWLM(nn.Module): def __init__(self, vocab_size: int, d_model: int = 512, n_layers: int = 8, ltc_kernel: int = 7, gsp_state: int = 128, cbs_topk: int = 8, dropout: float = 0.1, max_seq_len: int = 1024): super().__init__() self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(max_seq_len, d_model) self.layers = nn.ModuleList([ WeaverBlock(d_model, ltc_kernel, gsp_state, cbs_topk, dropout) for _ in range(n_layers) ]) self.norm = RMSNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) self.max_seq_len = max_seq_len def forward(self, input_ids: torch.Tensor): B, T = input_ids.size() assert T <= self.max_seq_len, "sequence too long" pos = torch.arange(T, device=input_ids.device) x = self.tok_emb(input_ids) + self.pos_emb(pos)[None, :, :] for blk in self.layers: x = blk(x) x = self.norm(x) return self.head(x) @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.1, eos_token_id: Optional[int] = None, ): self.eval() for _ in range(max_new_tokens): inp = input_ids[:, -self.max_seq_len:] logits = self.forward(inp)[:, -1, :] / max(1e-6, temperature) # repetition penalty (simple): downweight already seen token logits if repetition_penalty and repetition_penalty > 1.0: for b in range(input_ids.size(0)): seen = torch.bincount(input_ids[b], minlength=logits.size(-1)).bool() logits[b, seen] /= repetition_penalty # top-k filter if top_k and top_k > 0: k = min(top_k, logits.size(-1)) topk_vals, topk_idx = torch.topk(logits, k=k, dim=-1) mask = torch.full_like(logits, float("-inf")) logits = mask.scatter(1, topk_idx, topk_vals) # nucleus (top-p) filter if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) probs = torch.softmax(sorted_logits, dim=-1) cumsum = torch.cumsum(probs, dim=-1) cutoff = cumsum > top_p cutoff[..., 0] = False # keep at least one sorted_logits[cutoff] = float("-inf") # unsort back inv_idx = torch.argsort(sorted_idx, dim=-1) logits = torch.gather(sorted_logits, 1, inv_idx) probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return input_ids