import warnings

import numpy as np
import tensorflow as tf
import torch

from interpolator import Interpolator


def translate_state_dict(var_dict, state_dict):
    for name, (prev_name, weight) in zip(state_dict, var_dict.items()):
        print('Mapping', prev_name, '->', name)
        weight = torch.from_numpy(weight)
        if 'kernel' in prev_name:
            # Transpose the conv2d kernel weights, since TF uses (H, W, C, K) and PyTorch uses (K, C, H, W)
            weight = weight.permute(3, 2, 0, 1)

        assert state_dict[name].shape == weight.shape, f'Shape mismatch {state_dict[name].shape} != {weight.shape}'

        state_dict[name] = weight


def import_state_dict(interpolator: Interpolator, saved_model):
    variables = saved_model.keras_api.variables

    extract_dict = interpolator.extract.state_dict()
    flow_dict = interpolator.predict_flow.state_dict()
    fuse_dict = interpolator.fuse.state_dict()

    extract_vars = {}
    _flow_vars = {}
    _fuse_vars = {}

    for var in variables:
        name = var.name
        if name.startswith('feat_net'):
            extract_vars[name[9:]] = var.numpy()
        elif name.startswith('predict_flow'):
            _flow_vars[name[13:]] = var.numpy()
        elif name.startswith('fusion'):
            _fuse_vars[name[7:]] = var.numpy()

    # reverse order of modules to allow jit export
    # TODO: improve this hack
    flow_vars = dict(sorted(_flow_vars.items(), key=lambda x: x[0].split('/')[0], reverse=True))
    fuse_vars = dict(sorted(_fuse_vars.items(), key=lambda x: int((x[0].split('/')[0].split('_')[1:] or [0])[0]) // 3, reverse=True))

    assert len(extract_vars) == len(extract_dict), f'{len(extract_vars)} != {len(extract_dict)}'
    assert len(flow_vars) == len(flow_dict), f'{len(flow_vars)} != {len(flow_dict)}'
    assert len(fuse_vars) == len(fuse_dict), f'{len(fuse_vars)} != {len(fuse_dict)}'

    for state_dict, var_dict in ((extract_dict, extract_vars), (flow_dict, flow_vars), (fuse_dict, fuse_vars)):
        translate_state_dict(var_dict, state_dict)

    interpolator.extract.load_state_dict(extract_dict)
    interpolator.predict_flow.load_state_dict(flow_dict)
    interpolator.fuse.load_state_dict(fuse_dict)


def verify_debug_outputs(pt_outputs, tf_outputs):
    max_error = 0
    for name, predicted in pt_outputs.items():
        if name == 'image':
            continue
        pred_frfp = [f.permute(0, 2, 3, 1).detach().cpu().numpy() for f in predicted]
        true_frfp = [f.numpy() for f in tf_outputs[name]]

        for i, (pred, true) in enumerate(zip(pred_frfp, true_frfp)):
            assert pred.shape == true.shape, f'{name} {i} shape mismatch {pred.shape} != {true.shape}'
            error = np.max(np.abs(pred - true))
            max_error = max(max_error, error)
            assert error < 1, f'{name} {i} max error: {error}'
    print('Max intermediate error:', max_error)


def test_model(interpolator, model, half=False, gpu=False):
    torch.manual_seed(0)
    time = torch.full((1, 1), .5)
    x0 = torch.rand(1, 3, 256, 256)
    x1 = torch.rand(1, 3, 256, 256)

    x0_ = tf.convert_to_tensor(x0.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
    x1_ = tf.convert_to_tensor(x1.permute(0, 2, 3, 1).numpy(), dtype=tf.float32)
    time_ = tf.convert_to_tensor(time.numpy(), dtype=tf.float32)
    tf_outputs = model({'x0': x0_, 'x1': x1_, 'time': time_}, training=False)

    if half:
        x0 = x0.half()
        x1 = x1.half()
        time = time.half()

    if gpu and torch.cuda.is_available():
        x0 = x0.cuda()
        x1 = x1.cuda()
        time = time.cuda()

    with torch.no_grad():
        pt_outputs = interpolator.debug_forward(x0, x1, time)

    verify_debug_outputs(pt_outputs, tf_outputs)

    with torch.no_grad():
        prediction = interpolator(x0, x1, time)
    output_color = prediction.permute(0, 2, 3, 1).detach().cpu().numpy()
    true_color = tf_outputs['image'].numpy()
    error = np.abs(output_color - true_color).max()

    print('Color max error:', error)


def main(model_path, save_path, export_to_torchscript=True, use_gpu=False, fp16=True, skiptest=False):
    print(f'Exporting model to FP{["32", "16"][fp16]} {["state_dict", "torchscript"][export_to_torchscript]} '
          f'using {"CG"[use_gpu]}PU')
    model = tf.compat.v2.saved_model.load(model_path)
    interpolator = Interpolator()
    interpolator.eval()
    import_state_dict(interpolator, model)

    if use_gpu and torch.cuda.is_available():
        interpolator = interpolator.cuda()
    else:
        use_gpu = False

    if fp16:
        interpolator = interpolator.half()
    if export_to_torchscript:
        interpolator = torch.jit.script(interpolator)
    if export_to_torchscript:
        interpolator.save(save_path)
    else:
        torch.save(interpolator.state_dict(), save_path)

    if not skiptest:
        if not use_gpu and fp16:
            warnings.warn('Testing FP16 model on CPU is impossible, casting it back')
            interpolator = interpolator.float()
            fp16 = False
        test_model(interpolator, model, fp16, use_gpu)


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='Export frame-interpolator model to PyTorch state dict')

    parser.add_argument('model_path', type=str, help='Path to the TF SavedModel')
    parser.add_argument('save_path', type=str, help='Path to save the PyTorch state dict')
    parser.add_argument('--statedict', action='store_true', help='Export to state dict instead of TorchScript')
    parser.add_argument('--fp32', action='store_true', help='Save at full precision')
    parser.add_argument('--skiptest', action='store_true', help='Skip testing and save model immediately instead')
    parser.add_argument('--gpu', action='store_true', help='Use GPU')

    args = parser.parse_args()

    main(args.model_path, args.save_path, not args.statedict, args.gpu, not args.fp32, args.skiptest)