Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import io | |
import numpy as np | |
from typing import List, Dict, Any | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="ChatGPT Oasis Model Inference API", | |
description="FastAPI inference server for Oasis and ViT models", | |
version="1.0.0" | |
) | |
# Global variables to store loaded models | |
oasis_model = None | |
oasis_processor = None | |
vit_model = None | |
vit_processor = None | |
class InferenceRequest(BaseModel): | |
image: str # Base64 encoded image | |
model_name: str = "oasis500m" # Default to oasis model | |
class InferenceResponse(BaseModel): | |
predictions: List[Dict[str, Any]] | |
model_used: str | |
confidence_scores: List[float] | |
def load_models(): | |
"""Load both models into memory""" | |
global oasis_model, oasis_processor, vit_model, vit_processor | |
try: | |
logger.info("Loading Oasis 500M model...") | |
# Load Oasis model | |
oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m") | |
oasis_model = AutoModelForImageClassification.from_pretrained("microsoft/oasis-500m") | |
oasis_model.eval() | |
logger.info("Loading ViT-L-20 model...") | |
# Load ViT model | |
vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224") | |
vit_model = AutoModelForImageClassification.from_pretrained("google/vit-large-patch16-224") | |
vit_model.eval() | |
logger.info("All models loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise e | |
async def startup_event(): | |
"""Load models when the application starts""" | |
load_models() | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"message": "ChatGPT Oasis Model Inference API", | |
"version": "1.0.0", | |
"available_models": ["oasis500m", "vit-l-20"], | |
"endpoints": { | |
"health": "/health", | |
"inference": "/inference", | |
"upload_inference": "/upload_inference" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
models_status = { | |
"oasis500m": oasis_model is not None, | |
"vit-l-20": vit_model is not None | |
} | |
return { | |
"status": "healthy", | |
"models_loaded": models_status | |
} | |
def process_image_with_model(image: Image.Image, model_name: str): | |
"""Process image with the specified model""" | |
if model_name == "oasis500m": | |
if oasis_model is None or oasis_processor is None: | |
raise HTTPException(status_code=500, detail="Oasis model not loaded") | |
inputs = oasis_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = oasis_model(**inputs) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim=-1) | |
# Get top predictions | |
top_probs, top_indices = torch.topk(probabilities, 5) | |
predictions = [] | |
for i in range(top_indices.shape[1]): | |
pred = { | |
"label": oasis_model.config.id2label[top_indices[0][i].item()], | |
"confidence": top_probs[0][i].item() | |
} | |
predictions.append(pred) | |
return predictions | |
elif model_name == "vit-l-20": | |
if vit_model is None or vit_processor is None: | |
raise HTTPException(status_code=500, detail="ViT model not loaded") | |
inputs = vit_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vit_model(**inputs) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim=-1) | |
# Get top predictions | |
top_probs, top_indices = torch.topk(probabilities, 5) | |
predictions = [] | |
for i in range(top_indices.shape[1]): | |
pred = { | |
"label": vit_model.config.id2label[top_indices[0][i].item()], | |
"confidence": top_probs[0][i].item() | |
} | |
predictions.append(pred) | |
return predictions | |
else: | |
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}") | |
async def inference(request: InferenceRequest): | |
"""Inference endpoint using base64 encoded image""" | |
try: | |
import base64 | |
# Decode base64 image | |
image_data = base64.b64decode(request.image) | |
image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
# Process with model | |
predictions = process_image_with_model(image, request.model_name) | |
# Extract confidence scores | |
confidence_scores = [pred["confidence"] for pred in predictions] | |
return InferenceResponse( | |
predictions=predictions, | |
model_used=request.model_name, | |
confidence_scores=confidence_scores | |
) | |
except Exception as e: | |
logger.error(f"Inference error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def upload_inference( | |
file: UploadFile = File(...), | |
model_name: str = "oasis500m" | |
): | |
"""Inference endpoint using file upload""" | |
try: | |
# Validate file type | |
if not file.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="File must be an image") | |
# Read and process image | |
image_data = await file.read() | |
image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
# Process with model | |
predictions = process_image_with_model(image, model_name) | |
# Extract confidence scores | |
confidence_scores = [pred["confidence"] for pred in predictions] | |
return InferenceResponse( | |
predictions=predictions, | |
model_used=model_name, | |
confidence_scores=confidence_scores | |
) | |
except Exception as e: | |
logger.error(f"Upload inference error: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_models(): | |
"""List available models and their status""" | |
return { | |
"available_models": [ | |
{ | |
"name": "oasis500m", | |
"description": "Oasis 500M vision model", | |
"loaded": oasis_model is not None | |
}, | |
{ | |
"name": "vit-l-20", | |
"description": "Vision Transformer Large model", | |
"loaded": vit_model is not None | |
} | |
] | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |