import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist


from models.epalm import ePALM
from models.utils import freeze_whole_model, unfreeze_parameters, print_trainable_params_percentage

 

from transformers import AutoTokenizer


import utils



from dataset.audio_caption import get_loader 

from scheduler import create_scheduler
from optim import create_optimizer
 



from models.utils import filter_state, filter_msg, exclude_list


def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config):
    model.train()  
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    header = 'Train Epoch: [{}]'.format(epoch)
    print_freq = 50    
    step_size = 100
    warmup_iterations = warmup_steps*step_size  
    lm_loss_weight = config.get('lm_loss_weight', 1)
    append_eos_token = config.get('append_eos_token', False)
    eos_token = tokenizer.eos_token

    config_optim = utils.AttrDict(config['optimizer'])
    prompt_lr = config_optim.prompt_lr if hasattr(config_optim, 'prompt_lr') else None

    task_prompt = config.get('task_prompt', None)

    if prompt_lr is not None:
        metric_logger.add_meter('prompt_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))


    for i, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        image = batch["images"].to(device,non_blocking=True)

        text = batch["sent"]

        if append_eos_token:
            text = [t.replace(eos_token, '') + eos_token for t in text]

        if task_prompt is not None:
            text = [task_prompt + ' ' + t for t in text]



        text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device) 


        targets = text_input.input_ids.masked_fill(text_input.input_ids == tokenizer.pad_token_id, -100)
        

        answer_output = model(image=image, 
                              text=text_input, 
                              labels = targets,
                              return_dict = True,   
                              mode='train',
                              reduction='none',
                             )      
        
        loss = answer_output.loss         
        loss = loss.sum()/image.size(0)
        loss = loss*lm_loss_weight
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        
        metric_logger.update(loss=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        if prompt_lr is not None:
            metric_logger.update(prompt_lr=optimizer.param_groups[1]["lr"])
        if epoch==0 and i%step_size==0 and i<=warmup_iterations: 
            scheduler.step(i//step_size) 

            
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())     
    return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 




@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config) :
    model.eval()
            
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Generate Caption test result:'
    print_freq = 50

    predictions = []
    targets = []


    task_prompt = config.get('task_prompt', None)

    pad_token = tokenizer.pad_token
    eos_token = tokenizer.eos_token

    for n, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):        
        
        image = batch["images"].to(device,non_blocking=True)
        text = ['' for q in image]  
        if task_prompt is not None:
            text = [task_prompt + ' ' + t for t in text]
        text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device) 

        out = model(image=image, text=text_input, mode='generate', return_dict=True, max_length=30, do_sample=True)
        out_decode = []
        for i, o in enumerate(out):
            try:

                res = tokenizer.decode(o)
                if task_prompt is not None:
                    res = res.replace(task_prompt, '')
                response = res.split('</s>')[1].replace(pad_token, '').replace('</s>', '').replace(eos_token, '') # skip_special_tokens=True
            except TypeError:
                print(o)
                response = ' '
            out_decode.append(response)


        predictions.extend(out_decode)

        if 'targets' in batch:
            targets.extend(batch['targets'])




    evaluator = data_loader.evaluator
    eval_results = evaluator.evaluate(predictions, targets)


    wandb_log_dict = {}

    for score_name, score in eval_results.items():
        wandb_log_dict[f'Valid/{score_name}'] = score


    print(wandb_log_dict)

    return wandb_log_dict



