ICDR / utils /dataset_utils_CDD.py
Siwon123's picture
q
7f43945
import os
import random
import copy
from PIL import Image
import numpy as np
import json
from torch.utils.data import Dataset
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
from utils.image_utils import random_augmentation, crop_img
from utils.degradation_utils import Degradation
class DerainDehazeDataset(Dataset):
def __init__(self, args, img, text_prompt, task="derain"):
super(DerainDehazeDataset, self).__init__()
self.args = args
self.toTensor = ToTensor()
self.img = img
self.text_prompt = text_prompt
def __getitem__(self, idx):
degraded_inp = self.img
clean_path = ""
degradation = ""
text_prompt = self.text_prompt
degraded_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16)
clean_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16)
clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
degraded_name = [""]
return [degraded_name], degradation, degraded_img, clean_img, text_prompt
def __len__(self):
return 1