add files
Browse files- decoder.pt +3 -0
- inference.py +60 -0
- model.py +167 -0
- rolename.txt +0 -0
decoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6574bd2e0f77d393da6412bd11886c176e551dce94f4383b3bf81a5e1a61d745
|
3 |
+
size 180232
|
inference.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
class DecoderGRU(nn.Module):
|
9 |
+
def __init__(self, hidden_size, output_size):
|
10 |
+
super(DecoderGRU, self).__init__()
|
11 |
+
self.proj = nn.Linear(hidden_size, hidden_size)
|
12 |
+
self.embedding = nn.Embedding(output_size, hidden_size)
|
13 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
|
14 |
+
self.out = nn.Linear(hidden_size, output_size)
|
15 |
+
|
16 |
+
def forward(self, encoder_sample, target_tensor=None, max_length=16):
|
17 |
+
batch_size = encoder_sample.size(0)
|
18 |
+
decoder_hidden = self.proj(encoder_sample).unsqueeze(0)
|
19 |
+
if target_tensor is not None:
|
20 |
+
decoder_input = target_tensor
|
21 |
+
decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
|
22 |
+
else:
|
23 |
+
decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
|
24 |
+
decoder_outputs = []
|
25 |
+
for i in range(max_length):
|
26 |
+
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
|
27 |
+
decoder_outputs.append(decoder_output)
|
28 |
+
_, topi = decoder_output.topk(1)
|
29 |
+
decoder_input = topi.squeeze(-1).detach()
|
30 |
+
decoder_outputs = torch.cat(decoder_outputs, dim=1)
|
31 |
+
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
|
32 |
+
return decoder_outputs, decoder_hidden
|
33 |
+
|
34 |
+
def forward_step(self, input, hidden):
|
35 |
+
output = self.embedding(input)
|
36 |
+
output = F.relu(output)
|
37 |
+
output, hidden = self.gru(output, hidden)
|
38 |
+
output = self.out(output)
|
39 |
+
return output, hidden
|
40 |
+
|
41 |
+
dec = torch.load('decoder.pt').to('cpu')
|
42 |
+
|
43 |
+
SOS_token = 1
|
44 |
+
EOS_token = 2
|
45 |
+
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
|
46 |
+
vocab = ['<pad>', '<sos>', '<eos>'] + katakana
|
47 |
+
vocab_dict = {v: k for k, v in enumerate(vocab)}
|
48 |
+
|
49 |
+
h=64
|
50 |
+
max_len=40
|
51 |
+
|
52 |
+
def detokenize(tokens):
|
53 |
+
if EOS_token in tokens:
|
54 |
+
return ''.join(vocab[token] for token in tokens[:tokens.index(EOS_token)])
|
55 |
+
else:
|
56 |
+
return None
|
57 |
+
|
58 |
+
for name in [detokenize(seq) for seq in dec(torch.randn(16,h), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
|
59 |
+
if name is not None:
|
60 |
+
print(name)
|
model.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import optim
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
torch.manual_seed(114514)
|
12 |
+
torch.set_default_device('cuda')
|
13 |
+
|
14 |
+
SOS_token = 1
|
15 |
+
EOS_token = 2
|
16 |
+
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
|
17 |
+
vocab = ['<pad>', '<sos>', '<eos>'] + katakana
|
18 |
+
vocab_dict = {v: k for k, v in enumerate(vocab)}
|
19 |
+
|
20 |
+
texts = pd.read_csv('rolename.txt', header=None)[0].tolist()
|
21 |
+
vocab_size=len(vocab)
|
22 |
+
h=64
|
23 |
+
max_len=40
|
24 |
+
bs=64
|
25 |
+
lr=1e-3
|
26 |
+
epochs=20
|
27 |
+
|
28 |
+
def tokenize(text):
|
29 |
+
return [vocab_dict[ch] for ch in text]
|
30 |
+
|
31 |
+
def detokenize(tokens):
|
32 |
+
if EOS_token in tokens:
|
33 |
+
tokens = tokens[:tokens.index(EOS_token)]
|
34 |
+
return ''.join(vocab[token] for token in tokens)
|
35 |
+
|
36 |
+
class BatchNormVAE(nn.Module): # https://spaces.ac.cn/archives/7381/
|
37 |
+
def __init__(self, num_features, **kwargs):
|
38 |
+
super(BatchNormVAE, self).__init__()
|
39 |
+
kwargs['affine'] = False
|
40 |
+
self.TAU = 0.5
|
41 |
+
self.bn = nn.BatchNorm1d(num_features, **kwargs)
|
42 |
+
self.theta = nn.Parameter(torch.zeros(1))
|
43 |
+
|
44 |
+
def forward(self, mu, sigma):
|
45 |
+
mu = self.bn(mu)
|
46 |
+
sigma = self.bn(sigma)
|
47 |
+
scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta))
|
48 |
+
scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta))
|
49 |
+
return mu*scale_mu, sigma*scale_sigma
|
50 |
+
|
51 |
+
class EncoderVAEBiGRU(nn.Module):
|
52 |
+
def __init__(self, input_size, hidden_size, dropout_p=0.1):
|
53 |
+
super(EncoderVAEBiGRU, self).__init__()
|
54 |
+
self.hidden_size = hidden_size
|
55 |
+
self.embedding = nn.Embedding(input_size, hidden_size)
|
56 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
57 |
+
self.proj_mu = nn.Linear(2 * hidden_size, hidden_size)
|
58 |
+
self.proj_sigma = nn.Linear(2 * hidden_size, hidden_size)
|
59 |
+
self.dropout = nn.Dropout(dropout_p)
|
60 |
+
self.bn = BatchNormVAE(hidden_size)
|
61 |
+
|
62 |
+
def forward(self, input, input_lengths):
|
63 |
+
input_lengths = input_lengths.to('cpu')
|
64 |
+
embedded = self.dropout(self.embedding(input))
|
65 |
+
embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)
|
66 |
+
_, hidden = self.gru(embedded)
|
67 |
+
hidden = hidden.permute(1, 0, 2).flatten(1, 2)
|
68 |
+
mu = self.proj_mu(hidden)
|
69 |
+
sigma = self.proj_sigma(hidden) # not std, can be negative
|
70 |
+
mu, sigma = self.bn(mu, sigma)
|
71 |
+
return self._reparameterize(mu, sigma), mu, sigma ** 2
|
72 |
+
|
73 |
+
def _reparameterize(self, mu, sigma):
|
74 |
+
eps = torch.randn_like(sigma)
|
75 |
+
return eps * sigma + mu # var is sigma^2
|
76 |
+
|
77 |
+
class DecoderGRU(nn.Module):
|
78 |
+
def __init__(self, hidden_size, output_size):
|
79 |
+
super(DecoderGRU, self).__init__()
|
80 |
+
self.proj = nn.Linear(hidden_size, hidden_size)
|
81 |
+
self.embedding = nn.Embedding(output_size, hidden_size)
|
82 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
|
83 |
+
self.out = nn.Linear(hidden_size, output_size)
|
84 |
+
|
85 |
+
def forward(self, encoder_sample, target_tensor=None, max_length=16):
|
86 |
+
batch_size = encoder_sample.size(0)
|
87 |
+
decoder_hidden = self.proj(encoder_sample).unsqueeze(0)
|
88 |
+
if target_tensor is not None:
|
89 |
+
decoder_input = target_tensor
|
90 |
+
decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
|
91 |
+
else:
|
92 |
+
decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
|
93 |
+
decoder_outputs = []
|
94 |
+
for i in range(max_length):
|
95 |
+
decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
|
96 |
+
decoder_outputs.append(decoder_output)
|
97 |
+
_, topi = decoder_output.topk(1)
|
98 |
+
decoder_input = topi.squeeze(-1).detach()
|
99 |
+
decoder_outputs = torch.cat(decoder_outputs, dim=1)
|
100 |
+
decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
|
101 |
+
return decoder_outputs, decoder_hidden
|
102 |
+
|
103 |
+
def forward_step(self, input, hidden):
|
104 |
+
output = self.embedding(input)
|
105 |
+
output = F.relu(output)
|
106 |
+
output, hidden = self.gru(output, hidden)
|
107 |
+
output = self.out(output)
|
108 |
+
return output, hidden
|
109 |
+
|
110 |
+
class KatakanaDataset(Dataset):
|
111 |
+
def __init__(self, texts, tokenizer, max_length):
|
112 |
+
self.texts = texts
|
113 |
+
self.tokenizer = tokenizer
|
114 |
+
self.max_length = max_length
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.texts)
|
118 |
+
|
119 |
+
def __getitem__(self, idx):
|
120 |
+
text = self.texts[idx]
|
121 |
+
tokens = self.tokenizer(text)
|
122 |
+
enc_text = tokens
|
123 |
+
enc_len = len(enc_text)
|
124 |
+
input_text = [SOS_token] + tokens
|
125 |
+
target_text = tokens + [EOS_token]
|
126 |
+
enc_text = torch.tensor(enc_text + [0] * (self.max_length - len(enc_text)), dtype=torch.long)
|
127 |
+
input_text = torch.tensor(input_text + [0] * (self.max_length - len(input_text)), dtype=torch.long)
|
128 |
+
target_text = torch.tensor(target_text + [0] * (self.max_length - len(target_text)), dtype=torch.long)
|
129 |
+
return enc_text, enc_len, input_text, target_text
|
130 |
+
|
131 |
+
dataloader = DataLoader(
|
132 |
+
KatakanaDataset(texts, tokenize, max_len),
|
133 |
+
batch_size=bs,
|
134 |
+
shuffle=True,
|
135 |
+
generator=torch.Generator(device='cuda'),
|
136 |
+
)
|
137 |
+
|
138 |
+
def train_epoch(dataloader, encoder, decoder, optimizer):
|
139 |
+
total_loss = 0
|
140 |
+
nll = nn.NLLLoss()
|
141 |
+
for enc_text, enc_len, input_text, target_text in dataloader:
|
142 |
+
optimizer.zero_grad()
|
143 |
+
|
144 |
+
encoder_sample, mu, var = encoder(enc_text, enc_len)
|
145 |
+
decoder_outputs, _ = decoder(encoder_sample, input_text)
|
146 |
+
|
147 |
+
loss_recons = nll(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_text.view(-1))
|
148 |
+
loss_kld = 0.5 * torch.mean(mu ** 2 + var - var.log() - 1)
|
149 |
+
loss = loss_recons + loss_kld
|
150 |
+
loss.backward()
|
151 |
+
|
152 |
+
optimizer.step()
|
153 |
+
|
154 |
+
total_loss += loss.item()
|
155 |
+
return total_loss / len(dataloader)
|
156 |
+
|
157 |
+
enc = EncoderVAEBiGRU(vocab_size, h).train()
|
158 |
+
dec = DecoderGRU(h, vocab_size).train()
|
159 |
+
optimizer = optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=lr)
|
160 |
+
|
161 |
+
for i in range(epochs):
|
162 |
+
print('epoch=%d, loss=%f' % (i, train_epoch(dataloader, enc, dec, optimizer)))
|
163 |
+
|
164 |
+
dec.eval()
|
165 |
+
for name in [detokenize(seq) for seq in dec(torch.randn(8,h), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
|
166 |
+
print(name)
|
167 |
+
torch.save(dec, 'decoder.pt')
|
rolename.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|