radames's picture
Update app.py
7fb1c45
import os
import sys
from pathlib import Path
import torch
from PIL import Image
import torchvision.transforms as transforms
from networks.drn_seg import DRNSub, DRNSeg
from utils.tools import *
from utils.visualize import *
import gradio as gr
from huggingface_hub import hf_hub_download
def load_classifier(type, model_path, device=torch.device("cpu")):
if type == 'global':
model = DRNSub(1)
elif type == 'local':
model = DRNSeg(2)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
local_model_file = hf_hub_download(
repo_id="radames/FALdetector", filename="local.pth", token=True)
global_model_file = hf_hub_download(
repo_id="radames/FALdetector", filename="global.pth", token=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = load_classifier("global", global_model_file, device)
local_model = load_classifier("local", local_model_file, device)
faces_model_file = 'utils/dlib_face_detector/mmod_human_face_detector.dat'
tf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def predict(img_path):
im_w, im_h = Image.open(img_path).size
faces = face_detection(img_path, verbose=False,
model_file=faces_model_file)
if len(faces) == 0:
raise gr.Error("No face detected by dlib")
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
with torch.no_grad():
prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
h, w, _ = flow.shape
# Undoing the warps
modified = face.resize((w, h), Image.BICUBIC)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
reverse = Image.fromarray(reverse_np)
# heat map
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
cv_out = get_heatmap_cv(modified_np, flow_magn, max_flow_mag=7)
heat_map = Image.fromarray(cv_out)
print(prob)
return {"Probability FAL": prob}, modified, heat_map, reverse
with gr.Blocks() as blocks:
gr.Markdown("""
## Unofficial Demo
### Detecting Photoshopped Faces by Scripting Photoshop
#### FAL Detector Live Demo
* https://arxiv.org/abs/1906.05856
* https://peterwang512.github.io/FALdetector/
""")
with gr.Row():
with gr.Column():
in_image = gr.Image(label="Input Image", type="filepath")
run_btn = gr.Button(label="Run")
with gr.Column():
label = gr.Label(
label="Probability being modified by Photoshop FAL")
with gr.Row():
cropped = gr.Image(label="Cropped Input Image")
heatmap = gr.Image(label="Heatmap")
warped = gr.Image(label="Suggested Undo")
run_btn.click(fn=predict, inputs=[in_image], outputs=[
label, cropped, heatmap, warped])
gr.Examples(fn=predict,
examples=list(Path("./examples").glob("*.png")),
inputs=[in_image],
outputs=[label, cropped, heatmap, warped],
cache_examples=True)
blocks.launch()