import argparse import numpy as np import pickle import os import yaml import torch import torch.nn as nn from models import UNetEncoder, Decoder def load_training_data( path: str, standardize_weather: bool = False, standardize_so4: bool = False, log_so4: bool = False, remove_zeros: bool = True, return_pp_data: bool = False, year_averages: bool = False, ): with open(path, "rb") as io: data = pickle.load(io) C = data["covars_rast"] # [:, weather_cols] names = data["covars_names"] if standardize_weather: C -= C.mean((0, 2, 3), keepdims=True) C /= C.std((0, 2, 3), keepdims=True) if year_averages: Cyearly_average = np.zeros_like(C) for t in range(C.shape[0]): if t < 12: Cyearly_average[t] = np.mean(C[:12], 0) else: Cyearly_average[t] = np.mean(C[(t - 12) : t], 0) C = np.concatenate([C, Cyearly_average], 1) names = names + [x + ".yavg" for x in names] names = [x.replace(".", "_") for x in names] Y = data["so4_rast"] M = data["so4_mask"] M[92:, 185:] = 0.0 # annoying weird corner M[80:, :60] = 0.0 # annoying weird corner if remove_zeros: M = (Y > 0) * M M = M * np.prod(M, 0) else: M = np.stack([M] * Y.shape[0]) if log_so4: # Y = np.log(M * Y + 1e-8) Y = np.log(M * Y + 1.0) if standardize_so4: ix = np.where(M) Y -= Y[ix].mean() Y /= Y[ix].std() if not return_pp_data: return C, names, Y, M else: return C, names, Y, M, data["pp_locs"] def radius_from_dir(s: str, prefix: str): return int(s.split("/")[-1].split("_")[0].replace(prefix, "")) def load_models(dirs: dict, prefix="h", nd=5): D = {} for name, datadir in dirs.items(): radius = radius_from_dir(datadir, prefix) args = argparse.Namespace() with open(os.path.join(datadir, "args.yaml"), "r") as io: for k, v in yaml.load(io, Loader=yaml.FullLoader).items(): setattr(args, k, v) if k == "nbrs_av": setattr(args, "av_nbrs", v) elif k == "av_nbrs": setattr(args, "nbrs_av", v) bn_type = "frn" if not hasattr(args, "bn_type") else args.bn_type mkw = dict( n_hidden=args.nhidden, depth=args.depth, num_res=args.nres, ksize=args.ksize, groups=args.groups, batchnorm=True, batchnorm_type=bn_type, ) dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type) dev = "cuda" if torch.cuda.is_available() else "cpu" if not args.local and args.nbrs_av == 0: enc = UNetEncoder(nd, args.nhidden, **mkw) dec = Decoder(args.nhidden, nd, args.nhidden, **dkw) else: enc = nn.Identity() dec = Decoder(nd, nd, args.nhidden, **dkw) mod = nn.ModuleDict({"enc": enc, "dec": dec}) objs = dict( mod=mod, args=args, radius=radius, nbrs_av=args.nbrs_av, local=args.local, ) mod.eval() for p in mod.parameters(): p.requires_grad = False weights_path = os.path.join(datadir, "model.pt") state_dict = torch.load(weights_path, map_location=torch.device("cpu")) mod.load_state_dict(state_dict) D[datadir] = objs return D