def main(args, config):

    utils.init_distributed_mode(args)    
    
    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True
    
    start_epoch = 0
    max_epoch = config['schedular']['epochs']
    warmup_steps = config['schedular']['warmup_epochs']
    
    print(args, config)


    tokenizer = AutoTokenizer.from_pretrained(args.text_model, use_fast=False, local_files_only=True)



    
    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()     
    else:
        num_tasks = None
        global_rank = None


    #########
    num_workers = config.get('num_workers', 4)
    train_topk = config.get('train_topk', -1)
    valid_topk = config.get('valid_topk', -1)
    data_dir = args.data_dir

    args.image_size = config.get('image_res', 224)
    args.use_data_augmentation = True 

    black_image = config.get('black_image', False)
    print("black image:", black_image)



    # audio 
    args.melbins = config.get('melbins', 128)
    args.target_length = config.get('target_length', 1024)
    args.num_tries = config.get('num_tries', 1)

    args.skip_norm = config.get('skip_norm', True)
    args.norm_mean = config.get('norm_mean', None)
    args.norm_std = config.get('norm_std', None)
    args.noise = config.get('noise', False)

    args.freqm_p = config.get('freqm_p', 48)
    args.timem_p = config.get('timem_p', 192)



    train_split = config.get('train_split', 'train') 
    val_split = config.get('val_split', 'val') 
    test_split = config.get('test_split', 'test') 




    train_loader = get_loader(
        args,
        split=train_split, mode='train', batch_size=config['batch_size_train'],
        distributed=args.distributed,
        workers=num_workers,
        topk=train_topk,
        data_dir=data_dir,
        local_rank=global_rank, world_size=num_tasks, verbose=True, black_image=black_image
    )

    print('# len train loader:', len(train_loader))
    print(f'Building val loader')
    val_loader = get_loader(
        args,
        split=val_split, mode='val', batch_size=config['batch_size_test'],
        distributed=False, 
        workers=4,
        topk=valid_topk,data_dir=data_dir,
        local_rank=global_rank, world_size=num_tasks, verbose=True, black_image=black_image
    )
    print('# len val loader:', len(val_loader))

    print(f'Building test loader')
    test_loader = get_loader(
        args,
        split=test_split, mode='val', batch_size=config['batch_size_test'],
        distributed=False, 
        workers=4,
        topk=valid_topk,data_dir=data_dir,
        local_rank=global_rank, world_size=num_tasks, verbose=True
    )


    print('# len test loader:', len(test_loader))

    #### Model #### 
    print("Creating model")
    
    start_layer_idx = config.get('start_layer_idx', 0)
    end_layer_idx = config.get('end_layer_idx', 0)

    vision_model_name = config.get('vision_model_name', args.vision_model)

    model = ePALM(opt_model_name = args.text_model, 
                   vision_model_name = vision_model_name, 
                   use_vis_prefix = True, 
                   start_layer_idx = start_layer_idx, 
                   end_layer_idx = end_layer_idx, 
                   return_hidden_state_vision = True, 
                   config=config,
    )
    
        
    model = model.to(device)   
    
    arg_opt = utils.AttrDict(config['optimizer'])
    optimizer = create_optimizer(arg_opt, model, config=config)

    if hasattr(arg_opt, 'prompt_lr') and arg_opt.prompt_lr is not None:
        print('\tInitial other params params lr: %f' % optimizer.param_groups[0]['lr'])
        print('\tInitial prompt params lr: %f' % optimizer.param_groups[1]['lr'])

    arg_sche = utils.AttrDict(config['schedular'])
    lr_scheduler, _ = create_scheduler(arg_sche, optimizer)          
         
    best_epoch = 0 
    best_valid = 0 
    
    if args.checkpoint:    

        checkpoint = torch.load(args.checkpoint, map_location='cpu') 
        state_dict = checkpoint['model']
        msg = model.load_state_dict(state_dict,strict=False)  
        msg = filter_msg(msg, exclude_list)
        print('load checkpoint from %s'%args.checkpoint)
        print(msg)  

        if args.resume:
            model = model.to(device) 
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            start_epoch = checkpoint['epoch']+1  
            print(checkpoint.keys())
        if 'best_valid' in checkpoint:
            best_valid = checkpoint['best_valid'] 
            best_epoch = checkpoint['best_epoch'] 
            print("load best valid {} at epoch {}".format(best_valid, best_epoch))

    
    freeze_whole_model(model)
    unfreeze_parameters(model, config)
    print_trainable_params_percentage(model)
    
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module    
    
    
    print("Start training")
    start_time = time.time()


    for epoch in range(start_epoch, max_epoch):
        if epoch>0:
            lr_scheduler.step(epoch+warmup_steps)  
        
        if not args.evaluate:
            if args.distributed:
                train_loader.sampler.set_epoch(epoch)

            train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config)  

        if args.evaluate:
            break
            

        valid_results = evaluation(model, val_loader, tokenizer, device, config) 

        if utils.is_main_process():               
            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                         'epoch': epoch,
                        }                
            with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
                f.write(json.dumps(log_stats) + "\n")                        
                    
            save_obj = {
                'model': filter_state(model_without_ddp.state_dict(), exclude_list),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'config': config,
                'epoch': epoch,
                'best_valid': best_valid,
                'best_epoch': best_epoch,
            }

            if args.save_best:
                valid_score = valid_results['Valid/CIDEr']

                if valid_score > best_valid or epoch == 0:
                    best_valid = valid_score
                    best_epoch = epoch
                    print("Save best epoch:", best_epoch)

                    save_obj['best_valid'] = best_valid
                    save_obj['best_epoch'] = best_epoch

                    torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))  
            # else:
            torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_last.pth'))  


        dist.barrier()   
    

    if not args.evaluate:
        checkpoint = torch.load(os.path.join(args.output_dir, 'checkpoint_best.pth'), map_location='cpu') 
        state_dict = checkpoint['model']   
        msg = model.module.load_state_dict(state_dict,strict=False)  
        msg = filter_msg(msg, exclude_list)
        print('load checkpoint for test from %s'%os.path.join(args.output_dir, 'checkpoint_best.pth'))
        print(msg)
    vqa_result = evaluation(model, test_loader, tokenizer, device, config)    

                     
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str)) 
    
            

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/VQA.yaml') 
    parser.add_argument('--checkpoint', default='') 
    parser.add_argument('--output_dir', default='output/vqa')
    parser.add_argument('--evaluate', action='store_true')    
    parser.add_argument('--text_model', default='facebook/opt-350m')
    parser.add_argument('--vision_model', default='vit_base_patch16_224')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')    
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=True, type=bool)
    
    parser.add_argument('--data_dir', default='/data/mshukor/data')   
    parser.add_argument('--resume', action='store_true')    

    parser.add_argument('--save_best', action='store_true') 
    
    parser.add_argument('--image_dir', default='/data/mshukor/data')   

    

    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    args.result_dir = os.path.join(args.output_dir, 'result')

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    Path(args.result_dir).mkdir(parents=True, exist_ok=True)
        
    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))    
    
    main(args, config)