chatgpt-oasis / main.py
parthraninga's picture
Upload 10 files
95efa57 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import io
import numpy as np
from typing import List, Dict, Any
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="ChatGPT Oasis Model Inference API",
description="FastAPI inference server for Oasis and ViT models",
version="1.0.0"
)
# Global variables to store loaded models
oasis_model = None
oasis_processor = None
vit_model = None
vit_processor = None
class InferenceRequest(BaseModel):
image: str # Base64 encoded image
model_name: str = "oasis500m" # Default to oasis model
class InferenceResponse(BaseModel):
predictions: List[Dict[str, Any]]
model_used: str
confidence_scores: List[float]
def load_models():
"""Load both models into memory"""
global oasis_model, oasis_processor, vit_model, vit_processor
try:
logger.info("Loading Oasis 500M model...")
# Load Oasis model
oasis_processor = AutoImageProcessor.from_pretrained("microsoft/oasis-500m")
oasis_model = AutoModelForImageClassification.from_pretrained("microsoft/oasis-500m")
oasis_model.eval()
logger.info("Loading ViT-L-20 model...")
# Load ViT model
vit_processor = AutoImageProcessor.from_pretrained("google/vit-large-patch16-224")
vit_model = AutoModelForImageClassification.from_pretrained("google/vit-large-patch16-224")
vit_model.eval()
logger.info("All models loaded successfully!")
except Exception as e:
logger.error(f"Error loading models: {e}")
raise e
@app.on_event("startup")
async def startup_event():
"""Load models when the application starts"""
load_models()
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "ChatGPT Oasis Model Inference API",
"version": "1.0.0",
"available_models": ["oasis500m", "vit-l-20"],
"endpoints": {
"health": "/health",
"inference": "/inference",
"upload_inference": "/upload_inference"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
models_status = {
"oasis500m": oasis_model is not None,
"vit-l-20": vit_model is not None
}
return {
"status": "healthy",
"models_loaded": models_status
}
def process_image_with_model(image: Image.Image, model_name: str):
"""Process image with the specified model"""
if model_name == "oasis500m":
if oasis_model is None or oasis_processor is None:
raise HTTPException(status_code=500, detail="Oasis model not loaded")
inputs = oasis_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = oasis_model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1)
# Get top predictions
top_probs, top_indices = torch.topk(probabilities, 5)
predictions = []
for i in range(top_indices.shape[1]):
pred = {
"label": oasis_model.config.id2label[top_indices[0][i].item()],
"confidence": top_probs[0][i].item()
}
predictions.append(pred)
return predictions
elif model_name == "vit-l-20":
if vit_model is None or vit_processor is None:
raise HTTPException(status_code=500, detail="ViT model not loaded")
inputs = vit_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = vit_model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1)
# Get top predictions
top_probs, top_indices = torch.topk(probabilities, 5)
predictions = []
for i in range(top_indices.shape[1]):
pred = {
"label": vit_model.config.id2label[top_indices[0][i].item()],
"confidence": top_probs[0][i].item()
}
predictions.append(pred)
return predictions
else:
raise HTTPException(status_code=400, detail=f"Unknown model: {model_name}")
@app.post("/inference", response_model=InferenceResponse)
async def inference(request: InferenceRequest):
"""Inference endpoint using base64 encoded image"""
try:
import base64
# Decode base64 image
image_data = base64.b64decode(request.image)
image = Image.open(io.BytesIO(image_data)).convert('RGB')
# Process with model
predictions = process_image_with_model(image, request.model_name)
# Extract confidence scores
confidence_scores = [pred["confidence"] for pred in predictions]
return InferenceResponse(
predictions=predictions,
model_used=request.model_name,
confidence_scores=confidence_scores
)
except Exception as e:
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/upload_inference", response_model=InferenceResponse)
async def upload_inference(
file: UploadFile = File(...),
model_name: str = "oasis500m"
):
"""Inference endpoint using file upload"""
try:
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read and process image
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert('RGB')
# Process with model
predictions = process_image_with_model(image, model_name)
# Extract confidence scores
confidence_scores = [pred["confidence"] for pred in predictions]
return InferenceResponse(
predictions=predictions,
model_used=model_name,
confidence_scores=confidence_scores
)
except Exception as e:
logger.error(f"Upload inference error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models():
"""List available models and their status"""
return {
"available_models": [
{
"name": "oasis500m",
"description": "Oasis 500M vision model",
"loaded": oasis_model is not None
},
{
"name": "vit-l-20",
"description": "Vision Transformer Large model",
"loaded": vit_model is not None
}
]
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)