|
from typing import Sequence, Dict, Union, List, Mapping, Any, Optional |
|
import math |
|
import time |
|
import io |
|
import random |
|
|
|
import numpy as np |
|
import cv2 |
|
from PIL import Image |
|
import torch.utils.data as data |
|
|
|
from dataset.degradation import ( |
|
random_mixed_kernels, |
|
random_add_gaussian_noise, |
|
random_add_jpg_compression |
|
) |
|
from dataset.utils import load_file_list, center_crop_arr, random_crop_arr |
|
from utils.common import instantiate_from_config |
|
|
|
|
|
class CodeformerDataset(data.Dataset): |
|
|
|
def __init__( |
|
self, |
|
file_list: str, |
|
file_backend_cfg: Mapping[str, Any], |
|
out_size: int, |
|
crop_type: str, |
|
blur_kernel_size: int, |
|
kernel_list: Sequence[str], |
|
kernel_prob: Sequence[float], |
|
blur_sigma: Sequence[float], |
|
downsample_range: Sequence[float], |
|
noise_range: Sequence[float], |
|
jpeg_range: Sequence[int] |
|
) -> "CodeformerDataset": |
|
super(CodeformerDataset, self).__init__() |
|
self.file_list = file_list |
|
self.image_files = load_file_list(file_list) |
|
self.file_backend = instantiate_from_config(file_backend_cfg) |
|
self.out_size = out_size |
|
self.crop_type = crop_type |
|
assert self.crop_type in ["none", "center", "random"] |
|
|
|
self.blur_kernel_size = blur_kernel_size |
|
self.kernel_list = kernel_list |
|
self.kernel_prob = kernel_prob |
|
self.blur_sigma = blur_sigma |
|
self.downsample_range = downsample_range |
|
self.noise_range = noise_range |
|
self.jpeg_range = jpeg_range |
|
|
|
def load_gt_image(self, image_path: str, max_retry: int=5) -> Optional[np.ndarray]: |
|
image_bytes = None |
|
while image_bytes is None: |
|
if max_retry == 0: |
|
return None |
|
image_bytes = self.file_backend.get(image_path) |
|
max_retry -= 1 |
|
if image_bytes is None: |
|
time.sleep(0.5) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
if self.crop_type != "none": |
|
if image.height == self.out_size and image.width == self.out_size: |
|
image = np.array(image) |
|
else: |
|
if self.crop_type == "center": |
|
image = center_crop_arr(image, self.out_size) |
|
elif self.crop_type == "random": |
|
image = random_crop_arr(image, self.out_size, min_crop_frac=0.7) |
|
else: |
|
assert image.height == self.out_size and image.width == self.out_size |
|
image = np.array(image) |
|
|
|
return image |
|
|
|
def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]: |
|
|
|
img_gt = None |
|
while img_gt is None: |
|
|
|
image_file = self.image_files[index] |
|
gt_path = image_file["image_path"] |
|
prompt = image_file["prompt"] |
|
img_gt = self.load_gt_image(gt_path) |
|
if img_gt is None: |
|
print(f"filed to load {gt_path}, try another image") |
|
index = random.randint(0, len(self) - 1) |
|
|
|
|
|
img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32) |
|
h, w, _ = img_gt.shape |
|
if np.random.uniform() < 0.5: |
|
prompt = "" |
|
|
|
|
|
|
|
kernel = random_mixed_kernels( |
|
self.kernel_list, |
|
self.kernel_prob, |
|
self.blur_kernel_size, |
|
self.blur_sigma, |
|
self.blur_sigma, |
|
[-math.pi, math.pi], |
|
noise_range=None |
|
) |
|
img_lq = cv2.filter2D(img_gt, -1, kernel) |
|
|
|
scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) |
|
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) |
|
|
|
if self.noise_range is not None: |
|
img_lq = random_add_gaussian_noise(img_lq, self.noise_range) |
|
|
|
if self.jpeg_range is not None: |
|
img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) |
|
|
|
|
|
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32) |
|
|
|
lq = img_lq[..., ::-1].astype(np.float32) |
|
|
|
return gt, lq, prompt |
|
|
|
def __len__(self) -> int: |
|
return len(self.image_files) |
|
|