Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from functools import partial | |
import cv2 | |
import numpy as np | |
import torch | |
from basicsr.data import degradations as degradations | |
from basicsr.data.transforms import augment | |
from basicsr.utils import img2tensor | |
from torch.nn.functional import interpolate | |
from torchvision.transforms import Compose | |
from utils.basicsr_custom import ( | |
random_mixed_kernels, | |
random_add_gaussian_noise, | |
random_add_jpg_compression, | |
) | |
def create_degradation(degradation): | |
if degradation == 'sr_bicubic_x8_gaussian_noise_005': | |
return Compose([ | |
partial(down_scale, scale_factor=1.0 / 8.0, mode='bicubic'), | |
partial(add_gaussian_noise, std=0.05), | |
partial(interpolate, scale_factor=8.0, mode='nearest-exact'), | |
partial(torch.clip, min=0, max=1), | |
partial(torch.squeeze, dim=0), | |
lambda x: (x, None) | |
]) | |
elif degradation == 'gaussian_noise_035': | |
return Compose([ | |
partial(add_gaussian_noise, std=0.35), | |
partial(torch.clip, min=0, max=1), | |
partial(torch.squeeze, dim=0), | |
lambda x: (x, None) | |
]) | |
elif degradation == 'colorization_gaussian_noise_025': | |
return Compose([ | |
lambda x: torch.mean(x, dim=0, keepdim=True), | |
partial(add_gaussian_noise, std=0.25), | |
partial(torch.clip, min=0, max=1), | |
lambda x: (x, None) | |
]) | |
elif degradation == 'random_inpainting_gaussian_noise_01': | |
def inpainting_dps(x): | |
total = x.shape[1] ** 2 | |
# random pixel sampling | |
l, h = [0.9, 0.9] | |
prob = np.random.uniform(l, h) | |
mask_vec = torch.ones([1, x.shape[1] * x.shape[1]]) | |
samples = np.random.choice(x.shape[1] * x.shape[1], int(total * prob), replace=False) | |
mask_vec[:, samples] = 0 | |
mask_b = mask_vec.view(1, x.shape[1], x.shape[1]) | |
mask_b = mask_b.repeat(3, 1, 1) | |
mask = torch.ones_like(x, device=x.device) | |
mask[:, ...] = mask_b | |
return add_gaussian_noise(x * mask, 0.1).clip(0, 1), None | |
return inpainting_dps | |
elif degradation == 'difface': | |
def deg(x): | |
blur_kernel_size = 41 | |
kernel_list = ['iso', 'aniso'] | |
kernel_prob = [0.5, 0.5] | |
blur_sigma = [0.1, 15] | |
downsample_range = [0.8, 32] | |
noise_range = [0, 20] | |
jpeg_range = [30, 100] | |
gt_gray = True | |
gray_prob = 0.01 | |
x = x.permute(1, 2, 0).numpy()[..., ::-1].astype(np.float32) | |
# random horizontal flip | |
img_gt = augment(x.copy(), hflip=True, rotation=False) | |
h, w, _ = img_gt.shape | |
# ------------------------ generate lq image ------------------------ # | |
# blur | |
kernel = degradations.random_mixed_kernels( | |
kernel_list, | |
kernel_prob, | |
blur_kernel_size, | |
blur_sigma, | |
blur_sigma, [-math.pi, math.pi], | |
noise_range=None) | |
img_lq = cv2.filter2D(img_gt, -1, kernel) | |
# downsample | |
scale = np.random.uniform(downsample_range[0], downsample_range[1]) | |
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) | |
# noise | |
if noise_range is not None: | |
img_lq = random_add_gaussian_noise(img_lq, noise_range) | |
# jpeg compression | |
if jpeg_range is not None: | |
img_lq = random_add_jpg_compression(img_lq, jpeg_range) | |
# resize to original size | |
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) | |
# random color jitter (only for lq) | |
# if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): | |
# img_lq = self.color_jitter(img_lq, self.color_jitter_shift) | |
# random to gray (only for lq) | |
if np.random.uniform() < gray_prob: | |
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) | |
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) | |
if gt_gray: # whether convert GT to gray images | |
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) | |
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels | |
# BGR to RGB, HWC to CHW, numpy to tensor | |
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) | |
# random color jitter (pytorch version) (only for lq) | |
# if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): | |
# brightness = self.opt.get('brightness', (0.5, 1.5)) | |
# contrast = self.opt.get('contrast', (0.5, 1.5)) | |
# saturation = self.opt.get('saturation', (0, 1.5)) | |
# hue = self.opt.get('hue', (-0.1, 0.1)) | |
# img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) | |
# round and clip | |
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. | |
return img_lq, img_gt.clip(0, 1) | |
return deg | |
else: | |
raise NotImplementedError() | |
def down_scale(x, scale_factor, mode): | |
with torch.no_grad(): | |
return interpolate(x.unsqueeze(0), | |
scale_factor=scale_factor, | |
mode=mode, | |
antialias=True, | |
align_corners=False).clip(0, 1) | |
def add_gaussian_noise(x, std): | |
with torch.no_grad(): | |
x = x + torch.randn_like(x) * std | |
return x | |