import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self, s_dim, a_dim, word_list, words_width): super(Net, self).__init__() self.s_dim = s_dim self.a_dim = a_dim word_width = 26 * words_width layers = [ nn.Linear(s_dim, word_width), nn.Tanh(), ] self.v1 = nn.Sequential(*layers) self.v4 = nn.Linear(word_width, 1) self.actor_head = nn.Linear(word_width, word_width) self.distribution = torch.distributions.Categorical word_array = np.zeros((word_width, len(word_list))) for i, word in enumerate(word_list): for j, c in enumerate(word): word_array[j * 26 + (ord(c) - ord("A")), i] = 1 self.words = torch.Tensor(word_array) def forward(self, x): values = self.v1(x.float()) logits = torch.log_softmax( torch.tensordot(self.actor_head(values), self.words, dims=((1,), (0,))), dim=-1, ) values = self.v4(values) return logits, values def choose_action(self, s): self.eval() logits, _ = self.forward(s) prob = F.softmax(logits, dim=1).data m = self.distribution(prob) return m.sample().numpy()[0] def loss_func(self, s, a, v_t): self.train() logits, values = self.forward(s) td = v_t - values c_loss = td.pow(2) probs = F.softmax(logits, dim=1) m = self.distribution(probs) exp_v = m.log_prob(a) * td.detach().squeeze() a_loss = -exp_v total_loss = (c_loss + a_loss).mean() return total_loss class GreedyNet(Net): def choose_action(self, s): self.eval() logits, _ = self.forward(s) probabilities = logits.exp().squeeze(dim=-1) prob_np = probabilities.data.cpu().numpy() actions = np.argmax(prob_np, axis=1) return actions[0]