File size: 2,009 Bytes
c412087
1bd428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c412087
1bd428f
 
 
 
 
c412087
 
 
1bd428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10a05f
1bd428f
 
 
 
 
 
a777e34
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, s_dim, a_dim, word_list, words_width):
        super(Net, self).__init__()
        self.s_dim = s_dim
        self.a_dim = a_dim

        word_width = 26 * words_width
        layers = [
            nn.Linear(s_dim, word_width),
            nn.Tanh(),
        ]
        self.v1 = nn.Sequential(*layers)
        self.v4 = nn.Linear(word_width, 1)
        self.actor_head = nn.Linear(word_width, word_width)

        self.distribution = torch.distributions.Categorical
        word_array = np.zeros((word_width, len(word_list)))
        for i, word in enumerate(word_list):
            for j, c in enumerate(word):
                word_array[j * 26 + (ord(c) - ord("A")), i] = 1
        self.words = torch.Tensor(word_array)

    def forward(self, x):
        values = self.v1(x.float())
        logits = torch.log_softmax(
            torch.tensordot(self.actor_head(values), self.words, dims=((1,), (0,))),
            dim=-1,
        )
        values = self.v4(values)
        return logits, values

    def choose_action(self, s):
        self.eval()
        logits, _ = self.forward(s)
        prob = F.softmax(logits, dim=1).data
        m = self.distribution(prob)
        return m.sample().numpy()[0]

    def loss_func(self, s, a, v_t):
        self.train()
        logits, values = self.forward(s)
        td = v_t - values
        c_loss = td.pow(2)

        probs = F.softmax(logits, dim=1)
        m = self.distribution(probs)
        exp_v = m.log_prob(a) * td.detach().squeeze()
        a_loss = -exp_v
        total_loss = (c_loss + a_loss).mean()
        return total_loss


class GreedyNet(Net):
    def choose_action(self, s):
        self.eval()
        logits, _ = self.forward(s)
        probabilities = logits.exp().squeeze(dim=-1)
        prob_np = probabilities.data.cpu().numpy()

        actions = np.argmax(prob_np, axis=1)
        return actions[0]