File size: 30,390 Bytes
6d70ed4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 |
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for building models."""
from typing import Mapping, Optional, Tuple
import numpy as np
from scipy.spatial import transform
import xarray
def get_graph_spatial_features(
*, node_lat: np.ndarray, node_lon: np.ndarray,
senders: np.ndarray, receivers: np.ndarray,
add_node_positions: bool,
add_node_latitude: bool,
add_node_longitude: bool,
add_relative_positions: bool,
relative_longitude_local_coordinates: bool,
relative_latitude_local_coordinates: bool,
sine_cosine_encoding: bool = False,
encoding_num_freqs: int = 10,
encoding_multiplicative_factor: float = 1.2,
) -> Tuple[np.ndarray, np.ndarray]:
"""Computes spatial features for the nodes.
Args:
node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes]
node_lon: Longitudes in the [0, 360] interval of shape [num_nodes]
senders: Sender indices of shape [num_edges]
receivers: Receiver indices of shape [num_edges]
add_node_positions: Add unit norm absolute positions.
add_node_latitude: Add a feature for latitude (cos(90 - lat))
Note even if this is set to False, the model may be able to infer the
longitude from relative features, unless
`relative_latitude_local_coordinates` is also True, or if there is any
bias on the relative edge sizes for different longitudes.
add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
Note even if this is set to False, the model may be able to infer the
longitude from relative features, unless
`relative_longitude_local_coordinates` is also True, or if there is any
bias on the relative edge sizes for different longitudes.
add_relative_positions: Whether to relative positions in R3 to the edges.
relative_longitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 longitude.
relative_latitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 latitude.
sine_cosine_encoding: If True, we will transform the node/edge features
with sine and cosine functions, similar to NERF.
encoding_num_freqs: frequency parameter
encoding_multiplicative_factor: used for calculating the frequency.
Returns:
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
with node and edge features.
"""
num_nodes = node_lat.shape[0]
num_edges = senders.shape[0]
dtype = node_lat.dtype
node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon)
# Computing some node features.
node_features = []
if add_node_positions:
# Already in [-1, 1.] range.
node_features.extend(spherical_to_cartesian(node_phi, node_theta))
if add_node_latitude:
# Using the cos of theta.
# From 1. (north pole) to -1 (south pole).
node_features.append(np.cos(node_theta))
if add_node_longitude:
# Using the cos and sin, which is already normalized.
node_features.append(np.cos(node_phi))
node_features.append(np.sin(node_phi))
if not node_features:
node_features = np.zeros([num_nodes, 0], dtype=dtype)
else:
node_features = np.stack(node_features, axis=-1)
# Computing some edge features.
edge_features = []
if add_relative_positions:
relative_position = get_relative_position_in_receiver_local_coordinates(
node_phi=node_phi,
node_theta=node_theta,
senders=senders,
receivers=receivers,
latitude_local_coordinates=relative_latitude_local_coordinates,
longitude_local_coordinates=relative_longitude_local_coordinates
)
# Note this is L2 distance in 3d space, rather than geodesic distance.
relative_edge_distances = np.linalg.norm(
relative_position, axis=-1, keepdims=True)
# Normalize to the maximum edge distance. Note that we expect to always
# have an edge that goes in the opposite direction of any given edge
# so the distribution of relative positions should be symmetric around
# zero. So by scaling by the maximum length, we expect all relative
# positions to fall in the [-1., 1.] interval, and all relative distances
# to fall in the [0., 1.] interval.
max_edge_distance = relative_edge_distances.max()
edge_features.append(relative_edge_distances / max_edge_distance)
edge_features.append(relative_position / max_edge_distance)
if not edge_features:
edge_features = np.zeros([num_edges, 0], dtype=dtype)
else:
edge_features = np.concatenate(edge_features, axis=-1)
if sine_cosine_encoding:
def sine_cosine_transform(x: np.ndarray) -> np.ndarray:
freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs)
phases = freqs * x[..., None]
x_sin = np.sin(phases)
x_cos = np.cos(phases)
x_cat = np.concatenate([x_sin, x_cos], axis=-1)
return x_cat.reshape([x.shape[0], -1])
node_features = sine_cosine_transform(node_features)
edge_features = sine_cosine_transform(edge_features)
return node_features, edge_features
def lat_lon_to_leading_axes(
grid_xarray: xarray.DataArray) -> xarray.DataArray:
"""Reorders xarray so lat/lon axes come first."""
# leading + ["lat", "lon"] + trailing
# to
# ["lat", "lon"] + leading + trailing
return grid_xarray.transpose("lat", "lon", ...)
def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray:
"""Reorders xarray so batch/time/level axes come first (if present)."""
# ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing
# to
# [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing
input_dims = list(grid_xarray.dims)
output_dims = list(input_dims)
for leading_key in ["level", "time", "batch"]: # reverse order for insert
if leading_key in input_dims:
output_dims.remove(leading_key)
output_dims.insert(0, leading_key)
return grid_xarray.transpose(*output_dims)
def lat_lon_deg_to_spherical(node_lat: np.ndarray,
node_lon: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
phi = np.deg2rad(node_lon)
theta = np.deg2rad(90 - node_lat)
return phi, theta
def spherical_to_lat_lon(phi: np.ndarray,
theta: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
lon = np.mod(np.rad2deg(phi), 360)
lat = 90 - np.rad2deg(theta)
return lat, lon
def cartesian_to_spherical(x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
phi = np.arctan2(y, x)
with np.errstate(invalid="ignore"): # circumventing b/253179568
theta = np.arccos(z) # Assuming unit radius.
return phi, theta
def spherical_to_cartesian(
phi: np.ndarray, theta: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
# Assuming unit radius.
return (np.cos(phi)*np.sin(theta),
np.sin(phi)*np.sin(theta),
np.cos(theta))
def get_relative_position_in_receiver_local_coordinates(
node_phi: np.ndarray,
node_theta: np.ndarray,
senders: np.ndarray,
receivers: np.ndarray,
latitude_local_coordinates: bool,
longitude_local_coordinates: bool
) -> np.ndarray:
"""Returns relative position features for the edges.
The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
that local coordinate system after the rotation in R^3.
Args:
node_phi: [num_nodes] with polar angles.
node_theta: [num_nodes] with azimuthal angles.
senders: [num_edges] with indices.
receivers: [num_edges] with indices.
latitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at latitude 0.
longitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at longitude 0.
Returns:
Array of relative positions in R3 [num_edges, 3]
"""
node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1)
# No rotation in this case.
if not (latitude_local_coordinates or longitude_local_coordinates):
return node_pos[senders] - node_pos[receivers]
# Get rotation matrices for the local space space for every node.
rotation_matrices = get_rotation_matrices_to_local_coordinates(
reference_phi=node_phi,
reference_theta=node_theta,
rotate_latitude=latitude_local_coordinates,
rotate_longitude=longitude_local_coordinates)
# Each edge will be rotated according to the rotation matrix of its receiver
# node.
edge_rotation_matrices = rotation_matrices[receivers]
# Rotate all nodes to the rotated space of the corresponding edge.
# Note for receivers we can also do the matmul first and the gather second:
# ```
# receiver_pos_in_rotated_space = rotate_with_matrices(
# rotation_matrices, node_pos)[receivers]
# ```
# which is more efficient, however, we do gather first to keep it more
# symmetric with the sender computation.
receiver_pos_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, node_pos[receivers])
sender_pos_in_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, node_pos[senders])
# Note, here, that because the rotated space is chosen according to the
# receiver, if:
# * latitude_local_coordinates = True: latitude for the receivers will be
# 0, that is the z coordinate will always be 0.
# * longitude_local_coordinates = True: longitude for the receivers will be
# 0, that is the y coordinate will be 0.
# Now we can just subtract.
# Note we are rotating to a local coordinate system, where the y-z axes are
# parallel to a tangent plane to the sphere, but still remain in a 3d space.
# Note that if both `latitude_local_coordinates` and
# `longitude_local_coordinates` are True, and edges are short,
# then the difference in x coordinate between sender and receiver
# should be small, so we could consider dropping the new x coordinate if
# we wanted to the tangent plane, however in doing so
# we would lose information about the curvature of the mesh, which may be
# important for very coarse meshes.
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
def get_rotation_matrices_to_local_coordinates(
reference_phi: np.ndarray,
reference_theta: np.ndarray,
rotate_latitude: bool,
rotate_longitude: bool) -> np.ndarray:
"""Returns a rotation matrix to rotate to a point based on a reference vector.
The rotation matrix is build such that, a vector in the
same coordinate system at the reference point that points towards the pole
before the rotation, continues to point towards the pole after the rotation.
Args:
reference_phi: [leading_axis] Polar angles of the reference.
reference_theta: [leading_axis] Azimuthal angles of the reference.
rotate_latitude: Whether to produce a rotation matrix that would rotate
R^3 vectors to zero latitude.
rotate_longitude: Whether to produce a rotation matrix that would rotate
R^3 vectors to zero longitude.
Returns:
Matrices of shape [leading_axis] such that when applied to the reference
position with `rotate_with_matrices(rotation_matrices, reference_pos)`
* phi goes to 0. if "rotate_longitude" is True.
* theta goes to np.pi / 2 if "rotate_latitude" is True.
The rotation consists of:
* rotate_latitude = False, rotate_longitude = True:
Latitude preserving rotation.
* rotate_latitude = True, rotate_longitude = True:
Latitude preserving rotation, followed by longitude preserving
rotation.
* rotate_latitude = True, rotate_longitude = False:
Latitude preserving rotation, followed by longitude preserving
rotation, and the inverse of the latitude preserving rotation. Note
this is computationally different from rotating the longitude only
and is. We do it like this, so the polar geodesic curve, continues
to be aligned with one of the axis after the rotation.
"""
if rotate_longitude and rotate_latitude:
# We first rotate around the z axis "minus the azimuthal angle", to get the
# point with zero longitude
azimuthal_rotation = - reference_phi
# One then we will do a polar rotation (which can be done along the y
# axis now that we are at longitude 0.), "minus the polar angle plus 2pi"
# to get the point with zero latitude.
polar_rotation = - reference_theta + np.pi/2
return transform.Rotation.from_euler(
"zy", np.stack([azimuthal_rotation, polar_rotation],
axis=1)).as_matrix()
elif rotate_longitude:
# Just like the previous case, but applying only the azimuthal rotation.
azimuthal_rotation = - reference_phi
return transform.Rotation.from_euler("z", -reference_phi).as_matrix()
elif rotate_latitude:
# Just like the first case, but after doing the polar rotation, undoing
# the azimuthal rotation.
azimuthal_rotation = - reference_phi
polar_rotation = - reference_theta + np.pi/2
return transform.Rotation.from_euler(
"zyz", np.stack(
[azimuthal_rotation, polar_rotation, -azimuthal_rotation]
, axis=1)).as_matrix()
else:
raise ValueError(
"At least one of longitude and latitude should be rotated.")
def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray
) -> np.ndarray:
return np.einsum("bji,bi->bj", rotation_matrices, positions)
def get_bipartite_graph_spatial_features(
*,
senders_node_lat: np.ndarray,
senders_node_lon: np.ndarray,
senders: np.ndarray,
receivers_node_lat: np.ndarray,
receivers_node_lon: np.ndarray,
receivers: np.ndarray,
add_node_positions: bool,
add_node_latitude: bool,
add_node_longitude: bool,
add_relative_positions: bool,
edge_normalization_factor: Optional[float] = None,
relative_longitude_local_coordinates: bool,
relative_latitude_local_coordinates: bool,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Computes spatial features for the nodes.
This function is almost identical to `get_graph_spatial_features`. The only
difference is that sender nodes and receiver nodes can be in different arrays.
This is necessary to enable combination with typed Graph.
Args:
senders_node_lat: Latitudes in the [-90, 90] interval of shape
[num_sender_nodes]
senders_node_lon: Longitudes in the [0, 360] interval of shape
[num_sender_nodes]
senders: Sender indices of shape [num_edges], indices in [0,
num_sender_nodes)
receivers_node_lat: Latitudes in the [-90, 90] interval of shape
[num_receiver_nodes]
receivers_node_lon: Longitudes in the [0, 360] interval of shape
[num_receiver_nodes]
receivers: Receiver indices of shape [num_edges], indices in [0,
num_receiver_nodes)
add_node_positions: Add unit norm absolute positions.
add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if
this is set to False, the model may be able to infer the longitude from
relative features, unless `relative_latitude_local_coordinates` is also
True, or if there is any bias on the relative edge sizes for different
longitudes.
add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note
even if this is set to False, the model may be able to infer the longitude
from relative features, unless `relative_longitude_local_coordinates` is
also True, or if there is any bias on the relative edge sizes for
different longitudes.
add_relative_positions: Whether to relative positions in R3 to the edges.
edge_normalization_factor: Allows explicitly controlling edge normalization.
If None, defaults to max edge length. This supports using pre-trained
model weights with a different graph structure to what it was trained on.
relative_longitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 longitude.
relative_latitude_local_coordinates: If True, relative positions are
computed in a local space where the receiver is at 0 latitude.
Returns:
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
with node and edge features.
"""
num_senders = senders_node_lat.shape[0]
num_receivers = receivers_node_lat.shape[0]
num_edges = senders.shape[0]
dtype = senders_node_lat.dtype
assert receivers_node_lat.dtype == dtype
senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical(
senders_node_lat, senders_node_lon)
receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical(
receivers_node_lat, receivers_node_lon)
# Computing some node features.
senders_node_features = []
receivers_node_features = []
if add_node_positions:
# Already in [-1, 1.] range.
senders_node_features.extend(
spherical_to_cartesian(senders_node_phi, senders_node_theta))
receivers_node_features.extend(
spherical_to_cartesian(receivers_node_phi, receivers_node_theta))
if add_node_latitude:
# Using the cos of theta.
# From 1. (north pole) to -1 (south pole).
senders_node_features.append(np.cos(senders_node_theta))
receivers_node_features.append(np.cos(receivers_node_theta))
if add_node_longitude:
# Using the cos and sin, which is already normalized.
senders_node_features.append(np.cos(senders_node_phi))
senders_node_features.append(np.sin(senders_node_phi))
receivers_node_features.append(np.cos(receivers_node_phi))
receivers_node_features.append(np.sin(receivers_node_phi))
if not senders_node_features:
senders_node_features = np.zeros([num_senders, 0], dtype=dtype)
receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype)
else:
senders_node_features = np.stack(senders_node_features, axis=-1)
receivers_node_features = np.stack(receivers_node_features, axis=-1)
# Computing some edge features.
edge_features = []
if add_relative_positions:
relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long
senders_node_phi=senders_node_phi,
senders_node_theta=senders_node_theta,
receivers_node_phi=receivers_node_phi,
receivers_node_theta=receivers_node_theta,
senders=senders,
receivers=receivers,
latitude_local_coordinates=relative_latitude_local_coordinates,
longitude_local_coordinates=relative_longitude_local_coordinates)
# Note this is L2 distance in 3d space, rather than geodesic distance.
relative_edge_distances = np.linalg.norm(
relative_position, axis=-1, keepdims=True)
if edge_normalization_factor is None:
# Normalize to the maximum edge distance. Note that we expect to always
# have an edge that goes in the opposite direction of any given edge
# so the distribution of relative positions should be symmetric around
# zero. So by scaling by the maximum length, we expect all relative
# positions to fall in the [-1., 1.] interval, and all relative distances
# to fall in the [0., 1.] interval.
edge_normalization_factor = relative_edge_distances.max()
edge_features.append(relative_edge_distances / edge_normalization_factor)
edge_features.append(relative_position / edge_normalization_factor)
if not edge_features:
edge_features = np.zeros([num_edges, 0], dtype=dtype)
else:
edge_features = np.concatenate(edge_features, axis=-1)
return senders_node_features, receivers_node_features, edge_features
def get_bipartite_relative_position_in_receiver_local_coordinates(
senders_node_phi: np.ndarray,
senders_node_theta: np.ndarray,
senders: np.ndarray,
receivers_node_phi: np.ndarray,
receivers_node_theta: np.ndarray,
receivers: np.ndarray,
latitude_local_coordinates: bool,
longitude_local_coordinates: bool) -> np.ndarray:
"""Returns relative position features for the edges.
This function is equivalent to
`get_relative_position_in_receiver_local_coordinates`, but adapted to work
with bipartite typed graphs.
The relative positions will be computed in a rotated space for a local
coordinate system as defined by the receiver. The relative positions are
simply obtained by subtracting sender position minues receiver position in
that local coordinate system after the rotation in R^3.
Args:
senders_node_phi: [num_sender_nodes] with polar angles.
senders_node_theta: [num_sender_nodes] with azimuthal angles.
senders: [num_edges] with indices into sender nodes.
receivers_node_phi: [num_sender_nodes] with polar angles.
receivers_node_theta: [num_sender_nodes] with azimuthal angles.
receivers: [num_edges] with indices into receiver nodes.
latitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at latitude 0.
longitude_local_coordinates: Whether to rotate edges such that in the
positions are computed such that the receiver is always at longitude 0.
Returns:
Array of relative positions in R3 [num_edges, 3]
"""
senders_node_pos = np.stack(
spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1)
receivers_node_pos = np.stack(
spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1)
# No rotation in this case.
if not (latitude_local_coordinates or longitude_local_coordinates):
return senders_node_pos[senders] - receivers_node_pos[receivers]
# Get rotation matrices for the local space space for every receiver node.
receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates(
reference_phi=receivers_node_phi,
reference_theta=receivers_node_theta,
rotate_latitude=latitude_local_coordinates,
rotate_longitude=longitude_local_coordinates)
# Each edge will be rotated according to the rotation matrix of its receiver
# node.
edge_rotation_matrices = receiver_rotation_matrices[receivers]
# Rotate all nodes to the rotated space of the corresponding edge.
# Note for receivers we can also do the matmul first and the gather second:
# ```
# receiver_pos_in_rotated_space = rotate_with_matrices(
# rotation_matrices, node_pos)[receivers]
# ```
# which is more efficient, however, we do gather first to keep it more
# symmetric with the sender computation.
receiver_pos_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, receivers_node_pos[receivers])
sender_pos_in_in_rotated_space = rotate_with_matrices(
edge_rotation_matrices, senders_node_pos[senders])
# Note, here, that because the rotated space is chosen according to the
# receiver, if:
# * latitude_local_coordinates = True: latitude for the receivers will be
# 0, that is the z coordinate will always be 0.
# * longitude_local_coordinates = True: longitude for the receivers will be
# 0, that is the y coordinate will be 0.
# Now we can just subtract.
# Note we are rotating to a local coordinate system, where the y-z axes are
# parallel to a tangent plane to the sphere, but still remain in a 3d space.
# Note that if both `latitude_local_coordinates` and
# `longitude_local_coordinates` are True, and edges are short,
# then the difference in x coordinate between sender and receiver
# should be small, so we could consider dropping the new x coordinate if
# we wanted to the tangent plane, however in doing so
# we would lose information about the curvature of the mesh, which may be
# important for very coarse meshes.
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
def variable_to_stacked(
variable: xarray.Variable,
sizes: Mapping[str, int],
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
) -> xarray.Variable:
"""Converts an xarray.Variable to preserved_dims + ("channels",).
Any dimensions other than those included in preserved_dims get stacked into a
final "channels" dimension. If any of the preserved_dims are missing then they
are added, with the data broadcast/tiled to match the sizes specified in
`sizes`.
Args:
variable: An xarray.Variable.
sizes: Mapping including sizes for any dimensions which are not present in
`variable` but are needed for the output. This may be needed for example
for a static variable with only ("lat", "lon") dims, or if you want to
encode just the latitude coordinates (a variable with dims ("lat",)).
preserved_dims: dimensions of variable to not be folded in channels.
Returns:
An xarray.Variable with dimensions preserved_dims + ("channels",).
"""
stack_to_channels_dims = [
d for d in variable.dims if d not in preserved_dims]
if stack_to_channels_dims:
variable = variable.stack(channels=stack_to_channels_dims)
dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims}
dims["channels"] = variable.sizes.get("channels", 1)
return variable.set_dims(dims)
def dataset_to_stacked(
dataset: xarray.Dataset,
sizes: Optional[Mapping[str, int]] = None,
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
) -> xarray.DataArray:
"""Converts an xarray.Dataset to a single stacked array.
This takes each consistuent data_var, converts it into BHWC layout
using `variable_to_stacked`, then concats them all along the channels axis.
Args:
dataset: An xarray.Dataset.
sizes: Mapping including sizes for any dimensions which are not present in
the `dataset` but are needed for the output. See variable_to_stacked.
preserved_dims: dimensions from the dataset that should not be folded in
the predictions channels.
Returns:
An xarray.DataArray with dimensions preserved_dims + ("channels",).
Existing coordinates for preserved_dims axes will be preserved, however
there will be no coordinates for "channels".
"""
data_vars = [
variable_to_stacked(dataset.variables[name], sizes or dataset.sizes,
preserved_dims)
for name in sorted(dataset.data_vars.keys())
]
coords = {
dim: coord
for dim, coord in dataset.coords.items()
if dim in preserved_dims
}
return xarray.DataArray(
data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords)
def stacked_to_dataset(
stacked_array: xarray.Variable,
template_dataset: xarray.Dataset,
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
) -> xarray.Dataset:
"""The inverse of dataset_to_stacked.
Requires a template dataset to demonstrate the variables/shapes/coordinates
required.
All variables must have preserved_dims dimensions.
Args:
stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked
would if it was asked to encode `template_dataset`.
template_dataset: A template Dataset (or other mapping of DataArrays)
demonstrating the shape of output required (variables, shapes,
coordinates etc).
preserved_dims: dimensions from the target_template that were not folded in
the predictions channels. The preserved_dims need to be a subset of the
dims of all the variables of template_dataset.
Returns:
An xarray.Dataset (or other mapping of DataArrays) with the same shape and
type as template_dataset.
"""
unstack_from_channels_sizes = {}
var_names = sorted(template_dataset.keys())
for name in var_names:
template_var = template_dataset[name]
if not all(dim in template_var.dims for dim in preserved_dims):
raise ValueError(
f"stacked_to_dataset requires all Variables to have {preserved_dims} "
f"dimensions, but found only {template_var.dims}.")
unstack_from_channels_sizes[name] = {
dim: size for dim, size in template_var.sizes.items()
if dim not in preserved_dims}
channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
for name, unstack_sizes in unstack_from_channels_sizes.items()}
total_expected_channels = sum(channels.values())
found_channels = stacked_array.sizes["channels"]
if total_expected_channels != found_channels:
raise ValueError(
f"Expected {total_expected_channels} channels but found "
f"{found_channels}, when trying to convert a stacked array of shape "
f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")
data_vars = {}
index = 0
for name in var_names:
template_var = template_dataset[name]
var = stacked_array.isel({"channels": slice(index, index + channels[name])})
index += channels[name]
var = var.unstack({"channels": unstack_from_channels_sizes[name]})
var = var.transpose(*template_var.dims)
data_vars[name] = xarray.DataArray(
data=var,
coords=template_var.coords,
# This might not always be the same as the name it's keyed under; it
# will refer to the original variable name, whereas the key might be
# some alias e.g. temperature_850 under which it should be logged:
name=template_var.name,
)
return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count
|