image-detector / app.py
truens66's picture
Create app.py
a518ecb verified
raw
history blame
2.93 kB
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
import os
import numpy as np
import random
from resnet import resnet50
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
seed_torch(100)
def load_model(model_path):
model = resnet50(num_classes=1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict, strict=True)
if torch.cuda.is_available():
model.cuda()
model.eval()
return model
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image).unsqueeze(0)
return image
def predict_image(model, image):
if torch.cuda.is_available():
image = image.cuda()
with torch.no_grad():
output = model(image)
# Apply sigmoid to get probability between 0 and 1
prediction = torch.sigmoid(output).item()
# Clamp prediction between 0 and 1
prediction = max(0, min(prediction, 1))
# Convert to percentages
real_prob = round(prediction * 1, 2) # Rounded to 2 decimal places
fake_prob = round(1 - real_prob, 2) # Complementary probability
return real_prob, fake_prob
# def predict_image(model, image):
# if torch.cuda.is_available():
# image = image.cuda()
# with torch.no_grad():
# output = model(image)
# prediction = torch.sigmoid(output).item()
# real_prob = gr.number(min(max(prediction * 100, 0), 100)) # Convert to integer
# fake_prob = int(100 - real_prob) # Ensure complementary probability
# return real_prob, fake_prob
# Load the model once at the start
model_path = "model_epoch_last_3090.pth" # Update with the correct path to your model
model = load_model(model_path)
def detect_deepfake(image):
image = Image.fromarray(image).convert("RGB")
preprocessed_image = preprocess_image(image)
real_prob, fake_prob = predict_image(model, preprocessed_image)
print("real_prob", real_prob)
print("fake_prob", fake_prob)
return {"Real Confidence": real_prob, "Fake Confidence": fake_prob}
iface = gr.Interface(
fn=detect_deepfake,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Confidence Scores"),
title="Deepfake Detection",
description="Upload an image to determine its confidence scores for being real or fake."
)
if __name__ == "__main__":
iface.launch()