import os import time import uuid from typing import List, Optional, Dict, Any import torch from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = os.getenv("MODEL_ID", "LiquidAI/LFM2-1.2B") DEFAULT_MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256")) app = FastAPI(title="OpenAI-compatible API for LiquidAI/LFM2-1.2B") tokenizer = None model = None def get_dtype() -> torch.dtype: if torch.cuda.is_available(): # Prefer bfloat16 if supported; else float16 if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 # CPU return torch.float32 @app.on_event("startup") def load_model(): global tokenizer, model dtype = get_dtype() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=dtype, device_map="auto", trust_remote_code=True, ) # Ensure eos/bos tokens exist if tokenizer.eos_token is None: tokenizer.eos_token = tokenizer.sep_token or tokenizer.pad_token or "" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token class ChatMessage(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: Optional[str] = Field(default=MODEL_ID) messages: List[ChatMessage] temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 max_tokens: Optional[int] = None stop: Optional[List[str] | str] = None n: Optional[int] = 1 class CompletionRequest(BaseModel): model: Optional[str] = Field(default=MODEL_ID) prompt: str | List[str] temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.95 max_tokens: Optional[int] = None stop: Optional[List[str] | str] = None n: Optional[int] = 1 class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int # Simple chat prompt formatter def build_chat_prompt(messages: List[ChatMessage]) -> str: system_prefix = "You are a helpful assistant." system_msgs = [m.content for m in messages if m.role == "system"] if system_msgs: system_prefix = system_msgs[-1] conv: List[str] = [f"System: {system_prefix}"] for m in messages: if m.role == "system": continue role = "User" if m.role == "user" else ("Assistant" if m.role == "assistant" else m.role.capitalize()) conv.append(f"{role}: {m.content}") conv.append("Assistant:") return "\n".join(conv) def apply_stop_sequences(text: str, stop: Optional[List[str] | str]) -> str: if stop is None: return text stops = stop if isinstance(stop, list) else [stop] cut = len(text) for s in stops: if not s: continue idx = text.find(s) if idx != -1: cut = min(cut, idx) return text[:cut] def generate_once(prompt: str, temperature: float, top_p: float, max_new_tokens: int) -> Dict[str, Any]: assert tokenizer is not None and model is not None, "Model not loaded" inputs = tokenizer(prompt, return_tensors="pt").to(model.device) gen_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True if temperature and temperature > 0 else False, temperature=max(0.0, float(temperature or 0.0)), top_p=max(0.0, float(top_p or 1.0)), pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) out = tokenizer.decode(gen_ids[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) return { "text": out, "prompt_tokens": inputs["input_ids"].numel(), "completion_tokens": gen_ids[0].shape[0] - inputs["input_ids"].shape[-1], } @app.get("/") def root(): return RedirectResponse(url="/docs") @app.get("/health") def health(): return {"status": "ok", "model": MODEL_ID} @app.post("/v1/chat/completions") def chat_completions(req: ChatCompletionRequest): if req.n and req.n > 1: raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.") max_new = req.max_tokens or DEFAULT_MAX_TOKENS prompt = build_chat_prompt(req.messages) g = generate_once(prompt, req.temperature or 0.7, req.top_p or 0.95, max_new) text = apply_stop_sequences(g["text"], req.stop) created = int(time.time()) comp_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" usage = Usage( prompt_tokens=g["prompt_tokens"], completion_tokens=g["completion_tokens"], total_tokens=g["prompt_tokens"] + g["completion_tokens"], ) return { "id": comp_id, "object": "chat.completion", "created": created, "model": req.model or MODEL_ID, "choices": [ { "index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop", } ], "usage": usage.dict(), } @app.post("/v1/completions") def completions(req: CompletionRequest): if req.n and req.n > 1: raise HTTPException(status_code=400, detail="Only n=1 is supported in this simple server.") prompts = req.prompt if isinstance(req.prompt, list) else [req.prompt] if len(prompts) != 1: raise HTTPException(status_code=400, detail="Only a single prompt is supported in this simple server.") max_new = req.max_tokens or DEFAULT_MAX_TOKENS g = generate_once(prompts[0], req.temperature or 0.7, req.top_p or 0.95, max_new) text = apply_stop_sequences(g["text"], req.stop) created = int(time.time()) comp_id = f"cmpl-{uuid.uuid4().hex[:24]}" usage = Usage( prompt_tokens=g["prompt_tokens"], completion_tokens=g["completion_tokens"], total_tokens=g["prompt_tokens"] + g["completion_tokens"], ) return { "id": comp_id, "object": "text_completion", "created": created, "model": req.model or MODEL_ID, "choices": [ { "index": 0, "text": text, "finish_reason": "stop", "logprobs": None, } ], "usage": usage.dict(), } if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "7860")) uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)