Spaces:
Sleeping
Sleeping
Update pages/15_Graphs.py
Browse files- pages/15_Graphs.py +3 -3
pages/15_Graphs.py
CHANGED
|
@@ -13,9 +13,9 @@ os.environ['TF_USE_LEGACY_KERAS'] = '1'
|
|
| 13 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
| 14 |
graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
|
| 15 |
|
| 16 |
-
# Encode input features
|
| 17 |
graph = tfgnn.keras.layers.MapFeatures(
|
| 18 |
-
node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(
|
| 19 |
)(graph)
|
| 20 |
|
| 21 |
# For each round of message passing...
|
|
@@ -39,7 +39,7 @@ def create_sample_graph():
|
|
| 39 |
graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
|
| 40 |
|
| 41 |
# Create a GraphTensor
|
| 42 |
-
node_features = tf.random.normal((num_nodes,
|
| 43 |
edge_features = tf.random.normal((num_edges, 32))
|
| 44 |
|
| 45 |
graph_tensor = tfgnn.GraphTensor.from_pieces(
|
|
|
|
| 13 |
def model_fn(graph_tensor_spec: tfgnn.GraphTensorSpec):
|
| 14 |
graph = inputs = tf.keras.Input(type_spec=graph_tensor_spec)
|
| 15 |
|
| 16 |
+
# Encode input features to match the required output shape of 128
|
| 17 |
graph = tfgnn.keras.layers.MapFeatures(
|
| 18 |
+
node_sets_fn=lambda node_set, node_set_name: tf.keras.layers.Dense(128)(node_set['features'])
|
| 19 |
)(graph)
|
| 20 |
|
| 21 |
# For each round of message passing...
|
|
|
|
| 39 |
graph = nx.gnm_random_graph(num_nodes, num_edges, directed=True)
|
| 40 |
|
| 41 |
# Create a GraphTensor
|
| 42 |
+
node_features = tf.random.normal((num_nodes, 128)) # Match the dense layer output
|
| 43 |
edge_features = tf.random.normal((num_edges, 32))
|
| 44 |
|
| 45 |
graph_tensor = tfgnn.GraphTensor.from_pieces(
|