import os, sys
import torch

# Import files from same folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
from architecture.rrdb import RRDBNet
from architecture.grl import GRL
from architecture.dat import DAT
from architecture.swinir import SwinIR
from architecture.cunet import UNet_Full


def load_rrdb(generator_weight_PATH, scale, print_options=False):  
    ''' A simpler API to load RRDB model from Real-ESRGAN
    Args:
        generator_weight_PATH (str): The path to the weight
        scale (int): the scaling factor
        print_options (bool): whether to print options to show what kinds of setting is used
    Returns:
        generator (torch): the generator instance of the model
    '''  

    # Load the checkpoint
    checkpoint_g = torch.load(generator_weight_PATH)

    # Find the generator weight
    if 'params_ema' in checkpoint_g:
        # For official ESRNET/ESRGAN weight
        weight = checkpoint_g['params_ema']
        generator = RRDBNet(3, 3, scale=scale)    # Default blocks num is 6     

    elif 'params' in checkpoint_g:
        # For official ESRNET/ESRGAN weight
        weight = checkpoint_g['params']
        generator = RRDBNet(3, 3, scale=scale)          

    elif 'model_state_dict' in checkpoint_g:
        # For my personal trained weight
        weight = checkpoint_g['model_state_dict']
        generator = RRDBNet(3, 3, scale=scale)          

    else:
        print("This weight is not supported")
        os._exit(0)


    # Handle torch.compile weight key rename
    old_keys = [key for key in weight]
    for old_key in old_keys:
        if old_key[:10] == "_orig_mod.":
            new_key = old_key[10:]
            weight[new_key] = weight[old_key]
            del weight[old_key]

    generator.load_state_dict(weight)
    generator = generator.eval().cuda()


    # Print options to show what kinds of setting is used
    if print_options:
        if 'opt' in checkpoint_g:
            for key in checkpoint_g['opt']:
                value = checkpoint_g['opt'][key]
                print(f'{key} : {value}')

    return generator


def load_cunet(generator_weight_PATH, scale, print_options=False):
    ''' A simpler API to load CUNET model from Real-CUGAN
    Args:
        generator_weight_PATH (str): The path to the weight
        scale (int): the scaling factor
        print_options (bool): whether to print options to show what kinds of setting is used
    Returns:
        generator (torch): the generator instance of the model
    '''  
    # This func is deprecated now
    
    if scale != 2:
        raise NotImplementedError("We only support 2x in CUNET")

    # Load the checkpoint
    checkpoint_g = torch.load(generator_weight_PATH)

    # Find the generator weight
    if 'model_state_dict' in checkpoint_g:
        # For my personal trained weight
        weight = checkpoint_g['model_state_dict']
        loss = checkpoint_g["lowest_generator_weight"]
        if "iteration" in checkpoint_g:
            iteration = checkpoint_g["iteration"]
        else:
            iteration = "NAN"
        generator = UNet_Full()          
        # generator = torch.compile(generator)# torch.compile
        print(f"the generator weight is {loss} at iteration {iteration}")

    else:
        print("This weight is not supported")
        os._exit(0)


    # Handle torch.compile weight key rename
    old_keys = [key for key in weight]
    for old_key in old_keys:
        if old_key[:10] == "_orig_mod.":
            new_key = old_key[10:]
            weight[new_key] = weight[old_key]
            del weight[old_key]

    generator.load_state_dict(weight)
    generator = generator.eval().cuda()


    # Print options to show what kinds of setting is used
    if print_options:
        if 'opt' in checkpoint_g:
            for key in checkpoint_g['opt']:
                value = checkpoint_g['opt'][key]
                print(f'{key} : {value}')

    return generator

def load_grl(generator_weight_PATH, scale=4):
    ''' A simpler API to load GRL model
    Args:
        generator_weight_PATH (str): The path to the weight
        scale (int):        Scale Factor (Usually Set as 4)
    Returns:
        generator (torch): the generator instance of the model
    '''

    # Load the checkpoint
    checkpoint_g = torch.load(generator_weight_PATH)

     # Find the generator weight
    if 'model_state_dict' in checkpoint_g:
        weight = checkpoint_g['model_state_dict']

        # GRL tiny model (Note: tiny2 version)
        generator = GRL(
            upscale = scale,
            img_size = 64,
            window_size = 8,
            depths = [4, 4, 4, 4],
            embed_dim = 64,
            num_heads_window = [2, 2, 2, 2],
            num_heads_stripe = [2, 2, 2, 2],
            mlp_ratio = 2,
            qkv_proj_type = "linear",
            anchor_proj_type = "avgpool",
            anchor_window_down_factor = 2,
            out_proj_type = "linear",
            conv_type = "1conv",
            upsampler = "nearest+conv",     # Change
        ).cuda()

    else:
        print("This weight is not supported")
        os._exit(0)


    generator.load_state_dict(weight)
    generator = generator.eval().cuda()


    num_params = 0
    for p in generator.parameters():
        if p.requires_grad:
            num_params += p.numel()
    print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")


    return generator



def load_dat(generator_weight_PATH, scale=4):

    # Load the checkpoint
    checkpoint_g = torch.load(generator_weight_PATH)

     # Find the generator weight
    if 'model_state_dict' in checkpoint_g:
        weight = checkpoint_g['model_state_dict']

        # DAT small model in default
        generator = DAT(upscale = 4,
                        in_chans = 3,
                        img_size = 64,
                        img_range = 1.,
                        depth = [6, 6, 6, 6, 6, 6],
                        embed_dim = 180,
                        num_heads = [6, 6, 6, 6, 6, 6],
                        expansion_factor = 2,
                        resi_connection = '1conv',
                        split_size = [8, 16],
                        upsampler = 'pixelshuffledirect',
                        ).cuda()

    else:
        print("This weight is not supported")
        os._exit(0)


    generator.load_state_dict(weight)
    generator = generator.eval().cuda()


    num_params = 0
    for p in generator.parameters():
        if p.requires_grad:
            num_params += p.numel()
    print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")


    return generator