Spaces:
Runtime error
Runtime error
Upload encoders.py
Browse files- encoders.py +243 -0
encoders.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from utils import normalize_data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
+
|
9 |
+
|
10 |
+
class StyleEncoder(nn.Module):
|
11 |
+
def __init__(self, num_hyperparameters, em_size):
|
12 |
+
super().__init__()
|
13 |
+
self.em_size = em_size
|
14 |
+
self.embedding = nn.Linear(num_hyperparameters, self.em_size)
|
15 |
+
|
16 |
+
def forward(self, hyperparameters): # B x num_hps
|
17 |
+
return self.embedding(hyperparameters)
|
18 |
+
|
19 |
+
|
20 |
+
class StyleEmbEncoder(nn.Module):
|
21 |
+
def __init__(self, num_hyperparameters, em_size, num_embeddings=100):
|
22 |
+
super().__init__()
|
23 |
+
assert num_hyperparameters == 1
|
24 |
+
self.em_size = em_size
|
25 |
+
self.embedding = nn.Embedding(num_embeddings, self.em_size)
|
26 |
+
|
27 |
+
def forward(self, hyperparameters): # B x num_hps
|
28 |
+
return self.embedding(hyperparameters.squeeze(1))
|
29 |
+
|
30 |
+
|
31 |
+
class _PositionalEncoding(nn.Module):
|
32 |
+
def __init__(self, d_model, dropout=0.):
|
33 |
+
super().__init__()
|
34 |
+
self.dropout = nn.Dropout(p=dropout)
|
35 |
+
self.d_model = d_model
|
36 |
+
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
37 |
+
|
38 |
+
def forward(self, x):# T x B x num_features
|
39 |
+
assert self.d_model % x.shape[-1]*2 == 0
|
40 |
+
d_per_feature = self.d_model // x.shape[-1]
|
41 |
+
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
42 |
+
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
43 |
+
interval_size = 10
|
44 |
+
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
|
45 |
+
#print(div_term/2/math.pi)
|
46 |
+
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
47 |
+
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
48 |
+
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
49 |
+
|
50 |
+
|
51 |
+
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
52 |
+
|
53 |
+
class EmbeddingEncoder(nn.Module):
|
54 |
+
def __init__(self, num_features, em_size, num_embs=100):
|
55 |
+
super().__init__()
|
56 |
+
self.num_embs = num_embs
|
57 |
+
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
58 |
+
self.init_weights(.1)
|
59 |
+
self.min_max = (-2,+2)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def width(self):
|
63 |
+
return self.min_max[1] - self.min_max[0]
|
64 |
+
|
65 |
+
def init_weights(self, initrange):
|
66 |
+
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
67 |
+
|
68 |
+
def discretize(self, x):
|
69 |
+
split_size = self.width / self.num_embs
|
70 |
+
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
71 |
+
|
72 |
+
def forward(self, x): # T x B x num_features
|
73 |
+
x_idxs = self.discretize(x)
|
74 |
+
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
75 |
+
# print(x_idxs,self.embeddings.weight.shape)
|
76 |
+
return self.embeddings(x_idxs).mean(-2)
|
77 |
+
|
78 |
+
|
79 |
+
class Normalize(nn.Module):
|
80 |
+
def __init__(self, mean, std):
|
81 |
+
super().__init__()
|
82 |
+
self.mean = mean
|
83 |
+
self.std = std
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
return (x-self.mean)/self.std
|
87 |
+
|
88 |
+
|
89 |
+
def get_normalized_uniform_encoder(encoder_creator):
|
90 |
+
"""
|
91 |
+
This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
|
92 |
+
For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
|
93 |
+
be initialized with `encoder_creator(feature_dim, in_dim)`.
|
94 |
+
:param encoder:
|
95 |
+
:return:
|
96 |
+
"""
|
97 |
+
return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
|
98 |
+
|
99 |
+
|
100 |
+
def get_normalized_encoder(encoder_creator, data_std):
|
101 |
+
return lambda in_dim, out_dim: nn.Sequential(Normalize(0., data_std), encoder_creator(in_dim, out_dim))
|
102 |
+
|
103 |
+
|
104 |
+
class ZNormalize(nn.Module):
|
105 |
+
def forward(self, x):
|
106 |
+
return (x-x.mean(-1,keepdim=True))/x.std(-1,keepdim=True)
|
107 |
+
|
108 |
+
|
109 |
+
class AppendEmbeddingEncoder(nn.Module):
|
110 |
+
def __init__(self, base_encoder, num_features, emsize):
|
111 |
+
super().__init__()
|
112 |
+
self.num_features = num_features
|
113 |
+
self.base_encoder = base_encoder
|
114 |
+
self.emb = nn.Parameter(torch.zeros(emsize))
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
if (x[-1] == 1.).all():
|
118 |
+
append_embedding = True
|
119 |
+
else:
|
120 |
+
assert (x[-1] == 0.).all(), "You need to specify as last position whether to append embedding. " \
|
121 |
+
"If you don't want this behavior, please use the wrapped encoder instead."
|
122 |
+
append_embedding = False
|
123 |
+
x = x[:-1]
|
124 |
+
encoded_x = self.base_encoder(x)
|
125 |
+
if append_embedding:
|
126 |
+
encoded_x = torch.cat([encoded_x, self.emb[None, None, :].repeat(1, encoded_x.shape[1], 1)], 0)
|
127 |
+
return encoded_x
|
128 |
+
|
129 |
+
def get_append_embedding_encoder(encoder_creator):
|
130 |
+
return lambda num_features, emsize: AppendEmbeddingEncoder(encoder_creator(num_features, emsize), num_features, emsize)
|
131 |
+
|
132 |
+
|
133 |
+
class VariableNumFeaturesEncoder(nn.Module):
|
134 |
+
def __init__(self, base_encoder, num_features):
|
135 |
+
super().__init__()
|
136 |
+
self.base_encoder = base_encoder
|
137 |
+
self.num_features = num_features
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
x = x * (self.num_features/x.shape[-1])
|
141 |
+
x = torch.cat((x, torch.zeros(*x.shape[:-1], self.num_features - x.shape[-1], device=x.device)), -1)
|
142 |
+
return self.base_encoder(x)
|
143 |
+
|
144 |
+
|
145 |
+
def get_variable_num_features_encoder(encoder_creator):
|
146 |
+
return lambda num_features, emsize: VariableNumFeaturesEncoder(encoder_creator(num_features, emsize), num_features)
|
147 |
+
|
148 |
+
class NoMeanEncoder(nn.Module):
|
149 |
+
"""
|
150 |
+
This can be useful for any prior that is translation invariant in x or y.
|
151 |
+
A standard GP for example is translation invariant in x.
|
152 |
+
That is, GP(x_test+const,x_train+const,y_train) = GP(x_test,x_train,y_train).
|
153 |
+
"""
|
154 |
+
def __init__(self, base_encoder):
|
155 |
+
super().__init__()
|
156 |
+
self.base_encoder = base_encoder
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
return self.base_encoder(x - x.mean(0, keepdim=True))
|
160 |
+
|
161 |
+
|
162 |
+
def get_no_mean_encoder(encoder_creator):
|
163 |
+
return lambda num_features, emsize: NoMeanEncoder(encoder_creator(num_features, emsize))
|
164 |
+
|
165 |
+
Linear = nn.Linear
|
166 |
+
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
|
167 |
+
nn.ReLU(),
|
168 |
+
nn.Linear(emsize*2,emsize))
|
169 |
+
|
170 |
+
class NanHandlingEncoder(nn.Module):
|
171 |
+
def __init__(self, num_features, emsize, keep_nans=True):
|
172 |
+
super().__init__()
|
173 |
+
self.num_features = 2 * num_features if keep_nans else num_features
|
174 |
+
self.emsize = emsize
|
175 |
+
self.keep_nans = keep_nans
|
176 |
+
self.layer = nn.Linear(self.num_features, self.emsize)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
if self.keep_nans:
|
180 |
+
x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
|
181 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
182 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
183 |
+
)], -1)
|
184 |
+
else:
|
185 |
+
x = torch.nan_to_num(x, nan=0.0)
|
186 |
+
return self.layer(x)
|
187 |
+
|
188 |
+
|
189 |
+
class Linear(nn.Linear):
|
190 |
+
def __init__(self, num_features, emsize, replace_nan_by_zero=False):
|
191 |
+
super().__init__(num_features, emsize)
|
192 |
+
self.num_features = num_features
|
193 |
+
self.emsize = emsize
|
194 |
+
self.replace_nan_by_zero = replace_nan_by_zero
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
if self.replace_nan_by_zero:
|
198 |
+
x = torch.nan_to_num(x, nan=0.0)
|
199 |
+
return super().forward(x)
|
200 |
+
|
201 |
+
def __setstate__(self, state):
|
202 |
+
super().__setstate__(state)
|
203 |
+
self.__dict__.setdefault('replace_nan_by_zero', True)
|
204 |
+
|
205 |
+
|
206 |
+
class Conv(nn.Module):
|
207 |
+
def __init__(self, input_size, emsize):
|
208 |
+
super().__init__()
|
209 |
+
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
|
210 |
+
self.linear = nn.Linear(64,emsize)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
size = math.isqrt(x.shape[-1])
|
214 |
+
assert size*size == x.shape[-1]
|
215 |
+
x = x.reshape(*x.shape[:-1], 1, size, size)
|
216 |
+
for conv in self.convs:
|
217 |
+
if x.shape[-1] < 4:
|
218 |
+
break
|
219 |
+
x = conv(x)
|
220 |
+
x.relu_()
|
221 |
+
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
222 |
+
return self.linear(x)
|
223 |
+
|
224 |
+
|
225 |
+
class CanEmb(nn.Embedding):
|
226 |
+
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
|
227 |
+
assert embedding_dim % num_features == 0
|
228 |
+
embedding_dim = embedding_dim // num_features
|
229 |
+
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
230 |
+
|
231 |
+
def forward(self, x):
|
232 |
+
lx = x.long()
|
233 |
+
assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
|
234 |
+
x = super().forward(lx)
|
235 |
+
return x.view(*x.shape[:-2], -1)
|
236 |
+
|
237 |
+
|
238 |
+
def get_Canonical(num_classes):
|
239 |
+
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
240 |
+
|
241 |
+
|
242 |
+
def get_Embedding(num_embs_per_feature=100):
|
243 |
+
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|