Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,763 Bytes
b7f3942 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import math
from functools import partial
import cv2
import numpy as np
import torch
from basicsr.data import degradations as degradations
from basicsr.data.transforms import augment
from basicsr.utils import img2tensor
from torch.nn.functional import interpolate
from torchvision.transforms import Compose
from utils.basicsr_custom import (
random_mixed_kernels,
random_add_gaussian_noise,
random_add_jpg_compression,
)
def create_degradation(degradation):
if degradation == 'sr_bicubic_x8_gaussian_noise_005':
return Compose([
partial(down_scale, scale_factor=1.0 / 8.0, mode='bicubic'),
partial(add_gaussian_noise, std=0.05),
partial(interpolate, scale_factor=8.0, mode='nearest-exact'),
partial(torch.clip, min=0, max=1),
partial(torch.squeeze, dim=0),
lambda x: (x, None)
])
elif degradation == 'gaussian_noise_035':
return Compose([
partial(add_gaussian_noise, std=0.35),
partial(torch.clip, min=0, max=1),
partial(torch.squeeze, dim=0),
lambda x: (x, None)
])
elif degradation == 'colorization_gaussian_noise_025':
return Compose([
lambda x: torch.mean(x, dim=0, keepdim=True),
partial(add_gaussian_noise, std=0.25),
partial(torch.clip, min=0, max=1),
lambda x: (x, None)
])
elif degradation == 'random_inpainting_gaussian_noise_01':
def inpainting_dps(x):
total = x.shape[1] ** 2
# random pixel sampling
l, h = [0.9, 0.9]
prob = np.random.uniform(l, h)
mask_vec = torch.ones([1, x.shape[1] * x.shape[1]])
samples = np.random.choice(x.shape[1] * x.shape[1], int(total * prob), replace=False)
mask_vec[:, samples] = 0
mask_b = mask_vec.view(1, x.shape[1], x.shape[1])
mask_b = mask_b.repeat(3, 1, 1)
mask = torch.ones_like(x, device=x.device)
mask[:, ...] = mask_b
return add_gaussian_noise(x * mask, 0.1).clip(0, 1), None
return inpainting_dps
elif degradation == 'difface':
def deg(x):
blur_kernel_size = 41
kernel_list = ['iso', 'aniso']
kernel_prob = [0.5, 0.5]
blur_sigma = [0.1, 15]
downsample_range = [0.8, 32]
noise_range = [0, 20]
jpeg_range = [30, 100]
gt_gray = True
gray_prob = 0.01
x = x.permute(1, 2, 0).numpy()[..., ::-1].astype(np.float32)
# random horizontal flip
img_gt = augment(x.copy(), hflip=True, rotation=False)
h, w, _ = img_gt.shape
# ------------------------ generate lq image ------------------------ #
# blur
kernel = degradations.random_mixed_kernels(
kernel_list,
kernel_prob,
blur_kernel_size,
blur_sigma,
blur_sigma, [-math.pi, math.pi],
noise_range=None)
img_lq = cv2.filter2D(img_gt, -1, kernel)
# downsample
scale = np.random.uniform(downsample_range[0], downsample_range[1])
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
# noise
if noise_range is not None:
img_lq = random_add_gaussian_noise(img_lq, noise_range)
# jpeg compression
if jpeg_range is not None:
img_lq = random_add_jpg_compression(img_lq, jpeg_range)
# resize to original size
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
# random color jitter (only for lq)
# if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
# img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
# random to gray (only for lq)
if np.random.uniform() < gray_prob:
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
if gt_gray: # whether convert GT to gray images
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# random color jitter (pytorch version) (only for lq)
# if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
# brightness = self.opt.get('brightness', (0.5, 1.5))
# contrast = self.opt.get('contrast', (0.5, 1.5))
# saturation = self.opt.get('saturation', (0, 1.5))
# hue = self.opt.get('hue', (-0.1, 0.1))
# img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
# round and clip
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
return img_lq, img_gt.clip(0, 1)
return deg
else:
raise NotImplementedError()
def down_scale(x, scale_factor, mode):
with torch.no_grad():
return interpolate(x.unsqueeze(0),
scale_factor=scale_factor,
mode=mode,
antialias=True,
align_corners=False).clip(0, 1)
def add_gaussian_noise(x, std):
with torch.no_grad():
x = x + torch.randn_like(x) * std
return x
|