Spaces:
No application file
No application file
#STABLE ARCHITECTURE | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.checkpoint import checkpoint | |
from typing import Optional, Tuple | |
from dataclasses import dataclass | |
import tiktoken | |
import os | |
import json | |
import gradio as gr | |
#from fastapi import FastAPI | |
#from pydantic import BaseModel | |
#from fastapi.middleware.cors import CORSMiddleware | |
import uvicorn | |
import logging | |
from fastapi import FastAPI, HTTPException, status | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import Optional | |
#import torch | |
#import uvicorn | |
# ------------------------------------------------------------------------ | |
# 1) CONFIGURATION | |
# ------------------------------------------------------------------------ | |
@dataclass | |
class MiniMaxConfig: | |
# Basic GPT parameters | |
n_layer: int = 12 | |
n_head: int = 8 | |
n_embd: int = 512 | |
vocab_size: int = 200000 | |
block_size: int = 1024 | |
dropout: float = 0.1 | |
pad_token_id: int = 0 | |
bias: bool = False | |
tie_word_embeddings: bool = True | |
# Memory & training | |
use_checkpoint: bool = True | |
layer_norm_eps: float = 1e-5 | |
init_scale: float = 0.02 | |
# XPos / Rotary | |
rope_base: int = 10000 | |
rope_scale_base: float = 512.0 | |
adaptive_xpos: bool = True | |
use_adaptive_router: bool = False | |
# Attention enhancements | |
use_hybrid_attn: bool = True | |
lightning_ratio: int = 7 | |
lightning_block_size: int = 256 | |
use_flash_attn: bool = True | |
kv_cache: bool = False | |
# MoE settings | |
use_moe: bool = True | |
num_experts: int = 4 | |
moe_top_k: int = 2 | |
moe_capacity_factor: float = 1.2 | |
moe_balance_factor: float = 0.1 | |
diversity_factor: float = 0.01 | |
expert_dropout: float = 0.1 | |
z_loss_factor: float = 1e-4 | |
use_global_router: bool = False # placeholder for global routing improvements | |
# Normalization style: if True, use Post-LayerNorm (with DeepNorm scaling below) | |
use_post_layernorm: bool = True | |
# Hybrid attention: every X layers, use full softmax-based attention instead of lightning | |
hybrid_attention_interval: int = 8 | |
# ------------------------------------------------------------------------ | |
# 2) Enhanced RMSNorm with FP16 Safety | |
# ------------------------------------------------------------------------ | |
class EnhancedRMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
orig_dtype = x.dtype | |
if x.dtype == torch.float16: | |
x = x.float() | |
normed = x * torch.rsqrt((x * x).mean(dim=-1, keepdim=True) + self.eps) | |
normed = normed.to(orig_dtype) | |
return self.weight * normed | |
# ------------------------------------------------------------------------ | |
# 3) Adaptive XPos Rotary Embedding | |
# ------------------------------------------------------------------------ | |
class AdaptiveXPosRotaryEmbedding(nn.Module): | |
def __init__(self, dim, base=10000, scale_base=512.0, adaptive=True): | |
super().__init__() | |
assert dim % 2 == 0, "XPos dimension must be even." | |
self.dim = dim | |
self.base = base | |
self.scale_base = scale_base | |
self.adaptive = adaptive | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim // 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
def forward(self, seq_len, device, layer_depth=None, dtype=torch.float32): | |
t = torch.arange(seq_len, device=device, dtype=dtype) | |
scale = self.scale_base ** (t / self.scale_base) | |
if self.adaptive and layer_depth is not None: | |
scale *= torch.exp(-layer_depth / self.scale_base) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
scaled_freqs = freqs * scale.unsqueeze(-1) | |
emb = torch.cat([scaled_freqs, scaled_freqs], dim=-1) | |
return emb.cos().unsqueeze(0).unsqueeze(0), emb.sin().unsqueeze(0).unsqueeze(0) | |
def rotate_half(x: torch.Tensor): | |
half_dim = x.shape[-1] // 2 | |
x1 = x[..., :half_dim] | |
x2 = x[..., half_dim:] | |
return torch.cat([-x2, x1], dim=-1) | |
def apply_xpos_rotary_pos_emb(q, k, cos, sin): | |
B, nh, T, hd = q.shape | |
cos = cos[:, :, :T, :hd] | |
sin = sin[:, :, :T, :hd] | |
def rope(x): | |
return x * cos + rotate_half(x) * sin | |
return rope(q), rope(k) | |
# ------------------------------------------------------------------------ | |
# 4) Optimized Lightning Attention Module | |
# ------------------------------------------------------------------------ | |
class OptimizedLightningAttention(nn.Module): | |
def __init__(self, config: MiniMaxConfig): | |
super().__init__() | |
self.config = config | |
assert config.n_embd % config.n_head == 0 | |
self.n_head = config.n_head | |
self.head_dim = config.n_embd // config.n_head | |
self.dropout = config.dropout | |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
self.attn_dropout = nn.Dropout(config.dropout) | |
self.resid_dropout = nn.Dropout(config.dropout) | |
self.use_flash = config.use_flash_attn and hasattr(F, 'scaled_dot_product_attention') | |
self.kv_cache_enabled = config.kv_cache | |
self.register_buffer('kv_cache', None, persistent=False) | |
if config.adaptive_xpos: | |
self.xpos = AdaptiveXPosRotaryEmbedding( | |
dim=self.head_dim, | |
base=config.rope_base, | |
scale_base=config.rope_scale_base, | |
adaptive=config.use_adaptive_router | |
) | |
else: | |
self.xpos = None | |
def _shape_heads(self, x: torch.Tensor, B: int, T: int): | |
return x.view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, layer_idx: Optional[int] = None): | |
B, T, C = x.shape | |
qkv = self.c_attn(x) | |
q, k, v = qkv.split(C, dim=2) | |
q = self._shape_heads(q, B, T) | |
k = self._shape_heads(k, B, T) | |
v = self._shape_heads(v, B, T) | |
if layer_past is not None and self.kv_cache_enabled: | |
pk, pv = layer_past | |
k = torch.cat((pk, k), dim=2) | |
v = torch.cat((pv, v), dim=2) | |
if self.kv_cache_enabled: | |
self.kv_cache = (k, v) | |
if self.xpos is not None: | |
cos, sin = self.xpos(seq_len=T, device=x.device, layer_depth=layer_idx) | |
q, k = apply_xpos_rotary_pos_emb(q, k, cos, sin) | |
if mask is not None: | |
mask = mask.bool().unsqueeze(1).unsqueeze(2) | |
if self.use_flash: | |
y = F.scaled_dot_product_attention( | |
q, k, v, | |
attn_mask=mask, | |
dropout_p=self.dropout if self.training else 0.0, | |
is_causal=True | |
) | |
else: | |
scale = 1.0 / math.sqrt(self.head_dim) | |
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale | |
if mask is not None: | |
attn_scores = attn_scores.masked_fill(~mask, float('-inf')) | |
attn_probs = F.softmax(attn_scores, dim=-1) | |
attn_probs = self.attn_dropout(attn_probs) | |
y = torch.matmul(attn_probs, v) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) | |
y = self.resid_dropout(self.c_proj(y)) | |
return y | |
# ------------------------------------------------------------------------ | |
# 5) Enhanced Expert Block (for MoE experts) | |
# ------------------------------------------------------------------------ | |
class EnhancedExpertBlock(nn.Module): | |
def __init__(self, hidden_dim: int, dropout: float = 0.1): | |
super().__init__() | |
self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4) | |
self.act = nn.GELU() | |
self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim) | |
self.dropout = nn.Dropout(dropout) | |
with torch.no_grad(): | |
nn.init.orthogonal_(self.fc1.weight, gain=math.sqrt(2)) | |
nn.init.orthogonal_(self.fc2.weight, gain=math.sqrt(2)) | |
if self.fc1.bias is not None: | |
nn.init.zeros_(self.fc1.bias) | |
if self.fc2.bias is not None: | |
nn.init.zeros_(self.fc2.bias) | |
self.layer_scale = nn.Parameter(torch.ones(1, 1, hidden_dim) * 0.1) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
r = x | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
x = x * self.layer_scale | |
return r + x | |
# ------------------------------------------------------------------------ | |
# 6) Memory-Efficient MoE with Vectorized Dispatch | |
# ------------------------------------------------------------------------ | |
class MemoryEfficientMoE(nn.Module): | |
def __init__(self, config: MiniMaxConfig): | |
super().__init__() | |
self.num_experts = config.num_experts | |
self.top_k = config.moe_top_k | |
self.capacity_factor = config.moe_capacity_factor | |
self.balance_factor = config.moe_balance_factor | |
self.diversity_factor = config.diversity_factor | |
self.z_loss_factor = config.z_loss_factor | |
self.hidden_dim = config.n_embd | |
self.dropout = config.expert_dropout | |
self.experts = nn.ModuleList([ | |
EnhancedExpertBlock(self.hidden_dim, self.dropout) for _ in range(self.num_experts) | |
]) | |
self.router = nn.Linear(self.hidden_dim, self.num_experts) | |
self.register_buffer('aux_loss', torch.zeros(1)) | |
self.register_buffer('diversity_loss', torch.zeros(1)) | |
def compute_diversity_loss(self): | |
param_vecs = [] | |
for e in self.experts: | |
pvec = [] | |
for p in e.parameters(): | |
pvec.append(p.flatten()) | |
param_vecs.append(torch.cat(pvec, dim=0)) | |
div_loss = 0.0 | |
for i in range(self.num_experts): | |
for j in range(i+1, self.num_experts): | |
cos_sim = F.cosine_similarity( | |
param_vecs[i].unsqueeze(0), | |
param_vecs[j].unsqueeze(0) | |
) | |
div_loss += cos_sim ** 2 | |
return div_loss * self.diversity_factor | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
B, T, C = x.shape | |
N = B * T | |
E = self.num_experts | |
device = x.device | |
router_logits = self.router(x.view(N, C)) | |
router_probs = F.softmax(router_logits, dim=-1) | |
z_loss = self.z_loss_factor * (router_logits ** 2).mean() | |
importance = router_probs.mean(dim=0) | |
target = torch.ones_like(importance) / E | |
balance = F.mse_loss(importance, target, reduction='sum') * self.balance_factor | |
top_vals, top_inds = torch.topk(router_probs, self.top_k, dim=-1) | |
top_vals = top_vals / (top_vals.sum(dim=-1, keepdim=True) + 1e-9) | |
capacity = int(self.capacity_factor * (N // E + 1)) | |
out = torch.zeros_like(x.view(N, C), device=device) | |
used_slots = torch.zeros(E, dtype=torch.int32, device=device) | |
for i_k in range(self.top_k): | |
w = top_vals[:, i_k] | |
e_idx = top_inds[:, i_k] | |
mask = w > 1e-9 | |
if not mask.any(): | |
continue | |
valid_idx = mask.nonzero(as_tuple=True)[0] | |
for eid in range(E): | |
mask_eid = (e_idx[valid_idx] == eid) | |
count_e = mask_eid.sum().item() | |
if count_e == 0: | |
continue | |
c_before = used_slots[eid].item() | |
c_after = c_before + count_e | |
if c_before >= capacity: | |
continue | |
if c_after > capacity: | |
allowed = capacity - c_before | |
selected = mask_eid.nonzero(as_tuple=True)[0][:allowed] | |
real_idx = valid_idx[selected] | |
used_slots[eid] = capacity | |
else: | |
selected = mask_eid.nonzero(as_tuple=True)[0] | |
real_idx = valid_idx[selected] | |
used_slots[eid] += count_e | |
if len(real_idx) == 0: | |
continue | |
tokens = x.view(N, C)[real_idx] | |
y_ = self.experts[eid](tokens) | |
y_ = y_.view(len(real_idx), -1) | |
w_ = w[real_idx].view(-1, 1) | |
out[real_idx] += w_ * y_ | |
self.aux_loss = balance + z_loss | |
self.diversity_loss = self.compute_diversity_loss() | |
return out.view(B, T, C) | |
# ------------------------------------------------------------------------ | |
# 7) Enhanced Transformer Block with Hybrid Attention & DeepNorm | |
# ------------------------------------------------------------------------ | |
class EnhancedHybridBlock(nn.Module): | |
""" | |
Transformer block with hybrid attention and DeepNorm residual scaling. | |
Depending on `attn_type`, it uses either lightning attention or (placeholder) softmax attention. | |
""" | |
def __init__(self, config: MiniMaxConfig, layer_idx: int, attn_type: str = "lightning"): | |
super().__init__() | |
self.config = config | |
self.layer_idx = layer_idx | |
self.attn_type = attn_type | |
# Choose attention module. | |
# For softmax, you might replace this with a dedicated softmax attention module. | |
if attn_type == "softmax": | |
self.attn = OptimizedLightningAttention(config) # Placeholder for a softmax variant. | |
else: | |
self.attn = OptimizedLightningAttention(config) | |
# MoE or standard FFN | |
if config.use_moe: | |
self.mlp = MemoryEfficientMoE(config) | |
else: | |
self.mlp = EnhancedExpertBlock(config.n_embd, config.dropout) | |
self.ln_1 = EnhancedRMSNorm(config.n_embd, eps=config.layer_norm_eps) | |
self.ln_2 = EnhancedRMSNorm(config.n_embd, eps=config.layer_norm_eps) | |
self.use_checkpoint = config.use_checkpoint | |
# DeepNorm scaling factors for residual connections. | |
self.alpha_attn = nn.Parameter(torch.ones(1) * 0.5) | |
self.alpha_mlp = nn.Parameter(torch.ones(1) * 0.5) | |
def _forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): | |
if self.config.use_post_layernorm: | |
a_out = self.attn(x, mask, layer_idx=self.layer_idx) | |
x = x + self.alpha_attn * a_out | |
x = self.ln_1(x) | |
m_out = self.mlp(x) | |
x = x + self.alpha_mlp * m_out | |
x = self.ln_2(x) | |
else: | |
a = self.ln_1(x) | |
a_out = self.attn(a, mask, layer_idx=self.layer_idx) | |
x = x + self.alpha_attn * a_out | |
m = self.ln_2(x) | |
m_out = self.mlp(m) | |
x = x + self.alpha_mlp * m_out | |
return x | |
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): | |
if self.use_checkpoint and self.training: | |
return checkpoint(self._forward, x, mask) | |
else: | |
return self._forward(x, mask) | |
# ------------------------------------------------------------------------ | |
# 8) Full Model: EnhancedMiniMaxGPT with Hybrid Attention | |
# ------------------------------------------------------------------------ | |
class EnhancedMiniMaxGPT(nn.Module): | |
def __init__(self, config: MiniMaxConfig): | |
super().__init__() | |
self.config = config | |
# Embeddings | |
self.wte = nn.Embedding(config.vocab_size, config.n_embd) | |
self.wpe = nn.Embedding(config.block_size, config.n_embd) | |
self.drop = nn.Dropout(config.dropout) | |
# Build transformer blocks, alternating attention type based on hybrid_attention_interval. | |
self.blocks = nn.ModuleList() | |
interval = config.hybrid_attention_interval | |
for layer_idx in range(config.n_layer): | |
if (layer_idx + 1) % interval == 0: | |
attn_type = "softmax" | |
else: | |
attn_type = "lightning" | |
blk = EnhancedHybridBlock(config, layer_idx, attn_type=attn_type) | |
self.blocks.append(blk) | |
self.ln_f = EnhancedRMSNorm(config.n_embd, eps=config.layer_norm_eps) | |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
self.apply(self._init_weights) | |
if config.tie_word_embeddings: | |
self.lm_head.weight = self.wte.weight | |
print(f"[EnhancedMiniMaxGPT] #params (non-embeddings): {self.get_num_params(non_embedding=True)/1e6:.2f}M") | |
def _init_weights(self, module): | |
if isinstance(module, nn.Linear): | |
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_scale) | |
if module.bias is not None: | |
nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Embedding): | |
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_scale) | |
def get_num_params(self, non_embedding=True): | |
n_params = sum(p.numel() for p in self.parameters()) | |
if non_embedding: | |
n_params -= self.wte.weight.numel() | |
n_params -= self.wpe.weight.numel() | |
return n_params | |
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None): | |
B, T = input_ids.shape | |
device = input_ids.device | |
if attention_mask is None: | |
attention_mask = (input_ids != self.config.pad_token_id).long() | |
if T > self.config.block_size: | |
raise ValueError(f"Seq length {T} > block_size {self.config.block_size}") | |
pos_ids = torch.arange(T, device=device).unsqueeze(0) | |
x = self.wte(input_ids) + self.wpe(pos_ids) | |
x = self.drop(x) | |
for layer_idx, blk in enumerate(self.blocks): | |
x = blk(x, mask=attention_mask) | |
x = self.ln_f(x) | |
logits = self.lm_head(x) | |
loss = None | |
if targets is not None: | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_targets = targets[..., 1:].contiguous() | |
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), | |
shift_targets.view(-1), | |
ignore_index=self.config.pad_token_id) | |
return logits, loss | |
@torch.no_grad() | |
def generate(self, idx: torch.Tensor, max_new_tokens: int = 50, temperature: float = 1.0, | |
top_k: Optional[int] = None, top_p: Optional[float] = None): | |
device = idx.device | |
generated = idx | |
for _ in range(max_new_tokens): | |
idx_cond = generated[:, -self.config.block_size:] | |
logits, _ = self(idx_cond) | |
logits = logits[:, -1, :] / temperature | |
logits = torch.nan_to_num(logits, nan=float('-inf')) | |
if top_k is not None: | |
vals, _ = torch.topk(logits, top_k) | |
logits[logits < vals[:, [-1]]] = float('-inf') | |
if top_p is not None: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
sorted_probs = F.softmax(sorted_logits, dim=-1) | |
cum_probs = torch.cumsum(sorted_probs, dim=-1) | |
remove_mask = cum_probs > top_p | |
remove_mask[..., 1:] = remove_mask[..., :-1].clone() | |
remove_mask[..., 0] = False | |
sorted_logits[remove_mask] = float('-inf') | |
logits = torch.zeros_like(logits).scatter(1, sorted_indices, sorted_logits) | |
probs = F.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
generated = torch.cat([generated, next_token], dim=1) | |
return generated | |
# ------------------------------------------------------------------------ | |
# Example Usage: | |
# ------------------------------------------------------------------------ | |
#import tiktoken | |
#import logging | |
# Load Model | |
# --------------------------- | |
model = None | |
# --------------------------- | |
# Tokenizer Setup | |
# --------------------------- | |
special_tokens_dict = { | |
"<|user|>": 50257, | |
"<|assistant|>": 50258, | |
"<|pad|>": 50259, | |
"<|endoftext|>": 50260, | |
} | |
# Initialize the tokenizer | |
base_enc = tiktoken.encoding_for_model("gpt2") | |
encoding = tiktoken.Encoding( | |
name="gpt-4o-custom", | |
pat_str=base_enc._pat_str, | |
mergeable_ranks=base_enc._mergeable_ranks, | |
special_tokens={**base_enc._special_tokens, **special_tokens_dict}, | |
) | |
pad_token_id = special_tokens_dict["<|pad|>"] | |
"""def load_model(model_dir="./"): | |
global model | |
if model is not None: | |
return model | |
model_config = MiniMaxConfig( | |
vocab_size=encoding.n_vocab, | |
block_size=256, | |
n_layer=8, | |
n_head=4, | |
n_embd=384, | |
dropout=0.1, | |
) | |
model = EnhancedMiniMaxGPT(model_config) | |
model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device("cpu"))) | |
model.eval() | |
return model | |
model = load_model() | |
# ------------------------------------------------------------------------ | |
# API Setup | |
# ------------------------------------------------------------------------ | |
app = FastAPI() | |
class ChatRequest(BaseModel): | |
messages: list[dict] # List of messages with 'role' and 'content' | |
class ChatResponse(BaseModel): | |
response: str | |
def build_prompt(conversation_history): | |
prompt = "" | |
for turn in conversation_history: | |
if turn["role"] == "user": | |
prompt += f"<|user|> {turn['content'].strip()}\n" | |
else: | |
prompt += f"<|assistant|> {turn['content'].strip()}\n" | |
prompt += "<|assistant|> " | |
return prompt | |
def generate_response(conversation_history): | |
prompt_text = build_prompt(conversation_history) | |
input_ids = torch.tensor( | |
encoding.encode(prompt_text, allowed_special=set(special_tokens_dict.keys())), | |
dtype=torch.long, | |
).unsqueeze(0) | |
if input_ids.size(1) > model.config.block_size: | |
input_ids = input_ids[:, -model.config.block_size:] | |
generated_ids = model.generate( | |
idx=input_ids, | |
max_new_tokens=100, | |
temperature=0.8, | |
top_k=50, | |
top_p=0.95, | |
) | |
new_tokens = generated_ids[0].tolist()[len(input_ids[0]):] | |
response_text = encoding.decode(new_tokens).strip() | |
return response_text | |
@app.post("/api/chat", response_model=ChatResponse) | |
async def chat_endpoint(request: ChatRequest): | |
try: | |
response_text = generate_response(request.messages) | |
return ChatResponse(response=response_text) | |
except Exception as e: | |
return {"error": str(e)}""" | |
# --------------------------- | |
def load_model(model_dir="./"): | |
global model | |
if model is not None: | |
return model | |
config_path = os.path.join(model_dir, "config.json") | |
weights_path = os.path.join(model_dir, "pytorch_model.bin") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
model_config = MiniMaxConfig( | |
vocab_size=encoding.n_vocab, | |
block_size=512, | |
n_layer=12, | |
n_head=8, | |
n_embd=512, | |
dropout=0.1, | |
#tie_word_embeddings=True, | |
#adaptive_xpos=True, | |
hybrid_attention_interval=4, | |
num_experts= 2, | |
) | |
model = EnhancedMiniMaxGPT(model_config) | |
#model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu"))) | |
state_dict = torch.load(weights_path, map_location=torch.device("cpu")) | |
model.load_state_dict(state_dict, strict=False) | |
model.eval() | |
return model | |
def load_model_weights(checkpoint_path, config, device): | |
""" | |
Load only the model weights from a .pth file. | |
Ensures compatibility by loading only matched layers. | |
""" | |
if not os.path.exists(checkpoint_path): | |
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
model = EnhancedMiniMaxGPT(config).to(device) | |
state_dict = torch.load(checkpoint_path, map_location=device) | |
# Check for shape mismatches and fix aux_loss shape if necessary | |
model_state_dict = model.state_dict() | |
compatible_state_dict = {} | |
for k, v in state_dict.items(): | |
if k in model_state_dict: | |
if v.shape == model_state_dict[k].shape: | |
compatible_state_dict[k] = v | |
elif "aux_loss" in k and v.shape == torch.Size([]): | |
compatible_state_dict[k] = v.unsqueeze(0) # Convert scalar to tensor | |
print(f"Fixed shape for {k}") | |
else: | |
print(f"Skipping {k} due to shape mismatch.") | |
# Load compatible weights | |
model.load_state_dict(compatible_state_dict, strict=False) | |
model.eval() | |
print(f"✅ Loaded model weights from {checkpoint_path}") | |
return model | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def get_device(): | |
"""Return GPU device if available, else CPU.""" | |
return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Global model variable | |
model_config = MiniMaxConfig( | |
vocab_size=encoding.n_vocab, | |
block_size=512, | |
n_layer=12, | |
n_head=8, | |
n_embd=512, | |
dropout=0.1, | |
#tie_word_embeddings=True, | |
#adaptive_xpos=True, | |
hybrid_attention_interval=4, | |
num_experts= 2, | |
) | |
model = None | |
checkpoint_path = "pytorch_model.bin" | |
device = get_device() | |
model = load_model_weights(checkpoint_path, model_config, device)#load_model() | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
messages: list[ChatMessage] | |
class ChatResponse(BaseModel): | |
response: str | |
status: str = "success" | |
async def ensure_model_loaded(): | |
"""Ensure model is loaded before processing requests""" | |
global model | |
if model is None: | |
try: | |
logger.info("Loading model...") | |
model = load_model() | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Model initialization failed" | |
) | |
#@app.post("/api/chat", response_model=ChatResponse) | |
@app.post("/api/chat", response_model=ChatResponse) | |
async def chat_endpoint(request: ChatRequest): | |
try: | |
await ensure_model_loaded() | |
logger.info(f"Received chat request with {len(request.messages)} messages") | |
# Either: | |
# conversation = [msg.model_dump() for msg in request.messages] | |
# Or if you only need role & content: | |
conversation = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
response_text = generate_response(conversation) | |
logger.info("Response generated successfully") | |
return ChatResponse(response=response_text) | |
except Exception as e: | |
logger.error(f"Error processing request: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=str(e) | |
) | |
@app.get("/api/health") | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "healthy"} | |
import gradio as gr | |
# --------------------------- | |
def build_prompt(conversation_history): | |
prompt = "" | |
for turn in conversation_history: | |
if turn["role"] == "user": | |
prompt += f"<|user|> {turn['content'].strip()}\n" | |
else: | |
prompt += f"<|assistant|> {turn['content'].strip()}\n" | |
prompt += "<|assistant|> " | |
return prompt | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def generate_response(conversation_history): | |
# automatically set up device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
prompt_text = build_prompt(conversation_history) | |
input_ids = torch.tensor( | |
encoding.encode(prompt_text, allowed_special=set(special_tokens_dict.keys())), | |
dtype=torch.long, | |
device=device, | |
).unsqueeze(0) | |
if input_ids.size(1) > model.config.block_size: | |
input_ids = input_ids[:, -model.config.block_size:] | |
generated_ids = model.generate( | |
idx=input_ids, | |
max_new_tokens=100, | |
temperature=1.2, | |
top_k=50, | |
top_p=0.95, | |
) | |
new_tokens = generated_ids[0].tolist()[len(input_ids[0]):] | |
response_text = encoding.decode(new_tokens).strip() | |
return response_text | |
def chatbot_fn(user_input): | |
response = generate_response([{"role": "user", "content": user_input}]) | |
return response | |
# Expose Gradio as an API instead of UI | |
iface = gr.Interface(fn=chatbot_fn, inputs="text", outputs="text") | |
# Enable API mode by setting `server_name="0.0.0.0"` and `serve=True` | |
#iface.launch(server_name="0.0.0.0", server_port=7860) | |
# The magic: mount Gradio onto the FastAPI app at "/" | |
app = gr.mount_gradio_app(app, iface, path="/") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |