Sunbread commited on
Commit
db0dcb9
·
1 Parent(s): 84d86b3
Files changed (4) hide show
  1. decoder.pt +3 -0
  2. inference.py +60 -0
  3. model.py +167 -0
  4. rolename.txt +0 -0
decoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6574bd2e0f77d393da6412bd11886c176e551dce94f4383b3bf81a5e1a61d745
3
+ size 180232
inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ class DecoderGRU(nn.Module):
9
+ def __init__(self, hidden_size, output_size):
10
+ super(DecoderGRU, self).__init__()
11
+ self.proj = nn.Linear(hidden_size, hidden_size)
12
+ self.embedding = nn.Embedding(output_size, hidden_size)
13
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
14
+ self.out = nn.Linear(hidden_size, output_size)
15
+
16
+ def forward(self, encoder_sample, target_tensor=None, max_length=16):
17
+ batch_size = encoder_sample.size(0)
18
+ decoder_hidden = self.proj(encoder_sample).unsqueeze(0)
19
+ if target_tensor is not None:
20
+ decoder_input = target_tensor
21
+ decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
22
+ else:
23
+ decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
24
+ decoder_outputs = []
25
+ for i in range(max_length):
26
+ decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
27
+ decoder_outputs.append(decoder_output)
28
+ _, topi = decoder_output.topk(1)
29
+ decoder_input = topi.squeeze(-1).detach()
30
+ decoder_outputs = torch.cat(decoder_outputs, dim=1)
31
+ decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
32
+ return decoder_outputs, decoder_hidden
33
+
34
+ def forward_step(self, input, hidden):
35
+ output = self.embedding(input)
36
+ output = F.relu(output)
37
+ output, hidden = self.gru(output, hidden)
38
+ output = self.out(output)
39
+ return output, hidden
40
+
41
+ dec = torch.load('decoder.pt').to('cpu')
42
+
43
+ SOS_token = 1
44
+ EOS_token = 2
45
+ katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
46
+ vocab = ['<pad>', '<sos>', '<eos>'] + katakana
47
+ vocab_dict = {v: k for k, v in enumerate(vocab)}
48
+
49
+ h=64
50
+ max_len=40
51
+
52
+ def detokenize(tokens):
53
+ if EOS_token in tokens:
54
+ return ''.join(vocab[token] for token in tokens[:tokens.index(EOS_token)])
55
+ else:
56
+ return None
57
+
58
+ for name in [detokenize(seq) for seq in dec(torch.randn(16,h), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
59
+ if name is not None:
60
+ print(name)
model.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import optim
7
+ from torch.utils.data import DataLoader, Dataset
8
+ import torch.nn.functional as F
9
+ import pandas as pd
10
+
11
+ torch.manual_seed(114514)
12
+ torch.set_default_device('cuda')
13
+
14
+ SOS_token = 1
15
+ EOS_token = 2
16
+ katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
17
+ vocab = ['<pad>', '<sos>', '<eos>'] + katakana
18
+ vocab_dict = {v: k for k, v in enumerate(vocab)}
19
+
20
+ texts = pd.read_csv('rolename.txt', header=None)[0].tolist()
21
+ vocab_size=len(vocab)
22
+ h=64
23
+ max_len=40
24
+ bs=64
25
+ lr=1e-3
26
+ epochs=20
27
+
28
+ def tokenize(text):
29
+ return [vocab_dict[ch] for ch in text]
30
+
31
+ def detokenize(tokens):
32
+ if EOS_token in tokens:
33
+ tokens = tokens[:tokens.index(EOS_token)]
34
+ return ''.join(vocab[token] for token in tokens)
35
+
36
+ class BatchNormVAE(nn.Module): # https://spaces.ac.cn/archives/7381/
37
+ def __init__(self, num_features, **kwargs):
38
+ super(BatchNormVAE, self).__init__()
39
+ kwargs['affine'] = False
40
+ self.TAU = 0.5
41
+ self.bn = nn.BatchNorm1d(num_features, **kwargs)
42
+ self.theta = nn.Parameter(torch.zeros(1))
43
+
44
+ def forward(self, mu, sigma):
45
+ mu = self.bn(mu)
46
+ sigma = self.bn(sigma)
47
+ scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta))
48
+ scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta))
49
+ return mu*scale_mu, sigma*scale_sigma
50
+
51
+ class EncoderVAEBiGRU(nn.Module):
52
+ def __init__(self, input_size, hidden_size, dropout_p=0.1):
53
+ super(EncoderVAEBiGRU, self).__init__()
54
+ self.hidden_size = hidden_size
55
+ self.embedding = nn.Embedding(input_size, hidden_size)
56
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
57
+ self.proj_mu = nn.Linear(2 * hidden_size, hidden_size)
58
+ self.proj_sigma = nn.Linear(2 * hidden_size, hidden_size)
59
+ self.dropout = nn.Dropout(dropout_p)
60
+ self.bn = BatchNormVAE(hidden_size)
61
+
62
+ def forward(self, input, input_lengths):
63
+ input_lengths = input_lengths.to('cpu')
64
+ embedded = self.dropout(self.embedding(input))
65
+ embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True, enforce_sorted=False)
66
+ _, hidden = self.gru(embedded)
67
+ hidden = hidden.permute(1, 0, 2).flatten(1, 2)
68
+ mu = self.proj_mu(hidden)
69
+ sigma = self.proj_sigma(hidden) # not std, can be negative
70
+ mu, sigma = self.bn(mu, sigma)
71
+ return self._reparameterize(mu, sigma), mu, sigma ** 2
72
+
73
+ def _reparameterize(self, mu, sigma):
74
+ eps = torch.randn_like(sigma)
75
+ return eps * sigma + mu # var is sigma^2
76
+
77
+ class DecoderGRU(nn.Module):
78
+ def __init__(self, hidden_size, output_size):
79
+ super(DecoderGRU, self).__init__()
80
+ self.proj = nn.Linear(hidden_size, hidden_size)
81
+ self.embedding = nn.Embedding(output_size, hidden_size)
82
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
83
+ self.out = nn.Linear(hidden_size, output_size)
84
+
85
+ def forward(self, encoder_sample, target_tensor=None, max_length=16):
86
+ batch_size = encoder_sample.size(0)
87
+ decoder_hidden = self.proj(encoder_sample).unsqueeze(0)
88
+ if target_tensor is not None:
89
+ decoder_input = target_tensor
90
+ decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
91
+ else:
92
+ decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
93
+ decoder_outputs = []
94
+ for i in range(max_length):
95
+ decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
96
+ decoder_outputs.append(decoder_output)
97
+ _, topi = decoder_output.topk(1)
98
+ decoder_input = topi.squeeze(-1).detach()
99
+ decoder_outputs = torch.cat(decoder_outputs, dim=1)
100
+ decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
101
+ return decoder_outputs, decoder_hidden
102
+
103
+ def forward_step(self, input, hidden):
104
+ output = self.embedding(input)
105
+ output = F.relu(output)
106
+ output, hidden = self.gru(output, hidden)
107
+ output = self.out(output)
108
+ return output, hidden
109
+
110
+ class KatakanaDataset(Dataset):
111
+ def __init__(self, texts, tokenizer, max_length):
112
+ self.texts = texts
113
+ self.tokenizer = tokenizer
114
+ self.max_length = max_length
115
+
116
+ def __len__(self):
117
+ return len(self.texts)
118
+
119
+ def __getitem__(self, idx):
120
+ text = self.texts[idx]
121
+ tokens = self.tokenizer(text)
122
+ enc_text = tokens
123
+ enc_len = len(enc_text)
124
+ input_text = [SOS_token] + tokens
125
+ target_text = tokens + [EOS_token]
126
+ enc_text = torch.tensor(enc_text + [0] * (self.max_length - len(enc_text)), dtype=torch.long)
127
+ input_text = torch.tensor(input_text + [0] * (self.max_length - len(input_text)), dtype=torch.long)
128
+ target_text = torch.tensor(target_text + [0] * (self.max_length - len(target_text)), dtype=torch.long)
129
+ return enc_text, enc_len, input_text, target_text
130
+
131
+ dataloader = DataLoader(
132
+ KatakanaDataset(texts, tokenize, max_len),
133
+ batch_size=bs,
134
+ shuffle=True,
135
+ generator=torch.Generator(device='cuda'),
136
+ )
137
+
138
+ def train_epoch(dataloader, encoder, decoder, optimizer):
139
+ total_loss = 0
140
+ nll = nn.NLLLoss()
141
+ for enc_text, enc_len, input_text, target_text in dataloader:
142
+ optimizer.zero_grad()
143
+
144
+ encoder_sample, mu, var = encoder(enc_text, enc_len)
145
+ decoder_outputs, _ = decoder(encoder_sample, input_text)
146
+
147
+ loss_recons = nll(decoder_outputs.view(-1, decoder_outputs.size(-1)), target_text.view(-1))
148
+ loss_kld = 0.5 * torch.mean(mu ** 2 + var - var.log() - 1)
149
+ loss = loss_recons + loss_kld
150
+ loss.backward()
151
+
152
+ optimizer.step()
153
+
154
+ total_loss += loss.item()
155
+ return total_loss / len(dataloader)
156
+
157
+ enc = EncoderVAEBiGRU(vocab_size, h).train()
158
+ dec = DecoderGRU(h, vocab_size).train()
159
+ optimizer = optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=lr)
160
+
161
+ for i in range(epochs):
162
+ print('epoch=%d, loss=%f' % (i, train_epoch(dataloader, enc, dec, optimizer)))
163
+
164
+ dec.eval()
165
+ for name in [detokenize(seq) for seq in dec(torch.randn(8,h), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
166
+ print(name)
167
+ torch.save(dec, 'decoder.pt')
rolename.txt ADDED
The diff for this file is too large to render. See raw diff