arc-aid / use_gradio.py
offry's picture
py files
feaac02
import gradio as gr
import torch
import os
import kornia.filters
import torchvision.transforms.functional
import requests
from PIL import Image
from torchvision import transforms
from operator import itemgetter
import pickle
import io
from skimage.transform import resize
from utils_functions.imports import *
from util_models.resnet_with_skip import *
from util_models.densenet_with_skip import *
from util_models.glyphnet_with_skip import *
def create_retrieval_figure(res):
fig = plt.figure(figsize=[10 * 3, 10 * 3])
cols = 5
rows = 2
ax_query = fig.add_subplot(rows, 1, 1)
plt.rcParams['figure.facecolor'] = 'white'
plt.axis('off')
ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
names = ""
for i, image in zip(range(len(res)), res):
current_image_path = image
if i==0: continue
if i < 11:
image = cv2.imread(current_image_path)
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
ax = fig.add_subplot(rows, cols, i)
plt.axis('off')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1]
ax.set_title('Top {}'.format(i), fontsize=40)
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
# img_buf = io.BytesIO()
# plt.savefig(img_buf, format='png')
# im_fig = Image.open(img_buf)
# img_buf.close()
# return im_fig
return fig, names
def knn_calc(image_name, query_feature, features):
current_image_feature = features[image_name].to(device)
criterion = torch.nn.CosineSimilarity(dim=1)
dist = criterion(query_feature, current_image_feature).mean()
dist = -dist.item()
return dist
def return_all_features(model_test, query_images_paths, glyph = False):
model_test.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_test.to(device)
features = dict()
i = 0
transform = transforms.Compose([
transforms.RandomApply([transforms.ToPILImage(),], p=1),
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
gray_scale = transforms.Grayscale(num_output_channels=1)
with torch.no_grad():
for image_path in query_images_paths:
print(i)
i = i + 1
# if check_image_label(image_path, labels_dict) is not None:
img = cv2.imread(image_path)
img = transform(img)
# img = transforms.Grayscale(num_output_channels=1)(img).to(device)
img = img.unsqueeze(0).contiguous().to(device)
if glyph:
img = gray_scale(img)
current_image_features = model_test(img)
# current_image_features, _, _, _ = model_test(x1=img, x2=img)
features[image_path] = current_image_features
# if i % 5 == 0:
# print("Finished embedding of {} images".format(i))
del current_image_features
torch.cuda.empty_cache()
return features
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
experiment = "experiment_0"
checkpoint_path = os.path.join("../shapes_classification/checkpoints/"
"50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" +
"experiment_0_last_auto_model.pth.tar")
checkpoint_path = "multi_label.pth.tar"
resnet = models.resnet101(pretrained=True)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 13)
model = Resnet_with_skip(resnet).to(device)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint)
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
embedding_model_test.to(device)
periods_model = models.resnet101(pretrained=True)
periods_model.fc = nn.Linear(num_ftrs, 5)
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
periods_model.load_state_dict(periods_checkpoint)
periods_model.to(device)
data_dir = "../cssl_dataset/all_image_base/1/"
query_images_paths = []
for path in os.listdir(data_dir):
query_images_paths.append(os.path.join(data_dir, path))
# features = return_all_features(embedding_model_test, query_images_paths)
# with open('features.pkl', 'wb') as fp:
# pickle.dump(features, fp)
with open('features.pkl', 'rb') as fp:
features = pickle.load(fp)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
transforms.Normalize(mean=[-0.5, -0.5, -0.5],
std=[1., 1., 1.]),
])
labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos"))
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
periods_model.eval()
def predict(inp):
image_tensor = transform(inp)
image_tensor = image_tensor.to(device)
with torch.no_grad():
classification, reconstruction = model(image_tensor.unsqueeze(0))
periods_classification = periods_model(image_tensor.unsqueeze(0))
recon_tensor = reconstruction[0].repeat(3, 1, 1)
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy()
w, h = inp.size
plot_recon = resize(plot_recon, (h, w))
m = nn.Sigmoid()
y = m(classification)
preds = []
for sample in y:
for i in sample:
if i >=0.8:
preds.append(1)
else:
preds.append(0)
# prediction = torch.tensor(preds).to(device)
confidences = {}
true_labels = ""
for i in range(len(labels)):
if preds[i]==1:
if true_labels=="":
true_labels = true_labels + labels[i]
else:
true_labels = true_labels + "&" + labels[i]
confidences[true_labels] = torch.tensor(1.0).to(device)
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device)
dists = dict()
with torch.no_grad():
for i, image_name in enumerate(query_images_paths):
dist = knn_calc(image_name, feature, features)
dists[image_name] = dist
res = dict(sorted(dists.items(), key=itemgetter(1)))
fig, names = create_retrieval_figure(res)
return fig, names, plot_recon, confidences, periods_confidences
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True)