import importlib
import os
import sys
from typing import Callable, Dict, Union

import numpy as np
import yaml
import torch


def merge_a_into_b(a, b):
    # merge dict a into dict b. values in a will overwrite b.
    for k, v in a.items():
        if isinstance(v, dict) and k in b:
            assert isinstance(
                b[k], dict
            ), "Cannot inherit key '{}' from base!".format(k)
            merge_a_into_b(v, b[k])
        else:
            b[k] = v


def load_config(config_file):
    with open(config_file, "r") as reader:
        config = yaml.load(reader, Loader=yaml.FullLoader)
    if "inherit_from" in config:
        base_config_file = config["inherit_from"]
        base_config_file = os.path.join(
            os.path.dirname(config_file), base_config_file
        )
        assert not os.path.samefile(config_file, base_config_file), \
            "inherit from itself"
        base_config = load_config(base_config_file)
        del config["inherit_from"]
        merge_a_into_b(config, base_config)
        return base_config
    return config

def get_cls_from_str(string, reload=False):
    module_name, cls_name = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module_name)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module_name, package=None), cls_name)

def init_obj_from_dict(config, **kwargs):
    obj_args = config["args"].copy()
    obj_args.update(kwargs)
    for k in config:
        if k not in ["type", "args"] and isinstance(config[k], dict) and k not in kwargs:
            obj_args[k] = init_obj_from_dict(config[k])
    try:
        obj = get_cls_from_str(config["type"])(**obj_args)
        return obj
    except Exception as e:
        print(f"Initializing {config} failed, detailed error stack: ")
        raise e

def init_model_from_config(config, print_fn=sys.stdout.write):
    kwargs = {}
    for k in config:
        if k not in ["type", "args", "pretrained"]:
            sub_model = init_model_from_config(config[k], print_fn)
            if "pretrained" in config[k]:
                load_pretrained_model(sub_model,
                                      config[k]["pretrained"],
                                      print_fn)
            kwargs[k] = sub_model
    model = init_obj_from_dict(config, **kwargs)
    return model

def merge_load_state_dict(state_dict,
                          model: torch.nn.Module,
                          output_fn: Callable = sys.stdout.write):
    model_dict = model.state_dict()
    pretrained_dict = {}
    mismatch_keys = []
    for key, value in state_dict.items():
        if key in model_dict and model_dict[key].shape == value.shape:
            pretrained_dict[key] = value
        else:
            mismatch_keys.append(key)
    output_fn(f"Loading pre-trained model, with mismatched keys {mismatch_keys}\n")
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict, strict=True)
    return pretrained_dict.keys()


def load_pretrained_model(model: torch.nn.Module,
                          pretrained: Union[str, Dict],
                          output_fn: Callable = sys.stdout.write):
    if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
        output_fn(f"pretrained {pretrained} not exist!")
        return
    
    if hasattr(model, "load_pretrained"):
        model.load_pretrained(pretrained, output_fn)
        return

    if isinstance(pretrained, dict):
        state_dict = pretrained
    else:
        state_dict = torch.load(pretrained, map_location="cpu")

    if "model" in state_dict:
        state_dict = state_dict["model"]
    
    merge_load_state_dict(state_dict, model, output_fn)

def pad_sequence(data, pad_value=0):
    if isinstance(data[0], (np.ndarray, torch.Tensor)):
        data = [torch.as_tensor(arr) for arr in data]
    padded_seq = torch.nn.utils.rnn.pad_sequence(data,
                                                 batch_first=True,
                                                 padding_value=pad_value)
    length = np.array([x.shape[0] for x in data])
    return padded_seq, length