smallagent / app.py
EnzGamers's picture
Update app.py
fb7cb35 verified
raw
history blame
2.77 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
# --- Configuration ---
MODEL_ID = "deepseek-ai/deepseek-coder-1.3b-instruct"
DEVICE = "cpu"
# --- Chargement du modèle et du tokenizer ---
print(f"Début du chargement du modèle : {MODEL_ID}")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=DEVICE
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Modèle et tokenizer chargés avec succès sur le CPU.")
# --- Création de l'application API ---
app = FastAPI()
# --- Modèles de données pour la compatibilité OpenAI ---
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: list[ChatMessage]
max_tokens: int = 250
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: list[ChatCompletionResponseChoice]
# --- Définition de l'API compatible OpenAI ---
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""
Endpoint compatible avec l'API OpenAI Chat Completions.
"""
# Extraire le dernier message utilisateur pour le prompt
user_prompt = ""
if request.messages and request.messages[-1].role == "user":
user_prompt = request.messages[-1].content
if not user_prompt:
return {"error": "No user prompt found"}
# Préparation des inputs pour le modèle DeepSeek
messages_for_model = [{'role': 'user', 'content': user_prompt}]
inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
# Génération
outputs = model.generate(inputs, max_new_tokens=request.max_tokens, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
# Décodage
response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
# Formatage de la réponse au format OpenAI
response_message = ChatMessage(role="assistant", content=response_text)
choice = ChatCompletionResponseChoice(message=response_message)
completion_response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4()}",
created=int(time.time()),
model=request.model,
choices=[choice]
)
return completion_response
@app.get("/")
def root():
return {"status": "API compatible OpenAI en ligne", "model_id": MODEL_ID}