from typing import Sequence, Dict, Union, List, Mapping, Any, Optional import math import time import io import random import numpy as np import cv2 from PIL import Image import torch.utils.data as data from dataset.degradation import ( random_mixed_kernels, random_add_gaussian_noise, random_add_jpg_compression ) from dataset.utils import load_file_list, center_crop_arr, random_crop_arr from utils.common import instantiate_from_config class CodeformerDataset(data.Dataset): def __init__( self, file_list: str, file_backend_cfg: Mapping[str, Any], out_size: int, crop_type: str, blur_kernel_size: int, kernel_list: Sequence[str], kernel_prob: Sequence[float], blur_sigma: Sequence[float], downsample_range: Sequence[float], noise_range: Sequence[float], jpeg_range: Sequence[int] ) -> "CodeformerDataset": super(CodeformerDataset, self).__init__() self.file_list = file_list self.image_files = load_file_list(file_list) self.file_backend = instantiate_from_config(file_backend_cfg) self.out_size = out_size self.crop_type = crop_type assert self.crop_type in ["none", "center", "random"] # degradation configurations self.blur_kernel_size = blur_kernel_size self.kernel_list = kernel_list self.kernel_prob = kernel_prob self.blur_sigma = blur_sigma self.downsample_range = downsample_range self.noise_range = noise_range self.jpeg_range = jpeg_range def load_gt_image(self, image_path: str, max_retry: int=5) -> Optional[np.ndarray]: image_bytes = None while image_bytes is None: if max_retry == 0: return None image_bytes = self.file_backend.get(image_path) max_retry -= 1 if image_bytes is None: time.sleep(0.5) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") if self.crop_type != "none": if image.height == self.out_size and image.width == self.out_size: image = np.array(image) else: if self.crop_type == "center": image = center_crop_arr(image, self.out_size) elif self.crop_type == "random": image = random_crop_arr(image, self.out_size, min_crop_frac=0.7) else: assert image.height == self.out_size and image.width == self.out_size image = np.array(image) # hwc, rgb, 0,255, uint8 return image def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]: # load gt image img_gt = None while img_gt is None: # load meta file image_file = self.image_files[index] gt_path = image_file["image_path"] prompt = image_file["prompt"] img_gt = self.load_gt_image(gt_path) if img_gt is None: print(f"filed to load {gt_path}, try another image") index = random.randint(0, len(self) - 1) # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32) h, w, _ = img_gt.shape if np.random.uniform() < 0.5: prompt = "" # ------------------------ generate lq image ------------------------ # # blur kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, self.blur_kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], noise_range=None ) img_lq = cv2.filter2D(img_gt, -1, kernel) # downsample scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) # noise if self.noise_range is not None: img_lq = random_add_gaussian_noise(img_lq, self.noise_range) # jpeg compression if self.jpeg_range is not None: img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) # resize to original size img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) # BGR to RGB, [-1, 1] gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32) # BGR to RGB, [0, 1] lq = img_lq[..., ::-1].astype(np.float32) return gt, lq, prompt def __len__(self) -> int: return len(self.image_files)