# -*- coding: utf-8 -*-
import sys 
sys.path.append(".") 

import cv2
import os
import numpy as np
import argparse
from PIL import Image

import torch
from torch.utils.data import DataLoader

from core.dataset import TestDataset
from model.modules.flow_comp_raft import RAFT_bi
from model.recurrent_flow_completion import RecurrentFlowCompleteNet

from RAFT.utils.flow_viz_pt import flow_to_image

import cvbase
import imageio
from time import time

import warnings
warnings.filterwarnings("ignore")

def create_dir(dir):
    """Creates a directory if not exist.
    """
    if not os.path.exists(dir):
        os.makedirs(dir)

def save_flows(output, videoFlowF, videoFlowB):
    # create_dir(os.path.join(output, 'forward_flo'))
    # create_dir(os.path.join(output, 'backward_flo'))
    create_dir(os.path.join(output, 'forward_png'))
    create_dir(os.path.join(output, 'backward_png'))
    N = videoFlowF.shape[-1]
    for i in range(N):
        forward_flow = videoFlowF[..., i]
        backward_flow = videoFlowB[..., i]
        forward_flow_vis = cvbase.flow2rgb(forward_flow)
        backward_flow_vis = cvbase.flow2rgb(backward_flow)
        # cvbase.write_flow(forward_flow, os.path.join(output,  'forward_flo', '{:05d}.flo'.format(i)))
        # cvbase.write_flow(backward_flow, os.path.join(output,  'backward_flo', '{:05d}.flo'.format(i)))
        forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8)
        backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8)
        imageio.imwrite(os.path.join(output,  'forward_png', '{:05d}.png'.format(i)), forward_flow_vis)
        imageio.imwrite(os.path.join(output,  'backward_png', '{:05d}.png'.format(i)), backward_flow_vis)

def tensor2np(array):
    array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
    return array

def main_worker(args):
    # set up datasets and data loader
    args.size = (args.width, args.height)
    test_dataset = TestDataset(vars(args))

    test_loader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.num_workers)

    # set up models
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    fix_raft = RAFT_bi(args.raft_model_path, device)
    
    fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path)
    for p in fix_flow_complete.parameters():
        p.requires_grad = False
    fix_flow_complete.to(device)
    fix_flow_complete.eval()

    total_frame_epe = []
    time_all = []

    print('Start evaluation...')
    # create results directory
    result_path = os.path.join('results_flow', f'{args.dataset}')
    if not os.path.exists(result_path):
        os.makedirs(result_path)

    eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w")

    for index, items in enumerate(test_loader):
        frames, masks, flows_f, flows_b, video_name, frames_PIL = items
        local_masks = masks.float().to(device)

        video_length = frames.size(1)
        
        if args.load_flow:
            gt_flows_bi = (flows_f.to(device), flows_b.to(device))
        else:
            short_len = 60
            if frames.size(1) > short_len:
                gt_flows_f_list, gt_flows_b_list = [], []
                for f in range(0, video_length, short_len):
                    end_f = min(video_length, f + short_len)
                    if f == 0:
                        flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
                    else:
                        flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)
                    
                    gt_flows_f_list.append(flows_f)
                    gt_flows_b_list.append(flows_b)
                    gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
                    gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
                    gt_flows_bi = (gt_flows_f, gt_flows_b)
            else:
                gt_flows_bi = fix_raft(frames, iters=20)

        torch.cuda.synchronize()
        time_start = time()

        # flow_length = flows_f.size(1)
        # f_stride = 30
        # pred_flows_f = []
        # pred_flows_b = []
        # suffix = flow_length%f_stride
        # last = flow_length//f_stride
        # for f in range(0, flow_length, f_stride):
        #     gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride])
        #     pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1])
        #     pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1])
        #     pred_flows_f.append(pred_flows_f_i)
        #     pred_flows_b.append(pred_flows_b_i)
        # pred_flows_f = torch.cat(pred_flows_f, dim=1)
        # pred_flows_b = torch.cat(pred_flows_b, dim=1)
        # pred_flows_bi = (pred_flows_f, pred_flows_b)

        pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
        pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)

        torch.cuda.synchronize()
        time_i = time() - time_start
        time_i = time_i*1.0/frames.size(1)

        time_all = time_all+[time_i]*frames.size(1)

        cur_video_epe = []
        
        epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt())
        epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt())

        cur_video_epe.append(epe1.numpy())
        cur_video_epe.append(epe2.numpy())

        total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1)
        total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1)

        cur_epe = sum(cur_video_epe) / len(cur_video_epe)
        avg_time = sum(time_all) / len(time_all)
        print(
            f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}'
        )
        eval_summary.write(
            f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n'
        )

        # saving images for evaluating warpping errors
        if args.save_results:
            forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4)
            backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4)
            # forward_flows = flows_f.cpu().permute(1,0,2,3,4)
            # backward_flows = flows_b.cpu().permute(1,0,2,3,4)
            videoFlowF = list(forward_flows)
            videoFlowB = list(backward_flows)

            videoFlowF = tensor2np(videoFlowF)
            videoFlowB = tensor2np(videoFlowB)

            save_frame_path = os.path.join(result_path, video_name[0])
            save_flows(save_frame_path, videoFlowF, videoFlowB)

    avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe)

    print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}')
    eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n')
    eval_summary.close()
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--height', type=int, default=240)
    parser.add_argument('--width', type=int, default=432)
    parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str)
    parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str)
    parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str)
    parser.add_argument('--video_root', default='dataset_root', type=str)
    parser.add_argument('--mask_root', default='mask_root', type=str)
    parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str)
    parser.add_argument('--load_flow', default=False, type=bool)
    parser.add_argument("--raft_iter", type=int, default=20)
    parser.add_argument('--save_results', action='store_true')
    parser.add_argument('--num_workers', default=4, type=int)
    args = parser.parse_args()
    main_worker(args)