Spaces:
Runtime error
Runtime error
Anonymous
commited on
Commit
·
2bb865d
1
Parent(s):
5b4bf3b
Revert "Final code"
Browse filesThis reverts commit 5b4bf3bc57df685ce7894faa33d20baca68215fc.
- app.py +1 -1
- transformer.py +232 -0
app.py
CHANGED
@@ -17,7 +17,7 @@ def compute(table: np.array):
|
|
17 |
vfunc = np.vectorize(lambda s: len(str(s)))
|
18 |
non_empty_row_mask = (vfunc(table).sum(1) != 0)
|
19 |
table = table[non_empty_row_mask]
|
20 |
-
empty_mask = table == ''
|
21 |
empty_inds = np.where(empty_mask)
|
22 |
if table.shape[0] > 1024:
|
23 |
return "⚠️ **ERROR: TabPFN is not made for datasets with a trainingsize > 1024.**", None, None
|
|
|
17 |
vfunc = np.vectorize(lambda s: len(str(s)))
|
18 |
non_empty_row_mask = (vfunc(table).sum(1) != 0)
|
19 |
table = table[non_empty_row_mask]
|
20 |
+
empty_mask = table == '(predict)'
|
21 |
empty_inds = np.where(empty_mask)
|
22 |
if table.shape[0] > 1024:
|
23 |
return "⚠️ **ERROR: TabPFN is not made for datasets with a trainingsize > 1024.**", None, None
|
transformer.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import Module, TransformerEncoder
|
8 |
+
|
9 |
+
from layer import TransformerEncoderLayer, _get_activation_fn
|
10 |
+
from utils import SeqBN, bool_mask_to_att_mask
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class TransformerModel(nn.Module):
|
15 |
+
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
|
16 |
+
pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
|
17 |
+
activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
|
18 |
+
all_layers_same_init=False, efficient_eval_masking=True):
|
19 |
+
super().__init__()
|
20 |
+
self.model_type = 'Transformer'
|
21 |
+
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
|
22 |
+
pre_norm=pre_norm, recompute_attn=recompute_attn)
|
23 |
+
self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
|
24 |
+
if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
|
25 |
+
self.ninp = ninp
|
26 |
+
self.encoder = encoder
|
27 |
+
self.y_encoder = y_encoder
|
28 |
+
self.pos_encoder = pos_encoder
|
29 |
+
self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
|
30 |
+
self.input_ln = SeqBN(ninp) if input_normalization else None
|
31 |
+
self.style_encoder = style_encoder
|
32 |
+
self.init_method = init_method
|
33 |
+
if num_global_att_tokens is not None:
|
34 |
+
assert not full_attention
|
35 |
+
self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
36 |
+
self.full_attention = full_attention
|
37 |
+
self.efficient_eval_masking = efficient_eval_masking
|
38 |
+
|
39 |
+
self.n_out = n_out
|
40 |
+
self.nhid = nhid
|
41 |
+
|
42 |
+
self.init_weights()
|
43 |
+
|
44 |
+
def __setstate__(self, state):
|
45 |
+
super().__setstate__(state)
|
46 |
+
self.__dict__.setdefault('efficient_eval_masking', False)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def generate_square_subsequent_mask(sz):
|
50 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
51 |
+
return bool_mask_to_att_mask(mask)
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def generate_D_q_matrix(sz, query_size):
|
55 |
+
train_size = sz-query_size
|
56 |
+
mask = torch.zeros(sz,sz) == 0
|
57 |
+
mask[:,train_size:].zero_()
|
58 |
+
mask |= torch.eye(sz) == 1
|
59 |
+
return bool_mask_to_att_mask(mask)
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
63 |
+
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
64 |
+
sz = seq_len + num_global_att_tokens
|
65 |
+
mask = torch.zeros(num_query_tokens, sz) == 0
|
66 |
+
mask[:,train_size:].zero_()
|
67 |
+
mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
|
68 |
+
return bool_mask_to_att_mask(mask)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
72 |
+
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
73 |
+
trainset_size = seq_len - num_query_tokens
|
74 |
+
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
|
75 |
+
#mask[:,num_global_att_tokens:].zero_()
|
76 |
+
#mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
|
77 |
+
return bool_mask_to_att_mask(mask)
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
81 |
+
mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
|
82 |
+
return bool_mask_to_att_mask(mask)
|
83 |
+
|
84 |
+
def init_weights(self):
|
85 |
+
initrange = 1.
|
86 |
+
# if isinstance(self.encoder,EmbeddingEncoder):
|
87 |
+
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
88 |
+
# self.decoder.bias.data.zero_()
|
89 |
+
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
90 |
+
if self.init_method is not None:
|
91 |
+
self.apply(self.init_method)
|
92 |
+
for layer in self.transformer_encoder.layers:
|
93 |
+
nn.init.zeros_(layer.linear2.weight)
|
94 |
+
nn.init.zeros_(layer.linear2.bias)
|
95 |
+
attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
|
96 |
+
for attn in attns:
|
97 |
+
nn.init.zeros_(attn.out_proj.weight)
|
98 |
+
nn.init.zeros_(attn.out_proj.bias)
|
99 |
+
|
100 |
+
def forward(self, src, src_mask=None, single_eval_pos=None):
|
101 |
+
assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
|
102 |
+
|
103 |
+
if len(src) == 2: # (x,y) and no style
|
104 |
+
src = (None,) + src
|
105 |
+
|
106 |
+
style_src, x_src, y_src = src
|
107 |
+
x_src = self.encoder(x_src)
|
108 |
+
y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
|
109 |
+
style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else \
|
110 |
+
torch.tensor([], device=x_src.device)
|
111 |
+
global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
|
112 |
+
self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
|
113 |
+
|
114 |
+
if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
115 |
+
if src_mask is None:
|
116 |
+
if self.global_att_embeddings is None:
|
117 |
+
full_len = len(x_src) + len(style_src)
|
118 |
+
if self.full_attention:
|
119 |
+
src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
|
120 |
+
elif self.efficient_eval_masking:
|
121 |
+
src_mask = single_eval_pos + len(style_src)
|
122 |
+
else:
|
123 |
+
src_mask = self.generate_D_q_matrix(full_len, len(x_src) - single_eval_pos).to(x_src.device)
|
124 |
+
else:
|
125 |
+
src_mask_args = (self.global_att_embeddings.num_embeddings,
|
126 |
+
len(x_src) + len(style_src),
|
127 |
+
len(x_src) + len(style_src) - single_eval_pos)
|
128 |
+
src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
|
129 |
+
self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
|
130 |
+
self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
|
131 |
+
|
132 |
+
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
133 |
+
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
134 |
+
|
135 |
+
if self.input_ln is not None:
|
136 |
+
src = self.input_ln(src)
|
137 |
+
|
138 |
+
if self.pos_encoder is not None:
|
139 |
+
src = self.pos_encoder(src)
|
140 |
+
|
141 |
+
output = self.transformer_encoder(src, src_mask)
|
142 |
+
output = self.decoder(output)
|
143 |
+
return output[single_eval_pos+len(style_src)+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
|
144 |
+
|
145 |
+
@torch.no_grad()
|
146 |
+
def init_from_small_model(self, small_model):
|
147 |
+
assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
|
148 |
+
and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
|
149 |
+
|
150 |
+
def set_encoder_weights(my_encoder, small_model_encoder):
|
151 |
+
my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
|
152 |
+
if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
|
153 |
+
small_in_dim = small_encoder_linear.out_features
|
154 |
+
my_encoder_linear.weight.zero_()
|
155 |
+
my_encoder_linear.bias.zero_()
|
156 |
+
my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
|
157 |
+
my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
|
158 |
+
|
159 |
+
set_encoder_weights(self.encoder, small_model.encoder)
|
160 |
+
set_encoder_weights(self.y_encoder, small_model.y_encoder)
|
161 |
+
|
162 |
+
small_in_dim = small_model.decoder.in_features
|
163 |
+
|
164 |
+
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
|
165 |
+
self.decoder.bias = small_model.decoder.bias
|
166 |
+
|
167 |
+
for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
|
168 |
+
small_hid_dim = small_layer.linear1.out_features
|
169 |
+
my_in_dim = my_layer.linear1.in_features
|
170 |
+
|
171 |
+
# packed along q,k,v order in first dim
|
172 |
+
my_in_proj_w = my_layer.self_attn.in_proj_weight
|
173 |
+
small_in_proj_w = small_layer.self_attn.in_proj_weight
|
174 |
+
|
175 |
+
my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
|
176 |
+
small_in_dim,
|
177 |
+
small_in_dim)
|
178 |
+
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
|
179 |
+
:small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
|
180 |
+
|
181 |
+
my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
|
182 |
+
my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
|
183 |
+
|
184 |
+
my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
|
185 |
+
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
|
186 |
+
|
187 |
+
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
|
188 |
+
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
|
189 |
+
|
190 |
+
my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
|
191 |
+
my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
|
192 |
+
|
193 |
+
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
|
194 |
+
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
|
195 |
+
|
196 |
+
|
197 |
+
class TransformerEncoderDiffInit(Module):
|
198 |
+
r"""TransformerEncoder is a stack of N encoder layers
|
199 |
+
|
200 |
+
Args:
|
201 |
+
encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
|
202 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
203 |
+
norm: the layer normalization component (optional).
|
204 |
+
"""
|
205 |
+
__constants__ = ['norm']
|
206 |
+
|
207 |
+
def __init__(self, encoder_layer_creator, num_layers, norm=None):
|
208 |
+
super().__init__()
|
209 |
+
self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
|
210 |
+
self.num_layers = num_layers
|
211 |
+
self.norm = norm
|
212 |
+
|
213 |
+
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
214 |
+
r"""Pass the input through the encoder layers in turn.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
src: the sequence to the encoder (required).
|
218 |
+
mask: the mask for the src sequence (optional).
|
219 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
220 |
+
|
221 |
+
Shape:
|
222 |
+
see the docs in Transformer class.
|
223 |
+
"""
|
224 |
+
output = src
|
225 |
+
|
226 |
+
for mod in self.layers:
|
227 |
+
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
|
228 |
+
|
229 |
+
if self.norm is not None:
|
230 |
+
output = self.norm(output)
|
231 |
+
|
232 |
+
return output
|