ZeroIG / utils.py
syedaoon's picture
Upload 7 files
eb5b895 verified
raw
history blame
5.1 kB
import os
import numpy as np
import torch
import shutil
from torch.autograd import Variable
import matplotlib.pyplot as plt
from PIL import Image
def pair_downsampler(img):
# img has shape B C H W
c = img.shape[1]
filter1 = torch.FloatTensor([[[[0, 0.5], [0.5, 0]]]]).to(img.device)
filter1 = filter1.repeat(c, 1, 1, 1)
filter2 = torch.FloatTensor([[[[0.5, 0], [0, 0.5]]]]).to(img.device)
filter2 = filter2.repeat(c, 1, 1, 1)
output1 = torch.nn.functional.conv2d(img, filter1, stride=2, groups=c)
output2 = torch.nn.functional.conv2d(img, filter2, stride=2, groups=c)
return output1,output2
def gauss_cdf(x):
return 0.5*(1+torch.erf(x/torch.sqrt(torch.tensor(2.))))
def gauss_kernel(kernlen=21,nsig=3,channels=1):
interval=(2*nsig+1.)/(kernlen)
x=torch.linspace(-nsig-interval/2.,nsig+interval/2.,kernlen+1,).cuda()
#kern1d=torch.diff(torch.erf(x/math.sqrt(2.0)))/2.0
kern1d=torch.diff(gauss_cdf(x))
kernel_raw=torch.sqrt(torch.outer(kern1d,kern1d))
kernel=kernel_raw/torch.sum(kernel_raw)
#out_filter=kernel.unsqueeze(2).unsqueeze(3).repeat(1,1,channels,1)
out_filter=kernel.view(1,1,kernlen,kernlen)
out_filter = out_filter.repeat(channels,1,1,1)
return out_filter
class LocalMean(torch.nn.Module):
def __init__(self, patch_size=5):
super(LocalMean, self).__init__()
self.patch_size = patch_size
self.padding = self.patch_size // 2
def forward(self, image):
image = torch.nn.functional.pad(image, (self.padding, self.padding, self.padding, self.padding), mode='reflect')
patches = image.unfold(2, self.patch_size, 1).unfold(3, self.patch_size, 1)
return patches.mean(dim=(4, 5))
def blur(x):
device = x.device
kernel_size = 21
padding = kernel_size // 2
kernel_var = gauss_kernel(kernel_size, 1, x.size(1)).to(device)
x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
return torch.nn.functional .conv2d(x_padded, kernel_var, padding=0, groups=x.size(1))
def padr_tensor(img):
pad=2
pad_mod=torch.nn.ConstantPad2d(pad,0)
img_pad=pad_mod(img)
return img_pad
def calculate_local_variance(train_noisy):
b,c,w,h=train_noisy.shape
avg_pool = torch.nn.AvgPool2d(kernel_size=5,stride=1,padding=2)
noisy_avg= avg_pool(train_noisy)
noisy_avg_pad=padr_tensor(noisy_avg)
train_noisy=padr_tensor(train_noisy)
unfolded_noisy_avg=noisy_avg_pad.unfold(2,5,1).unfold(3,5,1)
unfolded_noisy=train_noisy.unfold(2,5,1).unfold(3,5,1)
unfolded_noisy_avg=unfolded_noisy_avg.reshape(unfolded_noisy_avg.shape[0],-1,5,5)
unfolded_noisy=unfolded_noisy.reshape(unfolded_noisy.shape[0],-1,5,5)
noisy_diff_squared=(unfolded_noisy-unfolded_noisy_avg)**2
noisy_var=torch.mean(noisy_diff_squared,dim=(2,3))
noisy_var=noisy_var.view(b,c,w,h)
return noisy_var
def count_parameters_in_MB(model):
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
def save_checkpoint(state, is_best, save):
filename = os.path.join(save, 'checkpoint.pth.tar')
torch.save(state, filename)
if is_best:
best_filename = os.path.join(save, 'model_best.pth.tar')
shutil.copyfile(filename, best_filename)
def save(model, model_path):
torch.save(model.state_dict(), model_path)
def load(model, model_path):
model.load_state_dict(torch.load(model_path))
def drop_path(x, drop_prob):
if drop_prob > 0.:
keep_prob = 1.-drop_prob
mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
x.div_(keep_prob)
x.mul_(mask)
return x
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.makedirs(path,exist_ok=True)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.makedirs(os.path.join(path, 'scripts'),exist_ok=True)
for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
def show_pic(pic, name,path):
pic_num = len(pic)
for i in range(pic_num):
img = pic[i]
image_numpy = img[0].cpu().float().numpy()
if image_numpy.shape[0]==3:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
img_name = name[i]
plt.subplot(5, 6, i + 1)
plt.xlabel(str(img_name))
plt.xticks([])
plt.yticks([])
plt.imshow(im)
elif image_numpy.shape[0]==1:
im = Image.fromarray(np.clip(image_numpy[0] * 255.0, 0, 255.0).astype('uint8'))
img_name = name[i]
plt.subplot(5, 6, i + 1)
plt.xlabel(str(img_name))
plt.xticks([])
plt.yticks([])
plt.imshow(im,plt.cm.gray)
plt.savefig(path)