Upload 8 files
Browse files- .gitattributes +2 -0
- data/train.en +3 -0
- data/train.vi +3 -0
- data/tst2012.en +0 -0
- data/tst2012.vi +0 -0
- data/tst2013.en +0 -0
- data/tst2013.vi +0 -0
- transformer.pth +3 -0
- transformer.py +787 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/train.en filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/train.vi filter=lfs diff=lfs merge=lfs -text
|
data/train.en
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c26dfeed74b6bf3752f5ca552f2412456f0de153f7c804df8717931fb3a5c78a
|
3 |
+
size 13603614
|
data/train.vi
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:707206edf2dc0280273952c7b70544ea8a1363aa69aaeb9d70514b888dc3067d
|
3 |
+
size 18074646
|
data/tst2012.en
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/tst2012.vi
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/tst2013.en
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/tst2013.vi
ADDED
The diff for this file is too large to render.
See raw diff
|
|
transformer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0a647a76a564f3ae42b372c26ad4361a333df342ffc9fb1a773d22fa9123b6ad
|
3 |
+
size 347866211
|
transformer.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
!python -m spacy download en_core_web_sm
|
5 |
+
import nltk
|
6 |
+
nltk.download('wordnet')
|
7 |
+
|
8 |
+
!pip install https://gitlab.com/trungtv/vi_spacy/-/raw/master/packages/vi_core_news_lg-3.6.0/dist/vi_core_news_lg-3.6.0.tar.gz
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.autograd import Variable
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import numpy as np
|
15 |
+
import os
|
16 |
+
import math
|
17 |
+
import nltk
|
18 |
+
import spacy
|
19 |
+
|
20 |
+
|
21 |
+
class Embedder(nn.Module):
|
22 |
+
def __init__(self, vocab_size, d_model):
|
23 |
+
super().__init__()
|
24 |
+
self.vocab_size = vocab_size
|
25 |
+
self.d_model = d_model
|
26 |
+
|
27 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return self.embed(x)
|
31 |
+
|
32 |
+
|
33 |
+
class PositionalEncoder(nn.Module):
|
34 |
+
def __init__(self, d_model, max_seq_length=200, dropout=0.1):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.d_model = d_model
|
38 |
+
self.dropout = nn.Dropout(dropout)
|
39 |
+
|
40 |
+
pe = torch.zeros(max_seq_length, d_model)
|
41 |
+
|
42 |
+
|
43 |
+
for pos in range(max_seq_length):
|
44 |
+
for i in range(0, d_model, 2):
|
45 |
+
pe[pos, i] = math.sin(pos/(10000**(2*i/d_model)))
|
46 |
+
pe[pos, i+1] = math.cos(pos/(10000**((2*i+1)/d_model)))
|
47 |
+
pe = pe.unsqueeze(0)
|
48 |
+
self.register_buffer('pe', pe)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
|
52 |
+
x = x*math.sqrt(self.d_model)
|
53 |
+
seq_length = x.size(1)
|
54 |
+
|
55 |
+
pe = Variable(self.pe[:, :seq_length], requires_grad=False)
|
56 |
+
|
57 |
+
if x.is_cuda:
|
58 |
+
pe.cuda()
|
59 |
+
# cộng embedding vector với pe
|
60 |
+
x = x + pe
|
61 |
+
x = self.dropout(x)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
def attention(q, k, v, mask=None, dropout=None):
|
68 |
+
"""
|
69 |
+
q: batch_size x head x seq_length x d_model
|
70 |
+
k: batch_size x head x seq_length x d_model
|
71 |
+
v: batch_size x head x seq_length x d_model
|
72 |
+
mask: batch_size x 1 x 1 x seq_length
|
73 |
+
output: batch_size x head x seq_length x d_model
|
74 |
+
"""
|
75 |
+
|
76 |
+
# attention score được tính bằng cách nhân q với k
|
77 |
+
d_k = q.size(-1)
|
78 |
+
scores = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(d_k)
|
79 |
+
|
80 |
+
if mask is not None:
|
81 |
+
mask = mask.unsqueeze(1)
|
82 |
+
scores = scores.masked_fill(mask==0, -1e9)
|
83 |
+
# chuẩn hóa bằng softmax
|
84 |
+
scores = F.softmax(scores, dim=-1)
|
85 |
+
|
86 |
+
if dropout is not None:
|
87 |
+
scores = dropout(scores)
|
88 |
+
|
89 |
+
output = torch.matmul(scores, v)
|
90 |
+
return output, scores
|
91 |
+
|
92 |
+
class MultiHeadAttention(nn.Module):
|
93 |
+
def __init__(self, heads, d_model, dropout=0.1):
|
94 |
+
super().__init__()
|
95 |
+
assert d_model % heads == 0
|
96 |
+
|
97 |
+
self.d_model = d_model
|
98 |
+
self.d_k = d_model//heads
|
99 |
+
self.h = heads
|
100 |
+
self.attn = None
|
101 |
+
|
102 |
+
# tạo ra 3 ma trận trọng số là q_linear, k_linear, v_linear
|
103 |
+
self.q_linear = nn.Linear(d_model, d_model)
|
104 |
+
self.k_linear = nn.Linear(d_model, d_model)
|
105 |
+
self.v_linear = nn.Linear(d_model, d_model)
|
106 |
+
|
107 |
+
self.dropout = nn.Dropout(dropout)
|
108 |
+
self.out = nn.Linear(d_model, d_model)
|
109 |
+
|
110 |
+
def forward(self, q, k, v, mask=None):
|
111 |
+
"""
|
112 |
+
q: batch_size x seq_length x d_model
|
113 |
+
k: batch_size x seq_length x d_model
|
114 |
+
v: batch_size x seq_length x d_model
|
115 |
+
mask: batch_size x 1 x seq_length
|
116 |
+
output: batch_size x seq_length x d_model
|
117 |
+
"""
|
118 |
+
bs = q.size(0)
|
119 |
+
# nhân ma trận trọng số q_linear, k_linear, v_linear với dữ liệu đầu vào q, k, v
|
120 |
+
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
|
121 |
+
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
|
122 |
+
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
|
123 |
+
|
124 |
+
q = q.transpose(1, 2)
|
125 |
+
k = k.transpose(1, 2)
|
126 |
+
v = v.transpose(1, 2)
|
127 |
+
|
128 |
+
# tính attention score
|
129 |
+
scores, self.attn = attention(q, k, v, mask, self.dropout)
|
130 |
+
|
131 |
+
concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
|
132 |
+
|
133 |
+
output = self.out(concat)
|
134 |
+
return output
|
135 |
+
|
136 |
+
"""# Normalization Layer
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
"""
|
142 |
+
|
143 |
+
class Norm(nn.Module):
|
144 |
+
def __init__(self, d_model, eps = 1e-6):
|
145 |
+
super().__init__()
|
146 |
+
|
147 |
+
self.size = d_model
|
148 |
+
|
149 |
+
|
150 |
+
self.alpha = nn.Parameter(torch.ones(self.size))
|
151 |
+
self.bias = nn.Parameter(torch.zeros(self.size))
|
152 |
+
|
153 |
+
self.eps = eps
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
|
157 |
+
/ (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
|
158 |
+
return norm
|
159 |
+
|
160 |
+
class FeedForward(nn.Module):
|
161 |
+
|
162 |
+
def __init__(self, d_model, d_ff=2048, dropout = 0.1):
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
|
166 |
+
self.linear_1 = nn.Linear(d_model, d_ff)
|
167 |
+
self.dropout = nn.Dropout(dropout)
|
168 |
+
self.linear_2 = nn.Linear(d_ff, d_model)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
x = self.dropout(F.relu(self.linear_1(x)))
|
172 |
+
x = self.linear_2(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
class EncoderLayer(nn.Module):
|
176 |
+
def __init__(self, d_model, heads, dropout=0.1):
|
177 |
+
super().__init__()
|
178 |
+
self.norm_1 = Norm(d_model)
|
179 |
+
self.norm_2 = Norm(d_model)
|
180 |
+
self.attn = MultiHeadAttention(heads, d_model, dropout=dropout)
|
181 |
+
self.ff = FeedForward(d_model, dropout=dropout)
|
182 |
+
self.dropout_1 = nn.Dropout(dropout)
|
183 |
+
self.dropout_2 = nn.Dropout(dropout)
|
184 |
+
|
185 |
+
def forward(self, x, mask):
|
186 |
+
"""
|
187 |
+
x: batch_size x seq_length x d_model
|
188 |
+
mask: batch_size x 1 x seq_length
|
189 |
+
output: batch_size x seq_length x d_model
|
190 |
+
"""
|
191 |
+
|
192 |
+
|
193 |
+
x2 = self.norm_1(x)
|
194 |
+
# tính attention value
|
195 |
+
x = x + self.dropout_1(self.attn(x2,x2,x2,mask))
|
196 |
+
x2 = self.norm_2(x)
|
197 |
+
x = x + self.dropout_2(self.ff(x2))
|
198 |
+
return x
|
199 |
+
|
200 |
+
"""# Decoder
|
201 |
+
Decoder thực hiện chức năng giải mã vector của câu nguồn thành câu đích
|
202 |
+
|
203 |
+
## Và Masked Multi Head Attention
|
204 |
+
|
205 |
+
"""
|
206 |
+
|
207 |
+
class DecoderLayer(nn.Module):
|
208 |
+
def __init__(self, d_model, heads, dropout=0.1):
|
209 |
+
super().__init__()
|
210 |
+
self.norm_1 = Norm(d_model)
|
211 |
+
self.norm_2 = Norm(d_model)
|
212 |
+
self.norm_3 = Norm(d_model)
|
213 |
+
|
214 |
+
self.dropout_1 = nn.Dropout(dropout)
|
215 |
+
self.dropout_2 = nn.Dropout(dropout)
|
216 |
+
self.dropout_3 = nn.Dropout(dropout)
|
217 |
+
|
218 |
+
self.attn_1 = MultiHeadAttention(heads, d_model, dropout=dropout)
|
219 |
+
self.attn_2 = MultiHeadAttention(heads, d_model, dropout=dropout)
|
220 |
+
self.ff = FeedForward(d_model, dropout=dropout)
|
221 |
+
|
222 |
+
def forward(self, x, e_outputs, src_mask, trg_mask):
|
223 |
+
"""
|
224 |
+
x: batch_size x seq_length x d_model
|
225 |
+
e_outputs: batch_size x seq_length x d_model
|
226 |
+
src_mask: batch_size x 1 x seq_length
|
227 |
+
trg_mask: batch_size x 1 x seq_length
|
228 |
+
"""
|
229 |
+
|
230 |
+
x2 = self.norm_1(x)
|
231 |
+
# multihead attention thứ nhất, chú ý các từ ở target
|
232 |
+
x = x + self.dropout_1(self.attn_1(x2, x2, x2, trg_mask))
|
233 |
+
x2 = self.norm_2(x)
|
234 |
+
# masked mulithead attention thứ 2. k, v là giá trị output của mô hình encoder
|
235 |
+
x = x + self.dropout_2(self.attn_2(x2, e_outputs, e_outputs, src_mask))
|
236 |
+
x2 = self.norm_3(x)
|
237 |
+
x = x + self.dropout_3(self.ff(x2))
|
238 |
+
return x
|
239 |
+
|
240 |
+
"""# Cài đặt Encoder
|
241 |
+
bao gồm N encoder layer
|
242 |
+
"""
|
243 |
+
|
244 |
+
import copy
|
245 |
+
|
246 |
+
def get_clones(module, N):
|
247 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
248 |
+
|
249 |
+
class Encoder(nn.Module):
|
250 |
+
"""Một encoder có nhiều encoder layer nhé !!!
|
251 |
+
"""
|
252 |
+
def __init__(self, vocab_size, d_model, N, heads, dropout):
|
253 |
+
super().__init__()
|
254 |
+
self.N = N
|
255 |
+
self.embed = Embedder(vocab_size, d_model)
|
256 |
+
self.pe = PositionalEncoder(d_model, dropout=dropout)
|
257 |
+
self.layers = get_clones(EncoderLayer(d_model, heads, dropout), N)
|
258 |
+
self.norm = Norm(d_model)
|
259 |
+
|
260 |
+
def forward(self, src, mask):
|
261 |
+
"""
|
262 |
+
src: batch_size x seq_length
|
263 |
+
mask: batch_size x 1 x seq_length
|
264 |
+
output: batch_size x seq_length x d_model
|
265 |
+
"""
|
266 |
+
x = self.embed(src)
|
267 |
+
x = self.pe(x)
|
268 |
+
for i in range(self.N):
|
269 |
+
x = self.layers[i](x, mask)
|
270 |
+
return self.norm(x)
|
271 |
+
|
272 |
+
"""# Cài đặt Decoder
|
273 |
+
bao gồm N decoder layers
|
274 |
+
"""
|
275 |
+
|
276 |
+
class Decoder(nn.Module):
|
277 |
+
"""Một decoder có nhiều decoder layer
|
278 |
+
"""
|
279 |
+
def __init__(self, vocab_size, d_model, N, heads, dropout):
|
280 |
+
super().__init__()
|
281 |
+
self.N = N
|
282 |
+
self.embed = Embedder(vocab_size, d_model)
|
283 |
+
self.pe = PositionalEncoder(d_model, dropout=dropout)
|
284 |
+
self.layers = get_clones(DecoderLayer(d_model, heads, dropout), N)
|
285 |
+
self.norm = Norm(d_model)
|
286 |
+
def forward(self, trg, e_outputs, src_mask, trg_mask):
|
287 |
+
"""
|
288 |
+
trg: batch_size x seq_length
|
289 |
+
e_outputs: batch_size x seq_length x d_model
|
290 |
+
src_mask: batch_size x 1 x seq_length
|
291 |
+
trg_mask: batch_size x 1 x seq_length
|
292 |
+
output: batch_size x seq_length x d_model
|
293 |
+
"""
|
294 |
+
x = self.embed(trg)
|
295 |
+
x = self.pe(x)
|
296 |
+
for i in range(self.N):
|
297 |
+
x = self.layers[i](x, e_outputs, src_mask, trg_mask)
|
298 |
+
return self.norm(x)
|
299 |
+
|
300 |
+
"""# Cài đặt Transformer
|
301 |
+
bao gồm encoder và decoder
|
302 |
+
"""
|
303 |
+
|
304 |
+
class Transformer(nn.Module):
|
305 |
+
# mô hình transformer hoàn chỉnh
|
306 |
+
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
|
307 |
+
super().__init__()
|
308 |
+
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
|
309 |
+
self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
|
310 |
+
self.out = nn.Linear(d_model, trg_vocab)
|
311 |
+
def forward(self, src, trg, src_mask, trg_mask):
|
312 |
+
|
313 |
+
#src: batch_size x seq_length
|
314 |
+
#trg: batch_size x seq_length
|
315 |
+
#src_mask: batch_size x 1 x seq_length
|
316 |
+
#trg_mask batch_size x 1 x seq_length
|
317 |
+
#output: batch_size x seq_length x vocab_size
|
318 |
+
|
319 |
+
e_outputs = self.encoder(src, src_mask)
|
320 |
+
|
321 |
+
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
|
322 |
+
output = self.out(d_output)
|
323 |
+
return output
|
324 |
+
|
325 |
+
from torchtext import data
|
326 |
+
#torchtext để load dữ liệu, giúp giảm thời gian và hiệu quả
|
327 |
+
class MyIterator(data.Iterator):
|
328 |
+
def create_batches(self):
|
329 |
+
if self.train:
|
330 |
+
def pool(d, random_shuffler):
|
331 |
+
for p in data.batch(d, self.batch_size * 100):
|
332 |
+
p_batch = data.batch(
|
333 |
+
sorted(p, key=self.sort_key),
|
334 |
+
self.batch_size, self.batch_size_fn)
|
335 |
+
for b in random_shuffler(list(p_batch)):
|
336 |
+
yield b
|
337 |
+
self.batches = pool(self.data(), self.random_shuffler)
|
338 |
+
|
339 |
+
else:
|
340 |
+
self.batches = []
|
341 |
+
for b in data.batch(self.data(), self.batch_size,
|
342 |
+
self.batch_size_fn):
|
343 |
+
self.batches.append(sorted(b, key=self.sort_key))
|
344 |
+
|
345 |
+
global max_src_in_batch, max_tgt_in_batch
|
346 |
+
|
347 |
+
def batch_size_fn(new, count, sofar):
|
348 |
+
global max_src_in_batch, max_tgt_in_batch
|
349 |
+
if count == 1:
|
350 |
+
max_src_in_batch = 0
|
351 |
+
max_tgt_in_batch = 0
|
352 |
+
max_src_in_batch = max(max_src_in_batch, len(new.src))
|
353 |
+
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
|
354 |
+
src_elements = count * max_src_in_batch
|
355 |
+
tgt_elements = count * max_tgt_in_batch
|
356 |
+
return max(src_elements, tgt_elements)
|
357 |
+
|
358 |
+
def nopeak_mask(size, device):
|
359 |
+
#Tạo mask được sử dụng trong decoder để lúc dự đoán trong quá trình huấn luyện mô hình không nhìn thấy được các từ ở tương lai
|
360 |
+
|
361 |
+
np_mask = np.triu(np.ones((1, size, size)),
|
362 |
+
k=1).astype('uint8')
|
363 |
+
np_mask = Variable(torch.from_numpy(np_mask) == 0)
|
364 |
+
np_mask = np_mask.to(device)
|
365 |
+
|
366 |
+
return np_mask
|
367 |
+
|
368 |
+
def create_masks(src, trg, src_pad, trg_pad, device):
|
369 |
+
#Tạo mask cho encoder, để mô hình không bỏ qua thông tin của các kí tự PAD do chúng ta thêm vào
|
370 |
+
|
371 |
+
src_mask = (src != src_pad).unsqueeze(-2)
|
372 |
+
|
373 |
+
if trg is not None:
|
374 |
+
trg_mask = (trg != trg_pad).unsqueeze(-2)
|
375 |
+
size = trg.size(1)
|
376 |
+
np_mask = nopeak_mask(size, device)
|
377 |
+
if trg.is_cuda:
|
378 |
+
np_mask.cuda()
|
379 |
+
trg_mask = trg_mask & np_mask
|
380 |
+
|
381 |
+
else:
|
382 |
+
trg_mask = None
|
383 |
+
return src_mask, trg_mask
|
384 |
+
|
385 |
+
from nltk.corpus import wordnet
|
386 |
+
import re
|
387 |
+
|
388 |
+
def get_synonym(word, SRC):
|
389 |
+
syns = wordnet.synsets(word)
|
390 |
+
for s in syns:
|
391 |
+
for l in s.lemmas():
|
392 |
+
if SRC.vocab.stoi[l.name()] != 0:
|
393 |
+
return SRC.vocab.stoi[l.name()]
|
394 |
+
|
395 |
+
return 0
|
396 |
+
|
397 |
+
def multiple_replace(dict, text):
|
398 |
+
regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
|
399 |
+
|
400 |
+
return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
|
401 |
+
|
402 |
+
def init_vars(src, model, SRC, TRG, device, k, max_len):
|
403 |
+
""" Tính toán các ma trận cần thiết trong quá trình translation sau khi mô hình học xong
|
404 |
+
"""
|
405 |
+
init_tok = TRG.vocab.stoi['<sos>']
|
406 |
+
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
|
407 |
+
|
408 |
+
# tính sẵn output của encoder
|
409 |
+
e_output = model.encoder(src, src_mask)
|
410 |
+
|
411 |
+
outputs = torch.LongTensor([[init_tok]])
|
412 |
+
|
413 |
+
outputs = outputs.to(device)
|
414 |
+
|
415 |
+
trg_mask = nopeak_mask(1, device)
|
416 |
+
# dự đoán kí tự đầu tiên
|
417 |
+
out = model.out(model.decoder(outputs,
|
418 |
+
e_output, src_mask, trg_mask))
|
419 |
+
out = F.softmax(out, dim=-1)
|
420 |
+
|
421 |
+
probs, ix = out[:, -1].data.topk(k)
|
422 |
+
log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0)
|
423 |
+
|
424 |
+
outputs = torch.zeros(k, max_len).long()
|
425 |
+
outputs = outputs.to(device)
|
426 |
+
outputs[:, 0] = init_tok
|
427 |
+
outputs[:, 1] = ix[0]
|
428 |
+
|
429 |
+
e_outputs = torch.zeros(k, e_output.size(-2),e_output.size(-1))
|
430 |
+
|
431 |
+
e_outputs = e_outputs.to(device)
|
432 |
+
e_outputs[:, :] = e_output[0]
|
433 |
+
|
434 |
+
return outputs, e_outputs, log_scores
|
435 |
+
|
436 |
+
def k_best_outputs(outputs, out, log_scores, i, k):
|
437 |
+
|
438 |
+
probs, ix = out[:, -1].data.topk(k)
|
439 |
+
log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1) + log_scores.transpose(0,1)
|
440 |
+
k_probs, k_ix = log_probs.view(-1).topk(k)
|
441 |
+
|
442 |
+
row = k_ix // k
|
443 |
+
col = k_ix % k
|
444 |
+
|
445 |
+
outputs[:, :i] = outputs[row, :i]
|
446 |
+
outputs[:, i] = ix[row, col]
|
447 |
+
|
448 |
+
log_scores = k_probs.unsqueeze(0)
|
449 |
+
|
450 |
+
return outputs, log_scores
|
451 |
+
|
452 |
+
def beam_search(src, model, SRC, TRG, device, k, max_len):
|
453 |
+
|
454 |
+
outputs, e_outputs, log_scores = init_vars(src, model, SRC, TRG, device, k, max_len)
|
455 |
+
eos_tok = TRG.vocab.stoi['<eos>']
|
456 |
+
src_mask = (src != SRC.vocab.stoi['<pad>']).unsqueeze(-2)
|
457 |
+
ind = None
|
458 |
+
for i in range(2, max_len):
|
459 |
+
|
460 |
+
trg_mask = nopeak_mask(i, device)
|
461 |
+
|
462 |
+
out = model.out(model.decoder(outputs[:,:i],
|
463 |
+
e_outputs, src_mask, trg_mask))
|
464 |
+
|
465 |
+
out = F.softmax(out, dim=-1)
|
466 |
+
|
467 |
+
outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, k)
|
468 |
+
|
469 |
+
ones = (outputs==eos_tok).nonzero()
|
470 |
+
sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).cuda()
|
471 |
+
for vec in ones:
|
472 |
+
i = vec[0]
|
473 |
+
if sentence_lengths[i]==0:
|
474 |
+
sentence_lengths[i] = vec[1]
|
475 |
+
|
476 |
+
num_finished_sentences = len([s for s in sentence_lengths if s > 0])
|
477 |
+
|
478 |
+
if num_finished_sentences == k:
|
479 |
+
alpha = 0.7
|
480 |
+
div = 1/(sentence_lengths.type_as(log_scores)**alpha)
|
481 |
+
_, ind = torch.max(log_scores * div, 1)
|
482 |
+
ind = ind.data[0]
|
483 |
+
break
|
484 |
+
|
485 |
+
if ind is None:
|
486 |
+
|
487 |
+
length = (outputs[0]==eos_tok).nonzero()[0] if len((outputs[0]==eos_tok).nonzero()) > 0 else -1
|
488 |
+
return ' '.join([TRG.vocab.itos[tok] for tok in outputs[0][1:length]])
|
489 |
+
|
490 |
+
else:
|
491 |
+
length = (outputs[ind]==eos_tok).nonzero()[0]
|
492 |
+
return ' '.join([TRG.vocab.itos[tok] for tok in outputs[ind][1:length]])
|
493 |
+
|
494 |
+
def translate_sentence(sentence, model, SRC, TRG, device, k, max_len):
|
495 |
+
"""Dịch một câu sử dụng beamsearch
|
496 |
+
"""
|
497 |
+
model.eval()
|
498 |
+
indexed = []
|
499 |
+
sentence = SRC.preprocess(sentence)
|
500 |
+
|
501 |
+
for tok in sentence:
|
502 |
+
if SRC.vocab.stoi[tok] != SRC.vocab.stoi['<eos>']:
|
503 |
+
indexed.append(SRC.vocab.stoi[tok])
|
504 |
+
else:
|
505 |
+
indexed.append(get_synonym(tok, SRC))
|
506 |
+
|
507 |
+
sentence = Variable(torch.LongTensor([indexed]))
|
508 |
+
|
509 |
+
sentence = sentence.to(device)
|
510 |
+
|
511 |
+
sentence = beam_search(sentence, model, SRC, TRG, device, k, max_len)
|
512 |
+
|
513 |
+
return multiple_replace({' ?' : '?',' !':'!',' .':'.','\' ':'\'',' ,':','}, sentence)
|
514 |
+
|
515 |
+
import re
|
516 |
+
|
517 |
+
class tokenize(object):
|
518 |
+
|
519 |
+
def __init__(self, lang):
|
520 |
+
self.nlp = spacy.load(lang)
|
521 |
+
|
522 |
+
def tokenizer(self, sentence):
|
523 |
+
sentence = re.sub(
|
524 |
+
r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sentence))
|
525 |
+
sentence = re.sub(r"[ ]+", " ", sentence)
|
526 |
+
sentence = re.sub(r"\!+", "!", sentence)
|
527 |
+
sentence = re.sub(r"\,+", ",", sentence)
|
528 |
+
sentence = re.sub(r"\?+", "?", sentence)
|
529 |
+
sentence = sentence.lower()
|
530 |
+
return [tok.text for tok in self.nlp.tokenizer(sentence) if tok.text != " "]
|
531 |
+
|
532 |
+
"""## Data loader
|
533 |
+
|
534 |
+
"""
|
535 |
+
|
536 |
+
!pip install dill
|
537 |
+
|
538 |
+
import os
|
539 |
+
import dill as pickle
|
540 |
+
import pandas as pd
|
541 |
+
|
542 |
+
def read_data(src_file, trg_file):
|
543 |
+
src_data = open(src_file).read().strip().split('\n')
|
544 |
+
|
545 |
+
trg_data = open(trg_file).read().strip().split('\n')
|
546 |
+
|
547 |
+
return src_data, trg_data
|
548 |
+
|
549 |
+
def create_fields(src_lang, trg_lang):
|
550 |
+
|
551 |
+
print("loading spacy tokenizers...")
|
552 |
+
|
553 |
+
t_src = tokenize(src_lang)
|
554 |
+
t_trg = tokenize(trg_lang)
|
555 |
+
|
556 |
+
TRG = data.Field(lower=True, tokenize=t_trg.tokenizer, init_token='<sos>', eos_token='<eos>')
|
557 |
+
SRC = data.Field(lower=True, tokenize=t_src.tokenizer)
|
558 |
+
|
559 |
+
return SRC, TRG
|
560 |
+
|
561 |
+
def create_dataset(src_data, trg_data, max_strlen, batchsize, device, SRC, TRG, istrain=True):
|
562 |
+
|
563 |
+
print("creating dataset and iterator... ")
|
564 |
+
|
565 |
+
raw_data = {'src' : [line for line in src_data], 'trg': [line for line in trg_data]}
|
566 |
+
df = pd.DataFrame(raw_data, columns=["src", "trg"])
|
567 |
+
|
568 |
+
mask = (df['src'].str.count(' ') < max_strlen) & (df['trg'].str.count(' ') < max_strlen)
|
569 |
+
df = df.loc[mask]
|
570 |
+
|
571 |
+
df.to_csv("translate_transformer_temp.csv", index=False)
|
572 |
+
|
573 |
+
data_fields = [('src', SRC), ('trg', TRG)]
|
574 |
+
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
|
575 |
+
|
576 |
+
train_iter = MyIterator(train, batch_size=batchsize, device=device,
|
577 |
+
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
|
578 |
+
batch_size_fn=batch_size_fn, train=istrain, shuffle=True)
|
579 |
+
|
580 |
+
os.remove('translate_transformer_temp.csv')
|
581 |
+
|
582 |
+
if istrain:
|
583 |
+
SRC.build_vocab(train)
|
584 |
+
TRG.build_vocab(train)
|
585 |
+
|
586 |
+
return train_iter
|
587 |
+
|
588 |
+
def step(model, optimizer,batch, criterion):
|
589 |
+
"""
|
590 |
+
Một lần cập nhật mô hình
|
591 |
+
"""
|
592 |
+
model.train()
|
593 |
+
|
594 |
+
src = batch.src.transpose(0,1).cuda()
|
595 |
+
trg = batch.trg.transpose(0,1).cuda()
|
596 |
+
trg_input = trg[:, :-1]
|
597 |
+
src_mask, trg_mask = create_masks(src, trg_input, src_pad, trg_pad, opt['device'])
|
598 |
+
preds = model(src, trg_input, src_mask, trg_mask)
|
599 |
+
|
600 |
+
ys = trg[:, 1:].contiguous().view(-1)
|
601 |
+
|
602 |
+
optimizer.zero_grad()
|
603 |
+
loss = criterion(preds.view(-1, preds.size(-1)), ys)
|
604 |
+
loss.backward()
|
605 |
+
optimizer.step_and_update_lr()
|
606 |
+
|
607 |
+
loss = loss.item()
|
608 |
+
|
609 |
+
return loss
|
610 |
+
|
611 |
+
def validiate(model, valid_iter, criterion):
|
612 |
+
""" Tính loss trên tập validation
|
613 |
+
"""
|
614 |
+
model.eval()
|
615 |
+
|
616 |
+
with torch.no_grad():
|
617 |
+
total_loss = []
|
618 |
+
for batch in valid_iter:
|
619 |
+
src = batch.src.transpose(0,1).cuda()
|
620 |
+
trg = batch.trg.transpose(0,1).cuda()
|
621 |
+
trg_input = trg[:, :-1]
|
622 |
+
src_mask, trg_mask = create_masks(src, trg_input, src_pad, trg_pad, opt['device'])
|
623 |
+
preds = model(src, trg_input, src_mask, trg_mask)
|
624 |
+
|
625 |
+
ys = trg[:, 1:].contiguous().view(-1)
|
626 |
+
|
627 |
+
loss = criterion(preds.view(-1, preds.size(-1)), ys)
|
628 |
+
|
629 |
+
loss = loss.item()
|
630 |
+
|
631 |
+
total_loss.append(loss)
|
632 |
+
|
633 |
+
avg_loss = np.mean(total_loss)
|
634 |
+
|
635 |
+
return avg_loss
|
636 |
+
|
637 |
+
"""# Optimizer
|
638 |
+
|
639 |
+
"""
|
640 |
+
|
641 |
+
class ScheduledOptim():
|
642 |
+
'''A simple wrapper class for learning rate scheduling'''
|
643 |
+
|
644 |
+
def __init__(self, optimizer, init_lr, d_model, n_warmup_steps):
|
645 |
+
self._optimizer = optimizer
|
646 |
+
self.init_lr = init_lr
|
647 |
+
self.d_model = d_model
|
648 |
+
self.n_warmup_steps = n_warmup_steps
|
649 |
+
self.n_steps = 0
|
650 |
+
|
651 |
+
|
652 |
+
def step_and_update_lr(self):
|
653 |
+
"Step with the inner optimizer"
|
654 |
+
self._update_learning_rate()
|
655 |
+
self._optimizer.step()
|
656 |
+
|
657 |
+
|
658 |
+
def zero_grad(self):
|
659 |
+
"Zero out the gradients with the inner optimizer"
|
660 |
+
self._optimizer.zero_grad()
|
661 |
+
|
662 |
+
|
663 |
+
def _get_lr_scale(self):
|
664 |
+
d_model = self.d_model
|
665 |
+
n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
|
666 |
+
return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))
|
667 |
+
|
668 |
+
def state_dict(self):
|
669 |
+
optimizer_state_dict = {
|
670 |
+
'init_lr':self.init_lr,
|
671 |
+
'd_model':self.d_model,
|
672 |
+
'n_warmup_steps':self.n_warmup_steps,
|
673 |
+
'n_steps':self.n_steps,
|
674 |
+
'_optimizer':self._optimizer.state_dict(),
|
675 |
+
}
|
676 |
+
|
677 |
+
return optimizer_state_dict
|
678 |
+
|
679 |
+
def load_state_dict(self, state_dict):
|
680 |
+
self.init_lr = state_dict['init_lr']
|
681 |
+
self.d_model = state_dict['d_model']
|
682 |
+
self.n_warmup_steps = state_dict['n_warmup_steps']
|
683 |
+
self.n_steps = state_dict['n_steps']
|
684 |
+
|
685 |
+
self._optimizer.load_state_dict(state_dict['_optimizer'])
|
686 |
+
|
687 |
+
def _update_learning_rate(self):
|
688 |
+
''' Learning rate scheduling per step '''
|
689 |
+
|
690 |
+
self.n_steps += 1
|
691 |
+
lr = self.init_lr * self._get_lr_scale()
|
692 |
+
|
693 |
+
for param_group in self._optimizer.param_groups:
|
694 |
+
param_group['lr'] = lr
|
695 |
+
|
696 |
+
"""# Label Smoothing
|
697 |
+
hạn chế hiện tượng overfit
|
698 |
+
|
699 |
+
|
700 |
+
"""
|
701 |
+
|
702 |
+
class LabelSmoothingLoss(nn.Module):
|
703 |
+
def __init__(self, classes, padding_idx, smoothing=0.0, dim=-1):
|
704 |
+
super(LabelSmoothingLoss, self).__init__()
|
705 |
+
self.confidence = 1.0 - smoothing
|
706 |
+
self.smoothing = smoothing
|
707 |
+
self.cls = classes
|
708 |
+
self.dim = dim
|
709 |
+
self.padding_idx = padding_idx
|
710 |
+
|
711 |
+
def forward(self, pred, target):
|
712 |
+
pred = pred.log_softmax(dim=self.dim)
|
713 |
+
with torch.no_grad():
|
714 |
+
# true_dist = pred.data.clone()
|
715 |
+
true_dist = torch.zeros_like(pred)
|
716 |
+
true_dist.fill_(self.smoothing / (self.cls - 2))
|
717 |
+
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
718 |
+
true_dist[:, self.padding_idx] = 0
|
719 |
+
mask = torch.nonzero(target.data == self.padding_idx, as_tuple=False)
|
720 |
+
if mask.dim() > 0:
|
721 |
+
true_dist.index_fill_(0, mask.squeeze(), 0.0)
|
722 |
+
|
723 |
+
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
|
724 |
+
|
725 |
+
from torchtext.data.metrics import bleu_score
|
726 |
+
|
727 |
+
def bleu(valid_src_data, valid_trg_data, model, SRC, TRG, device, k, max_strlen):
|
728 |
+
pred_sents = []
|
729 |
+
for sentence in valid_src_data:
|
730 |
+
pred_trg = translate_sentence(sentence, model, SRC, TRG, device, k, max_strlen)
|
731 |
+
pred_sents.append(pred_trg)
|
732 |
+
|
733 |
+
pred_sents = [TRG.preprocess(sent) for sent in pred_sents]
|
734 |
+
trg_sents = [[sent.split()] for sent in valid_trg_data]
|
735 |
+
|
736 |
+
return bleu_score(pred_sents, trg_sents)
|
737 |
+
|
738 |
+
opt = {
|
739 |
+
'train_src_data':'./data/train.en',
|
740 |
+
'train_trg_data':'./data/train.vi',
|
741 |
+
'valid_src_data':'./data/tst2013.en',
|
742 |
+
'valid_trg_data':'./data/tst2013.vi',
|
743 |
+
'src_lang':'en_core_web_sm',
|
744 |
+
'trg_lang':'vi_core_news_lg',
|
745 |
+
'max_strlen':160,
|
746 |
+
'batchsize':1500,
|
747 |
+
'device':'cuda',
|
748 |
+
'd_model': 512,
|
749 |
+
'n_layers': 6,
|
750 |
+
'heads': 8,
|
751 |
+
'dropout': 0.1,
|
752 |
+
'lr':0.0001,
|
753 |
+
'epochs':30,
|
754 |
+
'printevery': 200,
|
755 |
+
'k':5,
|
756 |
+
}
|
757 |
+
|
758 |
+
|
759 |
+
|
760 |
+
train_src_data, train_trg_data = read_data(opt['train_src_data'], opt['train_trg_data'])
|
761 |
+
valid_src_data, valid_trg_data = read_data(opt['valid_src_data'], opt['valid_trg_data'])
|
762 |
+
|
763 |
+
SRC, TRG = create_fields(opt['src_lang'], opt['trg_lang'])
|
764 |
+
train_iter = create_dataset(train_src_data, train_trg_data, opt['max_strlen'], opt['batchsize'], opt['device'], SRC, TRG, istrain=True)
|
765 |
+
valid_iter = create_dataset(valid_src_data, valid_trg_data, opt['max_strlen'], opt['batchsize'], opt['device'], SRC, TRG, istrain=False)
|
766 |
+
|
767 |
+
src_pad = SRC.vocab.stoi['<pad>']
|
768 |
+
trg_pad = TRG.vocab.stoi['<pad>']
|
769 |
+
|
770 |
+
model = Transformer(len(SRC.vocab), len(TRG.vocab), opt['d_model'], opt['n_layers'], opt['heads'], opt['dropout'])
|
771 |
+
|
772 |
+
for p in model.parameters():
|
773 |
+
if p.dim() > 1:
|
774 |
+
nn.init.xavier_uniform_(p)
|
775 |
+
|
776 |
+
model = model.to(opt['device'])
|
777 |
+
|
778 |
+
optimizer = ScheduledOptim(
|
779 |
+
torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
|
780 |
+
0.2, opt['d_model'], 4000)
|
781 |
+
|
782 |
+
criterion = LabelSmoothingLoss(len(TRG.vocab), padding_idx=trg_pad, smoothing=0.1)
|
783 |
+
|
784 |
+
|
785 |
+
|
786 |
+
|
787 |
+
|