|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A predictor that runs multiple graph neural networks on mesh data. |
|
|
|
It learns to interpolate between the grid and the mesh nodes, with the loss |
|
and the rollouts ultimately computed at the grid level. |
|
|
|
It uses ideas similar to those in Keisler (2022): |
|
|
|
Reference: |
|
https://arxiv.org/pdf/2202.07575.pdf |
|
|
|
It assumes data across time and level is stacked, and operates only operates in |
|
a 2D mesh over latitudes and longitudes. |
|
""" |
|
|
|
from typing import Any, Callable, Mapping, Optional |
|
|
|
import chex |
|
from graphcast import deep_typed_graph_net |
|
from graphcast import grid_mesh_connectivity |
|
from graphcast import icosahedral_mesh |
|
from graphcast import losses |
|
from graphcast import model_utils |
|
from graphcast import predictor_base |
|
from graphcast import typed_graph |
|
from graphcast import xarray_jax |
|
import jax.numpy as jnp |
|
import jraph |
|
import numpy as np |
|
import xarray |
|
|
|
Kwargs = Mapping[str, Any] |
|
|
|
GNN = Callable[[jraph.GraphsTuple], jraph.GraphsTuple] |
|
|
|
|
|
|
|
PRESSURE_LEVELS_ERA5_37 = ( |
|
1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300, |
|
350, 400, 450, 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900, |
|
925, 950, 975, 1000) |
|
|
|
|
|
PRESSURE_LEVELS_HRES_25 = ( |
|
1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 400, 500, 600, |
|
700, 800, 850, 900, 925, 950, 1000) |
|
|
|
|
|
PRESSURE_LEVELS_WEATHERBENCH_13 = ( |
|
50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000) |
|
|
|
PRESSURE_LEVELS = { |
|
13: PRESSURE_LEVELS_WEATHERBENCH_13, |
|
25: PRESSURE_LEVELS_HRES_25, |
|
37: PRESSURE_LEVELS_ERA5_37, |
|
} |
|
|
|
|
|
|
|
ALL_ATMOSPHERIC_VARS = ( |
|
"potential_vorticity", |
|
"specific_rain_water_content", |
|
"specific_snow_water_content", |
|
"geopotential", |
|
"temperature", |
|
"u_component_of_wind", |
|
"v_component_of_wind", |
|
"specific_humidity", |
|
"vertical_velocity", |
|
"vorticity", |
|
"divergence", |
|
"relative_humidity", |
|
"ozone_mass_mixing_ratio", |
|
"specific_cloud_liquid_water_content", |
|
"specific_cloud_ice_water_content", |
|
"fraction_of_cloud_cover", |
|
) |
|
|
|
TARGET_SURFACE_VARS = ( |
|
"2m_temperature", |
|
"mean_sea_level_pressure", |
|
"10m_v_component_of_wind", |
|
"10m_u_component_of_wind", |
|
"total_precipitation_6hr", |
|
) |
|
TARGET_SURFACE_NO_PRECIP_VARS = ( |
|
"2m_temperature", |
|
"mean_sea_level_pressure", |
|
"10m_v_component_of_wind", |
|
"10m_u_component_of_wind", |
|
) |
|
TARGET_ATMOSPHERIC_VARS = ( |
|
"temperature", |
|
"geopotential", |
|
"u_component_of_wind", |
|
"v_component_of_wind", |
|
"vertical_velocity", |
|
"specific_humidity", |
|
) |
|
TARGET_ATMOSPHERIC_NO_W_VARS = ( |
|
"temperature", |
|
"geopotential", |
|
"u_component_of_wind", |
|
"v_component_of_wind", |
|
"specific_humidity", |
|
) |
|
EXTERNAL_FORCING_VARS = ( |
|
"toa_incident_solar_radiation", |
|
) |
|
GENERATED_FORCING_VARS = ( |
|
"year_progress_sin", |
|
"year_progress_cos", |
|
"day_progress_sin", |
|
"day_progress_cos", |
|
) |
|
FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS |
|
STATIC_VARS = ( |
|
"geopotential_at_surface", |
|
"land_sea_mask", |
|
) |
|
|
|
|
|
@chex.dataclass(frozen=True, eq=True) |
|
class TaskConfig: |
|
"""Defines inputs and targets on which a model is trained and/or evaluated.""" |
|
input_variables: tuple[str, ...] |
|
|
|
target_variables: tuple[str, ...] |
|
forcing_variables: tuple[str, ...] |
|
pressure_levels: tuple[int, ...] |
|
input_duration: str |
|
|
|
TASK = TaskConfig( |
|
input_variables=( |
|
TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + |
|
STATIC_VARS), |
|
target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, |
|
forcing_variables=FORCING_VARS, |
|
pressure_levels=PRESSURE_LEVELS_ERA5_37, |
|
input_duration="12h", |
|
) |
|
TASK_13 = TaskConfig( |
|
input_variables=( |
|
TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + |
|
STATIC_VARS), |
|
target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, |
|
forcing_variables=FORCING_VARS, |
|
pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13, |
|
input_duration="12h", |
|
) |
|
TASK_13_PRECIP_OUT = TaskConfig( |
|
input_variables=( |
|
TARGET_SURFACE_NO_PRECIP_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS + |
|
STATIC_VARS), |
|
target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS, |
|
forcing_variables=FORCING_VARS, |
|
pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13, |
|
input_duration="12h", |
|
) |
|
|
|
|
|
@chex.dataclass(frozen=True, eq=True) |
|
class ModelConfig: |
|
"""Defines the architecture of the GraphCast neural network architecture. |
|
|
|
Properties: |
|
resolution: The resolution of the data, in degrees (e.g. 0.25 or 1.0). |
|
mesh_size: How many refinements to do on the multi-mesh. |
|
gnn_msg_steps: How many Graph Network message passing steps to do. |
|
latent_size: How many latent features to include in the various MLPs. |
|
hidden_layers: How many hidden layers for each MLP. |
|
radius_query_fraction_edge_length: Scalar that will be multiplied by the |
|
length of the longest edge of the finest mesh to define the radius of |
|
connectivity to use in the Grid2Mesh graph. Reasonable values are |
|
between 0.6 and 1. 0.6 reduces the number of grid points feeding into |
|
multiple mesh nodes and therefore reduces edge count and memory use, but |
|
1 gives better predictions. |
|
mesh2grid_edge_normalization_factor: Allows explicitly controlling edge |
|
normalization for mesh2grid edges. If None, defaults to max edge length. |
|
This supports using pre-trained model weights with a different graph |
|
structure to what it was trained on. |
|
""" |
|
resolution: float |
|
mesh_size: int |
|
latent_size: int |
|
gnn_msg_steps: int |
|
hidden_layers: int |
|
radius_query_fraction_edge_length: float |
|
mesh2grid_edge_normalization_factor: Optional[float] = None |
|
|
|
|
|
@chex.dataclass(frozen=True, eq=True) |
|
class CheckPoint: |
|
params: dict[str, Any] |
|
model_config: ModelConfig |
|
task_config: TaskConfig |
|
description: str |
|
license: str |
|
|
|
|
|
class GraphCast(predictor_base.Predictor): |
|
"""GraphCast Predictor. |
|
|
|
The model works on graphs that take into account: |
|
* Mesh nodes: nodes for the vertices of the mesh. |
|
* Grid nodes: nodes for the points of the grid. |
|
* Nodes: When referring to just "nodes", this means the joint set of |
|
both mesh nodes, concatenated with grid nodes. |
|
|
|
The model works with 3 graphs: |
|
* Grid2Mesh graph: Graph that contains all nodes. This graph is strictly |
|
bipartite with edges going from grid nodes to mesh nodes using a |
|
fixed radius query. The grid2mesh_gnn will operate in this graph. The output |
|
of this stage will be a latent representation for the mesh nodes, and a |
|
latent representation for the grid nodes. |
|
* Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will |
|
operate in this graph. It will update the latent state of the mesh nodes |
|
only. |
|
* Mesh2Grid graph: Graph that contains all nodes. This graph is strictly |
|
bipartite with edges going from mesh nodes to grid nodes such that each grid |
|
nodes is connected to 3 nodes of the mesh triangular face that contains |
|
the grid points. The mesh2grid_gnn will operate in this graph. It will |
|
process the updated latent state of the mesh nodes, and the latent state |
|
of the grid nodes, to produce the final output for the grid nodes. |
|
|
|
The model is built on top of `TypedGraph`s so the different types of nodes and |
|
edges can be stored and treated separately. |
|
|
|
""" |
|
|
|
def __init__(self, model_config: ModelConfig, task_config: TaskConfig): |
|
"""Initializes the predictor.""" |
|
self._spatial_features_kwargs = dict( |
|
add_node_positions=False, |
|
add_node_latitude=True, |
|
add_node_longitude=True, |
|
add_relative_positions=True, |
|
relative_longitude_local_coordinates=True, |
|
relative_latitude_local_coordinates=True, |
|
) |
|
|
|
|
|
self._meshes = ( |
|
icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( |
|
splits=model_config.mesh_size)) |
|
|
|
|
|
|
|
self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet( |
|
embed_nodes=True, |
|
embed_edges=True, |
|
edge_latent_size=dict(grid2mesh=model_config.latent_size), |
|
node_latent_size=dict( |
|
mesh_nodes=model_config.latent_size, |
|
grid_nodes=model_config.latent_size), |
|
mlp_hidden_size=model_config.latent_size, |
|
mlp_num_hidden_layers=model_config.hidden_layers, |
|
num_message_passing_steps=1, |
|
use_layer_norm=True, |
|
include_sent_messages_in_node_update=False, |
|
activation="swish", |
|
f32_aggregation=True, |
|
aggregate_normalization=None, |
|
name="grid2mesh_gnn", |
|
) |
|
|
|
|
|
self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet( |
|
embed_nodes=False, |
|
embed_edges=True, |
|
node_latent_size=dict(mesh_nodes=model_config.latent_size), |
|
edge_latent_size=dict(mesh=model_config.latent_size), |
|
mlp_hidden_size=model_config.latent_size, |
|
mlp_num_hidden_layers=model_config.hidden_layers, |
|
num_message_passing_steps=model_config.gnn_msg_steps, |
|
use_layer_norm=True, |
|
include_sent_messages_in_node_update=False, |
|
activation="swish", |
|
f32_aggregation=False, |
|
name="mesh_gnn", |
|
) |
|
|
|
num_surface_vars = len( |
|
set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS)) |
|
num_atmospheric_vars = len( |
|
set(task_config.target_variables) & set(ALL_ATMOSPHERIC_VARS)) |
|
num_outputs = (num_surface_vars + |
|
len(task_config.pressure_levels) * num_atmospheric_vars) |
|
|
|
|
|
|
|
self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet( |
|
|
|
node_output_size=dict(grid_nodes=num_outputs), |
|
embed_nodes=False, |
|
embed_edges=True, |
|
edge_latent_size=dict(mesh2grid=model_config.latent_size), |
|
node_latent_size=dict( |
|
mesh_nodes=model_config.latent_size, |
|
grid_nodes=model_config.latent_size), |
|
mlp_hidden_size=model_config.latent_size, |
|
mlp_num_hidden_layers=model_config.hidden_layers, |
|
num_message_passing_steps=1, |
|
use_layer_norm=True, |
|
include_sent_messages_in_node_update=False, |
|
activation="swish", |
|
f32_aggregation=False, |
|
name="mesh2grid_gnn", |
|
) |
|
|
|
|
|
|
|
self._query_radius = (_get_max_edge_distance(self._finest_mesh) |
|
* model_config.radius_query_fraction_edge_length) |
|
self._mesh2grid_edge_normalization_factor = ( |
|
model_config.mesh2grid_edge_normalization_factor |
|
) |
|
|
|
|
|
|
|
self._initialized = False |
|
|
|
|
|
|
|
self._num_mesh_nodes = None |
|
self._mesh_nodes_lat = None |
|
self._mesh_nodes_lon = None |
|
|
|
|
|
self._grid_lat = None |
|
self._grid_lon = None |
|
self._num_grid_nodes = None |
|
self._grid_nodes_lat = None |
|
self._grid_nodes_lon = None |
|
|
|
|
|
self._grid2mesh_graph_structure = None |
|
self._mesh_graph_structure = None |
|
self._mesh2grid_graph_structure = None |
|
|
|
@property |
|
def _finest_mesh(self): |
|
return self._meshes[-1] |
|
|
|
def __call__(self, |
|
inputs: xarray.Dataset, |
|
targets_template: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
is_training: bool = False, |
|
) -> xarray.Dataset: |
|
self._maybe_init(inputs) |
|
|
|
|
|
|
|
|
|
grid_node_features = self._inputs_to_grid_node_features(inputs, forcings) |
|
|
|
|
|
|
|
(latent_mesh_nodes, latent_grid_nodes |
|
) = self._run_grid2mesh_gnn(grid_node_features) |
|
|
|
|
|
|
|
updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes) |
|
|
|
|
|
|
|
output_grid_nodes = self._run_mesh2grid_gnn( |
|
updated_latent_mesh_nodes, latent_grid_nodes) |
|
|
|
|
|
|
|
|
|
return self._grid_node_outputs_to_prediction( |
|
output_grid_nodes, targets_template) |
|
|
|
def loss_and_predictions( |
|
self, |
|
inputs: xarray.Dataset, |
|
targets: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
) -> tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]: |
|
|
|
predictions = self( |
|
inputs, targets_template=targets, forcings=forcings, is_training=True) |
|
|
|
loss = losses.weighted_mse_per_level( |
|
predictions, targets, |
|
per_variable_weights={ |
|
|
|
|
|
|
|
|
|
|
|
"2m_temperature": 1.0, |
|
|
|
|
|
"10m_u_component_of_wind": 0.1, |
|
"10m_v_component_of_wind": 0.1, |
|
"mean_sea_level_pressure": 0.1, |
|
"total_precipitation_6hr": 0.1, |
|
}) |
|
return loss, predictions |
|
|
|
def loss( |
|
self, |
|
inputs: xarray.Dataset, |
|
targets: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
) -> predictor_base.LossAndDiagnostics: |
|
loss, _ = self.loss_and_predictions(inputs, targets, forcings) |
|
return loss |
|
|
|
def _maybe_init(self, sample_inputs: xarray.Dataset): |
|
"""Inits everything that has a dependency on the input coordinates.""" |
|
if not self._initialized: |
|
self._init_mesh_properties() |
|
self._init_grid_properties( |
|
grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon) |
|
self._grid2mesh_graph_structure = self._init_grid2mesh_graph() |
|
self._mesh_graph_structure = self._init_mesh_graph() |
|
self._mesh2grid_graph_structure = self._init_mesh2grid_graph() |
|
|
|
self._initialized = True |
|
|
|
def _init_mesh_properties(self): |
|
"""Inits static properties that have to do with mesh nodes.""" |
|
self._num_mesh_nodes = self._finest_mesh.vertices.shape[0] |
|
mesh_phi, mesh_theta = model_utils.cartesian_to_spherical( |
|
self._finest_mesh.vertices[:, 0], |
|
self._finest_mesh.vertices[:, 1], |
|
self._finest_mesh.vertices[:, 2]) |
|
( |
|
mesh_nodes_lat, |
|
mesh_nodes_lon, |
|
) = model_utils.spherical_to_lat_lon( |
|
phi=mesh_phi, theta=mesh_theta) |
|
|
|
self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32) |
|
self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32) |
|
|
|
def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray): |
|
"""Inits static properties that have to do with grid nodes.""" |
|
self._grid_lat = grid_lat.astype(np.float32) |
|
self._grid_lon = grid_lon.astype(np.float32) |
|
|
|
self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0] |
|
|
|
|
|
grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat) |
|
self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32) |
|
self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32) |
|
|
|
def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph: |
|
"""Build Grid2Mesh graph.""" |
|
|
|
|
|
assert self._grid_lat is not None and self._grid_lon is not None |
|
(grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices( |
|
grid_latitude=self._grid_lat, |
|
grid_longitude=self._grid_lon, |
|
mesh=self._finest_mesh, |
|
radius=self._query_radius) |
|
|
|
|
|
senders = grid_indices |
|
receivers = mesh_indices |
|
|
|
|
|
|
|
|
|
(senders_node_features, receivers_node_features, |
|
edge_features) = model_utils.get_bipartite_graph_spatial_features( |
|
senders_node_lat=self._grid_nodes_lat, |
|
senders_node_lon=self._grid_nodes_lon, |
|
receivers_node_lat=self._mesh_nodes_lat, |
|
receivers_node_lon=self._mesh_nodes_lon, |
|
senders=senders, |
|
receivers=receivers, |
|
edge_normalization_factor=None, |
|
**self._spatial_features_kwargs, |
|
) |
|
|
|
n_grid_node = np.array([self._num_grid_nodes]) |
|
n_mesh_node = np.array([self._num_mesh_nodes]) |
|
n_edge = np.array([mesh_indices.shape[0]]) |
|
grid_node_set = typed_graph.NodeSet( |
|
n_node=n_grid_node, features=senders_node_features) |
|
mesh_node_set = typed_graph.NodeSet( |
|
n_node=n_mesh_node, features=receivers_node_features) |
|
edge_set = typed_graph.EdgeSet( |
|
n_edge=n_edge, |
|
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), |
|
features=edge_features) |
|
nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} |
|
edges = { |
|
typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")): |
|
edge_set |
|
} |
|
grid2mesh_graph = typed_graph.TypedGraph( |
|
context=typed_graph.Context(n_graph=np.array([1]), features=()), |
|
nodes=nodes, |
|
edges=edges) |
|
return grid2mesh_graph |
|
|
|
def _init_mesh_graph(self) -> typed_graph.TypedGraph: |
|
"""Build Mesh graph.""" |
|
merged_mesh = icosahedral_mesh.merge_meshes(self._meshes) |
|
|
|
|
|
senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces) |
|
|
|
|
|
|
|
|
|
assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None |
|
node_features, edge_features = model_utils.get_graph_spatial_features( |
|
node_lat=self._mesh_nodes_lat, |
|
node_lon=self._mesh_nodes_lon, |
|
senders=senders, |
|
receivers=receivers, |
|
**self._spatial_features_kwargs, |
|
) |
|
|
|
n_mesh_node = np.array([self._num_mesh_nodes]) |
|
n_edge = np.array([senders.shape[0]]) |
|
assert n_mesh_node == len(node_features) |
|
mesh_node_set = typed_graph.NodeSet( |
|
n_node=n_mesh_node, features=node_features) |
|
edge_set = typed_graph.EdgeSet( |
|
n_edge=n_edge, |
|
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), |
|
features=edge_features) |
|
nodes = {"mesh_nodes": mesh_node_set} |
|
edges = { |
|
typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set |
|
} |
|
mesh_graph = typed_graph.TypedGraph( |
|
context=typed_graph.Context(n_graph=np.array([1]), features=()), |
|
nodes=nodes, |
|
edges=edges) |
|
|
|
return mesh_graph |
|
|
|
def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph: |
|
"""Build Mesh2Grid graph.""" |
|
|
|
|
|
|
|
(grid_indices, |
|
mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices( |
|
grid_latitude=self._grid_lat, |
|
grid_longitude=self._grid_lon, |
|
mesh=self._finest_mesh) |
|
|
|
|
|
senders = mesh_indices |
|
receivers = grid_indices |
|
|
|
|
|
assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None |
|
(senders_node_features, receivers_node_features, |
|
edge_features) = model_utils.get_bipartite_graph_spatial_features( |
|
senders_node_lat=self._mesh_nodes_lat, |
|
senders_node_lon=self._mesh_nodes_lon, |
|
receivers_node_lat=self._grid_nodes_lat, |
|
receivers_node_lon=self._grid_nodes_lon, |
|
senders=senders, |
|
receivers=receivers, |
|
edge_normalization_factor=self._mesh2grid_edge_normalization_factor, |
|
**self._spatial_features_kwargs, |
|
) |
|
|
|
n_grid_node = np.array([self._num_grid_nodes]) |
|
n_mesh_node = np.array([self._num_mesh_nodes]) |
|
n_edge = np.array([senders.shape[0]]) |
|
grid_node_set = typed_graph.NodeSet( |
|
n_node=n_grid_node, features=receivers_node_features) |
|
mesh_node_set = typed_graph.NodeSet( |
|
n_node=n_mesh_node, features=senders_node_features) |
|
edge_set = typed_graph.EdgeSet( |
|
n_edge=n_edge, |
|
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers), |
|
features=edge_features) |
|
nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set} |
|
edges = { |
|
typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")): |
|
edge_set |
|
} |
|
mesh2grid_graph = typed_graph.TypedGraph( |
|
context=typed_graph.Context(n_graph=np.array([1]), features=()), |
|
nodes=nodes, |
|
edges=edges) |
|
return mesh2grid_graph |
|
|
|
def _run_grid2mesh_gnn(self, grid_node_features: chex.Array, |
|
) -> tuple[chex.Array, chex.Array]: |
|
"""Runs the grid2mesh_gnn, extracting latent mesh and grid nodes.""" |
|
|
|
|
|
batch_size = grid_node_features.shape[1] |
|
|
|
grid2mesh_graph = self._grid2mesh_graph_structure |
|
assert grid2mesh_graph is not None |
|
grid_nodes = grid2mesh_graph.nodes["grid_nodes"] |
|
mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"] |
|
new_grid_nodes = grid_nodes._replace( |
|
features=jnp.concatenate([ |
|
grid_node_features, |
|
_add_batch_second_axis( |
|
grid_nodes.features.astype(grid_node_features.dtype), |
|
batch_size) |
|
], |
|
axis=-1)) |
|
|
|
|
|
|
|
|
|
dummy_mesh_node_features = jnp.zeros( |
|
(self._num_mesh_nodes,) + grid_node_features.shape[1:], |
|
dtype=grid_node_features.dtype) |
|
new_mesh_nodes = mesh_nodes._replace( |
|
features=jnp.concatenate([ |
|
dummy_mesh_node_features, |
|
_add_batch_second_axis( |
|
mesh_nodes.features.astype(dummy_mesh_node_features.dtype), |
|
batch_size) |
|
], |
|
axis=-1)) |
|
|
|
|
|
grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh") |
|
edges = grid2mesh_graph.edges[grid2mesh_edges_key] |
|
|
|
new_edges = edges._replace( |
|
features=_add_batch_second_axis( |
|
edges.features.astype(dummy_mesh_node_features.dtype), batch_size)) |
|
|
|
input_graph = self._grid2mesh_graph_structure._replace( |
|
edges={grid2mesh_edges_key: new_edges}, |
|
nodes={ |
|
"grid_nodes": new_grid_nodes, |
|
"mesh_nodes": new_mesh_nodes |
|
}) |
|
|
|
|
|
grid2mesh_out = self._grid2mesh_gnn(input_graph) |
|
latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features |
|
latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features |
|
return latent_mesh_nodes, latent_grid_nodes |
|
|
|
def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array: |
|
"""Runs the mesh_gnn, extracting updated latent mesh nodes.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = latent_mesh_nodes.shape[1] |
|
|
|
mesh_graph = self._mesh_graph_structure |
|
assert mesh_graph is not None |
|
mesh_edges_key = mesh_graph.edge_key_by_name("mesh") |
|
edges = mesh_graph.edges[mesh_edges_key] |
|
|
|
|
|
|
|
|
|
msg = ("The setup currently requires to only have one kind of edge in the" |
|
" mesh GNN.") |
|
assert len(mesh_graph.edges) == 1, msg |
|
|
|
new_edges = edges._replace( |
|
features=_add_batch_second_axis( |
|
edges.features.astype(latent_mesh_nodes.dtype), batch_size)) |
|
|
|
nodes = mesh_graph.nodes["mesh_nodes"] |
|
nodes = nodes._replace(features=latent_mesh_nodes) |
|
|
|
input_graph = mesh_graph._replace( |
|
edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes}) |
|
|
|
|
|
return self._mesh_gnn(input_graph).nodes["mesh_nodes"].features |
|
|
|
def _run_mesh2grid_gnn(self, |
|
updated_latent_mesh_nodes: chex.Array, |
|
latent_grid_nodes: chex.Array, |
|
) -> chex.Array: |
|
"""Runs the mesh2grid_gnn, extracting the output grid nodes.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = updated_latent_mesh_nodes.shape[1] |
|
|
|
mesh2grid_graph = self._mesh2grid_graph_structure |
|
assert mesh2grid_graph is not None |
|
mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"] |
|
grid_nodes = mesh2grid_graph.nodes["grid_nodes"] |
|
new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes) |
|
new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes) |
|
mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid") |
|
edges = mesh2grid_graph.edges[mesh2grid_key] |
|
|
|
new_edges = edges._replace( |
|
features=_add_batch_second_axis( |
|
edges.features.astype(latent_grid_nodes.dtype), batch_size)) |
|
|
|
input_graph = mesh2grid_graph._replace( |
|
edges={mesh2grid_key: new_edges}, |
|
nodes={ |
|
"mesh_nodes": new_mesh_nodes, |
|
"grid_nodes": new_grid_nodes |
|
}) |
|
|
|
|
|
output_graph = self._mesh2grid_gnn(input_graph) |
|
output_grid_nodes = output_graph.nodes["grid_nodes"].features |
|
|
|
return output_grid_nodes |
|
|
|
def _inputs_to_grid_node_features( |
|
self, |
|
inputs: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
) -> chex.Array: |
|
"""xarrays -> [num_grid_nodes, batch, num_channels].""" |
|
|
|
|
|
|
|
stacked_inputs = model_utils.dataset_to_stacked(inputs) |
|
stacked_forcings = model_utils.dataset_to_stacked(forcings) |
|
stacked_inputs = xarray.concat( |
|
[stacked_inputs, stacked_forcings], dim="channels") |
|
|
|
|
|
|
|
grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes( |
|
stacked_inputs) |
|
return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape( |
|
(-1,) + grid_xarray_lat_lon_leading.data.shape[2:]) |
|
|
|
def _grid_node_outputs_to_prediction( |
|
self, |
|
grid_node_outputs: chex.Array, |
|
targets_template: xarray.Dataset, |
|
) -> xarray.Dataset: |
|
"""[num_grid_nodes, batch, num_outputs] -> xarray.""" |
|
|
|
|
|
|
|
assert self._grid_lat is not None and self._grid_lon is not None |
|
grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0]) |
|
grid_outputs_lat_lon_leading = grid_node_outputs.reshape( |
|
grid_shape + grid_node_outputs.shape[1:]) |
|
dims = ("lat", "lon", "batch", "channels") |
|
grid_xarray_lat_lon_leading = xarray_jax.DataArray( |
|
data=grid_outputs_lat_lon_leading, |
|
dims=dims) |
|
grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading) |
|
|
|
|
|
|
|
return model_utils.stacked_to_dataset( |
|
grid_xarray.variable, targets_template) |
|
|
|
|
|
def _add_batch_second_axis(data, batch_size): |
|
|
|
assert data.ndim == 2 |
|
ones = jnp.ones([batch_size, 1], dtype=data.dtype) |
|
return data[:, None] * ones |
|
|
|
|
|
def _get_max_edge_distance(mesh): |
|
senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) |
|
edge_distances = np.linalg.norm( |
|
mesh.vertices[senders] - mesh.vertices[receivers], axis=-1) |
|
return edge_distances.max() |
|
|