import os
from collections import OrderedDict

import numpy as np
import torch
from PIL import Image
from torchvision.transforms import transforms

from sam_diffsr.utils_sr.hparams import set_hparams, hparams
from sam_diffsr.utils_sr.matlab_resize import imresize
from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer_ori


ROOT_PATH = os.path.dirname(__file__)


class sam_diffsr_demo:
    def __init__(self):
        set_hparams()
        ckpt_path = os.path.join(ROOT_PATH, 'weight/model_ckpt_steps_400000.ckpt')
        self.model_init(ckpt_path)
    
    def get_img_data(self, img_PIL, hparams, sr_scale=4):
        img_lr = img_PIL.convert('RGB')
        img_lr = np.uint8(np.asarray(img_lr))
        
        h, w, c = img_lr.shape
        h, w = h * sr_scale, w * sr_scale
        h = h - h % (sr_scale * 2)
        w = w - w % (sr_scale * 2)
        h_l = h // sr_scale
        w_l = w // sr_scale
        
        img_lr = img_lr[:h_l, :w_l]
        
        to_tensor_norm = transforms.Compose([
                transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        img_lr_up = imresize(img_lr / 256, hparams['sr_scale'])  # np.float [H, W, C]
        img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]]
        
        img_lr = torch.unsqueeze(img_lr, dim=0)
        img_lr_up = torch.unsqueeze(img_lr_up, dim=0)
        
        return img_lr, img_lr_up
    
    def load_checkpoint(self, ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        print(f'loding check from: {ckpt_path}')
        stat_dict = checkpoint['state_dict']['model']
        
        new_state_dict = OrderedDict()
        for k, v in stat_dict.items():
            if k[:7] == 'module.':
                k = k[7:]  # 去掉 `module.`
            new_state_dict[k] = v
        
        self.model.model.load_state_dict(new_state_dict)
        self.model.model.cuda()
        del checkpoint
        torch.cuda.empty_cache()
    
    def model_init(self, ckpt_path):
        self.model = trainer_ori()
        
        self.model.build_model()
        self.load_checkpoint(ckpt_path)
        
        torch.backends.cudnn.benchmark = False
    
    def infer(self, img_PIL):
        with torch.no_grad():
            self.model.model.eval()
            img_lr, img_lr_up = self.get_img_data(img_PIL, hparams, sr_scale=4)
            
            img_lr = img_lr.to('cuda')
            img_lr_up = img_lr_up.to('cuda')
            
            img_sr, _ = self.model.model.sample(img_lr, img_lr_up, img_lr_up.shape)
            
            img_sr = img_sr.clamp(-1, 1)
            img_sr = self.model.tensor2img(img_sr)[0]
            img_sr = Image.fromarray(img_sr)
        
        return img_sr