Spaces:
Running
Running
import argparse | |
import numpy as np | |
import pickle | |
import os | |
import yaml | |
import torch | |
import torch.nn as nn | |
from models import UNetEncoder, Decoder | |
def load_training_data( | |
path: str, | |
standardize_weather: bool = False, | |
standardize_so4: bool = False, | |
log_so4: bool = False, | |
remove_zeros: bool = True, | |
return_pp_data: bool = False, | |
year_averages: bool = False, | |
): | |
with open(path, "rb") as io: | |
data = pickle.load(io) | |
C = data["covars_rast"] # [:, weather_cols] | |
names = data["covars_names"] | |
if standardize_weather: | |
C -= C.mean((0, 2, 3), keepdims=True) | |
C /= C.std((0, 2, 3), keepdims=True) | |
if year_averages: | |
Cyearly_average = np.zeros_like(C) | |
for t in range(C.shape[0]): | |
if t < 12: | |
Cyearly_average[t] = np.mean(C[:12], 0) | |
else: | |
Cyearly_average[t] = np.mean(C[(t - 12) : t], 0) | |
C = np.concatenate([C, Cyearly_average], 1) | |
names = names + [x + ".yavg" for x in names] | |
names = [x.replace(".", "_") for x in names] | |
Y = data["so4_rast"] | |
M = data["so4_mask"] | |
M[92:, 185:] = 0.0 # annoying weird corner | |
M[80:, :60] = 0.0 # annoying weird corner | |
if remove_zeros: | |
M = (Y > 0) * M | |
M = M * np.prod(M, 0) | |
else: | |
M = np.stack([M] * Y.shape[0]) | |
if log_so4: | |
# Y = np.log(M * Y + 1e-8) | |
Y = np.log(M * Y + 1.0) | |
if standardize_so4: | |
ix = np.where(M) | |
Y -= Y[ix].mean() | |
Y /= Y[ix].std() | |
if not return_pp_data: | |
return C, names, Y, M | |
else: | |
return C, names, Y, M, data["pp_locs"] | |
def radius_from_dir(s: str, prefix: str): | |
return int(s.split("/")[-1].split("_")[0].replace(prefix, "")) | |
def load_models(dirs: dict, prefix="h", nd=5): | |
D = {} | |
for name, datadir in dirs.items(): | |
radius = radius_from_dir(datadir, prefix) | |
args = argparse.Namespace() | |
with open(os.path.join(datadir, "args.yaml"), "r") as io: | |
for k, v in yaml.load(io, Loader=yaml.FullLoader).items(): | |
setattr(args, k, v) | |
if k == "nbrs_av": | |
setattr(args, "av_nbrs", v) | |
elif k == "av_nbrs": | |
setattr(args, "nbrs_av", v) | |
bn_type = "frn" if not hasattr(args, "bn_type") else args.bn_type | |
mkw = dict( | |
n_hidden=args.nhidden, | |
depth=args.depth, | |
num_res=args.nres, | |
ksize=args.ksize, | |
groups=args.groups, | |
batchnorm=True, | |
batchnorm_type=bn_type, | |
) | |
dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type) | |
dev = "cuda" if torch.cuda.is_available() else "cpu" | |
if not args.local and args.nbrs_av == 0: | |
enc = UNetEncoder(nd, args.nhidden, **mkw) | |
dec = Decoder(args.nhidden, nd, args.nhidden, **dkw) | |
else: | |
enc = nn.Identity() | |
dec = Decoder(nd, nd, args.nhidden, **dkw) | |
mod = nn.ModuleDict({"enc": enc, "dec": dec}) | |
objs = dict( | |
mod=mod, | |
args=args, | |
radius=radius, | |
nbrs_av=args.nbrs_av, | |
local=args.local, | |
) | |
mod.eval() | |
for p in mod.parameters(): | |
p.requires_grad = False | |
weights_path = os.path.join(datadir, "model.pt") | |
state_dict = torch.load(weights_path, map_location=torch.device("cpu")) | |
mod.load_state_dict(state_dict) | |
D[datadir] = objs | |
return D | |