safety_classifier / safety_classifier.py
riu-rd's picture
Upload 8 files
68c0ad7 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import io
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from typing import List
# Create FastAPI app
app = FastAPI()
# Load the pre-trained SafetyCNN model
class SafetyCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(157 * 157 * 24, 64)
self.fc2 = nn.Linear(64, 1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# Initialize model and load pre-trained weights
safety_model = SafetyCNN()
safety_model.load_state_dict(torch.load("safety_model.pth", map_location=torch.device('cpu')))
safety_model.eval()
# Define transformations
tta_transforms = [
transforms.Compose([]),
transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)]),
transforms.Compose([transforms.RandomRotation(degrees=30)]),
transforms.Compose([transforms.RandomResizedCrop(size=(640, 640), scale=(0.8, 1.0))])
]
# Utility function to classify image with TTA
def classify_image_with_tta(image: Image.Image, model, tta_transforms: List[transforms.Compose], num_tta=4):
# List to accumulate predictions from TTA versions
augmented_predictions = []
# Apply each TTA transformation to the image
for i in range(num_tta):
tta_transform = tta_transforms[i % len(tta_transforms)]
augmented_image = tta_transform(image)
# Preprocess the image for the model
transform = transforms.Compose([
transforms.Resize((640, 640)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(augmented_image).unsqueeze(0) # Add batch dimension (1, 3, 640, 640)
# Run the model on the augmented image
with torch.no_grad():
output = model(input_tensor).squeeze(1)
# Apply sigmoid to get the probability
prob = torch.sigmoid(output).item()
augmented_predictions.append(prob)
# Average predictions over all TTA versions
avg_prob = np.mean(augmented_predictions)
# Set a threshold of 0.5 for binary classification
prediction = 1 if avg_prob > 0.5 else 0
return prediction, avg_prob
@app.get("/")
async def docs():
return RedirectResponse(url="/docs")
# FastAPI endpoint to upload an image and classify it
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
try:
# Read the uploaded file
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid image file")
# Classify the image using the model
prediction, avg_prob = classify_image_with_tta(image, safety_model, tta_transforms, num_tta=4)
# Create a response message
result = {
"prediction": "Unsafe" if prediction == 1 else "Safe",
"probability": avg_prob
}
return JSONResponse(content=result)
# Optional: Run with uvicorn if needed
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
# To run the server:
# uvicorn safety_classifier:app --reload