File size: 4,957 Bytes
52933b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
from .nets.attention_model.attention_model import *
class Problem:
def __init__(self, name):
self.NAME = name
class Backbone(nn.Module):
def __init__(
self,
embedding_dim=128,
problem_name="tsp",
n_encode_layers=3,
tanh_clipping=10.0,
n_heads=8,
device="cpu",
):
super(Backbone, self).__init__()
self.device = device
self.problem = Problem(problem_name)
self.embedding = AutoEmbedding(self.problem.NAME, {"embedding_dim": embedding_dim})
self.encoder = GraphAttentionEncoder(
n_heads=n_heads,
embed_dim=embedding_dim,
n_layers=n_encode_layers,
)
self.decoder = Decoder(
embedding_dim, self.embedding.context_dim, n_heads, self.problem, tanh_clipping
)
def forward(self, obs):
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
input = state.states["observations"]
embedding = self.embedding(input)
encoded_inputs, _ = self.encoder(embedding)
# decoding
cached_embeddings = self.decoder._precompute(encoded_inputs)
logits, glimpse = self.decoder.advance(cached_embeddings, state)
return logits, glimpse
def encode(self, obs):
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
input = state.states["observations"]
embedding = self.embedding(input)
encoded_inputs, _ = self.encoder(embedding)
cached_embeddings = self.decoder._precompute(encoded_inputs)
return cached_embeddings
def decode(self, obs, cached_embeddings):
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
logits, glimpse = self.decoder.advance(cached_embeddings, state)
return logits, glimpse
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
def forward(self, x):
logits = x[0] # .squeeze(1) # not needed for pomo
return logits
class Critic(nn.Module):
def __init__(self, *args, **kwargs):
super(Critic, self).__init__()
hidden_size = kwargs["hidden_size"]
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)
)
def forward(self, x):
out = self.mlp(x[1]) # B x T x h_dim --mlp--> B x T X 1
return out
class Agent(nn.Module):
def __init__(self, embedding_dim=128, device="cpu", name="tsp"):
super().__init__()
self.backbone = Backbone(embedding_dim=embedding_dim, device=device, problem_name=name)
self.critic = Critic(hidden_size=embedding_dim)
self.actor = Actor()
def forward(self, x): # only actor
x = self.backbone(x)
logits = self.actor(x)
action = logits.max(2)[1]
return action, logits
def get_value(self, x):
x = self.backbone(x)
return self.critic(x)
def get_action_and_value(self, x, action=None):
x = self.backbone(x)
logits = self.actor(x)
probs = torch.distributions.Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
def get_value_cached(self, x, state):
x = self.backbone.decode(x, state)
return self.critic(x)
def get_action_and_value_cached(self, x, action=None, state=None):
if state is None:
state = self.backbone.encode(x)
x = self.backbone.decode(x, state)
else:
x = self.backbone.decode(x, state)
logits = self.actor(x)
probs = torch.distributions.Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x), state
class stateWrapper:
"""
from dict of numpy arrays to an object that supplies function and data
"""
def __init__(self, states, device, problem="tsp"):
self.device = device
self.states = {k: torch.tensor(v, device=self.device) for k, v in states.items()}
if problem == "tsp":
self.is_initial_action = self.states["is_initial_action"].to(torch.bool)
self.first_a = self.states["first_node_idx"]
elif problem == "cvrp":
input = {
"loc": self.states["observations"],
"depot": self.states["depot"].squeeze(-1),
"demand": self.states["demand"],
}
self.states["observations"] = input
self.VEHICLE_CAPACITY = 0
self.used_capacity = -self.states["current_load"]
def get_current_node(self):
return self.states["last_node_idx"]
def get_mask(self):
return (1 - self.states["action_mask"]).to(torch.bool)
|