TabPFN commited on
Commit
f1c7310
·
1 Parent(s): 6e512b6

Upload encoders.py

Browse files
Files changed (1) hide show
  1. encoders.py +243 -0
encoders.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from utils import normalize_data
6
+ import torch.nn.functional as F
7
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
8
+
9
+
10
+ class StyleEncoder(nn.Module):
11
+ def __init__(self, num_hyperparameters, em_size):
12
+ super().__init__()
13
+ self.em_size = em_size
14
+ self.embedding = nn.Linear(num_hyperparameters, self.em_size)
15
+
16
+ def forward(self, hyperparameters): # B x num_hps
17
+ return self.embedding(hyperparameters)
18
+
19
+
20
+ class StyleEmbEncoder(nn.Module):
21
+ def __init__(self, num_hyperparameters, em_size, num_embeddings=100):
22
+ super().__init__()
23
+ assert num_hyperparameters == 1
24
+ self.em_size = em_size
25
+ self.embedding = nn.Embedding(num_embeddings, self.em_size)
26
+
27
+ def forward(self, hyperparameters): # B x num_hps
28
+ return self.embedding(hyperparameters.squeeze(1))
29
+
30
+
31
+ class _PositionalEncoding(nn.Module):
32
+ def __init__(self, d_model, dropout=0.):
33
+ super().__init__()
34
+ self.dropout = nn.Dropout(p=dropout)
35
+ self.d_model = d_model
36
+ self.device_test_tensor = nn.Parameter(torch.tensor(1.))
37
+
38
+ def forward(self, x):# T x B x num_features
39
+ assert self.d_model % x.shape[-1]*2 == 0
40
+ d_per_feature = self.d_model // x.shape[-1]
41
+ pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
42
+ #position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
43
+ interval_size = 10
44
+ div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
45
+ #print(div_term/2/math.pi)
46
+ pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
47
+ pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
48
+ return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
49
+
50
+
51
+ Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
52
+
53
+ class EmbeddingEncoder(nn.Module):
54
+ def __init__(self, num_features, em_size, num_embs=100):
55
+ super().__init__()
56
+ self.num_embs = num_embs
57
+ self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
58
+ self.init_weights(.1)
59
+ self.min_max = (-2,+2)
60
+
61
+ @property
62
+ def width(self):
63
+ return self.min_max[1] - self.min_max[0]
64
+
65
+ def init_weights(self, initrange):
66
+ self.embeddings.weight.data.uniform_(-initrange, initrange)
67
+
68
+ def discretize(self, x):
69
+ split_size = self.width / self.num_embs
70
+ return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
71
+
72
+ def forward(self, x): # T x B x num_features
73
+ x_idxs = self.discretize(x)
74
+ x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
75
+ # print(x_idxs,self.embeddings.weight.shape)
76
+ return self.embeddings(x_idxs).mean(-2)
77
+
78
+
79
+ class Normalize(nn.Module):
80
+ def __init__(self, mean, std):
81
+ super().__init__()
82
+ self.mean = mean
83
+ self.std = std
84
+
85
+ def forward(self, x):
86
+ return (x-self.mean)/self.std
87
+
88
+
89
+ def get_normalized_uniform_encoder(encoder_creator):
90
+ """
91
+ This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
92
+ For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
93
+ be initialized with `encoder_creator(feature_dim, in_dim)`.
94
+ :param encoder:
95
+ :return:
96
+ """
97
+ return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
98
+
99
+
100
+ def get_normalized_encoder(encoder_creator, data_std):
101
+ return lambda in_dim, out_dim: nn.Sequential(Normalize(0., data_std), encoder_creator(in_dim, out_dim))
102
+
103
+
104
+ class ZNormalize(nn.Module):
105
+ def forward(self, x):
106
+ return (x-x.mean(-1,keepdim=True))/x.std(-1,keepdim=True)
107
+
108
+
109
+ class AppendEmbeddingEncoder(nn.Module):
110
+ def __init__(self, base_encoder, num_features, emsize):
111
+ super().__init__()
112
+ self.num_features = num_features
113
+ self.base_encoder = base_encoder
114
+ self.emb = nn.Parameter(torch.zeros(emsize))
115
+
116
+ def forward(self, x):
117
+ if (x[-1] == 1.).all():
118
+ append_embedding = True
119
+ else:
120
+ assert (x[-1] == 0.).all(), "You need to specify as last position whether to append embedding. " \
121
+ "If you don't want this behavior, please use the wrapped encoder instead."
122
+ append_embedding = False
123
+ x = x[:-1]
124
+ encoded_x = self.base_encoder(x)
125
+ if append_embedding:
126
+ encoded_x = torch.cat([encoded_x, self.emb[None, None, :].repeat(1, encoded_x.shape[1], 1)], 0)
127
+ return encoded_x
128
+
129
+ def get_append_embedding_encoder(encoder_creator):
130
+ return lambda num_features, emsize: AppendEmbeddingEncoder(encoder_creator(num_features, emsize), num_features, emsize)
131
+
132
+
133
+ class VariableNumFeaturesEncoder(nn.Module):
134
+ def __init__(self, base_encoder, num_features):
135
+ super().__init__()
136
+ self.base_encoder = base_encoder
137
+ self.num_features = num_features
138
+
139
+ def forward(self, x):
140
+ x = x * (self.num_features/x.shape[-1])
141
+ x = torch.cat((x, torch.zeros(*x.shape[:-1], self.num_features - x.shape[-1], device=x.device)), -1)
142
+ return self.base_encoder(x)
143
+
144
+
145
+ def get_variable_num_features_encoder(encoder_creator):
146
+ return lambda num_features, emsize: VariableNumFeaturesEncoder(encoder_creator(num_features, emsize), num_features)
147
+
148
+ class NoMeanEncoder(nn.Module):
149
+ """
150
+ This can be useful for any prior that is translation invariant in x or y.
151
+ A standard GP for example is translation invariant in x.
152
+ That is, GP(x_test+const,x_train+const,y_train) = GP(x_test,x_train,y_train).
153
+ """
154
+ def __init__(self, base_encoder):
155
+ super().__init__()
156
+ self.base_encoder = base_encoder
157
+
158
+ def forward(self, x):
159
+ return self.base_encoder(x - x.mean(0, keepdim=True))
160
+
161
+
162
+ def get_no_mean_encoder(encoder_creator):
163
+ return lambda num_features, emsize: NoMeanEncoder(encoder_creator(num_features, emsize))
164
+
165
+ Linear = nn.Linear
166
+ MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
167
+ nn.ReLU(),
168
+ nn.Linear(emsize*2,emsize))
169
+
170
+ class NanHandlingEncoder(nn.Module):
171
+ def __init__(self, num_features, emsize, keep_nans=True):
172
+ super().__init__()
173
+ self.num_features = 2 * num_features if keep_nans else num_features
174
+ self.emsize = emsize
175
+ self.keep_nans = keep_nans
176
+ self.layer = nn.Linear(self.num_features, self.emsize)
177
+
178
+ def forward(self, x):
179
+ if self.keep_nans:
180
+ x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
181
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
182
+ + torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
183
+ )], -1)
184
+ else:
185
+ x = torch.nan_to_num(x, nan=0.0)
186
+ return self.layer(x)
187
+
188
+
189
+ class Linear(nn.Linear):
190
+ def __init__(self, num_features, emsize, replace_nan_by_zero=False):
191
+ super().__init__(num_features, emsize)
192
+ self.num_features = num_features
193
+ self.emsize = emsize
194
+ self.replace_nan_by_zero = replace_nan_by_zero
195
+
196
+ def forward(self, x):
197
+ if self.replace_nan_by_zero:
198
+ x = torch.nan_to_num(x, nan=0.0)
199
+ return super().forward(x)
200
+
201
+ def __setstate__(self, state):
202
+ super().__setstate__(state)
203
+ self.__dict__.setdefault('replace_nan_by_zero', True)
204
+
205
+
206
+ class Conv(nn.Module):
207
+ def __init__(self, input_size, emsize):
208
+ super().__init__()
209
+ self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
210
+ self.linear = nn.Linear(64,emsize)
211
+
212
+ def forward(self, x):
213
+ size = math.isqrt(x.shape[-1])
214
+ assert size*size == x.shape[-1]
215
+ x = x.reshape(*x.shape[:-1], 1, size, size)
216
+ for conv in self.convs:
217
+ if x.shape[-1] < 4:
218
+ break
219
+ x = conv(x)
220
+ x.relu_()
221
+ x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
222
+ return self.linear(x)
223
+
224
+
225
+ class CanEmb(nn.Embedding):
226
+ def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
227
+ assert embedding_dim % num_features == 0
228
+ embedding_dim = embedding_dim // num_features
229
+ super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
230
+
231
+ def forward(self, x):
232
+ lx = x.long()
233
+ assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
234
+ x = super().forward(lx)
235
+ return x.view(*x.shape[:-2], -1)
236
+
237
+
238
+ def get_Canonical(num_classes):
239
+ return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
240
+
241
+
242
+ def get_Embedding(num_embs_per_feature=100):
243
+ return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)