import os.path
import torch
import torch.utils.data as data
from PIL import Image
import random
import utils
import numpy as np
import torchvision.transforms as transforms
from utils_core import flow_viz
import cv2

class DDDataset(data.Dataset):
    def __init__(self):
        super(DDDataset, self).__init__()
    def initialize(self, opt):
        self.opt = opt
        self.dir_txt = opt.datapath
        self.paths = []
        in_file = open(self.dir_txt, "r")
        k = 0
        list_paths = in_file.readlines()
        for line in list_paths:
            #if k>=20: break
            flag = False
            line = line.strip()
            line = line.split()
            
            #source data
            if (not os.path.exists(line[0])):
                print(line[0]+" not exists")
                continue
            if (not os.path.exists(line[1])):
                print(line[1]+" not exists")
                continue
            if (not os.path.exists(line[2])):
                print(line[2]+" not exists")
                continue
            if (not os.path.exists(line[3])):
                print(line[3]+" not exists")
                continue
            # if (not os.path.exists(line[2])):
            #     print(line[2]+" not exists")
            #     continue

            # path_list = [line[0], line[1], line[2]]
            path_list = [line[0], line[1], line[2], line[3]]
            self.paths.append(path_list)
            k += 1
        in_file.close()
        self.data_size = len(self.paths)
        print("num data: ", len(self.paths))

    def process_data(self, color, mask):
        non_zero = mask.nonzero()
        bound = 10
        min_x = max(0, non_zero[1].min()-bound)
        max_x = min(self.opt.width-1, non_zero[1].max()+bound)
        min_y = max(0, non_zero[0].min()-bound)
        max_y = min(self.opt.height-1, non_zero[0].max()+bound)
        color = color * (mask!=0).astype(float)[:, :, None]
        crop_color = color[min_y:max_y, min_x:max_x, :]
        crop_color = cv2.resize(np.ascontiguousarray(crop_color), (self.opt.crop_width, self.opt.crop_height), interpolation=cv2.INTER_LINEAR)
        crop_params = [[min_x], [max_x], [min_y], [max_y]]

        return crop_color, crop_params

    def __getitem__(self, index):
        paths = self.paths[index % self.data_size]
        src_color = np.array(Image.open(paths[0]))
        src_color = src_color.astype(np.uint8)
        raw_src_color = src_color.copy()
        src_mask = np.array(Image.open(paths[1]))[:, :, 0]
        cv2.imwrite("test_mask.png", src_mask)
        src_mask_copy = src_mask.copy()
        src_crop_color, src_crop_params = self.process_data(src_color, src_mask)
        #self.write_mesh(src_X, src_Y, src_Z, "./tmp/src.obj")
        #HWC --> CHW, 
        raw_src_color = torch.from_numpy(raw_src_color).permute(2, 0, 1).float() / 255.0
        src_crop_color = torch.from_numpy(src_crop_color).permute(2, 0, 1).float() / 255.0

        src_mask_copy = (src_mask_copy!=0)
        src_mask_copy = torch.tensor(src_mask_copy[np.newaxis, :, :])

        tar_color = np.array(Image.open(paths[2]))
        tar_color = tar_color.astype(np.uint8)
        raw_tar_color = tar_color.copy()
        tar_mask = np.array(Image.open(paths[3]))[:, :, 0]
        tar_mask_copy = tar_mask.copy()
        tar_crop_color, tar_crop_params = self.process_data(tar_color, tar_mask) 

        raw_tar_color = torch.from_numpy(raw_tar_color).permute(2, 0, 1).float() / 255.0
        tar_crop_color = torch.from_numpy(tar_crop_color).permute(2, 0, 1).float() / 255.0

        tar_mask_copy = (tar_mask_copy!=0)
        tar_mask_copy = torch.tensor(tar_mask_copy[np.newaxis, :, :])

        Crop_param = torch.tensor(src_crop_params+tar_crop_params)

        split_ = paths[0].split("/")
        path1 = split_[-1][:-4] + "_" + paths[2].split("/")[-1][:-4] +".oflow"

        return {"path_flow":path1, "src_crop_color":src_crop_color, "tar_crop_color":tar_crop_color, "src_color":raw_src_color, "tar_color":raw_tar_color, "src_mask":src_mask_copy, "tar_mask":tar_mask_copy, "Crop_param":Crop_param}

    def __len__(self):
        return self.data_size

    def name(self):
        return 'DDDataset'