import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from thop import profile
from ptflops import get_model_complexity_info

import artist.data as data
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, MODEL


def test_model():
    cfg_update = pConfig(load=True)

    for k, v in cfg_update.cfg_dict.items():
        if isinstance(v, dict) and k in cfg:
            cfg[k].update(v)
        else:
            cfg[k] = v

    model = MODEL.build(cfg.UNet)
    print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
    
    # state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu')
    # model.load_state_dict(state_dict, strict=False)
    model = model.cuda()

    x = torch.Tensor(1, 4, 16, 32, 56).cuda()
    t = torch.Tensor(1).cuda()
    sims = torch.Tensor(1, 32).cuda()
    fps = torch.Tensor([8]).cuda()
    y = torch.Tensor(1, 1, 1024).cuda()
    image = torch.Tensor(1, 3, 256, 448).cuda()
    
    ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps)
    print('Out shape if {}'.format(ret.shape))

    # flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps))
    # print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))

    def prepare_input(resolution):
        return dict(x=[x, t, y, image, sims, fps])

    flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56), 
        input_constructor = prepare_input,
        as_strings=True, print_per_layer_stat=True)
    print('      - Flops:  ' + flops)
    print('      - Params: ' + params)

if __name__ == '__main__':
    test_model()