import matplotlib.pyplot as plt import numpy as np import streamlit as st import torch from huggingface_hub import PyTorchModelHubMixin from PIL import Image from torchvision import transforms from torchvision.transforms.functional import to_pil_image from model import ICN device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def mask_processing(x): if x > 90: return 140 elif x < 80: return 0 else: return 255 def grid_to_heatmap(grid, size=1024): mask = to_pil_image(grid.view(7, 7)) mask = mask.resize((size, size), Image.BICUBIC) mask = Image.eval(mask, mask_processing) colormap = plt.get_cmap("Wistia") heatmap = np.array(colormap(mask)) heatmap = (heatmap * 255).astype(np.uint8) heatmap = Image.fromarray(heatmap) return heatmap, mask def summary_image(img, fake, prediction): prediction -= prediction.min() prediction = prediction / prediction.max() size = 1024 img1 = img.resize((size, size)) img2 = fake.resize((size, size)) heatmap, mask = grid_to_heatmap(prediction) img1.paste(heatmap, (0, 0), mask) img2.paste(heatmap, (0, 0), mask) return img1, img2 @st.cache_resource def load_model(): model = torch.jit.load("traced_model.pt") model.eval().to(device) return model model = ICN.from_pretrained("AlexBlck/image-comparator").eval().to(device) # model = load_model() st.title("Image Comparator Network") st.write("## Upload a pair of images") cols = st.columns(2) with cols[0]: im1 = st.file_uploader("Image 1", type=["jpg", "png"]) with cols[1]: im2 = st.file_uploader("Image 2", type=["jpg", "png"]) if not (im1 and im2): st.stop() btn = st.button("Run") if not btn: st.stop() im1 = Image.open(im1).convert("RGB") im2 = Image.open(im2).convert("RGB") tr = transforms.Compose( [ transforms.Resize(size=(224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) img = torch.vstack((tr(im1), tr(im2))).unsqueeze(0) heatmap, cl = model(img.to(device)) confs = torch.softmax(cl, dim=1) pred = torch.argmax(confs, dim=1).item() if pred == 0: st.success("No Manipulation Detected") heatmap *= 0 elif pred == 1: st.warning("Manipulation Detected!") else: st.error("Images are not related.") heatmap *= 0 img1, img2 = summary_image(im1, im2, heatmap[0]) cols = st.columns(2) with cols[0]: st.image(img1) with cols[1]: st.image(img2)