""" Problem specific node embedding for dynamic feature. """ import torch.nn as nn def AutoDynamicEmbedding(problem_name, config): """ Automatically select the corresponding module according to ``problem_name`` """ mapping = { "tsp": NonDyanmicEmbedding, "cvrp": NonDyanmicEmbedding, "sdvrp": SDVRPDynamicEmbedding, "pctsp": NonDyanmicEmbedding, "op": NonDyanmicEmbedding, } embeddingClass = mapping[problem_name] embedding = embeddingClass(**config) return embedding class SDVRPDynamicEmbedding(nn.Module): """ Embedding for dynamic node feature for the split delivery vehicle routing problem. It is implemented as a linear projection of the demands left in each node. Args: embedding_dim: dimension of output Inputs: state * **state** : a class that provide ``state.demands_with_depot`` tensor Outputs: glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic * **glimpse_key_dynamic** : [batch, graph_size, embedding_dim] * **glimpse_val_dynamic** : [batch, graph_size, embedding_dim] * **logit_key_dynamic** : [batch, graph_size, embedding_dim] """ def __init__(self, embedding_dim): super(SDVRPDynamicEmbedding, self).__init__() self.projection = nn.Linear(1, 3 * embedding_dim, bias=False) def forward(self, state): glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.projection( state.demands_with_depot[:, 0, :, None].clone() ).chunk(3, dim=-1) return glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic class NonDyanmicEmbedding(nn.Module): """ Embedding for problems that do not have any dynamic node feature. It is implemented as simply returning zeros. Args: embedding_dim: dimension of output Inputs: state * **state** : not used, just for consistency Outputs: glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic * **glimpse_key_dynamic** : [batch, graph_size, embedding_dim] * **glimpse_val_dynamic** : [batch, graph_size, embedding_dim] * **logit_key_dynamic** : [batch, graph_size, embedding_dim] """ def __init__(self, embedding_dim): super(NonDyanmicEmbedding, self).__init__() def forward(self, state): return 0, 0, 0