Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Extra | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import time | |
| import uuid | |
| import json | |
| from typing import Optional, List, Union, Dict, Any | |
| # --- Configuration --- | |
| MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" | |
| DEVICE = "cpu" | |
| # --- Chargement du modèle --- | |
| print(f"Début du chargement du modèle : {MODEL_ID}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map=DEVICE | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| print("Modèle et tokenizer chargés avec succès sur le CPU.") | |
| # --- Création de l'application API --- | |
| app = FastAPI() | |
| # --- Modèles de données pour accepter la structure complexe de l'extension --- | |
| class ContentPart(BaseModel): | |
| type: str | |
| text: str | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: Union[str, List[ContentPart]] | |
| class ChatCompletionRequest(BaseModel): | |
| model: Optional[str] = None | |
| messages: List[ChatMessage] | |
| stream: Optional[bool] = False | |
| class Config: | |
| extra = Extra.ignore | |
| class ModelData(BaseModel): | |
| id: str | |
| object: str = "model" | |
| owned_by: str = "user" | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[ModelData] | |
| # --- Définition des API --- | |
| async def list_models(): | |
| """Répond à la requête GET /models pour satisfaire l'extension.""" | |
| return ModelList(data=[ModelData(id=MODEL_ID)]) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| """Endpoint principal qui gère la génération de texte en streaming.""" | |
| # On extrait le prompt de l'utilisateur de la structure complexe | |
| user_prompt = "" | |
| last_message = request.messages[-1] | |
| if isinstance(last_message.content, list): | |
| for part in last_message.content: | |
| if part.type == 'text': | |
| user_prompt += part.text + "\n" | |
| elif isinstance(last_message.content, str): | |
| user_prompt = last_message.content | |
| if not user_prompt: | |
| return {"error": "Prompt non trouvé."} | |
| # Préparation pour le modèle DeepSeek | |
| messages_for_model = [{'role': 'user', 'content': user_prompt}] | |
| inputs = tokenizer.apply_chat_template(messages_for_model, add_generation_prompt=True, return_tensors="pt").to(DEVICE) | |
| # Génération de la réponse complète | |
| outputs = model.generate(inputs, max_new_tokens=250, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) | |
| response_text = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) | |
| # Fonction génératrice pour le streaming | |
| async def stream_generator(): | |
| response_id = f"chatcmpl-{uuid.uuid4()}" | |
| # On envoie la réponse caractère par caractère, au format attendu | |
| for char in response_text: | |
| chunk = { | |
| "id": response_id, | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {"content": char}, | |
| "finish_reason": None | |
| }] | |
| } | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| await asyncio.sleep(0.01) # Petite pause pour simuler un flux | |
| # On envoie le chunk final de fin | |
| final_chunk = { | |
| "id": response_id, | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": MODEL_ID, | |
| "choices": [{ | |
| "index": 0, | |
| "delta": {}, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| yield f"data: {json.dumps(final_chunk)}\n\n" | |
| # On envoie le signal [DONE] | |
| yield "data: [DONE]\n\n" | |
| # Si l'extension demande un stream, on renvoie le générateur | |
| if request.stream: | |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") | |
| else: | |
| # Code de secours si le stream n'est pas demandé (peu probable) | |
| return {"choices": [{"message": {"role": "assistant", "content": response_text}}]} | |
| # ... (tout votre code existant reste inchangé) ... | |
| # Fonction génératrice pour le streaming | |
| async def stream_generator(): | |
| # ... (le contenu de cette fonction ne change pas) ... | |
| # Si l'extension demande un stream, on renvoie le générateur | |
| if request.stream: | |
| # ... (cette partie ne change pas) ... | |
| # =============================================================== | |
| # AJOUTEZ LE CODE CI-DESSOUS | |
| # =============================================================== | |
| async def spend_calculate(): | |
| """ | |
| Endpoint factice pour satisfaire le client qui essaie de calculer les coûts. | |
| Ne fait rien et renvoie une réponse de succès vide. | |
| """ | |
| return {} # Renvoie un JSON vide avec un statut 200 OK par défaut | |
| # =============================================================== | |
| # FIN DE L'AJOUT | |
| # =============================================================== | |
| def root(): | |
| return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} | |
| # On a besoin de asyncio pour la pause dans le stream | |
| import asyncio | |
| def root(): | |
| return {"status": "API compatible OpenAI en ligne (avec streaming)", "model_id": MODEL_ID} | |
| # On a besoin de asyncio pour la pause dans le stream | |
| import asyncio |