Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse, RedirectResponse | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
import os | |
import io | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 as cv | |
from typing import List | |
# Create FastAPI app | |
app = FastAPI() | |
# Load the pre-trained SafetyCNN model | |
class SafetyCNN(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1) | |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1) | |
self.dropout = nn.Dropout(0.5) | |
self.fc1 = nn.Linear(157 * 157 * 24, 64) | |
self.fc2 = nn.Linear(64, 1) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = torch.flatten(x, 1) | |
x = F.relu(self.fc1(x)) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
return x | |
# Initialize model and load pre-trained weights | |
safety_model = SafetyCNN() | |
safety_model.load_state_dict(torch.load("safety_model.pth", map_location=torch.device('cpu'))) | |
safety_model.eval() | |
# Define transformations | |
tta_transforms = [ | |
transforms.Compose([]), | |
transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)]), | |
transforms.Compose([transforms.RandomRotation(degrees=30)]), | |
transforms.Compose([transforms.RandomResizedCrop(size=(640, 640), scale=(0.8, 1.0))]) | |
] | |
# Utility function to classify image with TTA | |
def classify_image_with_tta(image: Image.Image, model, tta_transforms: List[transforms.Compose], num_tta=4): | |
# List to accumulate predictions from TTA versions | |
augmented_predictions = [] | |
# Apply each TTA transformation to the image | |
for i in range(num_tta): | |
tta_transform = tta_transforms[i % len(tta_transforms)] | |
augmented_image = tta_transform(image) | |
# Preprocess the image for the model | |
transform = transforms.Compose([ | |
transforms.Resize((640, 640)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
input_tensor = transform(augmented_image).unsqueeze(0) # Add batch dimension (1, 3, 640, 640) | |
# Run the model on the augmented image | |
with torch.no_grad(): | |
output = model(input_tensor).squeeze(1) | |
# Apply sigmoid to get the probability | |
prob = torch.sigmoid(output).item() | |
augmented_predictions.append(prob) | |
# Average predictions over all TTA versions | |
avg_prob = np.mean(augmented_predictions) | |
# Set a threshold of 0.5 for binary classification | |
prediction = 1 if avg_prob > 0.5 else 0 | |
return prediction, avg_prob | |
async def docs(): | |
return RedirectResponse(url="/docs") | |
# FastAPI endpoint to upload an image and classify it | |
async def classify_image(file: UploadFile = File(...)): | |
try: | |
# Read the uploaded file | |
contents = await file.read() | |
image = Image.open(io.BytesIO(contents)).convert("RGB") | |
except Exception as e: | |
raise HTTPException(status_code=400, detail="Invalid image file") | |
# Classify the image using the model | |
prediction, avg_prob = classify_image_with_tta(image, safety_model, tta_transforms, num_tta=4) | |
# Create a response message | |
result = { | |
"prediction": "Unsafe" if prediction == 1 else "Safe", | |
"probability": avg_prob | |
} | |
return JSONResponse(content=result) | |
# Optional: Run with uvicorn if needed | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
# To run the server: | |
# uvicorn safety_classifier:app --reload | |