import timeit |
from typing import Sequence |
from absl import logging |
from absl.testing import absltest |
from absl.testing import parameterized |
from graphcast import solar_radiation |
import numpy as np |
import pandas as pd |
import xarray as xa |
def _get_grid_lat_lon_coords( |
num_lat: int, num_lon: int |
) -> tuple[np.ndarray, np.ndarray]: |
"""Generates a linear latitude-longitude grid of the given size. |
Args: |
num_lat: Size of the latitude dimension of the grid. |
num_lon: Size of the longitude dimension of the grid. |
Returns: |
A tuple `(lat, lon)` containing 1D arrays with the latitude and longitude |
coordinates in degrees of the generated grid. |
""" |
lat = np.linspace(-90.0, 90.0, num=num_lat, endpoint=True) |
lon = np.linspace(0.0, 360.0, num=num_lon, endpoint=False) |
return lat, lon |
class SolarRadiationTest(parameterized.TestCase): |
def setUp(self): |
super().setUp() |
np.random.seed(0) |
def test_missing_dim_raises_value_error(self): |
data = xa.DataArray( |
np.random.randn(2, 2), |
coords=[np.array([0.1, 0.2]), np.array([0.0, 0.5])], |
dims=["lon", "x"], |
) |
with self.assertRaisesRegex( |
ValueError, r".* dimensions are missing in `data_array_like`." |
): |
solar_radiation.get_toa_incident_solar_radiation_for_xarray( |
data, integration_period="1h", num_integration_bins=360 |
) |
def test_missing_coordinate_raises_value_error(self): |
data = xa.Dataset( |
data_vars={"var1": (["x", "lat", "lon"], np.random.randn(2, 3, 2))}, |
coords={ |
"lat": np.array([0.0, 0.1, 0.2]), |
"lon": np.array([0.0, 0.5]), |
}, |
) |
with self.assertRaisesRegex( |
ValueError, r".* coordinates are missing in `data_array_like`." |
): |
solar_radiation.get_toa_incident_solar_radiation_for_xarray( |
data, integration_period="1h", num_integration_bins=360 |
) |
def test_shape_multiple_timestamps(self): |
data = xa.Dataset( |
data_vars={"var1": (["time", "lat", "lon"], np.random.randn(2, 4, 2))}, |
coords={ |
"lat": np.array([0.0, 0.1, 0.2, 0.3]), |
"lon": np.array([0.0, 0.5]), |
"time": np.array([100, 200], dtype="timedelta64[s]"), |
"datetime": xa.Variable( |
"time", np.array([10, 20], dtype="datetime64[D]") |
), |
}, |
) |
actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray( |
data, integration_period="1h", num_integration_bins=2 |
) |
self.assertEqual(("time", "lat", "lon"), actual.dims) |
self.assertEqual((2, 4, 2), actual.shape) |
def test_shape_single_timestamp(self): |
data = xa.Dataset( |
data_vars={"var1": (["lat", "lon"], np.random.randn(4, 2))}, |
coords={ |
"lat": np.array([0.0, 0.1, 0.2, 0.3]), |
"lon": np.array([0.0, 0.5]), |
"datetime": np.datetime64(10, "D"), |
}, |
) |
actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray( |
data, integration_period="1h", num_integration_bins=2 |
) |
self.assertEqual(("lat", "lon"), actual.dims) |
self.assertEqual((4, 2), actual.shape) |
@parameterized.named_parameters( |
dict( |
testcase_name="one_timestamp_jitted", |
periods=1, |
repeats=3, |
use_jit=True, |
), |
dict( |
testcase_name="one_timestamp_non_jitted", |
periods=1, |
repeats=3, |
use_jit=False, |
), |
dict( |
testcase_name="ten_timestamps_non_jitted", |
periods=10, |
repeats=1, |
use_jit=False, |
), |
) |
def test_full_spatial_resolution( |
self, periods: int, repeats: int, use_jit: bool |
): |
timestamps = pd.date_range(start="2023-09-25", periods=periods, freq="6h") |
lat, lon = _get_grid_lat_lon_coords(num_lat=721, num_lon=1440) |
def benchmark() -> None: |
solar_radiation.get_toa_incident_solar_radiation( |
timestamps, |
lat, |
lon, |
integration_period="1h", |
num_integration_bins=360, |
use_jit=use_jit, |
).block_until_ready() |
results = timeit.repeat(benchmark, repeat=repeats, number=1) |
logging.info( |
"Times to compute `tisr` for input of shape `%d, %d, %d` (seconds): %s", |
len(timestamps), |
len(lat), |
len(lon), |
np.array2string(np.array(results), precision=1), |
) |
class GetTsiTest(parameterized.TestCase): |
@parameterized.named_parameters( |
dict( |
testcase_name="reference_tsi_data", |
loader=solar_radiation.reference_tsi_data, |
expected_tsi=np.array([1361.0]), |
), |
dict( |
testcase_name="era5_tsi_data", |
loader=solar_radiation.era5_tsi_data, |
expected_tsi=np.array([1360.9440]), |
), |
) |
def test_mid_2020_lookup( |
self, loader: solar_radiation.TsiDataLoader, expected_tsi: np.ndarray |
): |
tsi_data = loader() |
tsi = solar_radiation.get_tsi( |
[np.datetime64("2020-07-02T00:00:00")], tsi_data |
) |
np.testing.assert_allclose(expected_tsi, tsi) |
@parameterized.named_parameters( |
dict( |
testcase_name="beginning_2020_left_boundary", |
timestamps=[np.datetime64("2020-01-01T00:00:00")], |
expected_tsi=np.array([1000.0]), |
), |
dict( |
testcase_name="mid_2020_exact", |
timestamps=[np.datetime64("2020-07-02T00:00:00")], |
expected_tsi=np.array([1000.0]), |
), |
dict( |
testcase_name="beginning_2021_interpolated", |
timestamps=[np.datetime64("2021-01-01T00:00:00")], |
expected_tsi=np.array([1150.0]), |
), |
dict( |
testcase_name="mid_2021_lookup", |
timestamps=[np.datetime64("2021-07-02T12:00:00")], |
expected_tsi=np.array([1300.0]), |
), |
dict( |
testcase_name="beginning_2022_interpolated", |
timestamps=[np.datetime64("2022-01-01T00:00:00")], |
expected_tsi=np.array([1250.0]), |
), |
dict( |
testcase_name="mid_2022_lookup", |
timestamps=[np.datetime64("2022-07-02T12:00:00")], |
expected_tsi=np.array([1200.0]), |
), |
dict( |
testcase_name="beginning_2023_right_boundary", |
timestamps=[np.datetime64("2023-01-01T00:00:00")], |
expected_tsi=np.array([1200.0]), |
), |
) |
def test_interpolation( |
self, timestamps: Sequence[np.datetime64], expected_tsi: np.ndarray |
): |
tsi_data = xa.DataArray( |
np.array([1000.0, 1300.0, 1200.0]), |
dims=["time"], |
coords={"time": np.array([2020.5, 2021.5, 2022.5])}, |
) |
tsi = solar_radiation.get_tsi(timestamps, tsi_data) |
np.testing.assert_allclose(expected_tsi, tsi) |
if __name__ == "__main__": |
absltest.main() |