File size: 2,380 Bytes
52933b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
Problem specific node embedding for dynamic feature.
"""

import torch.nn as nn


def AutoDynamicEmbedding(problem_name, config):
    """
    Automatically select the corresponding module according to ``problem_name``
    """
    mapping = {
        "tsp": NonDyanmicEmbedding,
        "cvrp": NonDyanmicEmbedding,
        "sdvrp": SDVRPDynamicEmbedding,
        "pctsp": NonDyanmicEmbedding,
        "op": NonDyanmicEmbedding,
    }
    embeddingClass = mapping[problem_name]
    embedding = embeddingClass(**config)
    return embedding


class SDVRPDynamicEmbedding(nn.Module):
    """
    Embedding for dynamic node feature for the split delivery vehicle routing problem.

    It is implemented as a linear projection of the demands left in each node.

    Args:
        embedding_dim: dimension of output
    Inputs: state
        * **state** : a class that provide ``state.demands_with_depot`` tensor
    Outputs: glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic
        * **glimpse_key_dynamic** : [batch, graph_size, embedding_dim]
        * **glimpse_val_dynamic** : [batch, graph_size, embedding_dim]
        * **logit_key_dynamic** : [batch, graph_size, embedding_dim]

    """

    def __init__(self, embedding_dim):
        super(SDVRPDynamicEmbedding, self).__init__()
        self.projection = nn.Linear(1, 3 * embedding_dim, bias=False)

    def forward(self, state):
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.projection(
            state.demands_with_depot[:, 0, :, None].clone()
        ).chunk(3, dim=-1)
        return glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic


class NonDyanmicEmbedding(nn.Module):
    """
    Embedding for problems that do not have any dynamic node feature.

    It is implemented as simply returning zeros.

    Args:
        embedding_dim: dimension of output
    Inputs: state
        * **state** : not used, just for consistency
    Outputs: glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic
        * **glimpse_key_dynamic** : [batch, graph_size, embedding_dim]
        * **glimpse_val_dynamic** : [batch, graph_size, embedding_dim]
        * **logit_key_dynamic** : [batch, graph_size, embedding_dim]

    """

    def __init__(self, embedding_dim):
        super(NonDyanmicEmbedding, self).__init__()

    def forward(self, state):
        return 0, 0, 0