import os
import sys
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as T

import utils.transforms as data
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS

from tools import *

def test_video_dataset():
    cfg_update = pConfig(load=True)

    for k, v in cfg_update.cfg_dict.items():
        if isinstance(v, dict) and k in cfg:
            cfg[k].update(v)
        else:
            cfg[k] = v

    exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
    save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name)
    os.system('rm -rf %s' % (save_dir))
    os.makedirs(save_dir, exist_ok=True)

    train_trans = data.Compose([
        data.CenterCropWide(size=cfg.resolution),
        data.ToTensor(),
        data.Normalize(mean=cfg.mean, std=cfg.std)])
    vit_trans = T.Compose([
        data.CenterCropWide(cfg.vit_resolution),
        T.ToTensor(),
        T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])

    video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
    video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w

    img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
    img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w

    vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
    vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w

    txt_size = cfg.resolution[1]
    nc = int(38 * (txt_size / 256))
    font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)

    dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans)
    print('There are %d videos' % (len(dataset)))
    for idx, item in enumerate(dataset):
        ref_frame, vit_frame, video_data, caption, video_key = item

        video_data = video_data.mul_(video_std).add_(video_mean)
        video_data.clamp_(0, 1)
        video_data = video_data.permute(0, 2, 3, 1)
        video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]

        # Single Image
        ref_frame = ref_frame.mul_(img_mean).add_(img_std)
        ref_frame.clamp_(0, 1)
        ref_frame = ref_frame.permute(1, 2, 0)
        ref_frame = (ref_frame.numpy() * 255).astype('uint8')

        # Text image
        txt_img = Image.new("RGB", (txt_size, txt_size), color="white") 
        draw = ImageDraw.Draw(txt_img)
        lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
        draw.text((0, 0), lines, fill="black", font=font)
        txt_img = np.array(txt_img)

        video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data]
        spath = os.path.join(save_dir, '%04d.gif' % (idx))
        imageio.mimwrite(spath, video_data, fps =8)

        # if idx > 100: break


def test_vit_image(test_video_flag=True):
    cfg_update = pConfig(load=True)

    for k, v in cfg_update.cfg_dict.items():
        if isinstance(v, dict) and k in cfg:
            cfg[k].update(v)
        else:
            cfg[k] = v

    exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
    save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name)
    os.system('rm -rf %s' % (save_dir))
    os.makedirs(save_dir, exist_ok=True)

    train_trans = data.Compose([
        data.CenterCropWide(size=cfg.resolution),
        data.ToTensor(),
        data.Normalize(mean=cfg.mean, std=cfg.std)])
    vit_trans = data.Compose([
        data.CenterCropWide(cfg.resolution),
        data.Resize(cfg.vit_resolution),
        data.ToTensor(),
        data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])

    img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
    img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
    
    vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
    vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w

    txt_size = cfg.resolution[1]
    nc = int(38 * (txt_size / 256))
    font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13)

    dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans)
    print('There are %d videos' % (len(dataset)))
    for idx, item in enumerate(dataset):
        ref_frame, vit_frame, video_data, caption, video_key = item
        video_data = video_data.mul_(img_std).add_(img_mean)
        video_data.clamp_(0, 1)
        video_data = video_data.permute(0, 2, 3, 1)
        video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]

        # Single Image
        vit_frame = vit_frame.mul_(vit_std).add_(vit_mean)
        vit_frame.clamp_(0, 1)
        vit_frame = vit_frame.permute(1, 2, 0)
        vit_frame = (vit_frame.numpy() * 255).astype('uint8')

        zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8)
        zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame

        # Text image
        txt_img = Image.new("RGB", (txt_size, txt_size), color="white") 
        draw = ImageDraw.Draw(txt_img)
        lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
        draw.text((0, 0), lines, fill="black", font=font)
        txt_img = np.array(txt_img)

        video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data]
        spath = os.path.join(save_dir, '%04d.gif' % (idx))
        imageio.mimwrite(spath, video_data, fps =8)

        # if idx > 100: break


if __name__ == '__main__':
    # test_video_dataset()
    test_vit_image()