import pickle from operator import itemgetter import cv2 import gradio as gr import kornia.filters import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import zipfile # from skimage.transform import resize from torchvision import transforms, models from get_models import Resnet_with_skip def create_retrieval_figure(res): fig = plt.figure(figsize=[10 * 3, 10 * 3]) cols = 5 rows = 2 ax_query = fig.add_subplot(rows, 1, 1) plt.rcParams['figure.facecolor'] = 'white' plt.axis('off') ax_query.set_title('Top 10 most similar items', fontsize=40) names = "" for i, image in zip(range(len(res)), res): if i >= 10: break current_image_path = "dataset/" + image.split("/")[3] + "/" + image.split("/")[4] archive = zipfile.ZipFile('dataset.zip', 'r') try: imgfile = archive.read(current_image_path) image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1) # image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR) except Exception: image = np.ones((224, 224, 3), dtype=np.uint8) * 255 text = "file not found" font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.7 font_thickness = 1 text_color = (0, 0, 0) # Black color # Get the size of the text (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, font_thickness) # Calculate the center position of the text text_x = (image.shape[1] - text_width) // 2 text_y = (image.shape[0] + text_height) // 2 # Put the text on the image cv2.putText(image, text, (text_x, text_y), font, font_scale, text_color, font_thickness) ax = fig.add_subplot(rows, cols, i + 1) plt.axis('off') plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) item_uuid = current_image_path.split("/")[2].split("_photoUUID")[0].split("itemUUID_")[1] ax.set_title('Top {}'.format(i), fontsize=40) names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n" return fig, names def knn_calc(image_name, query_feature, features): current_image_feature = features[image_name] criterion = torch.nn.CosineSimilarity(dim=1) dist = criterion(query_feature, current_image_feature).mean() dist = -dist.item() return dist checkpoint_path = "multi_label.pth.tar" resnet = models.resnet101(pretrained=True) num_ftrs = resnet.fc.in_features resnet.fc = nn.Linear(num_ftrs, 13) model = Resnet_with_skip(resnet) checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint) model.eval() embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1])) # periods_model = models.resnet101(pretrained=True) # periods_model.fc = nn.Linear(num_ftrs, 5) # periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu") # periods_model.load_state_dict(periods_checkpoint) with open('query_images_paths.pkl', 'rb') as fp: query_images_paths = pickle.load(fp) with open('features.pkl', 'rb') as fp: features = pickle.load(fp) model.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], std=[1 / 0.5, 1 / 0.5, 1 / 0.5]), transforms.Normalize(mean=[-0.5, -0.5, -0.5], std=[1., 1., 1.]), ]) # labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake'] # # periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2'] # periods_model.eval() def predict(inp): image_tensor = transform(inp) with torch.no_grad(): # classification, reconstruction = model(image_tensor.unsqueeze(0)) # periods_classification = periods_model(image_tensor.unsqueeze(0)) # recon_tensor = reconstruction[0].repeat(3, 1, 1) # recon_tensor = invTrans(kornia.enhance.invert(recon_tensor)) # plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy() # w, h = inp.size # m = nn.Sigmoid() # y = m(classification) # preds = [] # for sample in y: # for i in sample: # if i >=0.8: # preds.append(1) # else: # preds.append(0) # confidences = {} # true_labels = "" # for i in range(len(labels)): # if preds[i]==1: # if true_labels=="": # true_labels = true_labels + labels[i] # else: # true_labels = true_labels + "&" + labels[i] # confidences[true_labels] = torch.tensor(1.0) # # periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0) # periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))} feature = embedding_model_test(image_tensor.unsqueeze(0)) dists = dict() with torch.no_grad(): for i, image_name in enumerate(query_images_paths): dist = knn_calc(image_name, feature, features) dists[image_name] = dist res = dict(sorted(dists.items(), key=itemgetter(1))) fig, names = create_retrieval_figure(res) return fig, names a = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), title="ArcAid: Analysis of Archaeological Artifacts using Drawings", description="Easily classify artifacs, retrieve similar ones and generate drawings. " "https://arxiv.org/abs/2211.09480.", outputs=['plot', 'text'], ).launch(share=True) # examples=['anth.jpg', 'beetle_snakes.jpg', 'bird.jpg', 'cross.jpg', 'ibex.jpg', # 'lion.jpg', 'lion2.jpg', 'sa.jpg'], # outputs=[gr.Label(num_top_classes=3), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True)