File size: 2,743 Bytes
f717329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# pyre-unsafe

import pickle
import torch
from torch import nn

from detectron2.utils.file_io import PathManager

from .utils import normalize_embeddings


class VertexFeatureEmbedder(nn.Module):
    """

    Class responsible for embedding vertex features. Mapping from

    feature space to the embedding space is a tensor of size [K, D], where

        K = number of dimensions in the feature space

        D = number of dimensions in the embedding space

    Vertex features is a tensor of size [N, K], where

        N = number of vertices

        K = number of dimensions in the feature space

    Vertex embeddings are computed as F * E = tensor of size [N, D]

    """

    def __init__(

        self, num_vertices: int, feature_dim: int, embed_dim: int, train_features: bool = False

    ):
        """

        Initialize embedder, set random embeddings



        Args:

            num_vertices (int): number of vertices to embed

            feature_dim (int): number of dimensions in the feature space

            embed_dim (int): number of dimensions in the embedding space

            train_features (bool): determines whether vertex features should

                be trained (default: False)

        """
        super(VertexFeatureEmbedder, self).__init__()
        if train_features:
            self.features = nn.Parameter(torch.Tensor(num_vertices, feature_dim))
        else:
            self.register_buffer("features", torch.Tensor(num_vertices, feature_dim))
        self.embeddings = nn.Parameter(torch.Tensor(feature_dim, embed_dim))
        self.reset_parameters()

    @torch.no_grad()
    def reset_parameters(self):
        self.features.zero_()
        self.embeddings.zero_()

    def forward(self) -> torch.Tensor:
        """

        Produce vertex embeddings, a tensor of shape [N, D] where:

            N = number of vertices

            D = number of dimensions in the embedding space



        Return:

           Full vertex embeddings, a tensor of shape [N, D]

        """
        return normalize_embeddings(torch.mm(self.features, self.embeddings))

    @torch.no_grad()
    def load(self, fpath: str):
        """

        Load data from a file



        Args:

            fpath (str): file path to load data from

        """
        with PathManager.open(fpath, "rb") as hFile:
            data = pickle.load(hFile)
            for name in ["features", "embeddings"]:
                if name in data:
                    getattr(self, name).copy_(
                        torch.tensor(data[name]).float().to(device=getattr(self, name).device)
                    )