arc-aid / app.py
offry's picture
app file
c08cf9f
raw
history blame
5.75 kB
import pickle
from operator import itemgetter
import cv2
import gradio as gr
import kornia.filters
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import zipfile
# from skimage.transform import resize
from torchvision import transforms, models
from get_models import Resnet_with_skip
def create_retrieval_figure(res):
fig = plt.figure(figsize=[10 * 3, 10 * 3])
cols = 5
rows = 2
ax_query = fig.add_subplot(rows, 1, 1)
plt.rcParams['figure.facecolor'] = 'white'
plt.axis('off')
ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
names = ""
for i, image in zip(range(len(res)), res):
current_image_path = image.split("/")[3]+"/"+image.split("/")[4]
if i==0: continue
if i < 11:
archive = zipfile.ZipFile('dataset.zip', 'r')
current_image_path = current_image_path.split(".")[0] + ".gif"
imgfile = archive.read(current_image_path)
image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
ax = fig.add_subplot(rows, cols, i)
plt.axis('off')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1]
ax.set_title('Top {}'.format(i), fontsize=40)
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
return fig, names
def knn_calc(image_name, query_feature, features):
current_image_feature = features[image_name]
criterion = torch.nn.CosineSimilarity(dim=1)
dist = criterion(query_feature, current_image_feature).mean()
dist = -dist.item()
return dist
checkpoint_path = "multi_label.pth.tar"
resnet = models.resnet101(pretrained=True)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 13)
model = Resnet_with_skip(resnet)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint)
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
periods_model = models.resnet101(pretrained=True)
periods_model.fc = nn.Linear(num_ftrs, 5)
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
periods_model.load_state_dict(periods_checkpoint)
with open('query_images_paths.pkl', 'rb') as fp:
query_images_paths = pickle.load(fp)
with open('features.pkl', 'rb') as fp:
features = pickle.load(fp)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
transforms.Normalize(mean=[-0.5, -0.5, -0.5],
std=[1., 1., 1.]),
])
labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
periods_model.eval()
def predict(inp):
image_tensor = transform(inp)
with torch.no_grad():
classification, reconstruction = model(image_tensor.unsqueeze(0))
periods_classification = periods_model(image_tensor.unsqueeze(0))
recon_tensor = reconstruction[0].repeat(3, 1, 1)
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
w, h = inp.size
# plot_recon = resize(plot_recon, (h, w))
m = nn.Sigmoid()
y = m(classification)
preds = []
for sample in y:
for i in sample:
if i >=0.8:
preds.append(1)
else:
preds.append(0)
confidences = {}
true_labels = ""
for i in range(len(labels)):
if preds[i]==1:
if true_labels=="":
true_labels = true_labels + labels[i]
else:
true_labels = true_labels + "&" + labels[i]
confidences[true_labels] = torch.tensor(1.0)
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
feature = embedding_model_test(image_tensor.unsqueeze(0))
dists = dict()
with torch.no_grad():
for i, image_name in enumerate(query_images_paths):
dist = knn_calc(image_name, feature, features)
dists[image_name] = dist
res = dict(sorted(dists.items(), key=itemgetter(1)))
fig, names = create_retrieval_figure(res)
return confidences, periods_confidences, plot_recon, fig, names
a = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
description="Easily classify artifacs, retrieve similar ones and generate drawings. "
"https://arxiv.org/abs/2211.09480.",
# examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
# 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
outputs=[gr.Label(num_top_classes=3), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)