File size: 3,752 Bytes
6d70ed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data-structure for storing graphs with typed edges and nodes."""

from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar

ArrayLike = Union[Any]  # np.ndarray, jnp.ndarray, tf.tensor
ArrayLikeTree = Union[Any, ArrayLike]  # Nest of ArrayLike

_T = TypeVar('_T')


# All tensors have a "flat_batch_axis", which is similar to the leading
# axes of graph_tuples:
# * In the case of nodes this is simply a shared node and flat batch axis, with
# size corresponding to the total number of nodes in the flattened batch.
# * In the case of edges this is simply a shared edge and flat batch axis, with
# size corresponding to the total number of edges in the flattened batch.
# * In the case of globals this is simply the number of graphs in the flattened
# batch.

# All shapes may also have any additional leading shape "batch_shape".
# Options for building batches are:
# * Use a provided "flatten" method that takes a leading `batch_shape` and
#   it into the flat_batch_axis (this will be useful when using `tf.Dataset`
#   which supports batching into RaggedTensors, with leading batch shape even
#   if graphs have different numbers of nodes and edges), so the RaggedBatches
#   can then be converted into something without ragged dimensions that jax can
#   use.
# * Directly build a "flat batch" using a provided function for batching a list
#   of graphs (how it is done in `jraph`).


class NodeSet(NamedTuple):
  """Represents a set of nodes."""
  n_node: ArrayLike  # [num_flat_graphs]
  features: ArrayLikeTree  # Prev. `nodes`: [num_flat_nodes] + feature_shape


class EdgesIndices(NamedTuple):
  """Represents indices to nodes adjacent to the edges."""
  senders: ArrayLike  # [num_flat_edges]
  receivers: ArrayLike  # [num_flat_edges]


class EdgeSet(NamedTuple):
  """Represents a set of edges."""
  n_edge: ArrayLike  # [num_flat_graphs]
  indices: EdgesIndices
  features: ArrayLikeTree  # Prev. `edges`: [num_flat_edges] + feature_shape


class Context(NamedTuple):
  # `n_graph` always contains ones but it is useful to query the leading shape
  # in case of graphs without any nodes or edges sets.
  n_graph: ArrayLike  # [num_flat_graphs]
  features: ArrayLikeTree  # Prev. `globals`: [num_flat_graphs] + feature_shape


class EdgeSetKey(NamedTuple):
  name: str   # Name of the EdgeSet.

  # Sender node set name and receiver node set name connected by the edge set.
  node_sets: Tuple[str, str]


class TypedGraph(NamedTuple):
  """A graph with typed nodes and edges.

  A typed graph is made of a context, multiple sets of nodes and multiple
  sets of edges connecting those nodes (as indicated by the EdgeSetKey).
  """

  context: Context
  nodes: Mapping[str, NodeSet]
  edges: Mapping[EdgeSetKey, EdgeSet]

  def edge_key_by_name(self, name: str) -> EdgeSetKey:
    found_key = [k for k in self.edges.keys() if k.name == name]
    if len(found_key) != 1:
      raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
          name, ', '.join(x.name for x in self.edges.keys())))
    return found_key[0]

  def edge_by_name(self, name: str) -> EdgeSet:
    return self.edges[self.edge_key_by_name(name)]