Arcaid / app.py
offry's picture
Create app.py
c6b8c55
raw
history blame
16.3 kB
import os
import pickle
from operator import itemgetter
import cv2
import gradio as gr
import kornia.filters
import kornia.filters
import scipy.ndimage
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import random
from skimage.transform import resize
from torchvision import transforms, models
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
norm_layer = nn.BatchNorm2d
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.inst1 = nn.InstanceNorm2d(mid_channels)
# self.gn1 = nn.GroupNorm(4, mid_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.inst2 = nn.InstanceNorm2d(out_channels)
# self.gn2 = nn.GroupNorm(4, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
identity = x
out = self.conv1(x)
# out = self.bn1(out)
out = self.inst1(out)
# out = self.gn1(out)
out = self.relu(out)
out = self.conv2(out)
# out = self.bn2(out)
out = self.inst2(out)
# out = self.gn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
if in_channels == out_channels:
self.up = nn.Identity()
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class GaussianLayer(nn.Module):
def __init__(self):
super(GaussianLayer, self).__init__()
self.seq = nn.Sequential(
# nn.ReflectionPad2d(10),
nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
n= np.zeros((5,5))
n[3,3] = 1
k = scipy.ndimage.gaussian_filter(n,sigma=1)
for name, f in self.named_parameters():
f.data.copy_(torch.from_numpy(k))
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.up1 = Up(2048, 1024 // 1, False)
self.up2 = Up(1024, 512 // 1, False)
self.up3 = Up(512, 256 // 1, False)
self.conv2d_2_1 = conv3x3(256, 128)
self.gn1 = nn.GroupNorm(4, 128)
self.instance1 = nn.InstanceNorm2d(128)
self.up4 = Up(128, 64 // 1, False)
self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# self.upsample4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
self.upsample4_conv = DoubleConv(64, 64, 64 // 2)
self.up_ = Up(128, 128 // 1, False)
self.conv2d_2_2 = conv3x3(128, 6)
self.instance2 = nn.InstanceNorm2d(6)
self.gn2 = nn.GroupNorm(3, 6)
self.gaussian_blur = GaussianLayer()
self.up5 = Up(6, 3, False)
self.conv2d_2_3 = conv3x3(3, 1)
self.instance3 = nn.InstanceNorm2d(1)
self.gaussian_blur = GaussianLayer()
self.kernel = nn.Parameter(torch.tensor(
[[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]],
[[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]],
[[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ],
).unsqueeze(1))
self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1)
with torch.no_grad():
self.nms_conv.weight = self.kernel.float()
class Resnet_with_skip(nn.Module):
def __init__(self, model):
super(Resnet_with_skip, self).__init__()
self.model = model
self.decoder = Decoder()
def forward_pred(self, image):
pred_net = self.model(image)
return pred_net
def forward_decode(self, image):
identity = image
image = self.model.conv1(image)
image = self.model.bn1(image)
image = self.model.relu(image)
image1 = self.model.maxpool(image)
image2 = self.model.layer1(image1)
image3 = self.model.layer2(image2)
image4 = self.model.layer3(image3)
image5 = self.model.layer4(image4)
reconst1 = self.decoder.up1(image5, image4)
reconst2 = self.decoder.up2(reconst1, image3)
reconst3 = self.decoder.up3(reconst2, image2)
reconst = self.decoder.conv2d_2_1(reconst3)
# reconst = self.decoder.instance1(reconst)
reconst = self.decoder.gn1(reconst)
reconst = F.relu(reconst)
reconst4 = self.decoder.up4(reconst, image1)
# reconst5 = self.decoder.upsample4(reconst4)
reconst5 = self.decoder.upsample4(reconst4)
# reconst5 = self.decoder.upsample4_conv(reconst4)
reconst5 = self.decoder.up_(reconst5, image)
# reconst5 = reconst5 + image
reconst5 = self.decoder.conv2d_2_2(reconst5)
reconst5 = self.decoder.instance2(reconst5)
# reconst5 = self.decoder.gn2(reconst5)
reconst5 = F.relu(reconst5)
reconst = self.decoder.up5(reconst5, identity)
reconst = self.decoder.conv2d_2_3(reconst)
# reconst = self.decoder.instance3(reconst)
reconst = F.relu(reconst)
# return reconst
blurred = self.decoder.gaussian_blur(reconst)
gradients = kornia.filters.spatial_gradient(blurred, normalized=False)
# Unpack the edges
gx = gradients[:, :, 0]
gy = gradients[:, :, 1]
angle = torch.atan2(gy, gx)
# Radians to Degrees
import math
angle = 180.0 * angle / math.pi
# Round angle to the nearest 45 degree
angle = torch.round(angle / 45) * 45
nms_magnitude = self.decoder.nms_conv(blurred)
# nms_magnitude = F.conv2d(blurred, kernel.unsqueeze(1), padding=kernel.shape[-1]//2)
# Non-maximal suppression
# Get the indices for both directions
positive_idx = (angle / 45) % 8
positive_idx = positive_idx.long()
negative_idx = ((angle / 45) + 4) % 8
negative_idx = negative_idx.long()
# Apply the non-maximum suppression to the different directions
channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx)
channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx)
channel_select_filtered = torch.stack(
[channel_select_filtered_positive, channel_select_filtered_negative], 1
)
# is_max = channel_select_filtered.min(dim=1)[0] > 0.0
# magnitude = reconst * is_max
thresh = nn.Threshold(0.01, 0.01)
max_matrix = channel_select_filtered.min(dim=1)[0]
max_matrix = thresh(max_matrix)
magnitude = torch.mul(reconst, max_matrix)
# magnitude = torchvision.transforms.functional.invert(magnitude)
# magnitude = self.decoder.sharpen(magnitude)
# magnitude = self.decoder.threshold(magnitude)
magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0)
# magnitude = F.leaky_relu(magnitude)
return magnitude
def forward(self, image):
reconst = self.forward_decode(image)
pred = self.forward_pred(image)
return pred, reconst
def create_retrieval_figure(res):
fig = plt.figure(figsize=[10 * 3, 10 * 3])
cols = 5
rows = 2
ax_query = fig.add_subplot(rows, 1, 1)
plt.rcParams['figure.facecolor'] = 'white'
plt.axis('off')
ax_query.set_title('Top 10 most similar scarabs', fontsize=40)
names = ""
for i, image in zip(range(len(res)), res):
current_image_path = image
if i==0: continue
if i < 11:
image = cv2.imread(current_image_path)
# image_resized = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LINEAR)
ax = fig.add_subplot(rows, cols, i)
plt.axis('off')
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
item_uuid = current_image_path.split("/")[4].split("_photoUUID")[0].split("itemUUID_")[1]
ax.set_title('Top {}'.format(i), fontsize=40)
names = names + "Top " + str(i) + " item UUID is " + item_uuid + "\n"
# img_buf = io.BytesIO()
# plt.savefig(img_buf, format='png')
# im_fig = Image.open(img_buf)
# img_buf.close()
# return im_fig
return fig, names
def knn_calc(image_name, query_feature, features):
current_image_feature = features[image_name]
criterion = torch.nn.CosineSimilarity(dim=1)
dist = criterion(query_feature, current_image_feature).mean()
dist = -dist.item()
return dist
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
experiment = "experiment_0"
checkpoint_path = os.path.join("../shapes_classification/checkpoints/"
"50_50_pretrained_resnet101_experiment_0_train_images_with_drawings_batch_8_10:29:06/" +
"experiment_0_last_auto_model.pth.tar")
checkpoint_path = "multi_label.pth.tar"
resnet = models.resnet101(pretrained=True)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 13)
model = Resnet_with_skip(resnet).to(device)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint)
embedding_model_test = torch.nn.Sequential(*(list(model.children())[:-1]))
embedding_model_test.to(device)
periods_model = models.resnet101(pretrained=True)
periods_model.fc = nn.Linear(num_ftrs, 5)
periods_checkpoint = torch.load("periods.pth.tar", map_location="cpu")
periods_model.load_state_dict(periods_checkpoint)
periods_model.to(device)
data_dir = "../cssl_dataset/all_image_base/1/"
query_images_paths = []
for path in os.listdir(data_dir):
query_images_paths.append(os.path.join(data_dir, path))
with open('features.pkl', 'rb') as fp:
features = pickle.load(fp)
model.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
invTrans = transforms.Compose([transforms.Normalize(mean=[0., 0., 0.],
std=[1 / 0.5, 1 / 0.5, 1 / 0.5]),
transforms.Normalize(mean=[-0.5, -0.5, -0.5],
std=[1., 1., 1.]),
])
labels = sorted(os.listdir("../cssl_dataset/shape_multi_label/photos"))
periods_labels = ["MB1", "MB2", "LB", "Iron1", 'Iron2']
periods_model.eval()
def predict(inp):
image_tensor = transform(inp)
image_tensor = image_tensor.to(device)
with torch.no_grad():
classification, reconstruction = model(image_tensor.unsqueeze(0))
periods_classification = periods_model(image_tensor.unsqueeze(0))
recon_tensor = reconstruction[0].repeat(3, 1, 1)
recon_tensor = invTrans(kornia.enhance.invert(recon_tensor))
plot_recon = recon_tensor.to("cpu").permute(1, 2, 0).detach().numpy()
w, h = inp.size
plot_recon = resize(plot_recon, (h, w))
m = nn.Sigmoid()
y = m(classification)
preds = []
for sample in y:
for i in sample:
if i >=0.8:
preds.append(1)
else:
preds.append(0)
# prediction = torch.tensor(preds).to(device)
confidences = {}
true_labels = ""
for i in range(len(labels)):
if preds[i]==1:
if true_labels=="":
true_labels = true_labels + labels[i]
else:
true_labels = true_labels + "&" + labels[i]
confidences[true_labels] = torch.tensor(1.0).to(device)
periods_prediction = torch.nn.functional.softmax(periods_classification[0], dim=0)
periods_confidences = {periods_labels[i]: periods_prediction[i] for i in range(len(periods_labels))}
feature = embedding_model_test(image_tensor.unsqueeze(0)).to(device)
dists = dict()
with torch.no_grad():
for i, image_name in enumerate(query_images_paths):
dist = knn_calc(image_name, feature, features)
dists[image_name] = dist
res = dict(sorted(dists.items(), key=itemgetter(1)))
fig, names = create_retrieval_figure(res)
return fig, names, plot_recon, confidences, periods_confidences
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=['plot', 'text', "image", gr.Label(num_top_classes=1), gr.Label(num_top_classes=1)], ).launch(share=True)