|
import pickle |
|
from operator import itemgetter |
|
|
|
import cv2 |
|
import gradio as gr |
|
import kornia.filters |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import zipfile |
|
|
|
from torchvision import transforms, models |
|
from get_models import Resnet_with_skip |
|
|
|
|
|
def create_retrieval_figure(res): |
|
fig = plt.figure(figsize=[10 * 3, 10 * 3]) |
|
cols = 5 |
|
rows = 2 |
|
ax_query = fig.add_subplot(rows, 1, 1) |
|
plt.rcParams['figure.facecolor'] = 'white' |
|
plt.axis('off') |
|
ax_query.set_title('Top 10 most similar items', fontsize=40) |
|
names = "" |
|
for i, image in zip(range(len(res)), res): |
|
if i >= 10: |
|
break |
|
current_image_path = "dataset/" + image.split("/")[3] + "/" + image.split("/")[4] |
|
archive = zipfile.ZipFile('dataset.zip', 'r') |
|
try: |
|
imgfile = archive.read(current_image_path) |
|
image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1) |
|
|
|
except Exception: |
|
image = np.ones((224, 224, 3), dtype=np.uint8) * 255 |
|
text = "file not found" |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
font_scale = 0.7 |
|
font_thickness = 1 |
|
text_color = (0, 0, 0) |
|
|
|
|
|
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, font_thickness) |
|
|
|
|
|
text_x = (image.shape[1] - text_width) // 2 |
|
text_y = (image.shape[0] + text_height) // 2 |
|
|
|
|
|
cv2.putText(image, text, (text_x, text_y), font, font_scale, text_color, font_thickness) |
|
ax = fig.add_subplot(rows, cols, i + 1) |
|
plt.axis('off') |
|
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
item_uuid = current_image_path.split("/")[2].split("_photoUUID")[0].split("itemUUID_")[1] |
|
ax.set_title('Top {}'.format(i), fontsize=40) |
|
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n" |
|
return fig, names |
|
|
|
def knn_calc(image_name, query_feature, features): |
|
current_image_feature = features[image_name] |
|
criterion = torch.nn.CosineSimilarity(dim=1) |
|
dist = criterion(query_feature, current_image_feature).mean() |
|
dist = -dist.item() |
|
return dist |
|
|
|
checkpoint_path = "multi_label.pth.tar" |
|
|
|
resnet = models.resnet101(pretrained=True) |
|
num_ftrs = resnet.fc.in_features |
|
resnet.fc = nn.Linear(num_ftrs, 13) |
|
model = Resnet_with_skip(resnet) |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
model.load_state_dict(checkpoint) |
|
model.eval() |
|
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
with open('query_images_paths.pkl', 'rb') as fp: |
|
query_images_paths = pickle.load(fp) |
|
|
|
with open('features.pkl', 'rb') as fp: |
|
features = pickle.load(fp) |
|
|
|
|
|
|
|
model.eval() |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], |
|
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]), |
|
transforms.Normalize(mean=[-0.5, -0.5, -0.5], |
|
std=[1., 1., 1.]), |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(inp): |
|
image_tensor = transform(inp) |
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature = embedding_model_test(image_tensor.unsqueeze(0)) |
|
dists = dict() |
|
with torch.no_grad(): |
|
for i, image_name in enumerate(query_images_paths): |
|
dist = knn_calc(image_name, feature, features) |
|
dists[image_name] = dist |
|
res = dict(sorted(dists.items(), key=itemgetter(1))) |
|
fig, names = create_retrieval_figure(res) |
|
return fig, names |
|
|
|
|
|
a = gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
title="ArcAid: Analysis of Archaeological Artifacts using Drawings", |
|
description="Easily classify artifacs, retrieve similar ones and generate drawings. " |
|
"https://arxiv.org/abs/2211.09480.", |
|
outputs=['plot', 'text'], ).launch(share=True) |
|
|
|
|
|
|