Upload 5 files
Browse files- __init__.py +0 -0
- lnn.py +511 -0
- model.py +114 -0
- moe.py +88 -0
- pmb.py +210 -0
__init__.py
ADDED
|
File without changes
|
lnn.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 Quasar AI. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import math
|
| 18 |
+
from torch.nn import CrossEntropyLoss
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 21 |
+
from transformers.generation.utils import GenerationMixin
|
| 22 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 23 |
+
from transformers.utils.generic import ModelOutput
|
| 24 |
+
from typing import Optional, Tuple, List
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
from .pmb import ParameterMemoryBank
|
| 27 |
+
from .moe import MoELayer, Expert
|
| 28 |
+
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from torchdiffeq import odeint
|
| 33 |
+
except ImportError:
|
| 34 |
+
raise ImportError("torchdiffeq is not installed. Please install it with `pip install torchdiffeq`")
|
| 35 |
+
|
| 36 |
+
# --- 1. Configuration Class ---
|
| 37 |
+
class LNNConfig(PretrainedConfig):
|
| 38 |
+
"""
|
| 39 |
+
Configuration class for the Liquid Neural Network (LNN) model.
|
| 40 |
+
Inherits from HuggingFace's PretrainedConfig.
|
| 41 |
+
"""
|
| 42 |
+
model_type = "quasar"
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
vocab_size=151552,
|
| 47 |
+
hidden_size=8192,
|
| 48 |
+
num_hidden_layers=96, # 96 layers to keep active parameters manageable
|
| 49 |
+
activation='gelu',
|
| 50 |
+
lambda_res=0.0,
|
| 51 |
+
dt=0.2, # Step size for the fixed-step Euler solver.
|
| 52 |
+
initializer_range=0.02,
|
| 53 |
+
dropout=0.1,
|
| 54 |
+
use_pmb=False,
|
| 55 |
+
pmb_num_blocks=1024,
|
| 56 |
+
pmb_slots_per_block=4096,
|
| 57 |
+
pmb_top_k=1,
|
| 58 |
+
# MoE parameters
|
| 59 |
+
use_moe: bool = False,
|
| 60 |
+
num_experts: int = 407, # 407 experts to reach 440B total parameters
|
| 61 |
+
num_experts_per_tok: int = 4, # 4 active experts per token to maintain 25B active params
|
| 62 |
+
expert_dim: int = 32768, # 32K expert dimension for capacity
|
| 63 |
+
moe_load_balance_loss_weight: float = 0.01,
|
| 64 |
+
**kwargs
|
| 65 |
+
):
|
| 66 |
+
self.vocab_size = vocab_size
|
| 67 |
+
self.hidden_size = hidden_size
|
| 68 |
+
self.num_hidden_layers = num_hidden_layers
|
| 69 |
+
self.lambda_res = lambda_res
|
| 70 |
+
self.dt = dt
|
| 71 |
+
self.activation = activation
|
| 72 |
+
self.initializer_range = initializer_range
|
| 73 |
+
self.dropout = dropout
|
| 74 |
+
self.use_pmb = use_pmb
|
| 75 |
+
self.pmb_num_blocks = pmb_num_blocks
|
| 76 |
+
self.pmb_slots_per_block = pmb_slots_per_block
|
| 77 |
+
self.pmb_top_k = pmb_top_k
|
| 78 |
+
# MoE
|
| 79 |
+
self.use_moe = use_moe
|
| 80 |
+
self.num_experts = num_experts
|
| 81 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 82 |
+
self.expert_dim = expert_dim
|
| 83 |
+
self.moe_load_balance_loss_weight = moe_load_balance_loss_weight
|
| 84 |
+
super().__init__(**kwargs)
|
| 85 |
+
|
| 86 |
+
# --- 2. Custom Model Output ---
|
| 87 |
+
@dataclass
|
| 88 |
+
class LNNModelOutput(ModelOutput):
|
| 89 |
+
"""
|
| 90 |
+
Base class for LNN model's outputs, ensuring compatibility with HuggingFace.
|
| 91 |
+
"""
|
| 92 |
+
loss: Optional[torch.FloatTensor] = None
|
| 93 |
+
logits: torch.FloatTensor = None
|
| 94 |
+
last_hidden_state: torch.FloatTensor = None
|
| 95 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 96 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 97 |
+
load_balancing_loss: Optional[torch.FloatTensor] = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# --- 3. Core LNN Cell ---
|
| 101 |
+
class LNNCell(nn.Module):
|
| 102 |
+
"""A single Liquid Neural Network cell with continuous-time dynamics."""
|
| 103 |
+
def __init__(self, config: LNNConfig):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.hidden_size = config.hidden_size
|
| 106 |
+
self.lambda_res = config.lambda_res
|
| 107 |
+
|
| 108 |
+
# Core LNN parameters
|
| 109 |
+
self.W = nn.Parameter(torch.empty(config.hidden_size, config.hidden_size))
|
| 110 |
+
self.U = nn.Parameter(torch.empty(config.hidden_size, config.hidden_size))
|
| 111 |
+
self.b = nn.Parameter(torch.empty(config.hidden_size))
|
| 112 |
+
|
| 113 |
+
# Input-Dependent Dynamics
|
| 114 |
+
self.tau_w_h = nn.Linear(config.hidden_size, config.hidden_size)
|
| 115 |
+
self.tau_w_u = nn.Linear(config.hidden_size, config.hidden_size)
|
| 116 |
+
self.tau_b = nn.Parameter(torch.empty(config.hidden_size))
|
| 117 |
+
|
| 118 |
+
# Initialize weights
|
| 119 |
+
nn.init.orthogonal_(self.W) # Orthogonal init for recurrent weights
|
| 120 |
+
nn.init.xavier_uniform_(self.U)
|
| 121 |
+
nn.init.zeros_(self.b)
|
| 122 |
+
self.tau_b.data.uniform_(-2, 2)
|
| 123 |
+
|
| 124 |
+
self.sigma = nn.Tanh() # Use Tanh for bounded output and stability
|
| 125 |
+
|
| 126 |
+
def forward(self, h, u):
|
| 127 |
+
"""Core ODE dynamics calculation for a single discrete step."""
|
| 128 |
+
# 1. Compute Input-Dependent Time Constant (tau)
|
| 129 |
+
tau_control = self.tau_w_h(h) + self.tau_w_u(u) + self.tau_b
|
| 130 |
+
# Increased the floor from 0.01 to 1.0 to prevent division by a near-zero
|
| 131 |
+
# number, which is a common cause of NaN in bf16.
|
| 132 |
+
tau_positive = F.softplus(tau_control) + 1.0
|
| 133 |
+
|
| 134 |
+
# 2. Compute State Update
|
| 135 |
+
decay_term = -h / tau_positive
|
| 136 |
+
activation_input = F.linear(h, self.W) + F.linear(u, self.U) + self.b
|
| 137 |
+
activation_output = self.sigma(activation_input)
|
| 138 |
+
dx_dt = decay_term + activation_output
|
| 139 |
+
|
| 140 |
+
if self.lambda_res > 0:
|
| 141 |
+
dx_dt = dx_dt + self.lambda_res * u
|
| 142 |
+
|
| 143 |
+
# 3. Stability: Clip the derivative
|
| 144 |
+
dx_dt = torch.clamp(dx_dt, -10, 10)
|
| 145 |
+
return dx_dt
|
| 146 |
+
|
| 147 |
+
# --- 4. LNN Block (Layer + Residual) ---
|
| 148 |
+
class LNNBlock(nn.Module):
|
| 149 |
+
""" A single block of the LNN, using a fixed-step Euler loop. """
|
| 150 |
+
def __init__(self, config: LNNConfig):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.hidden_size = config.hidden_size
|
| 153 |
+
self.dt = config.dt
|
| 154 |
+
self.cell = LNNCell(config)
|
| 155 |
+
self.ln = nn.LayerNorm(config.hidden_size)
|
| 156 |
+
|
| 157 |
+
def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 158 |
+
"""
|
| 159 |
+
Processes the entire sequence using a fixed-step Euler integration loop,
|
| 160 |
+
starting from a given hidden state h.
|
| 161 |
+
This version is optimized to be JIT-friendly by pre-allocating the output tensor.
|
| 162 |
+
"""
|
| 163 |
+
seq_len = x.size(1)
|
| 164 |
+
# Pre-allocate tensor for outputs to avoid slow list appends
|
| 165 |
+
outputs = torch.empty(x.size(0), seq_len, self.hidden_size, device=x.device)
|
| 166 |
+
|
| 167 |
+
for t in range(seq_len):
|
| 168 |
+
u = x[:, t, :]
|
| 169 |
+
dx_dt = self.cell(h, u)
|
| 170 |
+
h = h + self.dt * dx_dt
|
| 171 |
+
# Clamp the hidden state to prevent runaway values, a common
|
| 172 |
+
# source of instability in recurrent models.
|
| 173 |
+
h = torch.clamp(h, -100, 100)
|
| 174 |
+
outputs[:, t, :] = h
|
| 175 |
+
|
| 176 |
+
# Add residual connection and layer norm
|
| 177 |
+
output = self.ln(outputs + x)
|
| 178 |
+
return output, h
|
| 179 |
+
|
| 180 |
+
# --- 5. Full LNN Model ---
|
| 181 |
+
class LNNModel(PreTrainedModel, GenerationMixin):
|
| 182 |
+
"""
|
| 183 |
+
The Liquid Neural Network Model.
|
| 184 |
+
This version restores the architecture from the high-performing `old_lnn.py`.
|
| 185 |
+
It uses stacked LNNBlocks to process the sequence and a Transformer-based
|
| 186 |
+
attention readout for global context before prediction.
|
| 187 |
+
"""
|
| 188 |
+
config_class = LNNConfig
|
| 189 |
+
|
| 190 |
+
def __init__(self, config: LNNConfig):
|
| 191 |
+
super().__init__(config)
|
| 192 |
+
self.config = config
|
| 193 |
+
|
| 194 |
+
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 195 |
+
self.blocks = nn.ModuleList([LNNBlock(config) for _ in range(config.num_hidden_layers)])
|
| 196 |
+
|
| 197 |
+
# JIT-compile the LNNBlocks for a significant performance boost
|
| 198 |
+
# Disabling JIT as a test, as it can sometimes cause unexpected memory allocation issues with recurrent loops.
|
| 199 |
+
# for i in range(len(self.blocks)):
|
| 200 |
+
# self.blocks[i] = torch.jit.script(self.blocks[i])
|
| 201 |
+
|
| 202 |
+
self.ln_final = nn.LayerNorm(config.hidden_size, eps=1e-5)
|
| 203 |
+
|
| 204 |
+
# The attention-based readout is removed to prevent the model from "cheating"
|
| 205 |
+
# by using self-attention on the whole sequence instead of relying on its
|
| 206 |
+
# recurrent state. This forces the LNN to learn more robust representations.
|
| 207 |
+
# self.readout = nn.TransformerEncoderLayer(...)
|
| 208 |
+
|
| 209 |
+
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size)
|
| 210 |
+
|
| 211 |
+
def get_input_embeddings(self):
|
| 212 |
+
return self.embedding
|
| 213 |
+
|
| 214 |
+
def set_input_embeddings(self, value):
|
| 215 |
+
self.embedding = value
|
| 216 |
+
|
| 217 |
+
def forward(
|
| 218 |
+
self,
|
| 219 |
+
input_ids: torch.LongTensor,
|
| 220 |
+
labels: Optional[torch.LongTensor] = None,
|
| 221 |
+
hidden_states: Optional[List[torch.Tensor]] = None,
|
| 222 |
+
attention_mask: Optional[torch.Tensor] = None, # Accept attention_mask
|
| 223 |
+
**kwargs, # Accept other arguments
|
| 224 |
+
) -> LNNModelOutput:
|
| 225 |
+
"""
|
| 226 |
+
Processes a sequence, calculates loss, and handles unexpected arguments.
|
| 227 |
+
The `attention_mask` is accepted but not used, as the LNN processes
|
| 228 |
+
the sequence recurrently.
|
| 229 |
+
"""
|
| 230 |
+
# 1. Get Embeddings
|
| 231 |
+
x = self.embedding(input_ids)
|
| 232 |
+
batch_size = input_ids.shape[0]
|
| 233 |
+
|
| 234 |
+
# 2. Initialize hidden states if not provided
|
| 235 |
+
if hidden_states is None:
|
| 236 |
+
hidden_states = [
|
| 237 |
+
torch.zeros(batch_size, self.config.hidden_size, device=x.device)
|
| 238 |
+
for _ in range(self.config.num_hidden_layers)
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
# 3. Process sequence through LNN blocks
|
| 242 |
+
new_hidden_states = []
|
| 243 |
+
layer_output = x
|
| 244 |
+
for i, block in enumerate(self.blocks):
|
| 245 |
+
h_initial = hidden_states[i]
|
| 246 |
+
layer_output, h_final = block(layer_output, h_initial)
|
| 247 |
+
new_hidden_states.append(h_final)
|
| 248 |
+
|
| 249 |
+
# 4. Final Projection (without attention readout)
|
| 250 |
+
final_output = self.ln_final(layer_output)
|
| 251 |
+
logits = self.proj_out(final_output)
|
| 252 |
+
|
| 253 |
+
# 5. Calculate loss if labels are provided
|
| 254 |
+
loss = None
|
| 255 |
+
if labels is not None:
|
| 256 |
+
# Shift so that logits at time t predict token at time t+1
|
| 257 |
+
# This is the standard procedure for training causal language models.
|
| 258 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 259 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 260 |
+
# Flatten the tokens and compute loss
|
| 261 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
| 262 |
+
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
|
| 263 |
+
|
| 264 |
+
return LNNModelOutput(
|
| 265 |
+
loss=loss,
|
| 266 |
+
logits=logits,
|
| 267 |
+
last_hidden_state=final_output,
|
| 268 |
+
hidden_states=tuple(new_hidden_states),
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def generate(
|
| 272 |
+
self,
|
| 273 |
+
input_ids: torch.LongTensor,
|
| 274 |
+
max_length: int = 100,
|
| 275 |
+
max_new_tokens: int = None,
|
| 276 |
+
temperature: float = 1.0,
|
| 277 |
+
top_k: int = 50,
|
| 278 |
+
top_p: float = 0.9,
|
| 279 |
+
do_sample: bool = True,
|
| 280 |
+
pad_token_id: int = None,
|
| 281 |
+
eos_token_id: int = None,
|
| 282 |
+
repetition_penalty: float = 1.0,
|
| 283 |
+
**kwargs
|
| 284 |
+
) -> torch.LongTensor:
|
| 285 |
+
"""
|
| 286 |
+
Generate text using the LNN model with improved repetition handling.
|
| 287 |
+
"""
|
| 288 |
+
batch_size = input_ids.shape[0]
|
| 289 |
+
device = input_ids.device
|
| 290 |
+
|
| 291 |
+
# Determine actual max length
|
| 292 |
+
if max_new_tokens is not None:
|
| 293 |
+
max_length = input_ids.shape[1] + max_new_tokens
|
| 294 |
+
|
| 295 |
+
# Initialize hidden states
|
| 296 |
+
hidden_states = [
|
| 297 |
+
torch.zeros(batch_size, self.config.hidden_size, device=device)
|
| 298 |
+
for _ in range(self.config.num_hidden_layers)
|
| 299 |
+
]
|
| 300 |
+
|
| 301 |
+
# Initialize output with input_ids
|
| 302 |
+
generated = input_ids.clone()
|
| 303 |
+
|
| 304 |
+
# Set model to evaluation mode
|
| 305 |
+
self.eval()
|
| 306 |
+
|
| 307 |
+
for step in range(max_length - input_ids.shape[1]):
|
| 308 |
+
# Get model output - only pass the last few tokens to avoid recomputing everything
|
| 309 |
+
context_length = min(generated.shape[1], 512) # Limit context to prevent memory issues
|
| 310 |
+
context_ids = generated[:, -context_length:]
|
| 311 |
+
|
| 312 |
+
with torch.no_grad():
|
| 313 |
+
outputs = self.forward(
|
| 314 |
+
input_ids=context_ids,
|
| 315 |
+
hidden_states=hidden_states if step == 0 else None # Only use initial hidden states
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Get logits for the last token
|
| 319 |
+
logits = outputs.logits[:, -1, :] # Shape: [batch_size, vocab_size]
|
| 320 |
+
|
| 321 |
+
# Apply repetition penalty
|
| 322 |
+
if repetition_penalty != 1.0:
|
| 323 |
+
for i in range(batch_size):
|
| 324 |
+
for token_id in set(generated[i].tolist()):
|
| 325 |
+
# If logit is positive, divide by penalty, else multiply
|
| 326 |
+
if logits[i, token_id] > 0:
|
| 327 |
+
logits[i, token_id] /= repetition_penalty
|
| 328 |
+
else:
|
| 329 |
+
logits[i, token_id] *= repetition_penalty
|
| 330 |
+
|
| 331 |
+
# Apply temperature
|
| 332 |
+
if temperature != 1.0:
|
| 333 |
+
logits = logits / temperature
|
| 334 |
+
|
| 335 |
+
# Apply top-k filtering
|
| 336 |
+
if top_k > 0:
|
| 337 |
+
top_k_values, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
|
| 338 |
+
indices_to_remove = logits < top_k_values[..., -1, None]
|
| 339 |
+
logits[indices_to_remove] = -float('inf')
|
| 340 |
+
|
| 341 |
+
# Apply top-p filtering
|
| 342 |
+
if top_p < 1.0:
|
| 343 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
| 344 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 345 |
+
|
| 346 |
+
# Remove tokens with cumulative probability above the threshold
|
| 347 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 348 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 349 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 350 |
+
|
| 351 |
+
# Convert back to original indices
|
| 352 |
+
indices_to_remove = sorted_indices_to_remove.gather(dim=-1, index=sorted_indices.argsort(dim=-1))
|
| 353 |
+
logits[indices_to_remove] = -float('inf')
|
| 354 |
+
|
| 355 |
+
# Sample next token
|
| 356 |
+
if do_sample:
|
| 357 |
+
probs = F.softmax(logits, dim=-1)
|
| 358 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 359 |
+
else:
|
| 360 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 361 |
+
|
| 362 |
+
# Append to generated sequence
|
| 363 |
+
generated = torch.cat([generated, next_token], dim=-1)
|
| 364 |
+
|
| 365 |
+
# Check for EOS token
|
| 366 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
return generated
|
| 370 |
+
|
| 371 |
+
def generate_simple(
|
| 372 |
+
self,
|
| 373 |
+
input_ids: torch.LongTensor,
|
| 374 |
+
max_length: int = 100,
|
| 375 |
+
temperature: float = 1.0,
|
| 376 |
+
do_sample: bool = True,
|
| 377 |
+
pad_token_id: int = None,
|
| 378 |
+
eos_token_id: int = None,
|
| 379 |
+
hidden_states: Optional[List[torch.Tensor]] = None,
|
| 380 |
+
**kwargs
|
| 381 |
+
) -> torch.LongTensor:
|
| 382 |
+
"""
|
| 383 |
+
Simple generate method without top-k/top-p sampling to avoid dimension issues.
|
| 384 |
+
"""
|
| 385 |
+
batch_size = input_ids.shape[0]
|
| 386 |
+
device = input_ids.device
|
| 387 |
+
|
| 388 |
+
# Initialize hidden states if not provided
|
| 389 |
+
if hidden_states is None:
|
| 390 |
+
hidden_states = [
|
| 391 |
+
torch.zeros(batch_size, self.config.hidden_size, device=device)
|
| 392 |
+
for _ in range(self.config.num_hidden_layers)
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
# Initialize output with input_ids
|
| 396 |
+
generated = input_ids.clone()
|
| 397 |
+
|
| 398 |
+
# Set model to evaluation mode
|
| 399 |
+
self.eval()
|
| 400 |
+
|
| 401 |
+
for _ in range(max_length - input_ids.shape[1]):
|
| 402 |
+
# Get model output
|
| 403 |
+
with torch.no_grad():
|
| 404 |
+
outputs = self.forward(
|
| 405 |
+
input_ids=generated,
|
| 406 |
+
hidden_states=hidden_states
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Get logits for the last token
|
| 410 |
+
logits = outputs.logits[:, -1, :] # Shape: [batch_size, vocab_size]
|
| 411 |
+
hidden_states = list(outputs.hidden_states)
|
| 412 |
+
|
| 413 |
+
# Apply temperature
|
| 414 |
+
if temperature != 1.0:
|
| 415 |
+
logits = logits / temperature
|
| 416 |
+
|
| 417 |
+
# Sample next token
|
| 418 |
+
if do_sample:
|
| 419 |
+
probs = F.softmax(logits, dim=-1)
|
| 420 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 421 |
+
else:
|
| 422 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 423 |
+
|
| 424 |
+
# Append to generated sequence
|
| 425 |
+
generated = torch.cat([generated, next_token], dim=-1)
|
| 426 |
+
|
| 427 |
+
# Check for EOS token
|
| 428 |
+
if eos_token_id is not None and (next_token == eos_token_id).all():
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
return generated
|
| 432 |
+
|
| 433 |
+
def prepare_inputs_for_generation(
|
| 434 |
+
self,
|
| 435 |
+
input_ids: torch.LongTensor,
|
| 436 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
| 437 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 438 |
+
use_cache: bool = True,
|
| 439 |
+
**kwargs
|
| 440 |
+
) -> dict:
|
| 441 |
+
"""
|
| 442 |
+
Prepare inputs for generation. For LNN, we use hidden_states instead of past_key_values.
|
| 443 |
+
"""
|
| 444 |
+
# For LNN, we don't use past_key_values in the traditional sense
|
| 445 |
+
# Instead, we rely on the recurrent nature of the model
|
| 446 |
+
model_inputs = {
|
| 447 |
+
"input_ids": input_ids,
|
| 448 |
+
"attention_mask": attention_mask,
|
| 449 |
+
"use_cache": use_cache,
|
| 450 |
+
}
|
| 451 |
+
return model_inputs
|
| 452 |
+
|
| 453 |
+
def _reorder_cache(self, past_key_values: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
|
| 454 |
+
"""
|
| 455 |
+
Reorder hidden states for beam search.
|
| 456 |
+
"""
|
| 457 |
+
if past_key_values is None:
|
| 458 |
+
return None
|
| 459 |
+
|
| 460 |
+
reordered_past = []
|
| 461 |
+
for hidden_state in past_key_values:
|
| 462 |
+
reordered_past.append(hidden_state.index_select(0, beam_idx))
|
| 463 |
+
return reordered_past
|
| 464 |
+
|
| 465 |
+
# --- 6. For Causal LM compatibility ---
|
| 466 |
+
class LNNForCausalLM(LNNModel):
|
| 467 |
+
"""
|
| 468 |
+
Wrapper class for compatibility with HuggingFace's CausalLM interface.
|
| 469 |
+
"""
|
| 470 |
+
def __init__(self, config: LNNConfig):
|
| 471 |
+
super().__init__(config)
|
| 472 |
+
self.lm_head = self.proj_out # Alias for compatibility
|
| 473 |
+
|
| 474 |
+
@property
|
| 475 |
+
def model(self):
|
| 476 |
+
"""Return self for compatibility with some HF utilities."""
|
| 477 |
+
return self
|
| 478 |
+
|
| 479 |
+
def get_output_embeddings(self):
|
| 480 |
+
return self.proj_out
|
| 481 |
+
|
| 482 |
+
def set_output_embeddings(self, new_embeddings):
|
| 483 |
+
self.proj_out = new_embeddings
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
input_ids: torch.LongTensor,
|
| 488 |
+
labels: Optional[torch.LongTensor] = None,
|
| 489 |
+
hidden_states: Optional[List[torch.Tensor]] = None,
|
| 490 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 491 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
| 492 |
+
use_cache: bool = True,
|
| 493 |
+
**kwargs,
|
| 494 |
+
) -> LNNModelOutput:
|
| 495 |
+
"""Forward pass that's compatible with CausalLM interface."""
|
| 496 |
+
return super().forward(
|
| 497 |
+
input_ids=input_ids,
|
| 498 |
+
labels=labels,
|
| 499 |
+
hidden_states=hidden_states,
|
| 500 |
+
attention_mask=attention_mask,
|
| 501 |
+
**kwargs
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# --- 7. Model registration ---
|
| 505 |
+
# Register the model with transformers
|
| 506 |
+
try:
|
| 507 |
+
from transformers import AutoModel, AutoModelForCausalLM
|
| 508 |
+
AutoModel.register(LNNConfig, LNNModel)
|
| 509 |
+
AutoModelForCausalLM.register(LNNConfig, LNNForCausalLM)
|
| 510 |
+
except ImportError:
|
| 511 |
+
pass # transformers not available or version doesn't support registration
|
model.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from .moe import MoELayer
|
| 7 |
+
|
| 8 |
+
class QuasarConfig(PretrainedConfig):
|
| 9 |
+
model_type = "quasar"
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
vocab_size=129280,
|
| 14 |
+
embedding_dim=8192,
|
| 15 |
+
num_hidden_layers=96, # 96 layers to keep active parameters manageable
|
| 16 |
+
num_attention_heads=64,
|
| 17 |
+
num_experts=407, # 407 experts to reach 440B total parameters
|
| 18 |
+
expert_dim=32768, # 32K expert dimension for capacity
|
| 19 |
+
top_k=4, # 4 active experts per token to maintain 25B active params
|
| 20 |
+
**kwargs
|
| 21 |
+
):
|
| 22 |
+
self.vocab_size = vocab_size
|
| 23 |
+
self.embedding_dim = embedding_dim
|
| 24 |
+
self.num_hidden_layers = num_hidden_layers
|
| 25 |
+
self.num_attention_heads = num_attention_heads
|
| 26 |
+
self.num_experts = num_experts
|
| 27 |
+
self.expert_dim = expert_dim
|
| 28 |
+
self.top_k = top_k
|
| 29 |
+
super().__init__(**kwargs)
|
| 30 |
+
|
| 31 |
+
class SelfAttention(nn.Module):
|
| 32 |
+
def __init__(self, config: QuasarConfig):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.num_heads = config.num_attention_heads
|
| 35 |
+
self.head_dim = config.embedding_dim // self.num_heads
|
| 36 |
+
self.q_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
|
| 37 |
+
self.k_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
|
| 38 |
+
self.v_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
|
| 39 |
+
self.out_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
batch_size, seq_len, _ = x.shape
|
| 43 |
+
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 44 |
+
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 45 |
+
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 46 |
+
|
| 47 |
+
attn_output = F.scaled_dot_product_attention(q, k, v)
|
| 48 |
+
|
| 49 |
+
output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
| 50 |
+
return self.out_proj(output)
|
| 51 |
+
|
| 52 |
+
class QuasarBlock(nn.Module):
|
| 53 |
+
def __init__(self, config: QuasarConfig):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.attention = SelfAttention(config)
|
| 56 |
+
self.moe_layer = MoELayer(
|
| 57 |
+
embedding_dim=config.embedding_dim,
|
| 58 |
+
num_experts=config.num_experts,
|
| 59 |
+
expert_dim=config.expert_dim,
|
| 60 |
+
top_k=config.top_k
|
| 61 |
+
)
|
| 62 |
+
self.ln1 = nn.LayerNorm(config.embedding_dim)
|
| 63 |
+
self.ln2 = nn.LayerNorm(config.embedding_dim)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
x = x + self.attention(self.ln1(x))
|
| 67 |
+
moe_out, lb_loss = self.moe_layer(self.ln2(x))
|
| 68 |
+
x = x + moe_out
|
| 69 |
+
return x, lb_loss
|
| 70 |
+
|
| 71 |
+
class Quasar(PreTrainedModel):
|
| 72 |
+
config_class = QuasarConfig
|
| 73 |
+
_supports_gradient_checkpointing = True
|
| 74 |
+
|
| 75 |
+
def __init__(self, config: QuasarConfig):
|
| 76 |
+
super().__init__(config)
|
| 77 |
+
self.config = config
|
| 78 |
+
self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
| 79 |
+
print(f"\nInitializing {config.num_hidden_layers} Quasar layers...")
|
| 80 |
+
self.layers = nn.ModuleList([QuasarBlock(config) for _ in tqdm(range(config.num_hidden_layers), desc="Creating Quasar Layers")])
|
| 81 |
+
self.final_ln = nn.LayerNorm(config.embedding_dim)
|
| 82 |
+
self.output_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
|
| 83 |
+
|
| 84 |
+
def forward(self, input_ids, labels=None, **kwargs):
|
| 85 |
+
x = self.embedding(input_ids)
|
| 86 |
+
total_lb_loss = 0.0
|
| 87 |
+
|
| 88 |
+
# Add config to kwargs for gradient checkpointing
|
| 89 |
+
kwargs['config'] = self.config
|
| 90 |
+
|
| 91 |
+
for layer in self.layers:
|
| 92 |
+
if self.is_gradient_checkpointing and self.training:
|
| 93 |
+
def create_custom_forward(module):
|
| 94 |
+
def custom_forward(*inputs):
|
| 95 |
+
return module(*inputs)
|
| 96 |
+
return custom_forward
|
| 97 |
+
x, lb_loss = torch.utils.checkpoint.checkpoint(create_custom_forward(layer), x, use_reentrant=False)
|
| 98 |
+
else:
|
| 99 |
+
x, lb_loss = layer(x)
|
| 100 |
+
total_lb_loss += lb_loss
|
| 101 |
+
|
| 102 |
+
x = self.final_ln(x)
|
| 103 |
+
logits = self.output_head(x)
|
| 104 |
+
|
| 105 |
+
loss = None
|
| 106 |
+
if labels is not None:
|
| 107 |
+
main_loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 108 |
+
loss = main_loss + total_lb_loss
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
'loss': loss,
|
| 112 |
+
'logits': logits,
|
| 113 |
+
'lb_loss': total_lb_loss
|
| 114 |
+
}
|
moe.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# c:\quasarv4\quasar\moe.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
class Expert(nn.Module):
|
| 9 |
+
"""An expert network. For Quasar, this could be an LNN layer followed by a feed-forward network."""
|
| 10 |
+
def __init__(self, embedding_dim, expert_dim):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.net = nn.Sequential(
|
| 13 |
+
nn.Linear(embedding_dim, expert_dim),
|
| 14 |
+
nn.GELU(),
|
| 15 |
+
nn.Linear(expert_dim, embedding_dim)
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return self.net(x)
|
| 20 |
+
|
| 21 |
+
class MoERouter(nn.Module):
|
| 22 |
+
"""A simple router that learns to dispatch tokens to experts."""
|
| 23 |
+
def __init__(self, embedding_dim, num_experts, top_k=2):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.top_k = top_k
|
| 26 |
+
self.gate = nn.Linear(embedding_dim, num_experts)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
""" Returns the top-k weights and indices for each token. """
|
| 30 |
+
gate_logits = self.gate(x.reshape(-1, x.shape[-1]))
|
| 31 |
+
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
|
| 32 |
+
top_k_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float).to(x.dtype)
|
| 33 |
+
return top_k_weights, top_k_indices
|
| 34 |
+
|
| 35 |
+
class MoELayer(nn.Module):
|
| 36 |
+
"""A Mixture of Experts layer."""
|
| 37 |
+
def __init__(self, embedding_dim, num_experts, expert_dim, top_k=2):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.router = MoERouter(embedding_dim, num_experts, top_k)
|
| 40 |
+
self.num_experts = num_experts
|
| 41 |
+
|
| 42 |
+
# Create experts
|
| 43 |
+
# Use a generator expression to avoid creating a temporary list of all experts in memory
|
| 44 |
+
self.experts = nn.ModuleList(Expert(embedding_dim, expert_dim) for _ in range(self.num_experts))
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
"""Forward pass for the MoE layer."""
|
| 48 |
+
original_shape = x.shape
|
| 49 |
+
flat_x = x.reshape(-1, x.shape[-1])
|
| 50 |
+
|
| 51 |
+
# Create the final output tensor on the correct device, avoiding meta-device issues.
|
| 52 |
+
final_output = torch.zeros(flat_x.shape, dtype=x.dtype, device=self.router.gate.weight.device)
|
| 53 |
+
|
| 54 |
+
# Get routing decisions from the router
|
| 55 |
+
top_k_weights, top_k_indices = self.router(x)
|
| 56 |
+
|
| 57 |
+
# Calculate load balancing loss using one_hot to be meta-tensor compatible
|
| 58 |
+
num_tokens = top_k_indices.size(0)
|
| 59 |
+
one_hot_indices = F.one_hot(top_k_indices, num_classes=self.num_experts).float()
|
| 60 |
+
tokens_per_expert = one_hot_indices.sum(dim=[0, 1])
|
| 61 |
+
router_probs_per_expert = torch.mean(F.softmax(self.router.gate.weight, dim=0), dim=1)
|
| 62 |
+
load_balancing_loss = self.num_experts * torch.dot(tokens_per_expert / num_tokens, router_probs_per_expert)
|
| 63 |
+
|
| 64 |
+
# Dispatch tokens to experts and aggregate outputs
|
| 65 |
+
for i in range(self.num_experts):
|
| 66 |
+
# Find which tokens are routed to this expert
|
| 67 |
+
expert_mask = (top_k_indices == i).any(dim=1)
|
| 68 |
+
expert_indices_for_expert = torch.where(expert_mask)[0]
|
| 69 |
+
|
| 70 |
+
if expert_indices_for_expert.numel() == 0:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
# Get the tokens for this expert
|
| 74 |
+
expert_tokens = flat_x[expert_indices_for_expert]
|
| 75 |
+
|
| 76 |
+
# Find the specific weight for this expert for each token
|
| 77 |
+
top_k_weights_for_expert = top_k_weights[expert_indices_for_expert]
|
| 78 |
+
is_expert_in_top_k = (top_k_indices[expert_indices_for_expert] == i)
|
| 79 |
+
weights_for_expert = torch.sum(top_k_weights_for_expert * is_expert_in_top_k, dim=1, keepdim=True)
|
| 80 |
+
|
| 81 |
+
# Process with expert and apply routing weight
|
| 82 |
+
expert_output = self.experts[i](expert_tokens)
|
| 83 |
+
weighted_output = expert_output * weights_for_expert
|
| 84 |
+
|
| 85 |
+
# Add the weighted output to the final output tensor at the correct positions
|
| 86 |
+
final_output.index_add_(0, expert_indices_for_expert, weighted_output)
|
| 87 |
+
|
| 88 |
+
return final_output.reshape(original_shape), load_balancing_loss
|
pmb.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import hashlib
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
class ParameterMemoryBank:
|
| 6 |
+
"""
|
| 7 |
+
Parameter Memory Bank (PMB) for infinite, queryable memory.
|
| 8 |
+
|
| 9 |
+
This implementation uses a two-level hashing system for constant-time
|
| 10 |
+
direct access and supports semantic similarity search.
|
| 11 |
+
|
| 12 |
+
- Level 1: A list of 'blocks'.
|
| 13 |
+
- Level 2: Each block is a dictionary-like structure mapping slots to items.
|
| 14 |
+
|
| 15 |
+
For simplicity, we use Python lists and dictionaries. A production system
|
| 16 |
+
would use a more optimized backend (e.g., Redis, custom memory store).
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, num_blocks=1024, slots_per_block=4096, embedding_dim=None):
|
| 19 |
+
self.num_blocks = num_blocks
|
| 20 |
+
self.slots_per_block = slots_per_block
|
| 21 |
+
self.embedding_dim = embedding_dim
|
| 22 |
+
|
| 23 |
+
# PMB is a list of blocks, where each block is a list of slots.
|
| 24 |
+
# Each slot can hold a tuple: (id, key_embedding, value)
|
| 25 |
+
self.pmb = [ [None] * slots_per_block for _ in range(num_blocks) ]
|
| 26 |
+
|
| 27 |
+
# For semantic search, we need a separate structure to hold all keys.
|
| 28 |
+
# This is a trade-off for efficient similarity search.
|
| 29 |
+
self.all_keys = []
|
| 30 |
+
self.key_locations = [] # Stores (block_idx, slot_idx) for each key
|
| 31 |
+
|
| 32 |
+
def _hash_fn(self, s, salt=""):
|
| 33 |
+
"""A simple, salted hash function."""
|
| 34 |
+
return int(hashlib.sha256((str(s) + salt).encode()).hexdigest(), 16)
|
| 35 |
+
|
| 36 |
+
def _get_hash_indices(self, item_id):
|
| 37 |
+
"""
|
| 38 |
+
Calculates the block and slot indices for a given item ID using
|
| 39 |
+
the two-level hashing scheme.
|
| 40 |
+
"""
|
| 41 |
+
block_hash = self._hash_fn(item_id, salt="block")
|
| 42 |
+
block_idx = block_hash % self.num_blocks
|
| 43 |
+
|
| 44 |
+
slot_hash = self._hash_fn(item_id, salt=f"slot_{block_idx}")
|
| 45 |
+
slot_idx = slot_hash % self.slots_per_block
|
| 46 |
+
|
| 47 |
+
return block_idx, slot_idx
|
| 48 |
+
|
| 49 |
+
def store(self, item_id, key_embedding, value):
|
| 50 |
+
"""
|
| 51 |
+
Stores a key-value pair in the PMB using its ID.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
item_id (str or int): A unique identifier for the data.
|
| 55 |
+
key_embedding (torch.Tensor): The embedding vector (k_i,j).
|
| 56 |
+
value (any): The data to store (v_i,j), e.g., text, metadata.
|
| 57 |
+
"""
|
| 58 |
+
if not isinstance(key_embedding, torch.Tensor):
|
| 59 |
+
raise TypeError("key_embedding must be a torch.Tensor")
|
| 60 |
+
|
| 61 |
+
block_idx, slot_idx = self._get_hash_indices(item_id)
|
| 62 |
+
|
| 63 |
+
# Store the item in the hash-based location.
|
| 64 |
+
# Note: This simple implementation doesn't handle hash collisions.
|
| 65 |
+
# A real system would need a collision resolution strategy (e.g., cuckoo hashing, chaining).
|
| 66 |
+
if self.pmb[block_idx][slot_idx] is not None:
|
| 67 |
+
# Handle collision by updating the existing entry or finding an empty slot
|
| 68 |
+
pass # For now, just overwrite
|
| 69 |
+
|
| 70 |
+
self.pmb[block_idx][slot_idx] = (item_id, key_embedding.detach().cpu(), value.detach().cpu() if isinstance(value, torch.Tensor) else value)
|
| 71 |
+
|
| 72 |
+
# Also store the key for semantic search
|
| 73 |
+
self.all_keys.append(key_embedding.detach().cpu())
|
| 74 |
+
self.key_locations.append((block_idx, slot_idx))
|
| 75 |
+
|
| 76 |
+
def retrieve_direct(self, item_id):
|
| 77 |
+
"""
|
| 78 |
+
Retrieves a value directly using its ID in O(1) time.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
item_id (str or int): The unique identifier of the item.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
The stored value, or None if not found.
|
| 85 |
+
"""
|
| 86 |
+
block_idx, slot_idx = self._get_hash_indices(item_id)
|
| 87 |
+
item = self.pmb[block_idx][slot_idx]
|
| 88 |
+
|
| 89 |
+
# Check if the found item ID matches, in case of no collision handling
|
| 90 |
+
if item and item[0] == item_id:
|
| 91 |
+
return item[2] # Return the value
|
| 92 |
+
return None
|
| 93 |
+
|
| 94 |
+
def retrieve_by_indices(self, indices):
|
| 95 |
+
"""
|
| 96 |
+
Retrieves items by their indices in the `all_keys` list.
|
| 97 |
+
Args:
|
| 98 |
+
indices (list or torch.Tensor): A list of indices.
|
| 99 |
+
Returns:
|
| 100 |
+
A list of the retrieved values.
|
| 101 |
+
"""
|
| 102 |
+
results = []
|
| 103 |
+
for idx in indices:
|
| 104 |
+
if idx < len(self.key_locations):
|
| 105 |
+
block_idx, slot_idx = self.key_locations[idx]
|
| 106 |
+
item = self.pmb[block_idx][slot_idx]
|
| 107 |
+
if item:
|
| 108 |
+
value = item[2] # Get the value
|
| 109 |
+
# Convert back to tensor if it was stored as tensor
|
| 110 |
+
if isinstance(value, torch.Tensor):
|
| 111 |
+
results.append(value)
|
| 112 |
+
else:
|
| 113 |
+
# If value is not a tensor, create a zero tensor of appropriate size
|
| 114 |
+
if self.embedding_dim:
|
| 115 |
+
results.append(torch.zeros(self.embedding_dim))
|
| 116 |
+
else:
|
| 117 |
+
# Fallback: use the key embedding as value
|
| 118 |
+
results.append(item[1]) # Use key embedding
|
| 119 |
+
else:
|
| 120 |
+
# No item found, append zero tensor
|
| 121 |
+
if self.embedding_dim:
|
| 122 |
+
results.append(torch.zeros(self.embedding_dim))
|
| 123 |
+
else:
|
| 124 |
+
results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
|
| 125 |
+
else:
|
| 126 |
+
# Index out of range
|
| 127 |
+
if self.embedding_dim:
|
| 128 |
+
results.append(torch.zeros(self.embedding_dim))
|
| 129 |
+
else:
|
| 130 |
+
results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
def retrieve_semantic(self, query_embeddings, top_k=1):
|
| 134 |
+
"""
|
| 135 |
+
Retrieves the top_k most semantically similar items for a batch of query embeddings.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
query_embeddings (torch.Tensor): Query vectors (batch_size, embedding_dim) or (batch_size, seq_len, embedding_dim).
|
| 139 |
+
top_k (int): The number of similar items to return for each query.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
A tensor of the aggregated retrieved values with the same shape as query_embeddings.
|
| 143 |
+
"""
|
| 144 |
+
if not self.all_keys or top_k == 0:
|
| 145 |
+
return torch.zeros_like(query_embeddings)
|
| 146 |
+
|
| 147 |
+
if not isinstance(query_embeddings, torch.Tensor):
|
| 148 |
+
raise TypeError("query_embeddings must be a torch.Tensor")
|
| 149 |
+
|
| 150 |
+
# Store original shape and device
|
| 151 |
+
original_shape = query_embeddings.shape
|
| 152 |
+
device = query_embeddings.device
|
| 153 |
+
|
| 154 |
+
# Flatten query embeddings to 2D for processing
|
| 155 |
+
if query_embeddings.dim() > 2:
|
| 156 |
+
query_flat = query_embeddings.view(-1, original_shape[-1])
|
| 157 |
+
else:
|
| 158 |
+
query_flat = query_embeddings
|
| 159 |
+
|
| 160 |
+
# Handle empty memory bank
|
| 161 |
+
if not self.all_keys:
|
| 162 |
+
return torch.zeros_like(query_embeddings)
|
| 163 |
+
|
| 164 |
+
try:
|
| 165 |
+
# Stack all keys into a single tensor
|
| 166 |
+
all_keys_tensor = torch.stack(self.all_keys, dim=0).to(device)
|
| 167 |
+
|
| 168 |
+
# Compute cosine similarity
|
| 169 |
+
query_norm = torch.nn.functional.normalize(query_flat, p=2, dim=-1)
|
| 170 |
+
keys_norm = torch.nn.functional.normalize(all_keys_tensor, p=2, dim=-1)
|
| 171 |
+
|
| 172 |
+
# Compute similarities: (batch_size, num_keys)
|
| 173 |
+
similarities = torch.mm(query_norm, keys_norm.T)
|
| 174 |
+
|
| 175 |
+
# Get top_k results for each query
|
| 176 |
+
k = min(top_k, len(self.all_keys))
|
| 177 |
+
if k > 0:
|
| 178 |
+
top_k_scores, top_k_indices = torch.topk(similarities, k=k, dim=1)
|
| 179 |
+
|
| 180 |
+
# Retrieve the corresponding values
|
| 181 |
+
batch_results = []
|
| 182 |
+
for i in range(query_flat.size(0)):
|
| 183 |
+
retrieved_values = self.retrieve_by_indices(top_k_indices[i].cpu().tolist())
|
| 184 |
+
|
| 185 |
+
if retrieved_values:
|
| 186 |
+
# Stack and move to correct device
|
| 187 |
+
stacked_values = torch.stack(retrieved_values, dim=0).to(device)
|
| 188 |
+
# Average the top_k retrieved values
|
| 189 |
+
aggregated_value = torch.mean(stacked_values, dim=0)
|
| 190 |
+
batch_results.append(aggregated_value)
|
| 191 |
+
else:
|
| 192 |
+
# No valid retrievals, use zero tensor
|
| 193 |
+
batch_results.append(torch.zeros(original_shape[-1], device=device))
|
| 194 |
+
|
| 195 |
+
# Stack all batch results
|
| 196 |
+
if batch_results:
|
| 197 |
+
result = torch.stack(batch_results, dim=0)
|
| 198 |
+
# Reshape back to original shape
|
| 199 |
+
return result.view(original_shape)
|
| 200 |
+
else:
|
| 201 |
+
return torch.zeros_like(query_embeddings)
|
| 202 |
+
else:
|
| 203 |
+
return torch.zeros_like(query_embeddings)
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Error in PMB retrieve_semantic: {e}")
|
| 207 |
+
return torch.zeros_like(query_embeddings)
|
| 208 |
+
|
| 209 |
+
def __len__(self):
|
| 210 |
+
return len(self.all_keys)
|