import os
import random
import copy
from PIL import Image
import numpy as np
import json

from torch.utils.data import Dataset
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor

from utils.image_utils import random_augmentation, crop_img
from utils.degradation_utils import Degradation


class DerainDehazeDataset(Dataset):
    def __init__(self, args, img, text_prompt, task="derain"):
        super(DerainDehazeDataset, self).__init__()
        self.args = args
        self.toTensor = ToTensor()
        self.img = img
        self.text_prompt = text_prompt

    def __getitem__(self, idx):
        degraded_inp = self.img
        clean_path = ""
        degradation = ""

        text_prompt = self.text_prompt

        degraded_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16)
        clean_img = crop_img(np.array(degraded_inp.convert('RGB')), base=16)

        clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
        degraded_name = [""]

        return [degraded_name], degradation, degraded_img, clean_img, text_prompt 

    def __len__(self):
        return 1