|
import torch |
|
|
|
from .nets.attention_model.attention_model import * |
|
|
|
|
|
class Problem: |
|
def __init__(self, name): |
|
self.NAME = name |
|
|
|
|
|
class Backbone(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim=128, |
|
problem_name="tsp", |
|
n_encode_layers=3, |
|
tanh_clipping=10.0, |
|
n_heads=8, |
|
device="cpu", |
|
): |
|
super(Backbone, self).__init__() |
|
self.device = device |
|
self.problem = Problem(problem_name) |
|
self.embedding = AutoEmbedding(self.problem.NAME, {"embedding_dim": embedding_dim}) |
|
|
|
self.encoder = GraphAttentionEncoder( |
|
n_heads=n_heads, |
|
embed_dim=embedding_dim, |
|
n_layers=n_encode_layers, |
|
) |
|
|
|
self.decoder = Decoder( |
|
embedding_dim, self.embedding.context_dim, n_heads, self.problem, tanh_clipping |
|
) |
|
|
|
def forward(self, obs): |
|
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME) |
|
input = state.states["observations"] |
|
embedding = self.embedding(input) |
|
encoded_inputs, _ = self.encoder(embedding) |
|
|
|
|
|
cached_embeddings = self.decoder._precompute(encoded_inputs) |
|
logits, glimpse = self.decoder.advance(cached_embeddings, state) |
|
|
|
return logits, glimpse |
|
|
|
def encode(self, obs): |
|
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME) |
|
input = state.states["observations"] |
|
embedding = self.embedding(input) |
|
encoded_inputs, _ = self.encoder(embedding) |
|
cached_embeddings = self.decoder._precompute(encoded_inputs) |
|
return cached_embeddings |
|
|
|
def decode(self, obs, cached_embeddings): |
|
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME) |
|
logits, glimpse = self.decoder.advance(cached_embeddings, state) |
|
|
|
return logits, glimpse |
|
|
|
|
|
class Actor(nn.Module): |
|
def __init__(self): |
|
super(Actor, self).__init__() |
|
|
|
def forward(self, x): |
|
logits = x[0] |
|
return logits |
|
|
|
|
|
class Critic(nn.Module): |
|
def __init__(self, *args, **kwargs): |
|
super(Critic, self).__init__() |
|
hidden_size = kwargs["hidden_size"] |
|
self.mlp = nn.Sequential( |
|
nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1) |
|
) |
|
|
|
def forward(self, x): |
|
out = self.mlp(x[1]) |
|
return out |
|
|
|
|
|
class Agent(nn.Module): |
|
def __init__(self, embedding_dim=128, device="cpu", name="tsp"): |
|
super().__init__() |
|
self.backbone = Backbone(embedding_dim=embedding_dim, device=device, problem_name=name) |
|
self.critic = Critic(hidden_size=embedding_dim) |
|
self.actor = Actor() |
|
|
|
def forward(self, x): |
|
x = self.backbone(x) |
|
logits = self.actor(x) |
|
action = logits.max(2)[1] |
|
return action, logits |
|
|
|
def get_value(self, x): |
|
x = self.backbone(x) |
|
return self.critic(x) |
|
|
|
def get_action_and_value(self, x, action=None): |
|
x = self.backbone(x) |
|
logits = self.actor(x) |
|
probs = torch.distributions.Categorical(logits=logits) |
|
if action is None: |
|
action = probs.sample() |
|
return action, probs.log_prob(action), probs.entropy(), self.critic(x) |
|
|
|
def get_value_cached(self, x, state): |
|
x = self.backbone.decode(x, state) |
|
return self.critic(x) |
|
|
|
def get_action_and_value_cached(self, x, action=None, state=None): |
|
if state is None: |
|
state = self.backbone.encode(x) |
|
x = self.backbone.decode(x, state) |
|
else: |
|
x = self.backbone.decode(x, state) |
|
logits = self.actor(x) |
|
probs = torch.distributions.Categorical(logits=logits) |
|
if action is None: |
|
action = probs.sample() |
|
return action, probs.log_prob(action), probs.entropy(), self.critic(x), state |
|
|
|
|
|
class stateWrapper: |
|
""" |
|
from dict of numpy arrays to an object that supplies function and data |
|
""" |
|
|
|
def __init__(self, states, device, problem="tsp"): |
|
self.device = device |
|
self.states = {k: torch.tensor(v, device=self.device) for k, v in states.items()} |
|
if problem == "tsp": |
|
self.is_initial_action = self.states["is_initial_action"].to(torch.bool) |
|
self.first_a = self.states["first_node_idx"] |
|
elif problem == "cvrp": |
|
input = { |
|
"loc": self.states["observations"], |
|
"depot": self.states["depot"].squeeze(-1), |
|
"demand": self.states["demand"], |
|
} |
|
self.states["observations"] = input |
|
self.VEHICLE_CAPACITY = 0 |
|
self.used_capacity = -self.states["current_load"] |
|
|
|
def get_current_node(self): |
|
return self.states["last_node_idx"] |
|
|
|
def get_mask(self): |
|
return (1 - self.states["action_mask"]).to(torch.bool) |
|
|