RLOR-TSP / models /attention_model_wrapper.py
Patrick WAN
initial commit
52933b5
import torch
from .nets.attention_model.attention_model import *
class Problem:
def __init__(self, name):
self.NAME = name
class Backbone(nn.Module):
def __init__(
self,
embedding_dim=128,
problem_name="tsp",
n_encode_layers=3,
tanh_clipping=10.0,
n_heads=8,
device="cpu",
):
super(Backbone, self).__init__()
self.device = device
self.problem = Problem(problem_name)
self.embedding = AutoEmbedding(self.problem.NAME, {"embedding_dim": embedding_dim})
self.encoder = GraphAttentionEncoder(
n_heads=n_heads,
embed_dim=embedding_dim,
n_layers=n_encode_layers,
)
self.decoder = Decoder(
embedding_dim, self.embedding.context_dim, n_heads, self.problem, tanh_clipping
)
def forward(self, obs):
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
input = state.states["observations"]
embedding = self.embedding(input)
encoded_inputs, _ = self.encoder(embedding)
# decoding
cached_embeddings = self.decoder._precompute(encoded_inputs)
logits, glimpse = self.decoder.advance(cached_embeddings, state)
return logits, glimpse
def encode(self, obs):
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
input = state.states["observations"]
embedding = self.embedding(input)
encoded_inputs, _ = self.encoder(embedding)
cached_embeddings = self.decoder._precompute(encoded_inputs)
return cached_embeddings
def decode(self, obs, cached_embeddings):
state = stateWrapper(obs, device=self.device, problem=self.problem.NAME)
logits, glimpse = self.decoder.advance(cached_embeddings, state)
return logits, glimpse
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
def forward(self, x):
logits = x[0] # .squeeze(1) # not needed for pomo
return logits
class Critic(nn.Module):
def __init__(self, *args, **kwargs):
super(Critic, self).__init__()
hidden_size = kwargs["hidden_size"]
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, 1)
)
def forward(self, x):
out = self.mlp(x[1]) # B x T x h_dim --mlp--> B x T X 1
return out
class Agent(nn.Module):
def __init__(self, embedding_dim=128, device="cpu", name="tsp"):
super().__init__()
self.backbone = Backbone(embedding_dim=embedding_dim, device=device, problem_name=name)
self.critic = Critic(hidden_size=embedding_dim)
self.actor = Actor()
def forward(self, x): # only actor
x = self.backbone(x)
logits = self.actor(x)
action = logits.max(2)[1]
return action, logits
def get_value(self, x):
x = self.backbone(x)
return self.critic(x)
def get_action_and_value(self, x, action=None):
x = self.backbone(x)
logits = self.actor(x)
probs = torch.distributions.Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x)
def get_value_cached(self, x, state):
x = self.backbone.decode(x, state)
return self.critic(x)
def get_action_and_value_cached(self, x, action=None, state=None):
if state is None:
state = self.backbone.encode(x)
x = self.backbone.decode(x, state)
else:
x = self.backbone.decode(x, state)
logits = self.actor(x)
probs = torch.distributions.Categorical(logits=logits)
if action is None:
action = probs.sample()
return action, probs.log_prob(action), probs.entropy(), self.critic(x), state
class stateWrapper:
"""
from dict of numpy arrays to an object that supplies function and data
"""
def __init__(self, states, device, problem="tsp"):
self.device = device
self.states = {k: torch.tensor(v, device=self.device) for k, v in states.items()}
if problem == "tsp":
self.is_initial_action = self.states["is_initial_action"].to(torch.bool)
self.first_a = self.states["first_node_idx"]
elif problem == "cvrp":
input = {
"loc": self.states["observations"],
"depot": self.states["depot"].squeeze(-1),
"demand": self.states["demand"],
}
self.states["observations"] = input
self.VEHICLE_CAPACITY = 0
self.used_capacity = -self.states["current_load"]
def get_current_node(self):
return self.states["last_node_idx"]
def get_mask(self):
return (1 - self.states["action_mask"]).to(torch.bool)