File size: 3,014 Bytes
db0dcb9
 
 
 
 
 
 
 
1e53095
db0dcb9
1e53095
 
 
db0dcb9
ffb7d49
db0dcb9
 
 
 
1e53095
 
 
 
db0dcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffb7d49
db0dcb9
 
 
 
 
 
 
1e53095
db0dcb9
1e53095
db0dcb9
 
 
 
 
 
 
1e53095
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F

class DecoderGRU(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super(DecoderGRU, self).__init__()
        self.proj1 = nn.Linear(latent_size, latent_size)
        self.proj_activation = nn.ReLU()
        self.proj2 = nn.Linear(latent_size, 2 * hidden_size)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers=2, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_sample, target_tensor=None, max_length=16):
        batch_size = encoder_sample.size(0)
        decoder_hidden = self.proj1(encoder_sample)
        decoder_hidden = self.proj_activation(decoder_hidden)
        decoder_hidden = self.proj2(decoder_hidden)
        decoder_hidden = decoder_hidden.view(batch_size, 2, -1).permute(1, 0, 2).contiguous()
        if target_tensor is not None:
            decoder_input = target_tensor
            decoder_outputs, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
        else:
            decoder_input = torch.empty(batch_size, 1, dtype=torch.long).fill_(SOS_token)
            decoder_outputs = []
            for i in range(max_length):
                decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
                decoder_outputs.append(decoder_output)
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()
            decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden

    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

dec = torch.load('decoder.pt', map_location='cpu')

SOS_token = 1
EOS_token = 2
katakana = list('゠ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヰヱヲンヴヵヶヷヸヹヺ・ーヽヾヿㇰㇱㇲㇳㇴㇵㇶㇷㇸㇹㇺㇻㇼㇽㇾㇿ')
vocab = ['<pad>', '<sos>', '<eos>'] + katakana
vocab_dict = {v: k for k, v in enumerate(vocab)}

h_latent=64
max_len=40
names=16

def detokenize(tokens):
    if EOS_token in tokens:
        return ''.join(vocab[token] for token in tokens[:tokens.index(EOS_token)])
    else:
        return None

while True:
    print('generating names...')
    for name in [detokenize(seq) for seq in dec(torch.randn(names,h_latent), max_length=max_len)[0].topk(1)[1].squeeze().tolist()]:
        if name is not None:
            print(name)
    input("press enter to continue generation...")