st-chat / old.txt
fartinalbania's picture
Update old.txt
9570ec3 verified
#STABLE ARCHITECTURE
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from typing import Optional, Tuple
from dataclasses import dataclass
import tiktoken
import os
import json
import gradio as gr
#from fastapi import FastAPI
#from pydantic import BaseModel
#from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import logging
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
#import torch
#import uvicorn
# ------------------------------------------------------------------------
# 1) CONFIGURATION
# ------------------------------------------------------------------------
@dataclass
class MiniMaxConfig:
# Basic GPT parameters
n_layer: int = 12
n_head: int = 8
n_embd: int = 512
vocab_size: int = 200000
block_size: int = 1024
dropout: float = 0.1
pad_token_id: int = 0
bias: bool = False
tie_word_embeddings: bool = True
# Memory & training
use_checkpoint: bool = True
layer_norm_eps: float = 1e-5
init_scale: float = 0.02
# XPos / Rotary
rope_base: int = 10000
rope_scale_base: float = 512.0
adaptive_xpos: bool = True
use_adaptive_router: bool = False
# Attention enhancements
use_hybrid_attn: bool = True
lightning_ratio: int = 7
lightning_block_size: int = 256
use_flash_attn: bool = True
kv_cache: bool = False
# MoE settings
use_moe: bool = True
num_experts: int = 4
moe_top_k: int = 2
moe_capacity_factor: float = 1.2
moe_balance_factor: float = 0.1
diversity_factor: float = 0.01
expert_dropout: float = 0.1
z_loss_factor: float = 1e-4
use_global_router: bool = False # placeholder for global routing improvements
# Normalization style: if True, use Post-LayerNorm (with DeepNorm scaling below)
use_post_layernorm: bool = True
# Hybrid attention: every X layers, use full softmax-based attention instead of lightning
hybrid_attention_interval: int = 8
# ------------------------------------------------------------------------
# 2) Enhanced RMSNorm with FP16 Safety
# ------------------------------------------------------------------------
class EnhancedRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
if x.dtype == torch.float16:
x = x.float()
normed = x * torch.rsqrt((x * x).mean(dim=-1, keepdim=True) + self.eps)
normed = normed.to(orig_dtype)
return self.weight * normed
# ------------------------------------------------------------------------
# 3) Adaptive XPos Rotary Embedding
# ------------------------------------------------------------------------
class AdaptiveXPosRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, scale_base=512.0, adaptive=True):
super().__init__()
assert dim % 2 == 0, "XPos dimension must be even."
self.dim = dim
self.base = base
self.scale_base = scale_base
self.adaptive = adaptive
inv_freq = 1.0 / (base ** (torch.arange(0, dim // 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seq_len, device, layer_depth=None, dtype=torch.float32):
t = torch.arange(seq_len, device=device, dtype=dtype)
scale = self.scale_base ** (t / self.scale_base)
if self.adaptive and layer_depth is not None:
scale *= torch.exp(-layer_depth / self.scale_base)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
scaled_freqs = freqs * scale.unsqueeze(-1)
emb = torch.cat([scaled_freqs, scaled_freqs], dim=-1)
return emb.cos().unsqueeze(0).unsqueeze(0), emb.sin().unsqueeze(0).unsqueeze(0)
def rotate_half(x: torch.Tensor):
half_dim = x.shape[-1] // 2
x1 = x[..., :half_dim]
x2 = x[..., half_dim:]
return torch.cat([-x2, x1], dim=-1)
def apply_xpos_rotary_pos_emb(q, k, cos, sin):
B, nh, T, hd = q.shape
cos = cos[:, :, :T, :hd]
sin = sin[:, :, :T, :hd]
def rope(x):
return x * cos + rotate_half(x) * sin
return rope(q), rope(k)
# ------------------------------------------------------------------------
# 4) Optimized Lightning Attention Module
# ------------------------------------------------------------------------
class OptimizedLightningAttention(nn.Module):
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.config = config
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.dropout = config.dropout
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.use_flash = config.use_flash_attn and hasattr(F, 'scaled_dot_product_attention')
self.kv_cache_enabled = config.kv_cache
self.register_buffer('kv_cache', None, persistent=False)
if config.adaptive_xpos:
self.xpos = AdaptiveXPosRotaryEmbedding(
dim=self.head_dim,
base=config.rope_base,
scale_base=config.rope_scale_base,
adaptive=config.use_adaptive_router
)
else:
self.xpos = None
def _shape_heads(self, x: torch.Tensor, B: int, T: int):
return x.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, layer_idx: Optional[int] = None):
B, T, C = x.shape
qkv = self.c_attn(x)
q, k, v = qkv.split(C, dim=2)
q = self._shape_heads(q, B, T)
k = self._shape_heads(k, B, T)
v = self._shape_heads(v, B, T)
if layer_past is not None and self.kv_cache_enabled:
pk, pv = layer_past
k = torch.cat((pk, k), dim=2)
v = torch.cat((pv, v), dim=2)
if self.kv_cache_enabled:
self.kv_cache = (k, v)
if self.xpos is not None:
cos, sin = self.xpos(seq_len=T, device=x.device, layer_depth=layer_idx)
q, k = apply_xpos_rotary_pos_emb(q, k, cos, sin)
if mask is not None:
mask = mask.bool().unsqueeze(1).unsqueeze(2)
if self.use_flash:
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True
)
else:
scale = 1.0 / math.sqrt(self.head_dim)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.attn_dropout(attn_probs)
y = torch.matmul(attn_probs, v)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
# ------------------------------------------------------------------------
# 5) Enhanced Expert Block (for MoE experts)
# ------------------------------------------------------------------------
class EnhancedExpertBlock(nn.Module):
def __init__(self, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim)
self.dropout = nn.Dropout(dropout)
with torch.no_grad():
nn.init.orthogonal_(self.fc1.weight, gain=math.sqrt(2))
nn.init.orthogonal_(self.fc2.weight, gain=math.sqrt(2))
if self.fc1.bias is not None:
nn.init.zeros_(self.fc1.bias)
if self.fc2.bias is not None:
nn.init.zeros_(self.fc2.bias)
self.layer_scale = nn.Parameter(torch.ones(1, 1, hidden_dim) * 0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r = x
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = x * self.layer_scale
return r + x
# ------------------------------------------------------------------------
# 6) Memory-Efficient MoE with Vectorized Dispatch
# ------------------------------------------------------------------------
class MemoryEfficientMoE(nn.Module):
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.moe_top_k
self.capacity_factor = config.moe_capacity_factor
self.balance_factor = config.moe_balance_factor
self.diversity_factor = config.diversity_factor
self.z_loss_factor = config.z_loss_factor
self.hidden_dim = config.n_embd
self.dropout = config.expert_dropout
self.experts = nn.ModuleList([
EnhancedExpertBlock(self.hidden_dim, self.dropout) for _ in range(self.num_experts)
])
self.router = nn.Linear(self.hidden_dim, self.num_experts)
self.register_buffer('aux_loss', torch.zeros(1))
self.register_buffer('diversity_loss', torch.zeros(1))
def compute_diversity_loss(self):
param_vecs = []
for e in self.experts:
pvec = []
for p in e.parameters():
pvec.append(p.flatten())
param_vecs.append(torch.cat(pvec, dim=0))
div_loss = 0.0
for i in range(self.num_experts):
for j in range(i+1, self.num_experts):
cos_sim = F.cosine_similarity(
param_vecs[i].unsqueeze(0),
param_vecs[j].unsqueeze(0)
)
div_loss += cos_sim ** 2
return div_loss * self.diversity_factor
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
N = B * T
E = self.num_experts
device = x.device
router_logits = self.router(x.view(N, C))
router_probs = F.softmax(router_logits, dim=-1)
z_loss = self.z_loss_factor * (router_logits ** 2).mean()
importance = router_probs.mean(dim=0)
target = torch.ones_like(importance) / E
balance = F.mse_loss(importance, target, reduction='sum') * self.balance_factor
top_vals, top_inds = torch.topk(router_probs, self.top_k, dim=-1)
top_vals = top_vals / (top_vals.sum(dim=-1, keepdim=True) + 1e-9)
capacity = int(self.capacity_factor * (N // E + 1))
out = torch.zeros_like(x.view(N, C), device=device)
used_slots = torch.zeros(E, dtype=torch.int32, device=device)
for i_k in range(self.top_k):
w = top_vals[:, i_k]
e_idx = top_inds[:, i_k]
mask = w > 1e-9
if not mask.any():
continue
valid_idx = mask.nonzero(as_tuple=True)[0]
for eid in range(E):
mask_eid = (e_idx[valid_idx] == eid)
count_e = mask_eid.sum().item()
if count_e == 0:
continue
c_before = used_slots[eid].item()
c_after = c_before + count_e
if c_before >= capacity:
continue
if c_after > capacity:
allowed = capacity - c_before
selected = mask_eid.nonzero(as_tuple=True)[0][:allowed]
real_idx = valid_idx[selected]
used_slots[eid] = capacity
else:
selected = mask_eid.nonzero(as_tuple=True)[0]
real_idx = valid_idx[selected]
used_slots[eid] += count_e
if len(real_idx) == 0:
continue
tokens = x.view(N, C)[real_idx]
y_ = self.experts[eid](tokens)
y_ = y_.view(len(real_idx), -1)
w_ = w[real_idx].view(-1, 1)
out[real_idx] += w_ * y_
self.aux_loss = balance + z_loss
self.diversity_loss = self.compute_diversity_loss()
return out.view(B, T, C)
# ------------------------------------------------------------------------
# 7) Enhanced Transformer Block with Hybrid Attention & DeepNorm
# ------------------------------------------------------------------------
class EnhancedHybridBlock(nn.Module):
"""
Transformer block with hybrid attention and DeepNorm residual scaling.
Depending on `attn_type`, it uses either lightning attention or (placeholder) softmax attention.
"""
def __init__(self, config: MiniMaxConfig, layer_idx: int, attn_type: str = "lightning"):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attn_type = attn_type
# Choose attention module.
# For softmax, you might replace this with a dedicated softmax attention module.
if attn_type == "softmax":
self.attn = OptimizedLightningAttention(config) # Placeholder for a softmax variant.
else:
self.attn = OptimizedLightningAttention(config)
# MoE or standard FFN
if config.use_moe:
self.mlp = MemoryEfficientMoE(config)
else:
self.mlp = EnhancedExpertBlock(config.n_embd, config.dropout)
self.ln_1 = EnhancedRMSNorm(config.n_embd, eps=config.layer_norm_eps)
self.ln_2 = EnhancedRMSNorm(config.n_embd, eps=config.layer_norm_eps)
self.use_checkpoint = config.use_checkpoint
# DeepNorm scaling factors for residual connections.
self.alpha_attn = nn.Parameter(torch.ones(1) * 0.5)
self.alpha_mlp = nn.Parameter(torch.ones(1) * 0.5)
def _forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
if self.config.use_post_layernorm:
a_out = self.attn(x, mask, layer_idx=self.layer_idx)
x = x + self.alpha_attn * a_out
x = self.ln_1(x)
m_out = self.mlp(x)
x = x + self.alpha_mlp * m_out
x = self.ln_2(x)
else:
a = self.ln_1(x)
a_out = self.attn(a, mask, layer_idx=self.layer_idx)
x = x + self.alpha_attn * a_out
m = self.ln_2(x)
m_out = self.mlp(m)
x = x + self.alpha_mlp * m_out
return x
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
if self.use_checkpoint and self.training:
return checkpoint(self._forward, x, mask)
else:
return self._forward(x, mask)
# ------------------------------------------------------------------------
# 8) Full Model: EnhancedMiniMaxGPT with Hybrid Attention
# ------------------------------------------------------------------------
class EnhancedMiniMaxGPT(nn.Module):
def __init__(self, config: MiniMaxConfig):
super().__init__()
self.config = config
# Embeddings
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.block_size, config.n_embd)
self.drop = nn.Dropout(config.dropout)
# Build transformer blocks, alternating attention type based on hybrid_attention_interval.
self.blocks = nn.ModuleList()
interval = config.hybrid_attention_interval
for layer_idx in range(config.n_layer):
if (layer_idx + 1) % interval == 0:
attn_type = "softmax"
else:
attn_type = "lightning"
blk = EnhancedHybridBlock(config, layer_idx, attn_type=attn_type)
self.blocks.append(blk)
self.ln_f = EnhancedRMSNorm(config.n_embd, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.apply(self._init_weights)
if config.tie_word_embeddings:
self.lm_head.weight = self.wte.weight
print(f"[EnhancedMiniMaxGPT] #params (non-embeddings): {self.get_num_params(non_embedding=True)/1e6:.2f}M")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_scale)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.init_scale)
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.wte.weight.numel()
n_params -= self.wpe.weight.numel()
return n_params
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None):
B, T = input_ids.shape
device = input_ids.device
if attention_mask is None:
attention_mask = (input_ids != self.config.pad_token_id).long()
if T > self.config.block_size:
raise ValueError(f"Seq length {T} > block_size {self.config.block_size}")
pos_ids = torch.arange(T, device=device).unsqueeze(0)
x = self.wte(input_ids) + self.wpe(pos_ids)
x = self.drop(x)
for layer_idx, blk in enumerate(self.blocks):
x = blk(x, mask=attention_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_targets = targets[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)),
shift_targets.view(-1),
ignore_index=self.config.pad_token_id)
return logits, loss
@torch.no_grad()
def generate(self, idx: torch.Tensor, max_new_tokens: int = 50, temperature: float = 1.0,
top_k: Optional[int] = None, top_p: Optional[float] = None):
device = idx.device
generated = idx
for _ in range(max_new_tokens):
idx_cond = generated[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
logits = torch.nan_to_num(logits, nan=float('-inf'))
if top_k is not None:
vals, _ = torch.topk(logits, top_k)
logits[logits < vals[:, [-1]]] = float('-inf')
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
remove_mask = cum_probs > top_p
remove_mask[..., 1:] = remove_mask[..., :-1].clone()
remove_mask[..., 0] = False
sorted_logits[remove_mask] = float('-inf')
logits = torch.zeros_like(logits).scatter(1, sorted_indices, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
return generated
# ------------------------------------------------------------------------
# Example Usage:
# ------------------------------------------------------------------------
#import tiktoken
#import logging
# Load Model
# ---------------------------
model = None
# ---------------------------
# Tokenizer Setup
# ---------------------------
special_tokens_dict = {
"<|user|>": 50257,
"<|assistant|>": 50258,
"<|pad|>": 50259,
"<|endoftext|>": 50260,
}
# Initialize the tokenizer
base_enc = tiktoken.encoding_for_model("gpt2")
encoding = tiktoken.Encoding(
name="gpt-4o-custom",
pat_str=base_enc._pat_str,
mergeable_ranks=base_enc._mergeable_ranks,
special_tokens={**base_enc._special_tokens, **special_tokens_dict},
)
pad_token_id = special_tokens_dict["<|pad|>"]
"""def load_model(model_dir="./"):
global model
if model is not None:
return model
model_config = MiniMaxConfig(
vocab_size=encoding.n_vocab,
block_size=256,
n_layer=8,
n_head=4,
n_embd=384,
dropout=0.1,
)
model = EnhancedMiniMaxGPT(model_config)
model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device("cpu")))
model.eval()
return model
model = load_model()
# ------------------------------------------------------------------------
# API Setup
# ------------------------------------------------------------------------
app = FastAPI()
class ChatRequest(BaseModel):
messages: list[dict] # List of messages with 'role' and 'content'
class ChatResponse(BaseModel):
response: str
def build_prompt(conversation_history):
prompt = ""
for turn in conversation_history:
if turn["role"] == "user":
prompt += f"<|user|> {turn['content'].strip()}\n"
else:
prompt += f"<|assistant|> {turn['content'].strip()}\n"
prompt += "<|assistant|> "
return prompt
def generate_response(conversation_history):
prompt_text = build_prompt(conversation_history)
input_ids = torch.tensor(
encoding.encode(prompt_text, allowed_special=set(special_tokens_dict.keys())),
dtype=torch.long,
).unsqueeze(0)
if input_ids.size(1) > model.config.block_size:
input_ids = input_ids[:, -model.config.block_size:]
generated_ids = model.generate(
idx=input_ids,
max_new_tokens=100,
temperature=0.8,
top_k=50,
top_p=0.95,
)
new_tokens = generated_ids[0].tolist()[len(input_ids[0]):]
response_text = encoding.decode(new_tokens).strip()
return response_text
@app.post("/api/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
response_text = generate_response(request.messages)
return ChatResponse(response=response_text)
except Exception as e:
return {"error": str(e)}"""
# ---------------------------
def load_model(model_dir="./"):
global model
if model is not None:
return model
config_path = os.path.join(model_dir, "config.json")
weights_path = os.path.join(model_dir, "pytorch_model.bin")
with open(config_path, "r") as f:
config = json.load(f)
model_config = MiniMaxConfig(
vocab_size=encoding.n_vocab,
block_size=512,
n_layer=12,
n_head=8,
n_embd=512,
dropout=0.1,
#tie_word_embeddings=True,
#adaptive_xpos=True,
hybrid_attention_interval=4,
num_experts= 2,
)
model = EnhancedMiniMaxGPT(model_config)
#model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
def load_model_weights(checkpoint_path, config, device):
"""
Load only the model weights from a .pth file.
Ensures compatibility by loading only matched layers.
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
model = EnhancedMiniMaxGPT(config).to(device)
state_dict = torch.load(checkpoint_path, map_location=device)
# Check for shape mismatches and fix aux_loss shape if necessary
model_state_dict = model.state_dict()
compatible_state_dict = {}
for k, v in state_dict.items():
if k in model_state_dict:
if v.shape == model_state_dict[k].shape:
compatible_state_dict[k] = v
elif "aux_loss" in k and v.shape == torch.Size([]):
compatible_state_dict[k] = v.unsqueeze(0) # Convert scalar to tensor
print(f"Fixed shape for {k}")
else:
print(f"Skipping {k} due to shape mismatch.")
# Load compatible weights
model.load_state_dict(compatible_state_dict, strict=False)
model.eval()
print(f"✅ Loaded model weights from {checkpoint_path}")
return model
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_device():
"""Return GPU device if available, else CPU."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Global model variable
model_config = MiniMaxConfig(
vocab_size=encoding.n_vocab,
block_size=512,
n_layer=12,
n_head=8,
n_embd=512,
dropout=0.1,
#tie_word_embeddings=True,
#adaptive_xpos=True,
hybrid_attention_interval=4,
num_experts= 2,
)
model = None
checkpoint_path = "pytorch_model.bin"
device = get_device()
model = load_model_weights(checkpoint_path, model_config, device)#load_model()
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[ChatMessage]
class ChatResponse(BaseModel):
response: str
status: str = "success"
async def ensure_model_loaded():
"""Ensure model is loaded before processing requests"""
global model
if model is None:
try:
logger.info("Loading model...")
model = load_model()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model initialization failed"
)
#@app.post("/api/chat", response_model=ChatResponse)
@app.post("/api/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
await ensure_model_loaded()
logger.info(f"Received chat request with {len(request.messages)} messages")
# Either:
# conversation = [msg.model_dump() for msg in request.messages]
# Or if you only need role & content:
conversation = [{"role": msg.role, "content": msg.content} for msg in request.messages]
response_text = generate_response(conversation)
logger.info("Response generated successfully")
return ChatResponse(response=response_text)
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e)
)
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
import gradio as gr
# ---------------------------
def build_prompt(conversation_history):
prompt = ""
for turn in conversation_history:
if turn["role"] == "user":
prompt += f"<|user|> {turn['content'].strip()}\n"
else:
prompt += f"<|assistant|> {turn['content'].strip()}\n"
prompt += "<|assistant|> "
return prompt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate_response(conversation_history):
# automatically set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
prompt_text = build_prompt(conversation_history)
input_ids = torch.tensor(
encoding.encode(prompt_text, allowed_special=set(special_tokens_dict.keys())),
dtype=torch.long,
device=device,
).unsqueeze(0)
if input_ids.size(1) > model.config.block_size:
input_ids = input_ids[:, -model.config.block_size:]
generated_ids = model.generate(
idx=input_ids,
max_new_tokens=100,
temperature=1.2,
top_k=50,
top_p=0.95,
)
new_tokens = generated_ids[0].tolist()[len(input_ids[0]):]
response_text = encoding.decode(new_tokens).strip()
return response_text
def chatbot_fn(user_input):
response = generate_response([{"role": "user", "content": user_input}])
return response
# Expose Gradio as an API instead of UI
iface = gr.Interface(fn=chatbot_fn, inputs="text", outputs="text")
# Enable API mode by setting `server_name="0.0.0.0"` and `serve=True`
#iface.launch(server_name="0.0.0.0", server_port=7860)
# The magic: mount Gradio onto the FastAPI app at "/"
app = gr.mount_gradio_app(app, iface, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)