arc-aid / app.py
offry's picture
Update app.py
7c5a026 verified
import pickle
from operator import itemgetter
import cv2
import gradio as gr
import kornia.filters
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import zipfile
# from skimage.transform import resize
from torchvision import transforms, models
from get_models import Resnet_with_skip
def create_retrieval_figure(res):
fig = plt.figure(figsize=[10 * 3, 10 * 3])
cols = 5
rows = 2
ax_query = fig.add_subplot(rows, 1, 1)
plt.rcParams['figure.facecolor'] = 'white'
plt.axis('off')
ax_query.set_title('Top 10 most similar items', fontsize=40)
names = ""
for i, image in zip(range(len(res)), res):
if i >= 10:
break
current_image_path = "dataset/" + image.split("/")[3] + "/" + image.split("/")[4]
archive = zipfile.ZipFile('dataset.zip', 'r')
try:
imgfile = archive.read(current_image_path)
image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
except Exception:
image = np.ones((224, 224, 3), dtype=np.uint8) * 255
text = "file not found"
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.7
font_thickness = 1
text_color = (0, 0, 0) # Black color
# Get the size of the text
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
# Calculate the center position of the text
text_x = (image.shape[1] - text_width) // 2
text_y = (image.shape[0] + text_height) // 2
# Put the text on the image
cv2.putText(image, text, (text_x, text_y), font, font_scale, text_color, font_thickness)
ax = fig.add_subplot(rows, cols, i + 1)
plt.axis('off')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
item_uuid = current_image_path.split("/")[2].split("_photoUUID")[0].split("itemUUID_")[1]
ax.set_title('Top {}'.format(i), fontsize=40)
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
return fig, names
def knn_calc(image_name, query_feature, features):
current_image_feature = features[image_name]
criterion = torch.nn.CosineSimilarity(dim=1)
dist = criterion(query_feature, current_image_feature).mean()
dist = -dist.item()
return dist
checkpoint_path = "multi_label.pth.tar"
resnet = models.resnet101(pretrained=True)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 13)
model = Resnet_with_skip(resnet)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
# periods_model = models.resnet101(pretrained=True)
# periods_model.fc = nn.Linear(num_ftrs, 5)
# periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
# periods_model.load_state_dict(periods_checkpoint)
with open('query_images_paths.pkl', 'rb') as fp:
query_images_paths = pickle.load(fp)
with open('features.pkl', 'rb') as fp:
features = pickle.load(fp)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
transforms.Normalize(mean=[-0.5, -0.5, -0.5],
std=[1., 1., 1.]),
])
# labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
#
# periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
# periods_model.eval()
def predict(inp):
image_tensor = transform(inp)
with torch.no_grad():
# classification, reconstruction = model(image_tensor.unsqueeze(0))
# periods_classification = periods_model(image_tensor.unsqueeze(0))
# recon_tensor = reconstruction[0].repeat(3, 1, 1)
# recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
# plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
# w, h = inp.size
# m = nn.Sigmoid()
# y = m(classification)
# preds = []
# for sample in y:
# for i in sample:
# if i >=0.8:
# preds.append(1)
# else:
# preds.append(0)
# confidences = {}
# true_labels = ""
# for i in range(len(labels)):
# if preds[i]==1:
# if true_labels=="":
# true_labels = true_labels + labels[i]
# else:
# true_labels = true_labels + "&" + labels[i]
# confidences[true_labels] = torch.tensor(1.0)
#
# periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
# periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
feature = embedding_model_test(image_tensor.unsqueeze(0))
dists = dict()
with torch.no_grad():
for i, image_name in enumerate(query_images_paths):
dist = knn_calc(image_name, feature, features)
dists[image_name] = dist
res = dict(sorted(dists.items(), key=itemgetter(1)))
fig, names = create_retrieval_figure(res)
return fig, names
a = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
description="Easily classify artifacs, retrieve similar ones and generate drawings. "
"https://arxiv.org/abs/2211.09480.",
outputs=['plot', 'text'], ).launch(share=True)
# examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
# 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
# outputs=[gr.Label(num_top_classes=3), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)