File size: 3,537 Bytes
6f47252
f049087
 
6f47252
 
 
 
 
f049087
 
 
 
 
 
 
 
 
89012a4
f049087
 
 
89012a4
f049087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89012a4
f049087
 
 
 
6f47252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89012a4
6f47252
 
 
 
 
 
 
 
 
 
 
 
 
89012a4
 
6f47252
 
89012a4
6f47252
 
 
 
 
 
 
 
 
 
 
89012a4
 
 
6f47252
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import numpy as np
import pickle
import os
import yaml
import torch
import torch.nn as nn
from models import UNetEncoder, Decoder


def load_training_data(
    path: str,
    standardize_weather: bool = False,
    standardize_so4: bool = False,
    log_so4: bool = False,
    remove_zeros: bool = True,
    return_pp_data: bool = False,
    year_averages: bool = False,
):
    with open(path, "rb") as io:
        data = pickle.load(io)
    C = data["covars_rast"]  # [:, weather_cols]
    names = data["covars_names"]
    if standardize_weather:
        C -= C.mean((0, 2, 3), keepdims=True)
        C /= C.std((0, 2, 3), keepdims=True)
    if year_averages:
        Cyearly_average = np.zeros_like(C)
        for t in range(C.shape[0]):
            if t < 12:
                Cyearly_average[t] = np.mean(C[:12], 0)
            else:
                Cyearly_average[t] = np.mean(C[(t - 12) : t], 0)
        C = np.concatenate([C, Cyearly_average], 1)
        names = names + [x + ".yavg" for x in names]
        names = [x.replace(".", "_") for x in names]

    Y = data["so4_rast"]
    M = data["so4_mask"]
    M[92:, 185:] = 0.0  # annoying weird corner
    M[80:, :60] = 0.0  # annoying weird corner
    if remove_zeros:
        M = (Y > 0) * M
        M = M * np.prod(M, 0)
    else:
        M = np.stack([M] * Y.shape[0])
    if log_so4:
        # Y = np.log(M * Y + 1e-8)
        Y = np.log(M * Y + 1.0)
    if standardize_so4:
        ix = np.where(M)
        Y -= Y[ix].mean()
        Y /= Y[ix].std()

    if not return_pp_data:
        return C, names, Y, M
    else:
        return C, names, Y, M, data["pp_locs"]


def radius_from_dir(s: str, prefix: str):
    return int(s.split("/")[-1].split("_")[0].replace(prefix, ""))


def load_models(dirs: dict, prefix="h", nd=5):
    D = {}
    for name, datadir in dirs.items():
        radius = radius_from_dir(datadir, prefix)
        args = argparse.Namespace()
        with open(os.path.join(datadir, "args.yaml"), "r") as io:
            for k, v in yaml.load(io, Loader=yaml.FullLoader).items():
                setattr(args, k, v)
                if k == "nbrs_av":
                    setattr(args, "av_nbrs", v)
                elif k == "av_nbrs":
                    setattr(args, "nbrs_av", v)

        bn_type = "frn" if not hasattr(args, "bn_type") else args.bn_type
        mkw = dict(
            n_hidden=args.nhidden,
            depth=args.depth,
            num_res=args.nres,
            ksize=args.ksize,
            groups=args.groups,
            batchnorm=True,
            batchnorm_type=bn_type,
        )

        dkw = dict(batchnorm=True, offset=True, batchnorm_type=bn_type)
        dev = "cuda" if torch.cuda.is_available() else "cpu"
        if not args.local and args.nbrs_av == 0:
            enc = UNetEncoder(nd, args.nhidden, **mkw)
            dec = Decoder(args.nhidden, nd, args.nhidden, **dkw)
        else:
            enc = nn.Identity()
            dec = Decoder(nd, nd, args.nhidden, **dkw)
        mod = nn.ModuleDict({"enc": enc, "dec": dec})
        objs = dict(
            mod=mod,
            args=args,
            radius=radius,
            nbrs_av=args.nbrs_av,
            local=args.local,
        )
        mod.eval()
        for p in mod.parameters():
            p.requires_grad = False
        weights_path = os.path.join(datadir, "model.pt")
        state_dict = torch.load(weights_path, map_location=torch.device("cpu"))
        mod.load_state_dict(state_dict)
        D[datadir] = objs
    return D