|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Dict |
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
from transformers.models.phi3.configuration_phi3 import Phi3Config |
|
from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM |
|
|
|
|
|
|
|
|
|
class VectorMemoryHead(nn.Module): |
|
""" |
|
A memory head that compresses a sequence of vectors into a fixed number of memory slots. |
|
It uses an encoder-decoder architecture with an attention-based memory compression mechanism. |
|
""" |
|
def __init__(self, hidden_dim: int, num_memory_slots: int, num_heads: int, ff_dim: int, device=None, dtype=None): |
|
super().__init__() |
|
self.hidden_dim = hidden_dim |
|
self.num_memory_slots = num_memory_slots |
|
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
d_model=hidden_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=0.1, batch_first=True, |
|
device=device, dtype=torch.float32 |
|
) |
|
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) |
|
self.memory_queries = nn.Parameter(torch.randn(1, num_memory_slots, hidden_dim, device=device, dtype=torch.float32)) |
|
self.memory_attention = nn.MultiheadAttention( |
|
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, |
|
device=device, dtype=torch.float32 |
|
) |
|
self.memory_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=torch.float32) |
|
self.decoder_attention = nn.MultiheadAttention( |
|
embed_dim=hidden_dim, num_heads=num_heads, dropout=0.1, batch_first=True, |
|
device=device, dtype=torch.float32 |
|
) |
|
self.decoder_layernorm = nn.LayerNorm(hidden_dim, device=device, dtype=torch.float32) |
|
self.decoder_ffn = nn.Sequential( |
|
nn.Linear(hidden_dim, ff_dim, device=device, dtype=torch.float32), |
|
nn.ReLU(), |
|
nn.Linear(ff_dim, hidden_dim, device=device, dtype=torch.float32) |
|
) |
|
|
|
def forward(self, memory_input_sequence: torch.Tensor): |
|
batch_size = memory_input_sequence.shape[0] |
|
encoded_vectors = self.encoder(memory_input_sequence.to(torch.float32)) |
|
queries = self.memory_queries.expand(batch_size, -1, -1) |
|
compressed_memory, _ = self.memory_attention( |
|
query=queries, key=encoded_vectors, value=encoded_vectors |
|
) |
|
compressed_memory = self.memory_layernorm(compressed_memory + queries) |
|
reconstructed, _ = self.decoder_attention( |
|
query=encoded_vectors, key=compressed_memory, value=compressed_memory |
|
) |
|
reconstructed_vectors = self.decoder_layernorm(reconstructed + encoded_vectors) |
|
reconstructed_vectors = self.decoder_ffn(reconstructed_vectors) |
|
return compressed_memory, reconstructed_vectors |
|
|
|
class GCVectorMemoryLayer(nn.Module): |
|
""" |
|
A self-correcting layer designed as a drop-in replacement for nn.Linear. |
|
It uses a VectorMemoryHead to generate corrections based on both |
|
local (layer input) and global (model input embeddings) context. |
|
""" |
|
def __init__(self, input_dim: int, output_dim: int, global_input_dim: int, |
|
memory_dim: int, num_memory_slots: int, memory_num_heads: int, |
|
global_state_storage: Dict, device=None, dtype=None): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.output_dim = output_dim |
|
self.memory_dim = memory_dim |
|
self.global_state_storage = global_state_storage |
|
self.linear = nn.Linear(input_dim, output_dim, bias=False, device=device, dtype=dtype) |
|
self.local_state_proj = nn.Linear(input_dim, memory_dim, device=device, dtype=torch.float32) |
|
self.global_state_proj = nn.Linear(global_input_dim, memory_dim, device=device, dtype=torch.float32) |
|
self.memory_head = VectorMemoryHead( |
|
hidden_dim=memory_dim, num_memory_slots=num_memory_slots, |
|
num_heads=memory_num_heads, ff_dim=memory_dim * 2, device=device |
|
) |
|
self.correction_head = nn.Linear(memory_dim, 2 * output_dim, device=device, dtype=torch.float32) |
|
self.last_corrected_activation: Optional[torch.Tensor] = None |
|
self.last_additive_correction: Optional[torch.Tensor] = None |
|
self.last_memory_input: Optional[torch.Tensor] = None |
|
self.last_reconstructed_from_memory: Optional[torch.Tensor] = None |
|
|
|
def forward(self, x: torch.Tensor): |
|
original_dtype = x.dtype |
|
base_output = self.linear(x) |
|
|
|
if 'embeds' not in self.global_state_storage: |
|
return base_output |
|
|
|
global_embeds = self.global_state_storage['embeds'] |
|
if global_embeds.shape[1] != x.shape[1]: |
|
global_embeds = global_embeds[:, -x.shape[1]:, :] |
|
|
|
B, S, _ = x.shape |
|
with torch.enable_grad(): |
|
proj_local = self.local_state_proj(x.to(torch.float32)) |
|
proj_global = self.global_state_proj(global_embeds.to(torch.float32)) |
|
memory_input = torch.stack([proj_global, proj_local], dim=2) |
|
memory_input_flat = memory_input.view(B * S, 2, self.memory_dim) |
|
compressed_mem_flat, recon_flat = self.memory_head(memory_input_flat) |
|
aggregated_thought_flat = compressed_mem_flat.mean(dim=1) |
|
aggregated_thought = aggregated_thought_flat.view(B, S, self.memory_dim) |
|
raw_correction = self.correction_head(aggregated_thought) |
|
gate, value = torch.chunk(raw_correction, 2, dim=-1) |
|
corrected_activation = base_output * torch.sigmoid(gate.to(original_dtype)) + value.to(original_dtype) |
|
|
|
if self.training: |
|
self.last_corrected_activation = corrected_activation |
|
self.last_additive_correction = value |
|
self.last_memory_input = memory_input_flat |
|
self.last_reconstructed_from_memory = recon_flat |
|
|
|
return corrected_activation |
|
|
|
AutoModelForCausalLM.register(Phi3Config, Phi3WithVectorMemoryForCausalLM) |
|
|
|
|