Anonymous commited on
Commit
2bb865d
·
1 Parent(s): 5b4bf3b

Revert "Final code"

Browse files

This reverts commit 5b4bf3bc57df685ce7894faa33d20baca68215fc.

Files changed (2) hide show
  1. app.py +1 -1
  2. transformer.py +232 -0
app.py CHANGED
@@ -17,7 +17,7 @@ def compute(table: np.array):
17
  vfunc = np.vectorize(lambda s: len(str(s)))
18
  non_empty_row_mask = (vfunc(table).sum(1) != 0)
19
  table = table[non_empty_row_mask]
20
- empty_mask = table == ''
21
  empty_inds = np.where(empty_mask)
22
  if table.shape[0] > 1024:
23
  return "⚠️ **ERROR: TabPFN is not made for datasets with a trainingsize > 1024.**", None, None
 
17
  vfunc = np.vectorize(lambda s: len(str(s)))
18
  non_empty_row_mask = (vfunc(table).sum(1) != 0)
19
  table = table[non_empty_row_mask]
20
+ empty_mask = table == '(predict)'
21
  empty_inds = np.where(empty_mask)
22
  if table.shape[0] > 1024:
23
  return "⚠️ **ERROR: TabPFN is not made for datasets with a trainingsize > 1024.**", None, None
