|
import os |
|
import random |
|
import copy |
|
from PIL import Image |
|
import numpy as np |
|
import json |
|
|
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor |
|
|
|
from utils.image_utils import random_augmentation, crop_img |
|
from utils.degradation_utils import Degradation |
|
|
|
|
|
class DerainDehazeDataset(Dataset): |
|
def __init__(self, args, img, text_prompt, task="derain"): |
|
super(DerainDehazeDataset, self).__init__() |
|
self.args = args |
|
self.toTensor = ToTensor() |
|
self.img = img |
|
self.text_prompt = text_prompt |
|
|
|
def __getitem__(self, idx): |
|
degraded_inp = self.img |
|
clean_path = "" |
|
degradation = "" |
|
|
|
text_prompt = self.text_prompt |
|
|
|
degraded_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) |
|
clean_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16) |
|
|
|
clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) |
|
degraded_name = [""] |
|
|
|
return [degraded_name], degradation, degraded_img, clean_img, text_prompt |
|
|
|
def __len__(self): |
|
return 1 |