import os
import time
import datetime

import torch
import torchvision

from utils import misc, metrics

best_psnr = 0


def train(train_loader, val_loader, model, optimizer, scheduler, loss_fn, logger, opt):
    total_step = opt.epochs * len(train_loader)

    step_time_log = misc.AverageMeter()
    loss_log = misc.AverageMeter(':6f')
    loss_fg_content_bg_appearance_construct_log = misc.AverageMeter(':6f')
    loss_lut_transform_image_log = misc.AverageMeter(':6f')
    loss_lut_regularize_log = misc.AverageMeter(':6f')

    start_epoch = 0

    "Load pretrained checkpoints"
    if opt.pretrained is not None:
        logger.info(f"Load pretrained weight from {opt.pretrained}")
        load_state = torch.load(opt.pretrained)
        model = model.cpu()
        model.load_state_dict(load_state['model'])
        model = model.to(opt.device)
        optimizer.load_state_dict(load_state['optimizer'])
        scheduler.load_state_dict(load_state['scheduler'])
        start_epoch = load_state['last_epoch'] + 1

    for epoch in range(start_epoch, opt.epochs):
        model.train()
        time_ckp = time.time()
        for step, batch in enumerate(train_loader):
            current_step = epoch * len(train_loader) + step + 1

            if opt.INRDecode and opt.hr_train:
                "List with 4 elements: [Input to Encoder, three different resolutions' crop to INR Decoder]"
                composite_image = [batch[f'composite_image{name}'].to(opt.device) for name in range(4)]
                real_image = [batch[f'real_image{name}'].to(opt.device) for name in range(4)]
                mask = [batch[f'mask{name}'].to(opt.device) for name in range(4)]
                coordinate_map = [batch[f'coordinate_map{name}'].to(opt.device) for name in range(4)]

                fg_INR_coordinates = coordinate_map[1:]

            else:
                composite_image = batch['composite_image'].to(opt.device)
                real_image = batch['real_image'].to(opt.device)
                mask = batch['mask'].to(opt.device)

                fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)

            fg_content_bg_appearance_construct, fit_lut3d, lut_transform_image = model(
                composite_image, mask, fg_INR_coordinates)

            if opt.INRDecode:
                loss_fg_content_bg_appearance_construct = 0
                """
                    Our LRIP module requires three different resolution layers, thus here 
                    `loss_fg_content_bg_appearance_construct` is calculated in multiple layers. 
                    Besides, when leverage `hr_train`, i.e. use RSC strategy (See Section 3.4), the `real_image`
                    and `mask` are list type, corresponding different resolutions' crop.
                """
                if opt.hr_train:
                    for n in range(3):
                        loss_fg_content_bg_appearance_construct += loss_fn['masked_mse'] \
                            (fg_content_bg_appearance_construct[n], real_image[3 - n], mask[3 - n])
                    loss_fg_content_bg_appearance_construct /= 3
                    loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image[1], mask[1])
                else:
                    for n in range(3):
                        loss_fg_content_bg_appearance_construct += loss_fn['MaskWeightedMSE'] \
                            (fg_content_bg_appearance_construct[n],
                             torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(real_image),
                             torchvision.transforms.Resize(opt.INR_input_size // 2 ** (3 - n - 1))(mask))
                    loss_fg_content_bg_appearance_construct /= 3
                    loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask)
                loss_lut_regularize = loss_fn['regularize_LUT'](fit_lut3d)

            else:
                loss_fg_content_bg_appearance_construct = 0
                loss_lut_transform_image = loss_fn['masked_mse'](lut_transform_image, real_image, mask)
                loss_lut_regularize = 0

            loss = loss_fg_content_bg_appearance_construct + loss_lut_transform_image + loss_lut_regularize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            step_time_log.update(time.time() - time_ckp)

            loss_fg_content_bg_appearance_construct_log.update(0 if isinstance(loss_fg_content_bg_appearance_construct,
                                                                               int) else loss_fg_content_bg_appearance_construct.item())
            loss_lut_transform_image_log.update(
                0 if isinstance(loss_lut_transform_image, int) else loss_lut_transform_image.item())
            loss_lut_regularize_log.update(0 if isinstance(loss_lut_regularize, int) else loss_lut_regularize.item())
            loss_log.update(loss.item())

            if current_step % opt.print_freq == 0:
                remain_secs = (total_step - current_step) * step_time_log.avg
                remain_time = datetime.timedelta(seconds=round(remain_secs))
                finish_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remain_secs))

                log_msg = f'Epoch: [{epoch}/{opt.epochs}]\t' \
                          f'Step: [{step}/{len(train_loader)}]\t' \
                          f'StepTime {step_time_log.val:.3f} ({step_time_log.avg:.3f})\t' \
                          f'lr {optimizer.param_groups[0]["lr"]}\t' \
                          f'Loss {loss_log.val:.4f} ({loss_log.avg:.4f})\t' \
                          f'Loss_fg_bg_cons {loss_fg_content_bg_appearance_construct_log.val:.4f} ({loss_fg_content_bg_appearance_construct_log.avg:.4f})\t' \
                          f'Loss_lut_trans {loss_lut_transform_image_log.val:.4f} ({loss_lut_transform_image_log.avg:.4f})\t' \
                          f'Loss_lut_reg {loss_lut_regularize_log.val:.4f} ({loss_lut_regularize_log.avg:.4f})\t' \
                          f'Remaining Time {remain_time} ({finish_time})'
                logger.info(log_msg)

                if opt.wandb:
                    import wandb
                    wandb.log(
                        {'Train/Epoch': epoch, 'Train/lr': optimizer.param_groups[0]['lr'], 'Train/Step': current_step,
                         'Train/Loss': loss_log.val,
                         'Train/Loss_fg_bg_cons': loss_fg_content_bg_appearance_construct_log.val,
                         'Train/Loss_lut_trans': loss_lut_transform_image_log.val,
                         'Train/Loss_lut_reg': loss_lut_regularize_log.val,
                         })

            time_ckp = time.time()

        state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'last_epoch': epoch,
                 'scheduler': scheduler.state_dict()}

        """
            As the validation of original resolution Harmonization will have no consistent resolution among images 
            (so fail to form a batch) and also may lead to out-of-memory problem when combined with training phase,
            we here only save the model when `opt.isFullRes` is True, leaving the evaluation in `inference.py`.
        """
        if opt.isFullRes and opt.hr_train:
            if epoch % 5 == 0:
                torch.save(state, os.path.join(opt.save_path, f"epoch{epoch}.pth"))
            else:
                torch.save(state, os.path.join(opt.save_path, "last.pth"))
        else:
            val(val_loader, model, logger, opt, state)


