import kornia.filters
import scipy.ndimage
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
if not mid_channels:
mid_channels = out_channels
norm_layer = nn.BatchNorm2d
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.inst1 = nn.InstanceNorm2d(mid_channels)
# self.gn1 = nn.GroupNorm(4, mid_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.inst2 = nn.InstanceNorm2d(out_channels)
# self.gn2 = nn.GroupNorm(4, out_channels)
self.downsample = None
if in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
def forward(self, x):
identity = x
out = self.conv1(x)
# out = self.bn1(out)
out = self.inst1(out)
# out = self.gn1(out)
out = self.relu(out)
out = self.conv2(out)
# out = self.bn2(out)
out = self.inst2(out)
# out = self.gn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
self.maxpool_conv = nn.Sequential(
DoubleConv(in_channels, out_channels)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
if in_channels == out_channels:
self.up = nn.Identity()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
x =[x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class GaussianLayer(nn.Module):
def __init__(self):
super(GaussianLayer, self).__init__()
self.seq = nn.Sequential(
# nn.ReflectionPad2d(10),
nn.Conv2d(1, 1, 5, stride=1, padding=2, bias=False)
def forward(self, x):
return self.seq(x)
def weights_init(self):
n= np.zeros((5,5))
n[3,3] = 1
k = scipy.ndimage.gaussian_filter(n,sigma=1)
for name, f in self.named_parameters():
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.up1 = Up(2048, 1024 // 1, False)
self.up2 = Up(1024, 512 // 1, False)
self.up3 = Up(512, 256 // 1, False)
self.conv2d_2_1 = conv3x3(256, 128)
self.gn1 = nn.GroupNorm(4, 128)
self.instance1 = nn.InstanceNorm2d(128)
self.up4 = Up(128, 64 // 1, False)
self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# self.upsample4 = nn.ConvTranspose2d(64, 64, 2, stride=2)
self.upsample4_conv = DoubleConv(64, 64, 64 // 2)
self.up_ = Up(128, 128 // 1, False)
self.conv2d_2_2 = conv3x3(128, 6)
self.instance2 = nn.InstanceNorm2d(6)
self.gn2 = nn.GroupNorm(3, 6)
self.gaussian_blur = GaussianLayer()
self.up5 = Up(6, 3, False)
self.conv2d_2_3 = conv3x3(3, 1)
self.instance3 = nn.InstanceNorm2d(1)
self.gaussian_blur = GaussianLayer()
self.kernel = nn.Parameter(torch.tensor(
[[[0.0, 0.0, 0.0], [0.0, 1.0, random.uniform(-1.0, 0.0)], [0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, random.uniform(-1.0, 0.0)]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, random.uniform(random.uniform(-1.0, 0.0), -0.0), 0.0]],
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [random.uniform(-1.0, 0.0), 0.0, 0.0]],
[[0.0, 0.0, 0.0], [random.uniform(-1.0, 0.0), 1.0, 0.0], [0.0, 0.0, 0.0]],
[[random.uniform(-1.0, 0.0), 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, random.uniform(-1.0, 0.0), 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, 0.0, random.uniform(-1.0, 0.0)], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], ],
self.nms_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False, groups=1)
with torch.no_grad():
self.nms_conv.weight = self.kernel.float()
class Resnet_with_skip(nn.Module):
def __init__(self, model):
super(Resnet_with_skip, self).__init__()
self.model = model
self.decoder = Decoder()
def forward_pred(self, image):
pred_net = self.model(image)
return pred_net
def forward_decode(self, image):
identity = image
image = self.model.conv1(image)
image = self.model.bn1(image)
image = self.model.relu(image)
image1 = self.model.maxpool(image)
image2 = self.model.layer1(image1)
image3 = self.model.layer2(image2)
image4 = self.model.layer3(image3)
image5 = self.model.layer4(image4)
reconst1 = self.decoder.up1(image5, image4)
reconst2 = self.decoder.up2(reconst1, image3)
reconst3 = self.decoder.up3(reconst2, image2)
reconst = self.decoder.conv2d_2_1(reconst3)
# reconst = self.decoder.instance1(reconst)
reconst = self.decoder.gn1(reconst)
reconst = F.relu(reconst)
reconst4 = self.decoder.up4(reconst, image1)
# reconst5 = self.decoder.upsample4(reconst4)
reconst5 = self.decoder.upsample4(reconst4)
# reconst5 = self.decoder.upsample4_conv(reconst4)
reconst5 = self.decoder.up_(reconst5, image)
# reconst5 = reconst5 + image
reconst5 = self.decoder.conv2d_2_2(reconst5)
reconst5 = self.decoder.instance2(reconst5)
# reconst5 = self.decoder.gn2(reconst5)
reconst5 = F.relu(reconst5)
reconst = self.decoder.up5(reconst5, identity)
reconst = self.decoder.conv2d_2_3(reconst)
# reconst = self.decoder.instance3(reconst)
reconst = F.relu(reconst)
# return reconst
blurred = self.decoder.gaussian_blur(reconst)
gradients = kornia.filters.spatial_gradient(blurred, normalized=False)
# Unpack the edges
gx = gradients[:, :, 0]
gy = gradients[:, :, 1]
angle = torch.atan2(gy, gx)
# Radians to Degrees
import math
angle = 180.0 * angle / math.pi
# Round angle to the nearest 45 degree
angle = torch.round(angle / 45) * 45
nms_magnitude = self.decoder.nms_conv(blurred)
# nms_magnitude = F.conv2d(blurred, kernel.unsqueeze(1), padding=kernel.shape[-1]//2)
# Non-maximal suppression
# Get the indices for both directions
positive_idx = (angle / 45) % 8
positive_idx = positive_idx.long()
negative_idx = ((angle / 45) + 4) % 8
negative_idx = negative_idx.long()
# Apply the non-maximum suppression to the different directions
channel_select_filtered_positive = torch.gather(nms_magnitude, 1, positive_idx)
channel_select_filtered_negative = torch.gather(nms_magnitude, 1, negative_idx)
channel_select_filtered = torch.stack(
[channel_select_filtered_positive, channel_select_filtered_negative], 1
# is_max = channel_select_filtered.min(dim=1)[0] > 0.0
# magnitude = reconst * is_max
thresh = nn.Threshold(0.01, 0.01)
max_matrix = channel_select_filtered.min(dim=1)[0]
max_matrix = thresh(max_matrix)
magnitude = torch.mul(reconst, max_matrix)
# magnitude = torchvision.transforms.functional.invert(magnitude)
# magnitude = self.decoder.sharpen(magnitude)
# magnitude = self.decoder.threshold(magnitude)
magnitude = kornia.enhance.adjust_gamma(magnitude, 2.0)
# magnitude = F.leaky_relu(magnitude)
return magnitude
def forward(self, image):
reconst = self.forward_decode(image)
pred = self.forward_pred(image)
return pred, reconst