from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel import torch import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import io import numpy as np from typing import List, Dict, Any import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="ChatGPT Oasis Model Inference API", description="FastAPI inference server for Oasis and ViT models", version="1.0.0" ) # Global variables to store loaded models oasis_model = None oasis_processor = None vit_model = None vit_processor = None class InferenceRequest(BaseModel): image: str # Base64 encoded image model_name: str = "oasis500m" # Default to oasis model class InferenceResponse(BaseModel): predictions: List[Dict[str, Any]] model_used: str confidence_scores: List[float] def load_models(): """Load both models into memory""" global oasis_model, oasis_processor, vit_model, vit_processor try: logger.info("Loading Oasis 500M model...") # Load Oasis model oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m") oasis_model = AutoModelForImageClassification.from_pretrained("microsoft/oasis-500m") oasis_model.eval() logger.info("Loading ViT-L-20 model...") # Load ViT model vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224") vit_model = AutoModelForImageClassification.from_pretrained("google/vit-large-patch16-224") vit_model.eval() logger.info("All models loaded successfully!") except Exception as e: logger.error(f"Error loading models: {e}") raise e @app.on_event("startup") async def startup_event(): """Load models when the application starts""" load_models() @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "ChatGPT Oasis Model Inference API", "version": "1.0.0", "available_models": ["oasis500m", "vit-l-20"], "endpoints": { "health": "/health", "inference": "/inference", "upload_inference": "/upload_inference" } } @app.get("/health") async def health_check(): """Health check endpoint""" models_status = { "oasis500m": oasis_model is not None, "vit-l-20": vit_model is not None } return { "status": "healthy", "models_loaded": models_status } def process_image_with_model(image: Image.Image, model_name: str): """Process image with the specified model""" if model_name == "oasis500m": if oasis_model is None or oasis_processor is None: raise HTTPException(status_code=500, detail="Oasis model not loaded") inputs = oasis_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = oasis_model(**inputs) logits = outputs.logits probabilities = F.softmax(logits, dim=-1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, 5) predictions = [] for i in range(top_indices.shape[1]): pred = { "label": oasis_model.config.id2label[top_indices[0][i].item()], "confidence": top_probs[0][i].item() } predictions.append(pred) return predictions elif model_name == "vit-l-20": if vit_model is None or vit_processor is None: raise HTTPException(status_code=500, detail="ViT model not loaded") inputs = vit_processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = vit_model(**inputs) logits = outputs.logits probabilities = F.softmax(logits, dim=-1) # Get top predictions top_probs, top_indices = torch.topk(probabilities, 5) predictions = [] for i in range(top_indices.shape[1]): pred = { "label": vit_model.config.id2label[top_indices[0][i].item()], "confidence": top_probs[0][i].item() } predictions.append(pred) return predictions else: raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}") @app.post("/inference", response_model=InferenceResponse) async def inference(request: InferenceRequest): """Inference endpoint using base64 encoded image""" try: import base64 # Decode base64 image image_data = base64.b64decode(request.image) image = Image.open(io.BytesIO(image_data)).convert('RGB') # Process with model predictions = process_image_with_model(image, request.model_name) # Extract confidence scores confidence_scores = [pred["confidence"] for pred in predictions] return InferenceResponse( predictions=predictions, model_used=request.model_name, confidence_scores=confidence_scores ) except Exception as e: logger.error(f"Inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/upload_inference", response_model=InferenceResponse) async def upload_inference( file: UploadFile = File(...), model_name: str = "oasis500m" ): """Inference endpoint using file upload""" try: # Validate file type if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read and process image image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert('RGB') # Process with model predictions = process_image_with_model(image, model_name) # Extract confidence scores confidence_scores = [pred["confidence"] for pred in predictions] return InferenceResponse( predictions=predictions, model_used=model_name, confidence_scores=confidence_scores ) except Exception as e: logger.error(f"Upload inference error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/models") async def list_models(): """List available models and their status""" return { "available_models": [ { "name": "oasis500m", "description": "Oasis 500M vision model", "loaded": oasis_model is not None }, { "name": "vit-l-20", "description": "Vision Transformer Large model", "loaded": vit_model is not None } ] } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)