def val(val_loader, model, logger, opt, state):
    global best_psnr
    current_process = 10
    model.eval()

    metric_log = {
        'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
    }

    lut_metric_log = {
        'HAdobe5k': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'HCOCO': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'Hday2night': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'HFlickr': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
        'All': {'Samples': 0, 'MSE': 0, 'fMSE': 0, 'PSNR': 0, 'SSIM': 0},
    }

    for step, batch in enumerate(val_loader):
        composite_image = batch['composite_image'].to(opt.device)
        real_image = batch['real_image'].to(opt.device)
        mask = batch['mask'].to(opt.device)
        category = batch['category']

        fg_INR_coordinates = batch['fg_INR_coordinates'].to(opt.device)
        bg_INR_coordinates = batch['bg_INR_coordinates'].to(opt.device)
        fg_transfer_INR_RGB = batch['fg_transfer_INR_RGB'].to(opt.device)

        with torch.no_grad():
            fg_content_bg_appearance_construct, _, lut_transform_image = model(
                composite_image,
                mask,
                fg_INR_coordinates,
                bg_INR_coordinates)
        if opt.INRDecode:
            pred_fg_image = fg_content_bg_appearance_construct[-1]
        else:
            pred_fg_image = None
        fg_transfer_INR_RGB = misc.lin2img(fg_transfer_INR_RGB,
                                           val_loader.dataset.INR_dataset.size) if fg_transfer_INR_RGB is not None else None

        "For INR"
        mask_INR = torchvision.transforms.Resize(opt.INR_input_size)(mask)

        if not opt.INRDecode:
            pred_harmonized_image = None
        else:
            pred_harmonized_image = pred_fg_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))
        lut_transform_image = lut_transform_image * (mask > 100 / 255.) + real_image * (~(mask > 100 / 255.))

        "Save the output images. For every 10 epochs, save more results, otherwise, save little. Thus save storage."
        if state['last_epoch'] % 10 == 0:
            misc.visualize(real_image, composite_image, mask, pred_fg_image,
                           pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False,
                           wandb=opt.wandb, isAll=True, step=step)
        elif step == 0:
            misc.visualize(real_image, composite_image, mask, pred_fg_image,
                           pred_harmonized_image, lut_transform_image, opt, state['last_epoch'], show=False,
                           wandb=opt.wandb, step=step)

        if opt.INRDecode:
            mse, fmse, psnr, ssim = metrics.calc_metrics(misc.normalize(pred_harmonized_image, opt, 'inv'),
                                                         misc.normalize(fg_transfer_INR_RGB, opt, 'inv'), mask_INR)

        lut_mse, lut_fmse, lut_psnr, lut_ssim = metrics.calc_metrics(misc.normalize(lut_transform_image, opt, 'inv'),
                                                                     misc.normalize(real_image, opt, 'inv'), mask)

        for idx in range(len(category)):
            if opt.INRDecode:
                metric_log[category[idx]]['Samples'] += 1
                metric_log[category[idx]]['MSE'] += mse[idx]
                metric_log[category[idx]]['fMSE'] += fmse[idx]
                metric_log[category[idx]]['PSNR'] += psnr[idx]
                metric_log[category[idx]]['SSIM'] += ssim[idx]

                metric_log['All']['Samples'] += 1
                metric_log['All']['MSE'] += mse[idx]
                metric_log['All']['fMSE'] += fmse[idx]
                metric_log['All']['PSNR'] += psnr[idx]
                metric_log['All']['SSIM'] += ssim[idx]

            lut_metric_log[category[idx]]['Samples'] += 1
            lut_metric_log[category[idx]]['MSE'] += lut_mse[idx]
            lut_metric_log[category[idx]]['fMSE'] += lut_fmse[idx]
            lut_metric_log[category[idx]]['PSNR'] += lut_psnr[idx]
            lut_metric_log[category[idx]]['SSIM'] += lut_ssim[idx]

            lut_metric_log['All']['Samples'] += 1
            lut_metric_log['All']['MSE'] += lut_mse[idx]
            lut_metric_log['All']['fMSE'] += lut_fmse[idx]
            lut_metric_log['All']['PSNR'] += lut_psnr[idx]
            lut_metric_log['All']['SSIM'] += lut_ssim[idx]

        if (step + 1) / len(val_loader) * 100 >= current_process:
            logger.info(f'Processing: {current_process}')
            current_process += 10

    logger.info('=========================')
    for key in metric_log.keys():
        if opt.INRDecode:
            msg = f"{key}-'MSE': {metric_log[key]['MSE'] / metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'fMSE': {metric_log[key]['fMSE'] / metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'PSNR': {metric_log[key]['PSNR'] / metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'SSIM': {metric_log[key]['SSIM'] / metric_log[key]['Samples']:.4f}\n" \
                  f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"
        else:
            msg = f"{key}-'LUT_MSE': {lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'LUT_fMSE': {lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'LUT_PSNR': {lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples']:.2f}\n" \
                  f"{key}-'LUT_SSIM': {lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']:.4f}\n"

        logger.info(msg)

        if opt.wandb:
            import wandb
            if opt.INRDecode:
                wandb.log(
                    {f'Val/{key}/Epoch': state['last_epoch'],
                     f'Val/{key}/MSE': metric_log[key]['MSE'] / metric_log[key]['Samples'],
                     f'Val/{key}/fMSE': metric_log[key]['fMSE'] / metric_log[key]['Samples'],
                     f'Val/{key}/PSNR': metric_log[key]['PSNR'] / metric_log[key]['Samples'],
                     f'Val/{key}/SSIM': metric_log[key]['SSIM'] / metric_log[key]['Samples'],
                     f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'],
                     f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'],
                     f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'],
                     f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']
                     })
            else:
                wandb.log(
                    {f'Val/{key}/Epoch': state['last_epoch'],
                     f'Val/{key}/LUT_MSE': lut_metric_log[key]['MSE'] / lut_metric_log[key]['Samples'],
                     f'Val/{key}/LUT_fMSE': lut_metric_log[key]['fMSE'] / lut_metric_log[key]['Samples'],
                     f'Val/{key}/LUT_PSNR': lut_metric_log[key]['PSNR'] / lut_metric_log[key]['Samples'],
                     f'Val/{key}/LUT_SSIM': lut_metric_log[key]['SSIM'] / lut_metric_log[key]['Samples']
                     })

    logger.info('=========================')

    if not opt.INRDecode:
        if lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples'] > best_psnr:
            logger.info("Best Save!")
            best_psnr = lut_metric_log['All']['PSNR'] / lut_metric_log['All']['Samples']
            torch.save(state, os.path.join(opt.save_path, "best.pth"))
        else:
            logger.info("Last Save!")
            torch.save(state, os.path.join(opt.save_path, "last.pth"))
    else:
        if metric_log['All']['PSNR'] / metric_log['All']['Samples'] > best_psnr:
            logger.info("Best Save!")
            best_psnr = metric_log['All']['PSNR'] / metric_log['All']['Samples']
            torch.save(state, os.path.join(opt.save_path, "best.pth"))
        else:
            logger.info("Last Save!")
            torch.save(state, os.path.join(opt.save_path, "last.pth"))