transformer.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from torch.nn import Module, TransformerEncoder
8
+
9
+ from layer import TransformerEncoderLayer, _get_activation_fn
10
+ from utils import SeqBN, bool_mask_to_att_mask
11
+
12
+
13
+
14
+ class TransformerModel(nn.Module):
15
+ def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
16
+ pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
17
+ activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
18
+ all_layers_same_init=False, efficient_eval_masking=True):
19
+ super().__init__()
20
+ self.model_type = 'Transformer'
21
+ encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
22
+ pre_norm=pre_norm, recompute_attn=recompute_attn)
23
+ self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
24
+ if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
25
+ self.ninp = ninp
26
+ self.encoder = encoder
27
+ self.y_encoder = y_encoder
28
+ self.pos_encoder = pos_encoder
29
+ self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
30
+ self.input_ln = SeqBN(ninp) if input_normalization else None
31
+ self.style_encoder = style_encoder
32
+ self.init_method = init_method
33
+ if num_global_att_tokens is not None:
34
+ assert not full_attention
35
+ self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
36
+ self.full_attention = full_attention
37
+ self.efficient_eval_masking = efficient_eval_masking
38
+
39
+ self.n_out = n_out
40
+ self.nhid = nhid
41
+
42
+ self.init_weights()
43
+
44
+ def __setstate__(self, state):
45
+ super().__setstate__(state)
46
+ self.__dict__.setdefault('efficient_eval_masking', False)
47
+
48
+ @staticmethod
49
+ def generate_square_subsequent_mask(sz):
50
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
51
+ return bool_mask_to_att_mask(mask)
52
+
53
+ @staticmethod
54
+ def generate_D_q_matrix(sz, query_size):
55
+ train_size = sz-query_size
56
+ mask = torch.zeros(sz,sz) == 0
57
+ mask[:,train_size:].zero_()
58
+ mask |= torch.eye(sz) == 1
59
+ return bool_mask_to_att_mask(mask)
60
+
61
+ @staticmethod
62
+ def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
63
+ train_size = seq_len + num_global_att_tokens - num_query_tokens
64
+ sz = seq_len + num_global_att_tokens
65
+ mask = torch.zeros(num_query_tokens, sz) == 0
66
+ mask[:,train_size:].zero_()
67
+ mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
68
+ return bool_mask_to_att_mask(mask)
69
+
70
+ @staticmethod
71
+ def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
72
+ train_size = seq_len + num_global_att_tokens - num_query_tokens
73
+ trainset_size = seq_len - num_query_tokens
74
+ mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
75
+ #mask[:,num_global_att_tokens:].zero_()
76
+ #mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
77
+ return bool_mask_to_att_mask(mask)
78
+
79
+ @staticmethod
80
+ def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
81
+ mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
82
+ return bool_mask_to_att_mask(mask)
83
+
84
+ def init_weights(self):
85
+ initrange = 1.
86
+ # if isinstance(self.encoder,EmbeddingEncoder):
87
+ # self.encoder.weight.data.uniform_(-initrange, initrange)
88
+ # self.decoder.bias.data.zero_()
89
+ # self.decoder.weight.data.uniform_(-initrange, initrange)
90
+ if self.init_method is not None:
91
+ self.apply(self.init_method)
92
+ for layer in self.transformer_encoder.layers:
93
+ nn.init.zeros_(layer.linear2.weight)
94
+ nn.init.zeros_(layer.linear2.bias)
95
+ attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
96
+ for attn in attns:
97
+ nn.init.zeros_(attn.out_proj.weight)
98
+ nn.init.zeros_(attn.out_proj.bias)
99
+
100
+ def forward(self, src, src_mask=None, single_eval_pos=None):
101
+ assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
102
+
103
+ if len(src) == 2: # (x,y) and no style
104
+ src = (None,) + src
105
+
106
+ style_src, x_src, y_src = src
107
+ x_src = self.encoder(x_src)
108
+ y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
109
+ style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else \
110
+ torch.tensor([], device=x_src.device)
111
+ global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
112
+ self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
113
+
114
+ if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
115
+ if src_mask is None:
116
+ if self.global_att_embeddings is None:
117
+ full_len = len(x_src) + len(style_src)
118
+ if self.full_attention:
119
+ src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
120
+ elif self.efficient_eval_masking:
121
+ src_mask = single_eval_pos + len(style_src)
122
+ else:
123
+ src_mask = self.generate_D_q_matrix(full_len, len(x_src) - single_eval_pos).to(x_src.device)
124
+ else:
125
+ src_mask_args = (self.global_att_embeddings.num_embeddings,
126
+ len(x_src) + len(style_src),
127
+ len(x_src) + len(style_src) - single_eval_pos)
128
+ src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
129
+ self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
130
+ self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
131
+
132
+ train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
133
+ src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
134
+
135
+ if self.input_ln is not None:
136
+ src = self.input_ln(src)
137
+
138
+ if self.pos_encoder is not None:
139
+ src = self.pos_encoder(src)
140
+
141
+ output = self.transformer_encoder(src, src_mask)
142
+ output = self.decoder(output)
143
+ return output[single_eval_pos+len(style_src)+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
144
+
145
+ @torch.no_grad()
146
+ def init_from_small_model(self, small_model):
147
+ assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
148
+ and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
149
+
150
+ def set_encoder_weights(my_encoder, small_model_encoder):
151
+ my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
152
+ if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
153
+ small_in_dim = small_encoder_linear.out_features
154
+ my_encoder_linear.weight.zero_()
155
+ my_encoder_linear.bias.zero_()
156
+ my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
157
+ my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
158
+
159
+ set_encoder_weights(self.encoder, small_model.encoder)
160
+ set_encoder_weights(self.y_encoder, small_model.y_encoder)
161
+
162
+ small_in_dim = small_model.decoder.in_features
163
+
164
+ self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
165
+ self.decoder.bias = small_model.decoder.bias
166
+
167
+ for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
168
+ small_hid_dim = small_layer.linear1.out_features
169
+ my_in_dim = my_layer.linear1.in_features
170
+
171
+ # packed along q,k,v order in first dim
172
+ my_in_proj_w = my_layer.self_attn.in_proj_weight
173
+ small_in_proj_w = small_layer.self_attn.in_proj_weight
174
+
175
+ my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
176
+ small_in_dim,
177
+ small_in_dim)
178
+ my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
179
+ :small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
180
+
181
+ my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
182
+ my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
183
+
184
+ my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
185
+ my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
186
+
187
+ my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
188
+ my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
189
+
190
+ my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
191
+ my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
192
+
193
+ my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
194
+ my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
195
+
196
+
197
+ class TransformerEncoderDiffInit(Module):
198
+ r"""TransformerEncoder is a stack of N encoder layers
199
+
200
+ Args:
201
+ encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
202
+ num_layers: the number of sub-encoder-layers in the encoder (required).
203
+ norm: the layer normalization component (optional).
204
+ """
205
+ __constants__ = ['norm']
206
+
207
+ def __init__(self, encoder_layer_creator, num_layers, norm=None):
208
+ super().__init__()
209
+ self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
210
+ self.num_layers = num_layers
211
+ self.norm = norm
212
+
213
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
214
+ r"""Pass the input through the encoder layers in turn.
215
+
216
+ Args:
217
+ src: the sequence to the encoder (required).
218
+ mask: the mask for the src sequence (optional).
219
+ src_key_padding_mask: the mask for the src keys per batch (optional).
220
+
221
+ Shape:
222
+ see the docs in Transformer class.
223
+ """
224
+ output = src
225
+
226
+ for mod in self.layers:
227
+ output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
228
+
229
+ if self.norm is not None:
230
+ output = self.norm(output)
231
+
232
+ return output