Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image | |
| from fastapi import FastAPI, HTTPException,Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from pydantic import BaseModel | |
| import base64 | |
| import io | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float32 | |
| processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") | |
| model = AutoModelForVision2Seq.from_pretrained( | |
| "HuggingFaceTB/SmolVLM-500M-Instruct", | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
| _attn_implementation="flash_attention_2" if DEVICE == "cuda" else "eager", | |
| ).to(DEVICE) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class PredictRequest(BaseModel): | |
| instruction: str | |
| imageBase64URL: str | |
| async def predict(request: PredictRequest): | |
| try: | |
| header, base64_string = request.imageBase64URL.split(',', 1) | |
| image_bytes = base64.b64decode(base64_string) | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": request.instruction} | |
| ] | |
| }, | |
| ] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=[image], return_tensors="pt").to(DEVICE) | |
| generated_ids = model.generate(**inputs, max_new_tokens=500) | |
| generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| response_text = generated_texts[0] | |
| return {"response": response_text} | |
| except Exception as e: | |
| print(f"Error durante la predicción: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}") | |
| async def read_root(request: Request): | |
| current_path = request.url.path | |
| print(f"Received GET request at path: {current_path}") | |
| return {"message": "SmolVLM-500M API is running!"} |