w2vec-app / utils.py
mauriciogtec's picture
working app
89012a4
import argparse
import numpy as np
import pickle
import os
import yaml
import torch
import torch.nn as nn
from models import UNetEncoder, Decoder
def load_training_data(
path: str,
standardize_weather: bool = False,
standardize_so4: bool = False,
log_so4: bool = False,
remove_zeros: bool = True,
return_pp_data: bool = False,
year_averages: bool = False,
):
with open(path, "rb") as io:
data = pickle.load(io)
C = data["covars_rast"] # [:, weather_cols]
names = data["covars_names"]
if standardize_weather:
C -= C.mean((0, 2, 3), keepdims=True)
C /= C.std((0, 2, 3), keepdims=True)
if year_averages:
Cyearly_average = np.zeros_like(C)
for t in range(C.shape[0]):
if t < 12:
Cyearly_average[t] = np.mean(C[:12], 0)
else:
Cyearly_average[t] = np.mean(C[(t - 12) : t], 0)
C = np.concatenate([C, Cyearly_average], 1)
names = names + [x + ".yavg" for x in names]
names = [x.replace(".", "_") for x in names]
Y = data["so4_rast"]
M = data["so4_mask"]
M[92:, 185:] = 0.0 # annoying weird corner
M[80:, :60] = 0.0 # annoying weird corner
if remove_zeros:
M = (Y > 0) * M
M = M * np.prod(M, 0)
else:
M = np.stack([M] * Y.shape[0])
if log_so4:
# Y = np.log(M * Y + 1e-8)
Y = np.log(M * Y + 1.0)
if standardize_so4:
ix = np.where(M)
Y -= Y[ix].mean()
Y /= Y[ix].std()
if not return_pp_data:
return C, names, Y, M
else:
return C, names, Y, M, data["pp_locs"]
def radius_from_dir(s: str, prefix: str):
return int(s.split("/")[-1].split("_")[0].replace(prefix, ""))
def load_models(dirs: dict, prefix="h", nd=5):
D = {}
for name, datadir in dirs.items():
radius = radius_from_dir(datadir, prefix)
args = argparse.Namespace()
with open(os.path.join(datadir, "args.yaml"), "r") as io:
for k, v in yaml.load(io, Loader=yaml.FullLoader).items():
setattr(args, k, v)
if k == "nbrs_av":
setattr(args, "av_nbrs", v)
elif k == "av_nbrs":
setattr(args, "nbrs_av", v)
bn_type = "frn" if not hasattr(args, "bn_type") else args.bn_type
mkw = dict(
n_hidden=args.nhidden,
depth=args.depth,
num_res=args.nres,
ksize=args.ksize,
groups=args.groups,
batchnorm=True,
batchnorm_type=bn_type,
)
dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type)
dev = "cuda" if torch.cuda.is_available() else "cpu"
if not args.local and args.nbrs_av == 0:
enc = UNetEncoder(nd, args.nhidden, **mkw)
dec = Decoder(args.nhidden, nd, args.nhidden, **dkw)
else:
enc = nn.Identity()
dec = Decoder(nd, nd, args.nhidden, **dkw)
mod = nn.ModuleDict({"enc": enc, "dec": dec})
objs = dict(
mod=mod,
args=args,
radius=radius,
nbrs_av=args.nbrs_av,
local=args.local,
)
mod.eval()
for p in mod.parameters():
p.requires_grad = False
weights_path = os.path.join(datadir, "model.pt")
state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
mod.load_state_dict(state_dict)
D[datadir] = objs
return D