File size: 5,668 Bytes
c6b8c55 8ff538a c6b8c55 4a508f2 c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 10fb420 c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 ab29ae2 c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 8ff538a c6b8c55 ab29ae2 c6b8c55 8ff538a 0c2a598 ab29ae2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import pickle
from operator import itemgetter
import cv2
import gradio as gr
import kornia.filters
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import zipfile
# from skimage.transform import resize
from torchvision import transforms, models
from get_models import Resnet_with_skip
def create_retrieval_figure(res):
fig = plt.figure(figsize=[10 * 3, 10 * 3])
cols = 5
rows = 2
ax_query = fig.add_subplot(rows, 1, 1)
plt.rcParams['figure.facecolor'] = 'white'
plt.axis('off')
ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
names = ""
for i, image in zip(range(len(res)), res):
current_image_path = image.split("/")[3]+"/"+image.split("/")[4]
if i==0: continue
if i < 11:
archive = zipfile.ZipFile('dataset.zip', 'r')
imgfile = archive.read(current_image_path)
image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1)
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
ax = fig.add_subplot(rows, cols, i)
plt.axis('off')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1]
ax.set_title('Top {}'.format(i), fontsize=40)
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
return fig, names
def knn_calc(image_name, query_feature, features):
current_image_feature = features[image_name]
criterion = torch.nn.CosineSimilarity(dim=1)
dist = criterion(query_feature, current_image_feature).mean()
dist = -dist.item()
return dist
checkpoint_path = "multi_label.pth.tar"
resnet = models.resnet101(pretrained=True)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 13)
model = Resnet_with_skip(resnet)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint)
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
periods_model = models.resnet101(pretrained=True)
periods_model.fc = nn.Linear(num_ftrs, 5)
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
periods_model.load_state_dict(periods_checkpoint)
with open('query_images_paths.pkl', 'rb') as fp:
query_images_paths = pickle.load(fp)
with open('features.pkl', 'rb') as fp:
features = pickle.load(fp)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
transforms.Normalize(mean=[-0.5, -0.5, -0.5],
std=[1., 1., 1.]),
])
labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake']
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
periods_model.eval()
def predict(inp):
image_tensor = transform(inp)
with torch.no_grad():
classification, reconstruction = model(image_tensor.unsqueeze(0))
periods_classification = periods_model(image_tensor.unsqueeze(0))
recon_tensor = reconstruction[0].repeat(3, 1, 1)
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy()
w, h = inp.size
# plot_recon = resize(plot_recon, (h, w))
m = nn.Sigmoid()
y = m(classification)
preds = []
for sample in y:
for i in sample:
if i >=0.8:
preds.append(1)
else:
preds.append(0)
confidences = {}
true_labels = ""
for i in range(len(labels)):
if preds[i]==1:
if true_labels=="":
true_labels = true_labels + labels[i]
else:
true_labels = true_labels + "&" + labels[i]
confidences[true_labels] = torch.tensor(1.0)
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
feature = embedding_model_test(image_tensor.unsqueeze(0))
dists = dict()
with torch.no_grad():
for i, image_name in enumerate(query_images_paths):
dist = knn_calc(image_name, feature, features)
dists[image_name] = dist
res = dict(sorted(dists.items(), key=itemgetter(1)))
fig, names = create_retrieval_figure(res)
return confidences, periods_confidences, plot_recon, fig, names
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
title="ArcAid: Analysis of Archaeological Artifacts using Drawings",
description="Easily classify artifacs, retrieve similar ones and generate drawings. "
"https://arxiv.org/abs/2211.09480.",
# examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg',
# 'lion.jpg', 'lion2.jpg', 'sa.jpg'],
outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)
|