|
import gradio as gr |
|
import torch |
|
import os |
|
import kornia.filters |
|
import torchvision.transforms.functional |
|
import requests |
|
from PIL import Image |
|
from torchvision import transforms |
|
from operator import itemgetter |
|
import pickle |
|
import io |
|
from skimage.transform import resize |
|
|
|
from utils_functions.imports import * |
|
|
|
from util_models.resnet_with_skip import * |
|
from util_models.densenet_with_skip import * |
|
from util_models.glyphnet_with_skip import * |
|
|
|
|
|
def create_retrieval_figure(res): |
|
fig = plt.figure(figsize=[10 * 3, 10 * 3]) |
|
cols = 5 |
|
rows = 2 |
|
ax_query = fig.add_subplot(rows, 1, 1) |
|
plt.rcParams['figure.facecolor'] = 'white' |
|
plt.axis('off') |
|
ax_query.set_title('Top 10 most similar scarabs', fontsize=40) |
|
names = "" |
|
for i, image in zip(range(len(res)), res): |
|
current_image_path = image |
|
if i==0: continue |
|
if i < 11: |
|
image = cv2.imread(current_image_path) |
|
|
|
ax = fig.add_subplot(rows, cols, i) |
|
plt.axis('off') |
|
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1] |
|
ax.set_title('Top {}'.format(i), fontsize=40) |
|
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
return fig, names |
|
|
|
def knn_calc(image_name, query_feature, features): |
|
current_image_feature = features[image_name].to(device) |
|
criterion = torch.nn.CosineSimilarity(dim=1) |
|
dist = criterion(query_feature, current_image_feature).mean() |
|
dist = -dist.item() |
|
return dist |
|
|
|
|
|
def return_all_features(model_test, query_images_paths, glyph = False): |
|
model_test.eval() |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model_test.to(device) |
|
features = dict() |
|
i = 0 |
|
transform = transforms.Compose([ |
|
transforms.RandomApply([transforms.ToPILImage(),], p=1), |
|
transforms.Resize((224, 224)), |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
gray_scale = transforms.Grayscale(num_output_channels=1) |
|
with torch.no_grad(): |
|
for image_path in query_images_paths: |
|
print(i) |
|
i = i + 1 |
|
|
|
img = cv2.imread(image_path) |
|
img = transform(img) |
|
|
|
img = img.unsqueeze(0).contiguous().to(device) |
|
if glyph: |
|
img = gray_scale(img) |
|
current_image_features = model_test(img) |
|
|
|
features[image_path] = current_image_features |
|
|
|
|
|
del current_image_features |
|
torch.cuda.empty_cache() |
|
return features |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
experiment = "experiment_0" |
|
checkpoint_path = os.path.join("../shapes_classification/checkpoints/" |
|
"50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" + |
|
"experiment_0_last_auto_model.pth.tar") |
|
checkpoint_path = "multi_label.pth.tar" |
|
|
|
resnet = models.resnet101(pretrained=True) |
|
num_ftrs = resnet.fc.in_features |
|
resnet.fc = nn.Linear(num_ftrs, 13) |
|
model = Resnet_with_skip(resnet).to(device) |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
model.load_state_dict(checkpoint) |
|
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1])) |
|
embedding_model_test.to(device) |
|
|
|
periods_model = models.resnet101(pretrained=True) |
|
periods_model.fc = nn.Linear(num_ftrs, 5) |
|
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu") |
|
periods_model.load_state_dict(periods_checkpoint) |
|
periods_model.to(device) |
|
|
|
data_dir = "../cssl_dataset/all_image_base/1/" |
|
query_images_paths = [] |
|
for path in os.listdir(data_dir): |
|
query_images_paths.append(os.path.join(data_dir, path)) |
|
|
|
|
|
|
|
|
|
with open('features.pkl', 'rb') as fp: |
|
features = pickle.load(fp) |
|
|
|
model.eval() |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], |
|
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]), |
|
transforms.Normalize(mean=[-0.5, -0.5, -0.5], |
|
std=[1., 1., 1.]), |
|
]) |
|
|
|
labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos")) |
|
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2'] |
|
periods_model.eval() |
|
|
|
def predict(inp): |
|
image_tensor = transform(inp) |
|
image_tensor = image_tensor.to(device) |
|
with torch.no_grad(): |
|
classification, reconstruction = model(image_tensor.unsqueeze(0)) |
|
periods_classification = periods_model(image_tensor.unsqueeze(0)) |
|
recon_tensor = reconstruction[0].repeat(3, 1, 1) |
|
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor)) |
|
plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy() |
|
w, h = inp.size |
|
plot_recon = resize(plot_recon, (h, w)) |
|
m = nn.Sigmoid() |
|
y = m(classification) |
|
preds = [] |
|
for sample in y: |
|
for i in sample: |
|
if i >=0.8: |
|
preds.append(1) |
|
else: |
|
preds.append(0) |
|
|
|
confidences = {} |
|
true_labels = "" |
|
for i in range(len(labels)): |
|
if preds[i]==1: |
|
if true_labels=="": |
|
true_labels = true_labels + labels[i] |
|
else: |
|
true_labels = true_labels + "&" + labels[i] |
|
confidences[true_labels] = torch.tensor(1.0).to(device) |
|
|
|
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0) |
|
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))} |
|
feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device) |
|
dists = dict() |
|
with torch.no_grad(): |
|
for i, image_name in enumerate(query_images_paths): |
|
dist = knn_calc(image_name, feature, features) |
|
dists[image_name] = dist |
|
res = dict(sorted(dists.items(), key=itemgetter(1))) |
|
fig, names = create_retrieval_figure(res) |
|
return fig, names, plot_recon, confidences, periods_confidences |
|
|
|
|
|
gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True) |
|
|