Spaces:
Runtime error
Runtime error
Upload layer.py
Browse files
layer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn.modules.transformer import *
|
5 |
+
from torch.nn.modules.transformer import _get_activation_fn
|
6 |
+
|
7 |
+
from torch.utils.checkpoint import checkpoint
|
8 |
+
|
9 |
+
|
10 |
+
class TransformerEncoderLayer(Module):
|
11 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
12 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
13 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
14 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
15 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
16 |
+
in a different way during application.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
d_model: the number of expected features in the input (required).
|
20 |
+
nhead: the number of heads in the multiheadattention models (required).
|
21 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
22 |
+
dropout: the dropout value (default=0.1).
|
23 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
24 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
25 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
26 |
+
as (batch, seq, feature). Default: ``False``.
|
27 |
+
|
28 |
+
Examples::
|
29 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
30 |
+
>>> src = torch.rand(10, 32, 512)
|
31 |
+
>>> out = encoder_layer(src)
|
32 |
+
|
33 |
+
Alternatively, when ``batch_first`` is ``True``:
|
34 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
35 |
+
>>> src = torch.rand(32, 10, 512)
|
36 |
+
>>> out = encoder_layer(src)
|
37 |
+
"""
|
38 |
+
__constants__ = ['batch_first']
|
39 |
+
|
40 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
41 |
+
layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
|
42 |
+
device=None, dtype=None, recompute_attn=False) -> None:
|
43 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
44 |
+
super().__init__()
|
45 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
46 |
+
**factory_kwargs)
|
47 |
+
# Implementation of Feedforward model
|
48 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
49 |
+
self.dropout = Dropout(dropout)
|
50 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
51 |
+
|
52 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
53 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
54 |
+
self.dropout1 = Dropout(dropout)
|
55 |
+
self.dropout2 = Dropout(dropout)
|
56 |
+
self.pre_norm = pre_norm
|
57 |
+
self.recompute_attn = recompute_attn
|
58 |
+
|
59 |
+
self.activation = _get_activation_fn(activation)
|
60 |
+
|
61 |
+
def __setstate__(self, state):
|
62 |
+
if 'activation' not in state:
|
63 |
+
state['activation'] = F.relu
|
64 |
+
super().__setstate__(state)
|
65 |
+
|
66 |
+
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
67 |
+
r"""Pass the input through the encoder layer.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
src: the sequence to the encoder layer (required).
|
71 |
+
src_mask: the mask for the src sequence (optional).
|
72 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
73 |
+
|
74 |
+
Shape:
|
75 |
+
see the docs in Transformer class.
|
76 |
+
"""
|
77 |
+
if self.pre_norm:
|
78 |
+
src_ = self.norm1(src)
|
79 |
+
else:
|
80 |
+
src_ = src
|
81 |
+
if isinstance(src_mask, tuple):
|
82 |
+
# global attention setup
|
83 |
+
assert not self.self_attn.batch_first
|
84 |
+
assert src_key_padding_mask is None
|
85 |
+
|
86 |
+
global_src_mask, trainset_src_mask, valset_src_mask = src_mask
|
87 |
+
|
88 |
+
num_global_tokens = global_src_mask.shape[0]
|
89 |
+
num_train_tokens = trainset_src_mask.shape[0]
|
90 |
+
|
91 |
+
global_tokens_src = src_[:num_global_tokens]
|
92 |
+
train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
|
93 |
+
global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
|
94 |
+
eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
|
95 |
+
|
96 |
+
|
97 |
+
attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
|
98 |
+
|
99 |
+
global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
|
100 |
+
train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
|
101 |
+
eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
|
102 |
+
None, True, valset_src_mask)[0]
|
103 |
+
|
104 |
+
src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
|
105 |
+
|
106 |
+
elif isinstance(src_mask, int):
|
107 |
+
assert src_key_padding_mask is None
|
108 |
+
single_eval_position = src_mask
|
109 |
+
src_left = self.self_attn(src_[:single_eval_position], src_[:single_eval_position], src_[:single_eval_position])[0]
|
110 |
+
src_right = self.self_attn(src_[single_eval_position:], src_[:single_eval_position], src_[:single_eval_position])[0]
|
111 |
+
src2 = torch.cat([src_left, src_right], dim=0)
|
112 |
+
else:
|
113 |
+
if self.recompute_attn:
|
114 |
+
src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
|
115 |
+
else:
|
116 |
+
src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
|
117 |
+
key_padding_mask=src_key_padding_mask)[0]
|
118 |
+
src = src + self.dropout1(src2)
|
119 |
+
if not self.pre_norm:
|
120 |
+
src = self.norm1(src)
|
121 |
+
|
122 |
+
if self.pre_norm:
|
123 |
+
src_ = self.norm2(src)
|
124 |
+
else:
|
125 |
+
src_ = src
|
126 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
|
127 |
+
src = src + self.dropout2(src2)
|
128 |
+
|
129 |
+
if not self.pre_norm:
|
130 |
+
src = self.norm2(src)
|
131 |
+
return src
|