Spaces:
Running
Running
File size: 3,891 Bytes
68c0ad7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import io
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from typing import List
# Create FastAPI app
app = FastAPI()
# Load the pre-trained SafetyCNN model
class SafetyCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1)
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(157 * 157 * 24, 64)
self.fc2 = nn.Linear(64, 1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# Initialize model and load pre-trained weights
safety_model = SafetyCNN()
safety_model.load_state_dict(torch.load("safety_model.pth", map_location=torch.device('cpu')))
safety_model.eval()
# Define transformations
tta_transforms = [
transforms.Compose([]),
transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)]),
transforms.Compose([transforms.RandomRotation(degrees=30)]),
transforms.Compose([transforms.RandomResizedCrop(size=(640, 640), scale=(0.8, 1.0))])
]
# Utility function to classify image with TTA
def classify_image_with_tta(image: Image.Image, model, tta_transforms: List[transforms.Compose], num_tta=4):
# List to accumulate predictions from TTA versions
augmented_predictions = []
# Apply each TTA transformation to the image
for i in range(num_tta):
tta_transform = tta_transforms[i % len(tta_transforms)]
augmented_image = tta_transform(image)
# Preprocess the image for the model
transform = transforms.Compose([
transforms.Resize((640, 640)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(augmented_image).unsqueeze(0) # Add batch dimension (1, 3, 640, 640)
# Run the model on the augmented image
with torch.no_grad():
output = model(input_tensor).squeeze(1)
# Apply sigmoid to get the probability
prob = torch.sigmoid(output).item()
augmented_predictions.append(prob)
# Average predictions over all TTA versions
avg_prob = np.mean(augmented_predictions)
# Set a threshold of 0.5 for binary classification
prediction = 1 if avg_prob > 0.5 else 0
return prediction, avg_prob
@app.get("/")
async def docs():
return RedirectResponse(url="/docs")
# FastAPI endpoint to upload an image and classify it
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
try:
# Read the uploaded file
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid image file")
# Classify the image using the model
prediction, avg_prob = classify_image_with_tta(image, safety_model, tta_transforms, num_tta=4)
# Create a response message
result = {
"prediction": "Unsafe" if prediction == 1 else "Safe",
"probability": avg_prob
}
return JSONResponse(content=result)
# Optional: Run with uvicorn if needed
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
# To run the server:
# uvicorn safety_classifier:app --reload
|