|
import os |
|
import pickle |
|
from operator import itemgetter |
|
|
|
import cv2 |
|
import gradio as gr |
|
import kornia.filters |
|
import kornia.filters |
|
import scipy.ndimage |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import random |
|
import zipfile |
|
|
|
from torchvision import transforms, models |
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
|
"""3x3 convolution with padding""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
|
padding=dilation, groups=groups, bias=False, dilation=dilation) |
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1): |
|
"""1x1 convolution""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
|
|
|
class DoubleConv(nn.Module): |
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels, mid_channels=None): |
|
super().__init__() |
|
if not mid_channels: |
|
mid_channels = out_channels |
|
norm_layer = nn.BatchNorm2d |
|
|
|
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(mid_channels) |
|
self.inst1 = nn.InstanceNorm2d(mid_channels) |
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False) |
|
self.bn2 = nn.BatchNorm2d(out_channels) |
|
self.inst2 = nn.InstanceNorm2d(out_channels) |
|
|
|
self.downsample = None |
|
if in_channels != out_channels: |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
) |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
|
|
out = self.inst1(out) |
|
|
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
|
|
out = self.inst2(out) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu(out) |
|
return out |
|
|
|
|
|
class Down(nn.Module): |
|
"""Downscaling with maxpool then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.maxpool_conv = nn.Sequential( |
|
nn.MaxPool2d(2), |
|
DoubleConv(in_channels, out_channels) |
|
) |
|
|
|
def forward(self, x): |
|
return self.maxpool_conv(x) |
|
|
|
|
|
class Up(nn.Module): |
|
"""Upscaling then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels, bilinear=True): |
|
super().__init__() |
|
|
|
|
|
if bilinear: |
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) |
|
else: |
|
if in_channels == out_channels: |
|
self.up = nn.Identity() |
|
else: |
|
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) |
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
|
def forward(self, x1, x2): |
|
x1 = self.up(x1) |
|
|
|
diffY = x2.size()[2] - x1.size()[2] |
|
diffX = x2.size()[3] - x1.size()[3] |
|
|
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, |
|
diffY // 2, diffY - diffY // 2]) |
|
|
|
|
|
|
|
x = torch.cat([x2, x1], dim=1) |
|
return self.conv(x) |
|
|
|
|
|
class OutConv(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super(OutConv, self).__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
class GaussianLayer(nn.Module): |
|
def __init__(self): |
|
super(GaussianLayer, self).__init__() |
|
self.seq = nn.Sequential( |
|
|
|
nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False) |
|
) |
|
|
|
self.weights_init() |
|
def forward(self, x): |
|
return self.seq(x) |
|
|
|
def weights_init(self): |
|
n= np.zeros((5,5)) |
|
n[3,3] = 1 |
|
k = scipy.ndimage.gaussian_filter(n,sigma=1) |
|
for name, f in self.named_parameters(): |
|
f.data.copy_(torch.from_numpy(k)) |
|
|
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
|
"""3x3 convolution with padding""" |
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
|
padding=dilation, groups=groups, bias=False, dilation=dilation) |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self): |
|
super(Decoder, self).__init__() |
|
self.up1 = Up(2048, 1024 // 1, False) |
|
self.up2 = Up(1024, 512 // 1, False) |
|
self.up3 = Up(512, 256 // 1, False) |
|
self.conv2d_2_1 = conv3x3(256, 128) |
|
self.gn1 = nn.GroupNorm(4, 128) |
|
self.instance1 = nn.InstanceNorm2d(128) |
|
self.up4 = Up(128, 64 // 1, False) |
|
self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
|
|
self.upsample4_conv = DoubleConv(64, 64, 64 // 2) |
|
self.up_ = Up(128, 128 // 1, False) |
|
self.conv2d_2_2 = conv3x3(128, 6) |
|
self.instance2 = nn.InstanceNorm2d(6) |
|
self.gn2 = nn.GroupNorm(3, 6) |
|
self.gaussian_blur = GaussianLayer() |
|
self.up5 = Up(6, 3, False) |
|
self.conv2d_2_3 = conv3x3(3, 1) |
|
self.instance3 = nn.InstanceNorm2d(1) |
|
self.gaussian_blur = GaussianLayer() |
|
self.kernel = nn.Parameter(torch.tensor( |
|
[[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]], |
|
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]], |
|
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]], |
|
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]], |
|
[[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]], |
|
[[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], |
|
[[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], |
|
[[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ], |
|
).unsqueeze(1)) |
|
|
|
self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1) |
|
with torch.no_grad(): |
|
self.nms_conv.weight = self.kernel.float() |
|
|
|
|
|
class Resnet_with_skip(nn.Module): |
|
def __init__(self, model): |
|
super(Resnet_with_skip, self).__init__() |
|
self.model = model |
|
self.decoder = Decoder() |
|
|
|
def forward_pred(self, image): |
|
pred_net = self.model(image) |
|
return pred_net |
|
|
|
def forward_decode(self, image): |
|
identity = image |
|
|
|
image = self.model.conv1(image) |
|
image = self.model.bn1(image) |
|
image = self.model.relu(image) |
|
image1 = self.model.maxpool(image) |
|
|
|
image2 = self.model.layer1(image1) |
|
image3 = self.model.layer2(image2) |
|
image4 = self.model.layer3(image3) |
|
image5 = self.model.layer4(image4) |
|
|
|
reconst1 = self.decoder.up1(image5, image4) |
|
reconst2 = self.decoder.up2(reconst1, image3) |
|
reconst3 = self.decoder.up3(reconst2, image2) |
|
reconst = self.decoder.conv2d_2_1(reconst3) |
|
|
|
reconst = self.decoder.gn1(reconst) |
|
reconst = F.relu(reconst) |
|
reconst4 = self.decoder.up4(reconst, image1) |
|
|
|
reconst5 = self.decoder.upsample4(reconst4) |
|
|
|
reconst5 = self.decoder.up_(reconst5, image) |
|
|
|
reconst5 = self.decoder.conv2d_2_2(reconst5) |
|
reconst5 = self.decoder.instance2(reconst5) |
|
|
|
reconst5 = F.relu(reconst5) |
|
reconst = self.decoder.up5(reconst5, identity) |
|
reconst = self.decoder.conv2d_2_3(reconst) |
|
|
|
reconst = F.relu(reconst) |
|
|
|
|
|
|
|
blurred = self.decoder.gaussian_blur(reconst) |
|
|
|
gradients = kornia.filters.spatial_gradient(blurred, normalized=False) |
|
|
|
gx = gradients[:, :, 0] |
|
gy = gradients[:, :, 1] |
|
|
|
angle = torch.atan2(gy, gx) |
|
|
|
|
|
import math |
|
angle = 180.0 * angle / math.pi |
|
|
|
|
|
angle = torch.round(angle / 45) * 45 |
|
nms_magnitude = self.decoder.nms_conv(blurred) |
|
|
|
|
|
|
|
|
|
positive_idx = (angle / 45) % 8 |
|
positive_idx = positive_idx.long() |
|
|
|
negative_idx = ((angle / 45) + 4) % 8 |
|
negative_idx = negative_idx.long() |
|
|
|
|
|
channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx) |
|
channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx) |
|
|
|
channel_select_filtered = torch.stack( |
|
[channel_select_filtered_positive, channel_select_filtered_negative], 1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
thresh = nn.Threshold(0.01, 0.01) |
|
max_matrix = channel_select_filtered.min(dim=1)[0] |
|
max_matrix = thresh(max_matrix) |
|
magnitude = torch.mul(reconst, max_matrix) |
|
|
|
|
|
|
|
magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0) |
|
|
|
return magnitude |
|
|
|
def forward(self, image): |
|
reconst = self.forward_decode(image) |
|
pred = self.forward_pred(image) |
|
return pred, reconst |
|
|
|
|
|
def create_retrieval_figure(res): |
|
fig = plt.figure(figsize=[10 * 3, 10 * 3]) |
|
cols = 5 |
|
rows = 2 |
|
ax_query = fig.add_subplot(rows, 1, 1) |
|
plt.rcParams['figure.facecolor'] = 'white' |
|
plt.axis('off') |
|
ax_query.set_title('Top 10 most similar scarabs', fontsize=40) |
|
names = "" |
|
for i, image in zip(range(len(res)), res): |
|
current_image_path = image.split("/")[3]+"/"+image.split("/")[4] |
|
if i==0: continue |
|
if i < 11: |
|
archive = zipfile.ZipFile('dataset.zip', 'r') |
|
imgfile = archive.read(current_image_path) |
|
image = cv2.imdecode(np.frombuffer(imgfile, np.uint8), 1) |
|
|
|
ax = fig.add_subplot(rows, cols, i) |
|
plt.axis('off') |
|
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
item_uuid = current_image_path.split("/")[1].split("_photoUUID")[0].split("itemUUID_")[1] |
|
ax.set_title('Top {}'.format(i), fontsize=40) |
|
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n" |
|
return fig, names |
|
|
|
def knn_calc(image_name, query_feature, features): |
|
current_image_feature = features[image_name] |
|
criterion = torch.nn.CosineSimilarity(dim=1) |
|
dist = criterion(query_feature, current_image_feature).mean() |
|
dist = -dist.item() |
|
return dist |
|
|
|
checkpoint_path = "multi_label.pth.tar" |
|
|
|
resnet = models.resnet101(pretrained=True) |
|
num_ftrs = resnet.fc.in_features |
|
resnet.fc = nn.Linear(num_ftrs, 13) |
|
model = Resnet_with_skip(resnet) |
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
model.load_state_dict(checkpoint) |
|
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1])) |
|
|
|
periods_model = models.resnet101(pretrained=True) |
|
periods_model.fc = nn.Linear(num_ftrs, 5) |
|
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu") |
|
periods_model.load_state_dict(periods_checkpoint) |
|
|
|
with open('query_images_paths.pkl', 'rb') as fp: |
|
query_images_paths = pickle.load(fp) |
|
|
|
with open('features.pkl', 'rb') as fp: |
|
features = pickle.load(fp) |
|
|
|
|
|
|
|
model.eval() |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.Grayscale(num_output_channels=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
|
]) |
|
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.], |
|
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]), |
|
transforms.Normalize(mean=[-0.5, -0.5, -0.5], |
|
std=[1., 1., 1.]), |
|
]) |
|
|
|
labels = ['ankh', 'anthropomorphic', 'bands', 'beetle', 'bird', 'circles', 'cross', 'duck', 'head', 'ibex', 'lion', 'sa', 'snake'] |
|
|
|
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2'] |
|
periods_model.eval() |
|
|
|
def predict(inp): |
|
image_tensor = transform(inp) |
|
with torch.no_grad(): |
|
classification, reconstruction = model(image_tensor.unsqueeze(0)) |
|
periods_classification = periods_model(image_tensor.unsqueeze(0)) |
|
recon_tensor = reconstruction[0].repeat(3, 1, 1) |
|
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor)) |
|
plot_recon = recon_tensor.permute(1, 2, 0).detach().numpy() |
|
w, h = inp.size |
|
|
|
m = nn.Sigmoid() |
|
y = m(classification) |
|
preds = [] |
|
for sample in y: |
|
for i in sample: |
|
if i >=0.8: |
|
preds.append(1) |
|
else: |
|
preds.append(0) |
|
confidences = {} |
|
true_labels = "" |
|
for i in range(len(labels)): |
|
if preds[i]==1: |
|
if true_labels=="": |
|
true_labels = true_labels + labels[i] |
|
else: |
|
true_labels = true_labels + "&" + labels[i] |
|
confidences[true_labels] = torch.tensor(1.0) |
|
|
|
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0) |
|
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))} |
|
feature = embedding_model_test(image_tensor.unsqueeze(0)) |
|
dists = dict() |
|
with torch.no_grad(): |
|
for i, image_name in enumerate(query_images_paths): |
|
dist = knn_calc(image_name, feature, features) |
|
dists[image_name] = dist |
|
res = dict(sorted(dists.items(), key=itemgetter(1))) |
|
fig, names = create_retrieval_figure(res) |
|
return confidences, periods_confidences, plot_recon, fig, names |
|
|
|
|
|
gr.Interface(fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
title="ArcAid: Analysis of Archaeological Artifacts using Drawings", |
|
description="Easily classify artifacs, retrieve similar ones and generate drawings. " |
|
"https://arxiv.org/abs/2211.09480.", |
|
|
|
|
|
outputs=[gr.Label(num_top_classes=1), gr.Label(num_top_classes=1), "image", 'plot', 'text'], ).launch(share=True, enable_queue=True) |
|
|