eyad-silx commited on
Commit
f03ee14
·
verified ·
1 Parent(s): 8d93384

Upload 5 files

Browse files
Files changed (5) hide show
  1. __init__.py +0 -0
  2. lnn.py +511 -0
  3. model.py +114 -0
  4. moe.py +88 -0
  5. pmb.py +210 -0
__init__.py ADDED
File without changes
lnn.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Quasar AI. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import math
18
+ from torch.nn import CrossEntropyLoss
19
+ import torch.nn.functional as F
20
+ from transformers import PreTrainedModel, PretrainedConfig
21
+ from transformers.generation.utils import GenerationMixin
22
+ from transformers.modeling_outputs import CausalLMOutputWithPast
23
+ from transformers.utils.generic import ModelOutput
24
+ from typing import Optional, Tuple, List
25
+ from dataclasses import dataclass
26
+ from .pmb import ParameterMemoryBank
27
+ from .moe import MoELayer, Expert
28
+
29
+ from tqdm import tqdm
30
+
31
+ try:
32
+ from torchdiffeq import odeint
33
+ except ImportError:
34
+ raise ImportError("torchdiffeq is not installed. Please install it with `pip install torchdiffeq`")
35
+
36
+ # --- 1. Configuration Class ---
37
+ class LNNConfig(PretrainedConfig):
38
+ """
39
+ Configuration class for the Liquid Neural Network (LNN) model.
40
+ Inherits from HuggingFace's PretrainedConfig.
41
+ """
42
+ model_type = "quasar"
43
+
44
+ def __init__(
45
+ self,
46
+ vocab_size=151552,
47
+ hidden_size=8192,
48
+ num_hidden_layers=96, # 96 layers to keep active parameters manageable
49
+ activation='gelu',
50
+ lambda_res=0.0,
51
+ dt=0.2, # Step size for the fixed-step Euler solver.
52
+ initializer_range=0.02,
53
+ dropout=0.1,
54
+ use_pmb=False,
55
+ pmb_num_blocks=1024,
56
+ pmb_slots_per_block=4096,
57
+ pmb_top_k=1,
58
+ # MoE parameters
59
+ use_moe: bool = False,
60
+ num_experts: int = 407, # 407 experts to reach 440B total parameters
61
+ num_experts_per_tok: int = 4, # 4 active experts per token to maintain 25B active params
62
+ expert_dim: int = 32768, # 32K expert dimension for capacity
63
+ moe_load_balance_loss_weight: float = 0.01,
64
+ **kwargs
65
+ ):
66
+ self.vocab_size = vocab_size
67
+ self.hidden_size = hidden_size
68
+ self.num_hidden_layers = num_hidden_layers
69
+ self.lambda_res = lambda_res
70
+ self.dt = dt
71
+ self.activation = activation
72
+ self.initializer_range = initializer_range
73
+ self.dropout = dropout
74
+ self.use_pmb = use_pmb
75
+ self.pmb_num_blocks = pmb_num_blocks
76
+ self.pmb_slots_per_block = pmb_slots_per_block
77
+ self.pmb_top_k = pmb_top_k
78
+ # MoE
79
+ self.use_moe = use_moe
80
+ self.num_experts = num_experts
81
+ self.num_experts_per_tok = num_experts_per_tok
82
+ self.expert_dim = expert_dim
83
+ self.moe_load_balance_loss_weight = moe_load_balance_loss_weight
84
+ super().__init__(**kwargs)
85
+
86
+ # --- 2. Custom Model Output ---
87
+ @dataclass
88
+ class LNNModelOutput(ModelOutput):
89
+ """
90
+ Base class for LNN model's outputs, ensuring compatibility with HuggingFace.
91
+ """
92
+ loss: Optional[torch.FloatTensor] = None
93
+ logits: torch.FloatTensor = None
94
+ last_hidden_state: torch.FloatTensor = None
95
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
96
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
97
+ load_balancing_loss: Optional[torch.FloatTensor] = None
98
+
99
+
100
+ # --- 3. Core LNN Cell ---
101
+ class LNNCell(nn.Module):
102
+ """A single Liquid Neural Network cell with continuous-time dynamics."""
103
+ def __init__(self, config: LNNConfig):
104
+ super().__init__()
105
+ self.hidden_size = config.hidden_size
106
+ self.lambda_res = config.lambda_res
107
+
108
+ # Core LNN parameters
109
+ self.W = nn.Parameter(torch.empty(config.hidden_size, config.hidden_size))
110
+ self.U = nn.Parameter(torch.empty(config.hidden_size, config.hidden_size))
111
+ self.b = nn.Parameter(torch.empty(config.hidden_size))
112
+
113
+ # Input-Dependent Dynamics
114
+ self.tau_w_h = nn.Linear(config.hidden_size, config.hidden_size)
115
+ self.tau_w_u = nn.Linear(config.hidden_size, config.hidden_size)
116
+ self.tau_b = nn.Parameter(torch.empty(config.hidden_size))
117
+
118
+ # Initialize weights
119
+ nn.init.orthogonal_(self.W) # Orthogonal init for recurrent weights
120
+ nn.init.xavier_uniform_(self.U)
121
+ nn.init.zeros_(self.b)
122
+ self.tau_b.data.uniform_(-2, 2)
123
+
124
+ self.sigma = nn.Tanh() # Use Tanh for bounded output and stability
125
+
126
+ def forward(self, h, u):
127
+ """Core ODE dynamics calculation for a single discrete step."""
128
+ # 1. Compute Input-Dependent Time Constant (tau)
129
+ tau_control = self.tau_w_h(h) + self.tau_w_u(u) + self.tau_b
130
+ # Increased the floor from 0.01 to 1.0 to prevent division by a near-zero
131
+ # number, which is a common cause of NaN in bf16.
132
+ tau_positive = F.softplus(tau_control) + 1.0
133
+
134
+ # 2. Compute State Update
135
+ decay_term = -h / tau_positive
136
+ activation_input = F.linear(h, self.W) + F.linear(u, self.U) + self.b
137
+ activation_output = self.sigma(activation_input)
138
+ dx_dt = decay_term + activation_output
139
+
140
+ if self.lambda_res > 0:
141
+ dx_dt = dx_dt + self.lambda_res * u
142
+
143
+ # 3. Stability: Clip the derivative
144
+ dx_dt = torch.clamp(dx_dt, -10, 10)
145
+ return dx_dt
146
+
147
+ # --- 4. LNN Block (Layer + Residual) ---
148
+ class LNNBlock(nn.Module):
149
+ """ A single block of the LNN, using a fixed-step Euler loop. """
150
+ def __init__(self, config: LNNConfig):
151
+ super().__init__()
152
+ self.hidden_size = config.hidden_size
153
+ self.dt = config.dt
154
+ self.cell = LNNCell(config)
155
+ self.ln = nn.LayerNorm(config.hidden_size)
156
+
157
+ def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
158
+ """
159
+ Processes the entire sequence using a fixed-step Euler integration loop,
160
+ starting from a given hidden state h.
161
+ This version is optimized to be JIT-friendly by pre-allocating the output tensor.
162
+ """
163
+ seq_len = x.size(1)
164
+ # Pre-allocate tensor for outputs to avoid slow list appends
165
+ outputs = torch.empty(x.size(0), seq_len, self.hidden_size, device=x.device)
166
+
167
+ for t in range(seq_len):
168
+ u = x[:, t, :]
169
+ dx_dt = self.cell(h, u)
170
+ h = h + self.dt * dx_dt
171
+ # Clamp the hidden state to prevent runaway values, a common
172
+ # source of instability in recurrent models.
173
+ h = torch.clamp(h, -100, 100)
174
+ outputs[:, t, :] = h
175
+
176
+ # Add residual connection and layer norm
177
+ output = self.ln(outputs + x)
178
+ return output, h
179
+
180
+ # --- 5. Full LNN Model ---
181
+ class LNNModel(PreTrainedModel, GenerationMixin):
182
+ """
183
+ The Liquid Neural Network Model.
184
+ This version restores the architecture from the high-performing `old_lnn.py`.
185
+ It uses stacked LNNBlocks to process the sequence and a Transformer-based
186
+ attention readout for global context before prediction.
187
+ """
188
+ config_class = LNNConfig
189
+
190
+ def __init__(self, config: LNNConfig):
191
+ super().__init__(config)
192
+ self.config = config
193
+
194
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
195
+ self.blocks = nn.ModuleList([LNNBlock(config) for _ in range(config.num_hidden_layers)])
196
+
197
+ # JIT-compile the LNNBlocks for a significant performance boost
198
+ # Disabling JIT as a test, as it can sometimes cause unexpected memory allocation issues with recurrent loops.
199
+ # for i in range(len(self.blocks)):
200
+ # self.blocks[i] = torch.jit.script(self.blocks[i])
201
+
202
+ self.ln_final = nn.LayerNorm(config.hidden_size, eps=1e-5)
203
+
204
+ # The attention-based readout is removed to prevent the model from "cheating"
205
+ # by using self-attention on the whole sequence instead of relying on its
206
+ # recurrent state. This forces the LNN to learn more robust representations.
207
+ # self.readout = nn.TransformerEncoderLayer(...)
208
+
209
+ self.proj_out = nn.Linear(config.hidden_size, config.vocab_size)
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embedding
213
+
214
+ def set_input_embeddings(self, value):
215
+ self.embedding = value
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: torch.LongTensor,
220
+ labels: Optional[torch.LongTensor] = None,
221
+ hidden_states: Optional[List[torch.Tensor]] = None,
222
+ attention_mask: Optional[torch.Tensor] = None, # Accept attention_mask
223
+ **kwargs, # Accept other arguments
224
+ ) -> LNNModelOutput:
225
+ """
226
+ Processes a sequence, calculates loss, and handles unexpected arguments.
227
+ The `attention_mask` is accepted but not used, as the LNN processes
228
+ the sequence recurrently.
229
+ """
230
+ # 1. Get Embeddings
231
+ x = self.embedding(input_ids)
232
+ batch_size = input_ids.shape[0]
233
+
234
+ # 2. Initialize hidden states if not provided
235
+ if hidden_states is None:
236
+ hidden_states = [
237
+ torch.zeros(batch_size, self.config.hidden_size, device=x.device)
238
+ for _ in range(self.config.num_hidden_layers)
239
+ ]
240
+
241
+ # 3. Process sequence through LNN blocks
242
+ new_hidden_states = []
243
+ layer_output = x
244
+ for i, block in enumerate(self.blocks):
245
+ h_initial = hidden_states[i]
246
+ layer_output, h_final = block(layer_output, h_initial)
247
+ new_hidden_states.append(h_final)
248
+
249
+ # 4. Final Projection (without attention readout)
250
+ final_output = self.ln_final(layer_output)
251
+ logits = self.proj_out(final_output)
252
+
253
+ # 5. Calculate loss if labels are provided
254
+ loss = None
255
+ if labels is not None:
256
+ # Shift so that logits at time t predict token at time t+1
257
+ # This is the standard procedure for training causal language models.
258
+ shift_logits = logits[:, :-1, :].contiguous()
259
+ shift_labels = labels[:, 1:].contiguous()
260
+ # Flatten the tokens and compute loss
261
+ loss_fct = torch.nn.CrossEntropyLoss()
262
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
263
+
264
+ return LNNModelOutput(
265
+ loss=loss,
266
+ logits=logits,
267
+ last_hidden_state=final_output,
268
+ hidden_states=tuple(new_hidden_states),
269
+ )
270
+
271
+ def generate(
272
+ self,
273
+ input_ids: torch.LongTensor,
274
+ max_length: int = 100,
275
+ max_new_tokens: int = None,
276
+ temperature: float = 1.0,
277
+ top_k: int = 50,
278
+ top_p: float = 0.9,
279
+ do_sample: bool = True,
280
+ pad_token_id: int = None,
281
+ eos_token_id: int = None,
282
+ repetition_penalty: float = 1.0,
283
+ **kwargs
284
+ ) -> torch.LongTensor:
285
+ """
286
+ Generate text using the LNN model with improved repetition handling.
287
+ """
288
+ batch_size = input_ids.shape[0]
289
+ device = input_ids.device
290
+
291
+ # Determine actual max length
292
+ if max_new_tokens is not None:
293
+ max_length = input_ids.shape[1] + max_new_tokens
294
+
295
+ # Initialize hidden states
296
+ hidden_states = [
297
+ torch.zeros(batch_size, self.config.hidden_size, device=device)
298
+ for _ in range(self.config.num_hidden_layers)
299
+ ]
300
+
301
+ # Initialize output with input_ids
302
+ generated = input_ids.clone()
303
+
304
+ # Set model to evaluation mode
305
+ self.eval()
306
+
307
+ for step in range(max_length - input_ids.shape[1]):
308
+ # Get model output - only pass the last few tokens to avoid recomputing everything
309
+ context_length = min(generated.shape[1], 512) # Limit context to prevent memory issues
310
+ context_ids = generated[:, -context_length:]
311
+
312
+ with torch.no_grad():
313
+ outputs = self.forward(
314
+ input_ids=context_ids,
315
+ hidden_states=hidden_states if step == 0 else None # Only use initial hidden states
316
+ )
317
+
318
+ # Get logits for the last token
319
+ logits = outputs.logits[:, -1, :] # Shape: [batch_size, vocab_size]
320
+
321
+ # Apply repetition penalty
322
+ if repetition_penalty != 1.0:
323
+ for i in range(batch_size):
324
+ for token_id in set(generated[i].tolist()):
325
+ # If logit is positive, divide by penalty, else multiply
326
+ if logits[i, token_id] > 0:
327
+ logits[i, token_id] /= repetition_penalty
328
+ else:
329
+ logits[i, token_id] *= repetition_penalty
330
+
331
+ # Apply temperature
332
+ if temperature != 1.0:
333
+ logits = logits / temperature
334
+
335
+ # Apply top-k filtering
336
+ if top_k > 0:
337
+ top_k_values, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1)
338
+ indices_to_remove = logits < top_k_values[..., -1, None]
339
+ logits[indices_to_remove] = -float('inf')
340
+
341
+ # Apply top-p filtering
342
+ if top_p < 1.0:
343
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
344
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
345
+
346
+ # Remove tokens with cumulative probability above the threshold
347
+ sorted_indices_to_remove = cumulative_probs > top_p
348
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
349
+ sorted_indices_to_remove[..., 0] = 0
350
+
351
+ # Convert back to original indices
352
+ indices_to_remove = sorted_indices_to_remove.gather(dim=-1, index=sorted_indices.argsort(dim=-1))
353
+ logits[indices_to_remove] = -float('inf')
354
+
355
+ # Sample next token
356
+ if do_sample:
357
+ probs = F.softmax(logits, dim=-1)
358
+ next_token = torch.multinomial(probs, num_samples=1)
359
+ else:
360
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
361
+
362
+ # Append to generated sequence
363
+ generated = torch.cat([generated, next_token], dim=-1)
364
+
365
+ # Check for EOS token
366
+ if eos_token_id is not None and (next_token == eos_token_id).all():
367
+ break
368
+
369
+ return generated
370
+
371
+ def generate_simple(
372
+ self,
373
+ input_ids: torch.LongTensor,
374
+ max_length: int = 100,
375
+ temperature: float = 1.0,
376
+ do_sample: bool = True,
377
+ pad_token_id: int = None,
378
+ eos_token_id: int = None,
379
+ hidden_states: Optional[List[torch.Tensor]] = None,
380
+ **kwargs
381
+ ) -> torch.LongTensor:
382
+ """
383
+ Simple generate method without top-k/top-p sampling to avoid dimension issues.
384
+ """
385
+ batch_size = input_ids.shape[0]
386
+ device = input_ids.device
387
+
388
+ # Initialize hidden states if not provided
389
+ if hidden_states is None:
390
+ hidden_states = [
391
+ torch.zeros(batch_size, self.config.hidden_size, device=device)
392
+ for _ in range(self.config.num_hidden_layers)
393
+ ]
394
+
395
+ # Initialize output with input_ids
396
+ generated = input_ids.clone()
397
+
398
+ # Set model to evaluation mode
399
+ self.eval()
400
+
401
+ for _ in range(max_length - input_ids.shape[1]):
402
+ # Get model output
403
+ with torch.no_grad():
404
+ outputs = self.forward(
405
+ input_ids=generated,
406
+ hidden_states=hidden_states
407
+ )
408
+
409
+ # Get logits for the last token
410
+ logits = outputs.logits[:, -1, :] # Shape: [batch_size, vocab_size]
411
+ hidden_states = list(outputs.hidden_states)
412
+
413
+ # Apply temperature
414
+ if temperature != 1.0:
415
+ logits = logits / temperature
416
+
417
+ # Sample next token
418
+ if do_sample:
419
+ probs = F.softmax(logits, dim=-1)
420
+ next_token = torch.multinomial(probs, num_samples=1)
421
+ else:
422
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
423
+
424
+ # Append to generated sequence
425
+ generated = torch.cat([generated, next_token], dim=-1)
426
+
427
+ # Check for EOS token
428
+ if eos_token_id is not None and (next_token == eos_token_id).all():
429
+ break
430
+
431
+ return generated
432
+
433
+ def prepare_inputs_for_generation(
434
+ self,
435
+ input_ids: torch.LongTensor,
436
+ past_key_values: Optional[List[torch.Tensor]] = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ use_cache: bool = True,
439
+ **kwargs
440
+ ) -> dict:
441
+ """
442
+ Prepare inputs for generation. For LNN, we use hidden_states instead of past_key_values.
443
+ """
444
+ # For LNN, we don't use past_key_values in the traditional sense
445
+ # Instead, we rely on the recurrent nature of the model
446
+ model_inputs = {
447
+ "input_ids": input_ids,
448
+ "attention_mask": attention_mask,
449
+ "use_cache": use_cache,
450
+ }
451
+ return model_inputs
452
+
453
+ def _reorder_cache(self, past_key_values: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
454
+ """
455
+ Reorder hidden states for beam search.
456
+ """
457
+ if past_key_values is None:
458
+ return None
459
+
460
+ reordered_past = []
461
+ for hidden_state in past_key_values:
462
+ reordered_past.append(hidden_state.index_select(0, beam_idx))
463
+ return reordered_past
464
+
465
+ # --- 6. For Causal LM compatibility ---
466
+ class LNNForCausalLM(LNNModel):
467
+ """
468
+ Wrapper class for compatibility with HuggingFace's CausalLM interface.
469
+ """
470
+ def __init__(self, config: LNNConfig):
471
+ super().__init__(config)
472
+ self.lm_head = self.proj_out # Alias for compatibility
473
+
474
+ @property
475
+ def model(self):
476
+ """Return self for compatibility with some HF utilities."""
477
+ return self
478
+
479
+ def get_output_embeddings(self):
480
+ return self.proj_out
481
+
482
+ def set_output_embeddings(self, new_embeddings):
483
+ self.proj_out = new_embeddings
484
+
485
+ def forward(
486
+ self,
487
+ input_ids: torch.LongTensor,
488
+ labels: Optional[torch.LongTensor] = None,
489
+ hidden_states: Optional[List[torch.Tensor]] = None,
490
+ attention_mask: Optional[torch.Tensor] = None,
491
+ past_key_values: Optional[List[torch.Tensor]] = None,
492
+ use_cache: bool = True,
493
+ **kwargs,
494
+ ) -> LNNModelOutput:
495
+ """Forward pass that's compatible with CausalLM interface."""
496
+ return super().forward(
497
+ input_ids=input_ids,
498
+ labels=labels,
499
+ hidden_states=hidden_states,
500
+ attention_mask=attention_mask,
501
+ **kwargs
502
+ )
503
+
504
+ # --- 7. Model registration ---
505
+ # Register the model with transformers
506
+ try:
507
+ from transformers import AutoModel, AutoModelForCausalLM
508
+ AutoModel.register(LNNConfig, LNNModel)
509
+ AutoModelForCausalLM.register(LNNConfig, LNNForCausalLM)
510
+ except ImportError:
511
+ pass # transformers not available or version doesn't support registration
model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig
5
+ from tqdm import tqdm
6
+ from .moe import MoELayer
7
+
8
+ class QuasarConfig(PretrainedConfig):
9
+ model_type = "quasar"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size=129280,
14
+ embedding_dim=8192,
15
+ num_hidden_layers=96, # 96 layers to keep active parameters manageable
16
+ num_attention_heads=64,
17
+ num_experts=407, # 407 experts to reach 440B total parameters
18
+ expert_dim=32768, # 32K expert dimension for capacity
19
+ top_k=4, # 4 active experts per token to maintain 25B active params
20
+ **kwargs
21
+ ):
22
+ self.vocab_size = vocab_size
23
+ self.embedding_dim = embedding_dim
24
+ self.num_hidden_layers = num_hidden_layers
25
+ self.num_attention_heads = num_attention_heads
26
+ self.num_experts = num_experts
27
+ self.expert_dim = expert_dim
28
+ self.top_k = top_k
29
+ super().__init__(**kwargs)
30
+
31
+ class SelfAttention(nn.Module):
32
+ def __init__(self, config: QuasarConfig):
33
+ super().__init__()
34
+ self.num_heads = config.num_attention_heads
35
+ self.head_dim = config.embedding_dim // self.num_heads
36
+ self.q_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
37
+ self.k_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
38
+ self.v_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
39
+ self.out_proj = nn.Linear(config.embedding_dim, config.embedding_dim, bias=False)
40
+
41
+ def forward(self, x):
42
+ batch_size, seq_len, _ = x.shape
43
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
44
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
45
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
46
+
47
+ attn_output = F.scaled_dot_product_attention(q, k, v)
48
+
49
+ output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
50
+ return self.out_proj(output)
51
+
52
+ class QuasarBlock(nn.Module):
53
+ def __init__(self, config: QuasarConfig):
54
+ super().__init__()
55
+ self.attention = SelfAttention(config)
56
+ self.moe_layer = MoELayer(
57
+ embedding_dim=config.embedding_dim,
58
+ num_experts=config.num_experts,
59
+ expert_dim=config.expert_dim,
60
+ top_k=config.top_k
61
+ )
62
+ self.ln1 = nn.LayerNorm(config.embedding_dim)
63
+ self.ln2 = nn.LayerNorm(config.embedding_dim)
64
+
65
+ def forward(self, x):
66
+ x = x + self.attention(self.ln1(x))
67
+ moe_out, lb_loss = self.moe_layer(self.ln2(x))
68
+ x = x + moe_out
69
+ return x, lb_loss
70
+
71
+ class Quasar(PreTrainedModel):
72
+ config_class = QuasarConfig
73
+ _supports_gradient_checkpointing = True
74
+
75
+ def __init__(self, config: QuasarConfig):
76
+ super().__init__(config)
77
+ self.config = config
78
+ self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
79
+ print(f"\nInitializing {config.num_hidden_layers} Quasar layers...")
80
+ self.layers = nn.ModuleList([QuasarBlock(config) for _ in tqdm(range(config.num_hidden_layers), desc="Creating Quasar Layers")])
81
+ self.final_ln = nn.LayerNorm(config.embedding_dim)
82
+ self.output_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
83
+
84
+ def forward(self, input_ids, labels=None, **kwargs):
85
+ x = self.embedding(input_ids)
86
+ total_lb_loss = 0.0
87
+
88
+ # Add config to kwargs for gradient checkpointing
89
+ kwargs['config'] = self.config
90
+
91
+ for layer in self.layers:
92
+ if self.is_gradient_checkpointing and self.training:
93
+ def create_custom_forward(module):
94
+ def custom_forward(*inputs):
95
+ return module(*inputs)
96
+ return custom_forward
97
+ x, lb_loss = torch.utils.checkpoint.checkpoint(create_custom_forward(layer), x, use_reentrant=False)
98
+ else:
99
+ x, lb_loss = layer(x)
100
+ total_lb_loss += lb_loss
101
+
102
+ x = self.final_ln(x)
103
+ logits = self.output_head(x)
104
+
105
+ loss = None
106
+ if labels is not None:
107
+ main_loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1))
108
+ loss = main_loss + total_lb_loss
109
+
110
+ return {
111
+ 'loss': loss,
112
+ 'logits': logits,
113
+ 'lb_loss': total_lb_loss
114
+ }
moe.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # c:\quasarv4\quasar\moe.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+ class Expert(nn.Module):
9
+ """An expert network. For Quasar, this could be an LNN layer followed by a feed-forward network."""
10
+ def __init__(self, embedding_dim, expert_dim):
11
+ super().__init__()
12
+ self.net = nn.Sequential(
13
+ nn.Linear(embedding_dim, expert_dim),
14
+ nn.GELU(),
15
+ nn.Linear(expert_dim, embedding_dim)
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.net(x)
20
+
21
+ class MoERouter(nn.Module):
22
+ """A simple router that learns to dispatch tokens to experts."""
23
+ def __init__(self, embedding_dim, num_experts, top_k=2):
24
+ super().__init__()
25
+ self.top_k = top_k
26
+ self.gate = nn.Linear(embedding_dim, num_experts)
27
+
28
+ def forward(self, x):
29
+ """ Returns the top-k weights and indices for each token. """
30
+ gate_logits = self.gate(x.reshape(-1, x.shape[-1]))
31
+ top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
32
+ top_k_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float).to(x.dtype)
33
+ return top_k_weights, top_k_indices
34
+
35
+ class MoELayer(nn.Module):
36
+ """A Mixture of Experts layer."""
37
+ def __init__(self, embedding_dim, num_experts, expert_dim, top_k=2):
38
+ super().__init__()
39
+ self.router = MoERouter(embedding_dim, num_experts, top_k)
40
+ self.num_experts = num_experts
41
+
42
+ # Create experts
43
+ # Use a generator expression to avoid creating a temporary list of all experts in memory
44
+ self.experts = nn.ModuleList(Expert(embedding_dim, expert_dim) for _ in range(self.num_experts))
45
+
46
+ def forward(self, x):
47
+ """Forward pass for the MoE layer."""
48
+ original_shape = x.shape
49
+ flat_x = x.reshape(-1, x.shape[-1])
50
+
51
+ # Create the final output tensor on the correct device, avoiding meta-device issues.
52
+ final_output = torch.zeros(flat_x.shape, dtype=x.dtype, device=self.router.gate.weight.device)
53
+
54
+ # Get routing decisions from the router
55
+ top_k_weights, top_k_indices = self.router(x)
56
+
57
+ # Calculate load balancing loss using one_hot to be meta-tensor compatible
58
+ num_tokens = top_k_indices.size(0)
59
+ one_hot_indices = F.one_hot(top_k_indices, num_classes=self.num_experts).float()
60
+ tokens_per_expert = one_hot_indices.sum(dim=[0, 1])
61
+ router_probs_per_expert = torch.mean(F.softmax(self.router.gate.weight, dim=0), dim=1)
62
+ load_balancing_loss = self.num_experts * torch.dot(tokens_per_expert / num_tokens, router_probs_per_expert)
63
+
64
+ # Dispatch tokens to experts and aggregate outputs
65
+ for i in range(self.num_experts):
66
+ # Find which tokens are routed to this expert
67
+ expert_mask = (top_k_indices == i).any(dim=1)
68
+ expert_indices_for_expert = torch.where(expert_mask)[0]
69
+
70
+ if expert_indices_for_expert.numel() == 0:
71
+ continue
72
+
73
+ # Get the tokens for this expert
74
+ expert_tokens = flat_x[expert_indices_for_expert]
75
+
76
+ # Find the specific weight for this expert for each token
77
+ top_k_weights_for_expert = top_k_weights[expert_indices_for_expert]
78
+ is_expert_in_top_k = (top_k_indices[expert_indices_for_expert] == i)
79
+ weights_for_expert = torch.sum(top_k_weights_for_expert * is_expert_in_top_k, dim=1, keepdim=True)
80
+
81
+ # Process with expert and apply routing weight
82
+ expert_output = self.experts[i](expert_tokens)
83
+ weighted_output = expert_output * weights_for_expert
84
+
85
+ # Add the weighted output to the final output tensor at the correct positions
86
+ final_output.index_add_(0, expert_indices_for_expert, weighted_output)
87
+
88
+ return final_output.reshape(original_shape), load_balancing_loss
pmb.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import hashlib
3
+ import numpy as np
4
+
5
+ class ParameterMemoryBank:
6
+ """
7
+ Parameter Memory Bank (PMB) for infinite, queryable memory.
8
+
9
+ This implementation uses a two-level hashing system for constant-time
10
+ direct access and supports semantic similarity search.
11
+
12
+ - Level 1: A list of 'blocks'.
13
+ - Level 2: Each block is a dictionary-like structure mapping slots to items.
14
+
15
+ For simplicity, we use Python lists and dictionaries. A production system
16
+ would use a more optimized backend (e.g., Redis, custom memory store).
17
+ """
18
+ def __init__(self, num_blocks=1024, slots_per_block=4096, embedding_dim=None):
19
+ self.num_blocks = num_blocks
20
+ self.slots_per_block = slots_per_block
21
+ self.embedding_dim = embedding_dim
22
+
23
+ # PMB is a list of blocks, where each block is a list of slots.
24
+ # Each slot can hold a tuple: (id, key_embedding, value)
25
+ self.pmb = [ [None] * slots_per_block for _ in range(num_blocks) ]
26
+
27
+ # For semantic search, we need a separate structure to hold all keys.
28
+ # This is a trade-off for efficient similarity search.
29
+ self.all_keys = []
30
+ self.key_locations = [] # Stores (block_idx, slot_idx) for each key
31
+
32
+ def _hash_fn(self, s, salt=""):
33
+ """A simple, salted hash function."""
34
+ return int(hashlib.sha256((str(s) + salt).encode()).hexdigest(), 16)
35
+
36
+ def _get_hash_indices(self, item_id):
37
+ """
38
+ Calculates the block and slot indices for a given item ID using
39
+ the two-level hashing scheme.
40
+ """
41
+ block_hash = self._hash_fn(item_id, salt="block")
42
+ block_idx = block_hash % self.num_blocks
43
+
44
+ slot_hash = self._hash_fn(item_id, salt=f"slot_{block_idx}")
45
+ slot_idx = slot_hash % self.slots_per_block
46
+
47
+ return block_idx, slot_idx
48
+
49
+ def store(self, item_id, key_embedding, value):
50
+ """
51
+ Stores a key-value pair in the PMB using its ID.
52
+
53
+ Args:
54
+ item_id (str or int): A unique identifier for the data.
55
+ key_embedding (torch.Tensor): The embedding vector (k_i,j).
56
+ value (any): The data to store (v_i,j), e.g., text, metadata.
57
+ """
58
+ if not isinstance(key_embedding, torch.Tensor):
59
+ raise TypeError("key_embedding must be a torch.Tensor")
60
+
61
+ block_idx, slot_idx = self._get_hash_indices(item_id)
62
+
63
+ # Store the item in the hash-based location.
64
+ # Note: This simple implementation doesn't handle hash collisions.
65
+ # A real system would need a collision resolution strategy (e.g., cuckoo hashing, chaining).
66
+ if self.pmb[block_idx][slot_idx] is not None:
67
+ # Handle collision by updating the existing entry or finding an empty slot
68
+ pass # For now, just overwrite
69
+
70
+ self.pmb[block_idx][slot_idx] = (item_id, key_embedding.detach().cpu(), value.detach().cpu() if isinstance(value, torch.Tensor) else value)
71
+
72
+ # Also store the key for semantic search
73
+ self.all_keys.append(key_embedding.detach().cpu())
74
+ self.key_locations.append((block_idx, slot_idx))
75
+
76
+ def retrieve_direct(self, item_id):
77
+ """
78
+ Retrieves a value directly using its ID in O(1) time.
79
+
80
+ Args:
81
+ item_id (str or int): The unique identifier of the item.
82
+
83
+ Returns:
84
+ The stored value, or None if not found.
85
+ """
86
+ block_idx, slot_idx = self._get_hash_indices(item_id)
87
+ item = self.pmb[block_idx][slot_idx]
88
+
89
+ # Check if the found item ID matches, in case of no collision handling
90
+ if item and item[0] == item_id:
91
+ return item[2] # Return the value
92
+ return None
93
+
94
+ def retrieve_by_indices(self, indices):
95
+ """
96
+ Retrieves items by their indices in the `all_keys` list.
97
+ Args:
98
+ indices (list or torch.Tensor): A list of indices.
99
+ Returns:
100
+ A list of the retrieved values.
101
+ """
102
+ results = []
103
+ for idx in indices:
104
+ if idx < len(self.key_locations):
105
+ block_idx, slot_idx = self.key_locations[idx]
106
+ item = self.pmb[block_idx][slot_idx]
107
+ if item:
108
+ value = item[2] # Get the value
109
+ # Convert back to tensor if it was stored as tensor
110
+ if isinstance(value, torch.Tensor):
111
+ results.append(value)
112
+ else:
113
+ # If value is not a tensor, create a zero tensor of appropriate size
114
+ if self.embedding_dim:
115
+ results.append(torch.zeros(self.embedding_dim))
116
+ else:
117
+ # Fallback: use the key embedding as value
118
+ results.append(item[1]) # Use key embedding
119
+ else:
120
+ # No item found, append zero tensor
121
+ if self.embedding_dim:
122
+ results.append(torch.zeros(self.embedding_dim))
123
+ else:
124
+ results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
125
+ else:
126
+ # Index out of range
127
+ if self.embedding_dim:
128
+ results.append(torch.zeros(self.embedding_dim))
129
+ else:
130
+ results.append(torch.zeros_like(self.all_keys[0]) if self.all_keys else torch.zeros(1))
131
+ return results
132
+
133
+ def retrieve_semantic(self, query_embeddings, top_k=1):
134
+ """
135
+ Retrieves the top_k most semantically similar items for a batch of query embeddings.
136
+
137
+ Args:
138
+ query_embeddings (torch.Tensor): Query vectors (batch_size, embedding_dim) or (batch_size, seq_len, embedding_dim).
139
+ top_k (int): The number of similar items to return for each query.
140
+
141
+ Returns:
142
+ A tensor of the aggregated retrieved values with the same shape as query_embeddings.
143
+ """
144
+ if not self.all_keys or top_k == 0:
145
+ return torch.zeros_like(query_embeddings)
146
+
147
+ if not isinstance(query_embeddings, torch.Tensor):
148
+ raise TypeError("query_embeddings must be a torch.Tensor")
149
+
150
+ # Store original shape and device
151
+ original_shape = query_embeddings.shape
152
+ device = query_embeddings.device
153
+
154
+ # Flatten query embeddings to 2D for processing
155
+ if query_embeddings.dim() > 2:
156
+ query_flat = query_embeddings.view(-1, original_shape[-1])
157
+ else:
158
+ query_flat = query_embeddings
159
+
160
+ # Handle empty memory bank
161
+ if not self.all_keys:
162
+ return torch.zeros_like(query_embeddings)
163
+
164
+ try:
165
+ # Stack all keys into a single tensor
166
+ all_keys_tensor = torch.stack(self.all_keys, dim=0).to(device)
167
+
168
+ # Compute cosine similarity
169
+ query_norm = torch.nn.functional.normalize(query_flat, p=2, dim=-1)
170
+ keys_norm = torch.nn.functional.normalize(all_keys_tensor, p=2, dim=-1)
171
+
172
+ # Compute similarities: (batch_size, num_keys)
173
+ similarities = torch.mm(query_norm, keys_norm.T)
174
+
175
+ # Get top_k results for each query
176
+ k = min(top_k, len(self.all_keys))
177
+ if k > 0:
178
+ top_k_scores, top_k_indices = torch.topk(similarities, k=k, dim=1)
179
+
180
+ # Retrieve the corresponding values
181
+ batch_results = []
182
+ for i in range(query_flat.size(0)):
183
+ retrieved_values = self.retrieve_by_indices(top_k_indices[i].cpu().tolist())
184
+
185
+ if retrieved_values:
186
+ # Stack and move to correct device
187
+ stacked_values = torch.stack(retrieved_values, dim=0).to(device)
188
+ # Average the top_k retrieved values
189
+ aggregated_value = torch.mean(stacked_values, dim=0)
190
+ batch_results.append(aggregated_value)
191
+ else:
192
+ # No valid retrievals, use zero tensor
193
+ batch_results.append(torch.zeros(original_shape[-1], device=device))
194
+
195
+ # Stack all batch results
196
+ if batch_results:
197
+ result = torch.stack(batch_results, dim=0)
198
+ # Reshape back to original shape
199
+ return result.view(original_shape)
200
+ else:
201
+ return torch.zeros_like(query_embeddings)
202
+ else:
203
+ return torch.zeros_like(query_embeddings)
204
+
205
+ except Exception as e:
206
+ print(f"Error in PMB retrieve_semantic: {e}")
207
+ return torch.zeros_like(query_embeddings)
208
+
209
+ def __len__(self):
210
+ return len(self.all_keys)