|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import math |
|
from torch.nn import CrossEntropyLoss |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.generation.utils import GenerationMixin |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.utils.generic import ModelOutput |
|
from typing import Optional, Tuple, List |
|
from dataclasses import dataclass |
|
from .pmb import ParameterMemoryBank |
|
from .moe import MoELayer, Expert |
|
|
|
from tqdm import tqdm |
|
|
|
try: |
|
from torchdiffeq import odeint |
|
except ImportError: |
|
raise ImportError("torchdiffeq is not installed. Please install it with `pip install torchdiffeq`") |
|
|
|
|
|
class LNNConfig(PretrainedConfig): |
|
""" |
|
Configuration class for the Liquid Neural Network (LNN) model. |
|
Inherits from HuggingFace's PretrainedConfig. |
|
""" |
|
model_type = "quasar" |
|
|
|
def __init__( |
|
self, |
|
vocab_size=151552, |
|
hidden_size=8192, |
|
num_hidden_layers=96, |
|
activation='gelu', |
|
lambda_res=0.0, |
|
dt=0.2, |
|
initializer_range=0.02, |
|
dropout=0.1, |
|
use_pmb=False, |
|
pmb_num_blocks=1024, |
|
pmb_slots_per_block=4096, |
|
pmb_top_k=1, |
|
|
|
use_moe: bool = False, |
|
num_experts: int = 407, |
|
num_experts_per_tok: int = 4, |
|
expert_dim: int = 32768, |
|
moe_load_balance_loss_weight: float = 0.01, |
|
**kwargs |
|
): |
|
self.vocab_size = vocab_size |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.lambda_res = lambda_res |
|
self.dt = dt |
|
self.activation = activation |
|
self.initializer_range = initializer_range |
|
self.dropout = dropout |
|
self.use_pmb = use_pmb |
|
self.pmb_num_blocks = pmb_num_blocks |
|
self.pmb_slots_per_block = pmb_slots_per_block |
|
self.pmb_top_k = pmb_top_k |
|
|
|
self.use_moe = use_moe |
|
self.num_experts = num_experts |
|
self.num_experts_per_tok = num_experts_per_tok |
|
self.expert_dim = expert_dim |
|
self.moe_load_balance_loss_weight = moe_load_balance_loss_weight |
|
super().__init__(**kwargs) |
|
|
|
|
|
@dataclass |
|
class LNNModelOutput(ModelOutput): |
|
""" |
|
Base class for LNN model's outputs, ensuring compatibility with HuggingFace. |
|
""" |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
load_balancing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
|
|
class LNNCell(nn.Module): |
|
"""A single Liquid Neural Network cell with continuous-time dynamics.""" |
|
def __init__(self, config: LNNConfig): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.lambda_res = config.lambda_res |
|
|
|
|
|
self.W = nn.Parameter(torch.empty(config.hidden_size, config.hidden_size)) |
|
self.U = nn.Parameter(torch.empty(config.hidden_size, config.hidden_size)) |
|
self.b = nn.Parameter(torch.empty(config.hidden_size)) |
|
|
|
|
|
self.tau_w_h = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.tau_w_u = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.tau_b = nn.Parameter(torch.empty(config.hidden_size)) |
|
|
|
|
|
nn.init.orthogonal_(self.W) |
|
nn.init.xavier_uniform_(self.U) |
|
nn.init.zeros_(self.b) |
|
self.tau_b.data.uniform_(-2, 2) |
|
|
|
self.sigma = nn.Tanh() |
|
|
|
def forward(self, h, u): |
|
"""Core ODE dynamics calculation for a single discrete step.""" |
|
|
|
tau_control = self.tau_w_h(h) + self.tau_w_u(u) + self.tau_b |
|
|
|
|
|
tau_positive = F.softplus(tau_control) + 1.0 |
|
|
|
|
|
decay_term = -h / tau_positive |
|
activation_input = F.linear(h, self.W) + F.linear(u, self.U) + self.b |
|
activation_output = self.sigma(activation_input) |
|
dx_dt = decay_term + activation_output |
|
|
|
if self.lambda_res > 0: |
|
dx_dt = dx_dt + self.lambda_res * u |
|
|
|
|
|
dx_dt = torch.clamp(dx_dt, -10, 10) |
|
return dx_dt |
|
|
|
|
|
class LNNBlock(nn.Module): |
|
""" A single block of the LNN, using a fixed-step Euler loop. """ |
|
def __init__(self, config: LNNConfig): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.dt = config.dt |
|
self.cell = LNNCell(config) |
|
self.ln = nn.LayerNorm(config.hidden_size) |
|
|
|
def forward(self, x: torch.Tensor, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Processes the entire sequence using a fixed-step Euler integration loop, |
|
starting from a given hidden state h. |
|
This version is optimized to be JIT-friendly by pre-allocating the output tensor. |
|
""" |
|
seq_len = x.size(1) |
|
|
|
outputs = torch.empty(x.size(0), seq_len, self.hidden_size, device=x.device) |
|
|
|
for t in range(seq_len): |
|
u = x[:, t, :] |
|
dx_dt = self.cell(h, u) |
|
h = h + self.dt * dx_dt |
|
|
|
|
|
h = torch.clamp(h, -100, 100) |
|
outputs[:, t, :] = h |
|
|
|
|
|
output = self.ln(outputs + x) |
|
return output, h |
|
|
|
|
|
class LNNModel(PreTrainedModel, GenerationMixin): |
|
""" |
|
The Liquid Neural Network Model. |
|
This version restores the architecture from the high-performing `old_lnn.py`. |
|
It uses stacked LNNBlocks to process the sequence and a Transformer-based |
|
attention readout for global context before prediction. |
|
""" |
|
config_class = LNNConfig |
|
|
|
def __init__(self, config: LNNConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.blocks = nn.ModuleList([LNNBlock(config) for _ in range(config.num_hidden_layers)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ln_final = nn.LayerNorm(config.hidden_size, eps=1e-5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
|
def get_input_embeddings(self): |
|
return self.embedding |
|
|
|
def set_input_embeddings(self, value): |
|
self.embedding = value |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
labels: Optional[torch.LongTensor] = None, |
|
hidden_states: Optional[List[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> LNNModelOutput: |
|
""" |
|
Processes a sequence, calculates loss, and handles unexpected arguments. |
|
The `attention_mask` is accepted but not used, as the LNN processes |
|
the sequence recurrently. |
|
""" |
|
|
|
x = self.embedding(input_ids) |
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
if hidden_states is None: |
|
hidden_states = [ |
|
torch.zeros(batch_size, self.config.hidden_size, device=x.device) |
|
for _ in range(self.config.num_hidden_layers) |
|
] |
|
|
|
|
|
new_hidden_states = [] |
|
layer_output = x |
|
for i, block in enumerate(self.blocks): |
|
h_initial = hidden_states[i] |
|
layer_output, h_final = block(layer_output, h_initial) |
|
new_hidden_states.append(h_final) |
|
|
|
|
|
final_output = self.ln_final(layer_output) |
|
logits = self.proj_out(final_output) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
shift_labels = labels[:, 1:].contiguous() |
|
|
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) |
|
|
|
return LNNModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
last_hidden_state=final_output, |
|
hidden_states=tuple(new_hidden_states), |
|
) |
|
|
|
def generate( |
|
self, |
|
input_ids: torch.LongTensor, |
|
max_length: int = 100, |
|
max_new_tokens: int = None, |
|
temperature: float = 1.0, |
|
top_k: int = 50, |
|
top_p: float = 0.9, |
|
do_sample: bool = True, |
|
pad_token_id: int = None, |
|
eos_token_id: int = None, |
|
repetition_penalty: float = 1.0, |
|
**kwargs |
|
) -> torch.LongTensor: |
|
""" |
|
Generate text using the LNN model with improved repetition handling. |
|
""" |
|
batch_size = input_ids.shape[0] |
|
device = input_ids.device |
|
|
|
|
|
if max_new_tokens is not None: |
|
max_length = input_ids.shape[1] + max_new_tokens |
|
|
|
|
|
hidden_states = [ |
|
torch.zeros(batch_size, self.config.hidden_size, device=device) |
|
for _ in range(self.config.num_hidden_layers) |
|
] |
|
|
|
|
|
generated = input_ids.clone() |
|
|
|
|
|
self.eval() |
|
|
|
for step in range(max_length - input_ids.shape[1]): |
|
|
|
context_length = min(generated.shape[1], 512) |
|
context_ids = generated[:, -context_length:] |
|
|
|
with torch.no_grad(): |
|
outputs = self.forward( |
|
input_ids=context_ids, |
|
hidden_states=hidden_states if step == 0 else None |
|
) |
|
|
|
|
|
logits = outputs.logits[:, -1, :] |
|
|
|
|
|
if repetition_penalty != 1.0: |
|
for i in range(batch_size): |
|
for token_id in set(generated[i].tolist()): |
|
|
|
if logits[i, token_id] > 0: |
|
logits[i, token_id] /= repetition_penalty |
|
else: |
|
logits[i, token_id] *= repetition_penalty |
|
|
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
|
|
if top_k > 0: |
|
top_k_values, _ = torch.topk(logits, min(top_k, logits.size(-1)), dim=-1) |
|
indices_to_remove = logits < top_k_values[..., -1, None] |
|
logits[indices_to_remove] = -float('inf') |
|
|
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.gather(dim=-1, index=sorted_indices.argsort(dim=-1)) |
|
logits[indices_to_remove] = -float('inf') |
|
|
|
|
|
if do_sample: |
|
probs = F.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=-1) |
|
|
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
break |
|
|
|
return generated |
|
|
|
def generate_simple( |
|
self, |
|
input_ids: torch.LongTensor, |
|
max_length: int = 100, |
|
temperature: float = 1.0, |
|
do_sample: bool = True, |
|
pad_token_id: int = None, |
|
eos_token_id: int = None, |
|
hidden_states: Optional[List[torch.Tensor]] = None, |
|
**kwargs |
|
) -> torch.LongTensor: |
|
""" |
|
Simple generate method without top-k/top-p sampling to avoid dimension issues. |
|
""" |
|
batch_size = input_ids.shape[0] |
|
device = input_ids.device |
|
|
|
|
|
if hidden_states is None: |
|
hidden_states = [ |
|
torch.zeros(batch_size, self.config.hidden_size, device=device) |
|
for _ in range(self.config.num_hidden_layers) |
|
] |
|
|
|
|
|
generated = input_ids.clone() |
|
|
|
|
|
self.eval() |
|
|
|
for _ in range(max_length - input_ids.shape[1]): |
|
|
|
with torch.no_grad(): |
|
outputs = self.forward( |
|
input_ids=generated, |
|
hidden_states=hidden_states |
|
) |
|
|
|
|
|
logits = outputs.logits[:, -1, :] |
|
hidden_states = list(outputs.hidden_states) |
|
|
|
|
|
if temperature != 1.0: |
|
logits = logits / temperature |
|
|
|
|
|
if do_sample: |
|
probs = F.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=-1) |
|
|
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all(): |
|
break |
|
|
|
return generated |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: torch.LongTensor, |
|
past_key_values: Optional[List[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
use_cache: bool = True, |
|
**kwargs |
|
) -> dict: |
|
""" |
|
Prepare inputs for generation. For LNN, we use hidden_states instead of past_key_values. |
|
""" |
|
|
|
|
|
model_inputs = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"use_cache": use_cache, |
|
} |
|
return model_inputs |
|
|
|
def _reorder_cache(self, past_key_values: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]: |
|
""" |
|
Reorder hidden states for beam search. |
|
""" |
|
if past_key_values is None: |
|
return None |
|
|
|
reordered_past = [] |
|
for hidden_state in past_key_values: |
|
reordered_past.append(hidden_state.index_select(0, beam_idx)) |
|
return reordered_past |
|
|
|
|
|
class LNNForCausalLM(LNNModel): |
|
""" |
|
Wrapper class for compatibility with HuggingFace's CausalLM interface. |
|
""" |
|
def __init__(self, config: LNNConfig): |
|
super().__init__(config) |
|
self.lm_head = self.proj_out |
|
|
|
@property |
|
def model(self): |
|
"""Return self for compatibility with some HF utilities.""" |
|
return self |
|
|
|
def get_output_embeddings(self): |
|
return self.proj_out |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.proj_out = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
labels: Optional[torch.LongTensor] = None, |
|
hidden_states: Optional[List[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.Tensor]] = None, |
|
use_cache: bool = True, |
|
**kwargs, |
|
) -> LNNModelOutput: |
|
"""Forward pass that's compatible with CausalLM interface.""" |
|
return super().forward( |
|
input_ids=input_ids, |
|
labels=labels, |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
**kwargs |
|
) |
|
|
|
|
|
|
|
try: |
|
from transformers import AutoModel, AutoModelForCausalLM |
|
AutoModel.register(LNNConfig, LNNModel) |
|
AutoModelForCausalLM.register(LNNConfig, LNNForCausalLM) |
|
except ImportError: |
|
pass |