import glob
import json
import os
import cv2
import pickle
import random
import re
import subprocess
from functools import partial

import librosa.core
import numpy as np
import torch
import torch.distributions
import torch.distributed as dist
import torch.optim
import torch.utils.data

from utils.commons.indexed_datasets import IndexedDataset
from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import csv
from utils.commons.hparams import hparams, set_hparams
from utils.commons.meters import Timer
from data_util.face3d_helper import Face3DHelper
from utils.audio import librosa_wav2mfcc
from utils.commons.dataset_utils import collate_xd
from utils.commons.tensor_utils  import convert_to_tensor
from data_gen.utils.process_video.extract_segment_imgs import decode_segmap_mask_from_image
from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
from utils.commons.image_utils import load_image_as_uint8_tensor
from modules.eg3ds.camera_utils.pose_sampler import UnifiedCameraPoseSampler


def sample_idx(img_dir, num_frames):
    cnt = 0
    while True:
        cnt += 1
        if cnt > 1000:
            print(f"recycle for more than 1000 times, check this {img_dir}")
        idx = random.randint(0, num_frames-1)
        ret1 = find_img_name(img_dir, idx)
        if ret1 == 'None':
            continue
        ret2 = find_img_name(img_dir.replace("/gt_imgs/","/head_imgs/"), idx)
        if ret2 == 'None':
            continue
        ret3 = find_img_name(img_dir.replace("/gt_imgs/","/inpaint_torso_imgs/"), idx)
        if ret3 == 'None':
            continue
        ret4 = find_img_name(img_dir.replace("/gt_imgs/","/com_imgs/"), idx)
        if ret4 == 'None':
            continue
        return idx
    

def find_img_name(img_dir, idx):
    gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".jpg")
    if not os.path.exists(gt_img_fname):
        gt_img_fname = os.path.join(img_dir, str(idx) + ".jpg")
    if not os.path.exists(gt_img_fname):
        gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".jpg")
    if not os.path.exists(gt_img_fname):
        gt_img_fname = os.path.join(img_dir, format(idx, "08d") + ".png")
    if not os.path.exists(gt_img_fname):
        gt_img_fname = os.path.join(img_dir, format(idx, "05d") + ".png")
    if not os.path.exists(gt_img_fname):
        gt_img_fname = os.path.join(img_dir, str(idx) + ".png")
    if os.path.exists(gt_img_fname):
        return gt_img_fname
    else:
        return 'None'
    
    
