File size: 3,891 Bytes
68c0ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import io
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
from typing import List

# Create FastAPI app
app = FastAPI()

# Load the pre-trained SafetyCNN model
class SafetyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(157 * 157 * 24, 64)
        self.fc2 = nn.Linear(64, 1)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Initialize model and load pre-trained weights
safety_model = SafetyCNN()
safety_model.load_state_dict(torch.load("safety_model.pth", map_location=torch.device('cpu')))
safety_model.eval()

# Define transformations
tta_transforms = [
    transforms.Compose([]),
    transforms.Compose([transforms.RandomHorizontalFlip(p=1.0)]),
    transforms.Compose([transforms.RandomRotation(degrees=30)]),
    transforms.Compose([transforms.RandomResizedCrop(size=(640, 640), scale=(0.8, 1.0))])
]

# Utility function to classify image with TTA
def classify_image_with_tta(image: Image.Image, model, tta_transforms: List[transforms.Compose], num_tta=4):
    # List to accumulate predictions from TTA versions
    augmented_predictions = []

    # Apply each TTA transformation to the image
    for i in range(num_tta):
        tta_transform = tta_transforms[i % len(tta_transforms)]
        augmented_image = tta_transform(image)
        
        # Preprocess the image for the model
        transform = transforms.Compose([
            transforms.Resize((640, 640)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        input_tensor = transform(augmented_image).unsqueeze(0)  # Add batch dimension (1, 3, 640, 640)

        # Run the model on the augmented image
        with torch.no_grad():
            output = model(input_tensor).squeeze(1)

        # Apply sigmoid to get the probability
        prob = torch.sigmoid(output).item()
        augmented_predictions.append(prob)

    # Average predictions over all TTA versions
    avg_prob = np.mean(augmented_predictions)

    # Set a threshold of 0.5 for binary classification
    prediction = 1 if avg_prob > 0.5 else 0

    return prediction, avg_prob

@app.get("/")
async def docs():
    return RedirectResponse(url="/docs")

# FastAPI endpoint to upload an image and classify it
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
    try:
        # Read the uploaded file
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail="Invalid image file")
    
    # Classify the image using the model
    prediction, avg_prob = classify_image_with_tta(image, safety_model, tta_transforms, num_tta=4)
    
    # Create a response message
    result = {
        "prediction": "Unsafe" if prediction == 1 else "Safe",
        "probability": avg_prob
    }
    return JSONResponse(content=result)

# Optional: Run with uvicorn if needed
# if __name__ == "__main__":
#     import uvicorn
#     uvicorn.run(app, host="0.0.0.0", port=8000)

# To run the server:
# uvicorn safety_classifier:app --reload