import argparse
import numpy as np
import pickle
import os
import yaml
import torch
import torch.nn as nn
from models import UNetEncoder, Decoder


def load_training_data(
    path: str,
    standardize_weather: bool = False,
    standardize_so4: bool = False,
    log_so4: bool = False,
    remove_zeros: bool = True,
    return_pp_data: bool = False,
    year_averages: bool = False,
):
    with open(path, "rb") as io:
        data = pickle.load(io)
    C = data["covars_rast"]  # [:, weather_cols]
    names = data["covars_names"]
    if standardize_weather:
        C -= C.mean((0, 2, 3), keepdims=True)
        C /= C.std((0, 2, 3), keepdims=True)
    if year_averages:
        Cyearly_average = np.zeros_like(C)
        for t in range(C.shape[0]):
            if t < 12:
                Cyearly_average[t] = np.mean(C[:12], 0)
            else:
                Cyearly_average[t] = np.mean(C[(t - 12) : t], 0)
        C = np.concatenate([C, Cyearly_average], 1)
        names = names + [x + ".yavg" for x in names]
        names = [x.replace(".", "_") for x in names]

    Y = data["so4_rast"]
    M = data["so4_mask"]
    M[92:, 185:] = 0.0  # annoying weird corner
    M[80:, :60] = 0.0  # annoying weird corner
    if remove_zeros:
        M = (Y > 0) * M
        M = M * np.prod(M, 0)
    else:
        M = np.stack([M] * Y.shape[0])
    if log_so4:
        # Y = np.log(M * Y + 1e-8)
        Y = np.log(M * Y + 1.0)
    if standardize_so4:
        ix = np.where(M)
        Y -= Y[ix].mean()
        Y /= Y[ix].std()

    if not return_pp_data:
        return C, names, Y, M
    else:
        return C, names, Y, M, data["pp_locs"]


def radius_from_dir(s: str, prefix: str):
    return int(s.split("/")[-1].split("_")[0].replace(prefix, ""))


def load_models(dirs: dict, prefix="h", nd=5):
    D = {}
    for name, datadir in dirs.items():
        radius = radius_from_dir(datadir, prefix)
        args = argparse.Namespace()
        with open(os.path.join(datadir, "args.yaml"), "r") as io:
            for k, v in yaml.load(io, Loader=yaml.FullLoader).items():
                setattr(args, k, v)
                if k == "nbrs_av":
                    setattr(args, "av_nbrs", v)
                elif k == "av_nbrs":
                    setattr(args, "nbrs_av", v)

        bn_type = "frn" if not hasattr(args, "bn_type") else args.bn_type
        mkw = dict(
            n_hidden=args.nhidden,
            depth=args.depth,
            num_res=args.nres,
            ksize=args.ksize,
            groups=args.groups,
            batchnorm=True,
            batchnorm_type=bn_type,
        )

        dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type)
        dev = "cuda" if torch.cuda.is_available() else "cpu"
        if not args.local and args.nbrs_av == 0:
            enc = UNetEncoder(nd, args.nhidden, **mkw)
            dec = Decoder(args.nhidden, nd, args.nhidden, **dkw)
        else:
            enc = nn.Identity()
            dec = Decoder(nd, nd, args.nhidden, **dkw)
        mod = nn.ModuleDict({"enc": enc, "dec": dec})
        objs = dict(
            mod=mod,
            args=args,
            radius=radius,
            nbrs_av=args.nbrs_av,
            local=args.local,
        )
        mod.eval()
        for p in mod.parameters():
            p.requires_grad = False
        weights_path = os.path.join(datadir, "model.pt")
        state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
        mod.load_state_dict(state_dict)
        D[datadir] = objs
    return D