def get_win_from_arr(arr, index, win_size):
    left = index - win_size//2
    right = index + (win_size - win_size//2)
    pad_left = 0
    pad_right = 0
    if left < 0:
        pad_left = -left
        left = 0
    if right > arr.shape[0]:
        pad_right = right - arr.shape[0]
        right = arr.shape[0]
    win = arr[left:right]
    if pad_left > 0:
        if isinstance(arr, np.ndarray):
            win = np.concatenate([np.zeros_like(win[:pad_left]), win], axis=0)
        else:
            win = torch.cat([torch.zeros_like(win[:pad_left]), win], dim=0)
    if pad_right > 0:
        if isinstance(arr, np.ndarray):
            win = np.concatenate([win, np.zeros_like(win[:pad_right])], axis=0) # [8, 16]
        else:
            win = torch.cat([win, torch.zeros_like(win[:pad_right])], dim=0) # [8, 16]
    return win


class Img2Plane_Dataset(Dataset):
    def __init__(self, prefix='train', data_dir=None):
        self.db_key = prefix
        self.ds = None
        self.sizes = None
        self.x_maxframes = 200 # 50 video frames
        self.face3d_helper = Face3DHelper('deep_3drecon/BFM')
        self.x_multiply = 8
        self.hparams = hparams
        self.pose_sampler = UnifiedCameraPoseSampler()
        self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir

    def __len__(self):
        ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}')
        return len(ds)

    def _get_item(self, index):
        """
        This func is necessary to open files in multi-threads!
        """
        if self.ds is None:
            self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}')
        return self.ds[index]
    
    def __getitem__(self, idx):
        raw_item = self._get_item(idx)
        if raw_item is None:
            print("loading from binary data failed!")
            return None
        item = {
            'idx': idx,
            'item_name': raw_item['img_dir'],
        }
        img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/')
        num_frames = len(raw_item['exp'])

        hparams = self.hparams
        camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()})
        c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
        raw_item['c2w'] = c2w
        raw_item['intrinsics'] = intrinsics


        max_pitch = 10 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref
        min_pitch = -max_pitch
        pitch = random.random() * (max_pitch - min_pitch) + min_pitch
        max_yaw = 16 / 180 * 3.1415926
        min_yaw = - max_yaw
        yaw = random.random() * (max_yaw - min_yaw) + min_yaw
        distance = random.random() * (3.2-2.7) + 2.7 # [2.7, 4.0]
        ws_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0]

        if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 :
            max_pitch = 26 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref
            min_pitch = -max_pitch
            pitch = random.random() * (max_pitch - min_pitch) + min_pitch
            max_yaw = 38 / 180 * 3.1415926
            min_yaw = - max_yaw
            yaw = random.random() * (max_yaw - min_yaw) + min_yaw
            distance = random.random() * (4.0-2.7) + 2.7 # [2.7, 4.0]
            real_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0]
        else:
            real_idx = sample_idx(img_dir, num_frames)
            real_c2w = raw_item['c2w'][real_idx]
            real_intrinsics = raw_item['intrinsics'][real_idx]
            real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0)
            real_camera = convert_to_tensor(real_camera)

        if hparams.get("random_sample_pose", False) is True and random.random() < 0.5 :
            max_pitch = 26 / 180 * 3.1415926 # range for mv pitch angle is smaller than that of ref
            min_pitch = -max_pitch
            pitch = random.random() * (max_pitch - min_pitch) + min_pitch
            max_yaw = 38 / 180 * 3.1415926
            min_yaw = - max_yaw
            yaw = random.random() * (max_yaw - min_yaw) + min_yaw
            distance = random.random() * (4.0-2.7) + 2.7 # [2.7, 4.0]
            fake_camera = self.pose_sampler.get_camera_pose(pitch, yaw, lookat_location=torch.tensor([0,0,0.2]), distance_to_orig=distance)[0]
        else:
            fake_idx = sample_idx(img_dir, num_frames)
            fake_c2w = raw_item['c2w'][fake_idx]
            fake_intrinsics = raw_item['intrinsics'][fake_idx]
            fake_camera = np.concatenate([fake_c2w.reshape([16,]), fake_intrinsics.reshape([9,])], axis=0)
            fake_camera = convert_to_tensor(fake_camera)

        item.update({
            'ws_camera': ws_camera,
            'real_camera': real_camera,
            'fake_camera': fake_camera,
            # id,exp,euler,trans, used to generate the secc map
        })

        return item
    
    def get_dataloader(self, batch_size=1, num_workers=0):
        loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers)
        return loader

    def collater(self, samples):
        hparams = self.hparams
        if len(samples) == 0:
            return {}
        batch = {}

        batch['ffhq_ws_cameras'] = torch.stack([s['ws_camera'] for s in samples], dim=0) # [B, 204]
        batch['ffhq_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) # [B, 204]
        batch['ffhq_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) # [B, 204]
        return batch



class Motion2Video_Dataset(Dataset):
    def __init__(self, prefix='train', data_dir=None):
        self.db_key = prefix
        self.ds = None
        self.sizes = None
        self.x_maxframes = 200 # 50 video frames
        self.face3d_helper = Face3DHelper('deep_3drecon/BFM')
        self.x_multiply = 8
        self.hparams = hparams
        self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir

    def __len__(self):
        ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}')
        return len(ds)

    def _get_item(self, index):
        """
        This func is necessary to open files in multi-threads!
        """
        if self.ds is None:
            self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}')
        return self.ds[index]
    
    def __getitem__(self, idx):
        raw_item = self._get_item(idx)
        if raw_item is None:
            print("loading from binary data failed!")
            return None
        item = {
            'idx': idx,
            'item_name': raw_item['img_dir'],
        }
            
        camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':convert_to_tensor(raw_item['euler']).cpu(), 'trans':convert_to_tensor(raw_item['trans']).cpu()})
        c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
        raw_item['c2w'] = c2w
        raw_item['intrinsics'] = intrinsics

        img_dir = raw_item['img_dir'].replace('/com_imgs/', '/gt_imgs/')
        num_frames = len(raw_item['exp'])

        # src 
        real_idx = sample_idx(img_dir, num_frames)
        real_c2w = raw_item['c2w'][real_idx]
        
        real_intrinsics = raw_item['intrinsics'][real_idx]
        real_camera = np.concatenate([real_c2w.reshape([16,]) , real_intrinsics.reshape([9,])], axis=0)
        real_camera = convert_to_tensor(real_camera)
        item['real_camera'] = real_camera

        gt_img_fname = find_img_name(img_dir, real_idx)
        gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] # ignore alpha channel when png
        item['real_gt_img'] = gt_img.float() / 127.5 - 1
        # for key in ['head', 'torso', 'torso_with_bg', 'person']:
        for key in ['head', 'com', 'inpaint_torso']:
            key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/")
            key_img_fname = find_img_name(key_img_dir, real_idx)
            key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] # ignore alpha channel when png
            item[f'real_{key}_img'] = key_img.float() / 127.5 - 1
        bg_img_name = img_dir.replace("/gt_imgs/",f"/bg_img/") + '.jpg'
        bg_img = load_image_as_uint8_tensor(bg_img_name)[..., :3] # ignore alpha channel when png
        item[f'bg_img'] = bg_img.float() / 127.5 - 1

        seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png")
        seg_img = cv2.imread(seg_img_name)[:,:, ::-1]
        segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
        item[f'real_segmap'] = segmap
        item[f'real_head_mask'] = segmap[[1,3,5]].sum(dim=0)
        item[f'real_torso_mask'] = segmap[[2,4]].sum(dim=0)
        item.update({
            # id,exp,euler,trans, used to generate the secc map
            'real_identity': convert_to_tensor(raw_item['id']).reshape([80,]),
            # 'real_identity': convert_to_tensor(raw_item['id'][real_idx]).reshape([80,]),
            'real_expression': convert_to_tensor(raw_item['exp'][real_idx]).reshape([64,]),
            'real_euler': convert_to_tensor(raw_item['euler'][real_idx]).reshape([3,]),
            'real_trans': convert_to_tensor(raw_item['trans'][real_idx]).reshape([3,]),
        })

        pertube_idx_candidates = [idx for idx in [real_idx-1,  real_idx+1] if (idx>=0 and idx <= num_frames-1 )] # previous frame
        # pertube_idx_candidates = [idx for idx in [real_idx-2,  real_idx-1,  real_idx+1,  real_idx+2] if (idx>=0 and idx <= num_frames-1 )] # previous frame
        pertube_idx = random.choice(pertube_idx_candidates)
        item[f'real_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,])
        item[f'real_pertube_expression_2'] = item['real_expression'] * 2 - item[f'real_pertube_expression_1']

        # tgt
        fake_idx = sample_idx(img_dir, num_frames)
        min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2))
        while abs(fake_idx - real_idx) < min_offset:
            fake_idx = sample_idx(img_dir, num_frames)
            min_offset = min(50, max((num_frames-1-fake_idx)//2, (fake_idx)//2))
        fake_c2w = raw_item['c2w'][fake_idx]

        fake_intrinsics = raw_item['intrinsics'][fake_idx]
        fake_camera = np.concatenate([fake_c2w.reshape([16,]) , fake_intrinsics.reshape([9,])], axis=0)
        fake_camera = convert_to_tensor(fake_camera)
        item['fake_camera'] = fake_camera
        
        gt_img_fname = find_img_name(img_dir, fake_idx)
        gt_img = load_image_as_uint8_tensor(gt_img_fname)[..., :3] # ignore alpha channel when png
        item['fake_gt_img'] = gt_img.float() / 127.5 - 1
        seg_img_name = gt_img_fname.replace("/gt_imgs/",f"/segmaps/").replace(".jpg", ".png")
        seg_img = cv2.imread(seg_img_name)[:,:, ::-1]
        segmap = torch.from_numpy(decode_segmap_mask_from_image(seg_img)) # [6, H, W]
        item[f'fake_segmap'] = segmap
        item[f'fake_head_mask'] = segmap[[1,3,5]].sum(dim=0)
        item[f'fake_torso_mask'] = segmap[[2,4]].sum(dim=0)
        # for key in ['head', 'torso', 'torso_with_bg', 'person']:
        for key in ['head', 'com', 'inpaint_torso']:
            key_img_dir = img_dir.replace("/gt_imgs/",f"/{key}_imgs/")
            key_img_fname = find_img_name(key_img_dir, fake_idx)
            key_img = load_image_as_uint8_tensor(key_img_fname)[..., :3] # ignore alpha channel when png
            item[f'fake_{key}_img'] = key_img.float() / 127.5 - 1

        item.update({
            # id,exp,euler,trans, used to generate the secc map
            f'fake_identity': convert_to_tensor(raw_item['id']).reshape([80,]),
            # f'fake_identity': convert_to_tensor(raw_item['id'][fake_idx]).reshape([80,]),
            f'fake_expression': convert_to_tensor(raw_item['exp'][fake_idx]).reshape([64,]),
            f'fake_euler': convert_to_tensor(raw_item['euler'][fake_idx]).reshape([3,]),
            f'fake_trans': convert_to_tensor(raw_item['trans'][fake_idx]).reshape([3,]),
        })

        # pertube_idx_candidates = [idx for idx in [fake_idx-2,  fake_idx-1,  fake_idx+1,  fake_idx+2] if (idx>=0 and idx <= num_frames-1 )] # previous frame
        pertube_idx_candidates = [idx for idx in [fake_idx-1,  fake_idx+1] if (idx>=0 and idx <= num_frames-1 )] # previous frame
        pertube_idx = random.choice(pertube_idx_candidates)
        item[f'fake_pertube_expression_1'] = convert_to_tensor(raw_item['exp'][pertube_idx]).reshape([64,])
        item[f'fake_pertube_expression_2'] = item['fake_expression'] * 2 - item[f'fake_pertube_expression_1']

        return item

    def get_dataloader(self, batch_size=1, num_workers=0):
        loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers)
        return loader

    def collater(self, samples):
        hparams = self.hparams
        if len(samples) == 0:
            return {}
        batch = {}

        batch['th1kh_item_names'] = [s['item_name'] for s in samples]
        batch['th1kh_ref_gt_imgs'] = torch.stack([s['real_gt_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W]
        
        batch['th1kh_ref_head_masks'] = torch.stack([s['real_head_mask'] for s in samples]) # [B,6,H,W]
        batch['th1kh_ref_torso_masks'] = torch.stack([s['real_torso_mask'] for s in samples]) # [B,6,H,W]
        batch['th1kh_ref_segmaps'] = torch.stack([s['real_segmap'] for s in samples]) # [B,6,H,W]
        # for key in ['head', 'torso', 'torso_with_bg', 'person']:
        for key in ['head', 'com', 'inpaint_torso']:
            batch[f'th1kh_ref_{key}_imgs'] = torch.stack([s[f'real_{key}_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W]
        batch[f'th1kh_bg_imgs'] = torch.stack([s[f'bg_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W]
        
        batch['th1kh_ref_cameras'] = torch.stack([s['real_camera'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_ref_ids'] = torch.stack([s['real_identity'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_ref_exps'] = torch.stack([s['real_expression'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_ref_eulers'] = torch.stack([s['real_euler'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_ref_trans'] = torch.stack([s['real_trans'] for s in samples], dim=0) # [B, 204]

        batch['th1kh_mv_gt_imgs'] = torch.stack([s['fake_gt_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W]
        # for key in ['head', 'torso', 'torso_with_bg', 'person']:
        for key in ['head', 'com', 'inpaint_torso']:
            batch[f'th1kh_mv_{key}_imgs'] = torch.stack([s[f'fake_{key}_img'] for s in samples]).permute(0,3,1,2) # [B, H, W, 3]==>[B,3,H,W]

        batch['th1kh_mv_head_masks'] = torch.stack([s['fake_head_mask'] for s in samples]) # [B,6,H,W]
        batch['th1kh_mv_torso_masks'] = torch.stack([s['fake_torso_mask'] for s in samples]) # [B,6,H,W]
        batch['th1kh_mv_cameras'] = torch.stack([s['fake_camera'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_mv_ids'] = torch.stack([s['fake_identity'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_mv_exps'] = torch.stack([s['fake_expression'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_mv_eulers'] = torch.stack([s['fake_euler'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_mv_trans'] = torch.stack([s['fake_trans'] for s in samples], dim=0) # [B, 204]

        batch['th1kh_ref_pertube_exps_1'] = torch.stack([s['real_pertube_expression_1'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_ref_pertube_exps_2'] = torch.stack([s['real_pertube_expression_2'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_mv_pertube_exps_1'] = torch.stack([s['fake_pertube_expression_1'] for s in samples], dim=0) # [B, 204]
        batch['th1kh_mv_pertube_exps_2'] = torch.stack([s['fake_pertube_expression_2'] for s in samples], dim=0) # [B, 204]

        return batch

if __name__ == '__main__':
    os.environ["OMP_NUM_THREADS"] = "1"

    ds = Img2Plane_Dataset("train", 'data/binary/th1kh')
    # ds = Motion2Video_Dataset("train", 'data/binary/th1kh')
    dl = ds.get_dataloader()
    for b in tqdm(dl):
        pass