Climate-ML-Foundation-Models / aurora_utils.py
qq1990's picture
roll back
4df22b4
import streamlit as st
import torch
from aurora import Aurora, Batch, Metadata
import numpy as np
from datetime import datetime
def aurora_config_ui():
st.subheader("Aurora Model Data Input")
# Detailed data description section
st.markdown("""
**Available Models & Usage:**
Aurora provides several pretrained and fine-tuned models at 0.25° and 0.1° resolutions.
Models and weights are available through the HuggingFace repository: [microsoft/aurora](https://huggingface.co/microsoft/aurora).
**Aurora 0.25° Pretrained**
- Trained on a variety of data.
- Suitable if no fine-tuned version exists for your dataset or to fine-tune Aurora yourself.
- Use if your dataset is ERA5 at 0.25° resolution (721x1440).
**Aurora 0.25° Pretrained Small**
- A smaller version of the pretrained model for debugging purposes.
**Aurora 0.25° Fine-Tuned**
- Fine-tuned on IFS HRES T0.
- Best performance at 0.25° but should only be used for IFS HRES T0 data.
- May not give optimal results for other datasets.
**Aurora 0.1° Fine-Tuned**
- For IFS HRES T0 at 0.1° resolution (1801x3600).
- Best performing at 0.1° resolution.
- Data must match IFS HRES T0 conditions.
**Required Variables & Pressure Levels:**
For all Aurora models at these resolutions, the following inputs are required:
- **Surface-level variables:** 2t, 10u, 10v, msl
- **Static variables:** lsm, slt, z
- **Atmospheric variables:** t, u, v, q, z
- **Pressure levels (hPa):** 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000
Latitude range should decrease from 90°N to -90°S, and longitude range from 0° to 360° (excluding 360°). Data should be in single precision float32.
**Data Format (Batch):**
Data should be provided as a `aurora.Batch` object:
- `surf_vars` dict with shape (b, t, h, w)
- `static_vars` dict with shape (h, w)
- `atmos_vars` dict with shape (b, t, c, h, w)
- `metadata` containing lat, lon, time, and atmos_levels.
For detailed instructions and examples, refer to the official Aurora documentation and code repository.
""")
# File uploader for Aurora data
st.markdown("### Upload Your Input Data Files for Aurora")
st.markdown("Upload the NetCDF files (e.g., `.nc`, `.netcdf`, `.nc4`) containing the required variables.")
uploaded_files = st.file_uploader(
"Drag and drop or select multiple .nc files",
accept_multiple_files=True,
key="aurora_uploader",
type=["nc", "netcdf", "nc4"]
)
st.markdown("---")
st.markdown("### References & Resources")
st.markdown("""
- **HuggingFace Repository:** [microsoft/aurora](https://huggingface.co/microsoft/aurora)
- **Model Usage Examples:**
```python
from aurora import Aurora
model = Aurora()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
```
- **API & Documentation:** Refer to the Aurora official GitHub and HuggingFace pages for detailed instructions.
""")
return uploaded_files
def prepare_aurora_batch(ds):
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Ensure that the 'lev' dimension exists
if 'lev' not in ds.dims:
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
# Define the _prepare function
def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
# Select previous and current time steps
selected = x[[i - 6, i]]
# Add a batch dimension
selected = selected[None]
# Ensure data is contiguous
selected = selected.copy()
# Convert to PyTorch tensor
return torch.from_numpy(selected)
# Adjust latitudes and longitudes
lat = ds.lat.values * -1
lon = ds.lon.values + 180
# Subset the dataset to only include the desired pressure levels
ds_subset = ds.sel(lev=desired_levels, method="nearest")
# Verify that all desired levels are present
present_levels = ds_subset.lev.values
missing_levels = set(desired_levels) - set(present_levels)
if missing_levels:
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
# Extract pressure levels after subsetting
lev = ds_subset.lev.values # Pressure levels in hPa
# Prepare surface variables at 1000 hPa
try:
lev_index_1000 = np.where(lev == 1000)[0][0]
except IndexError:
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
SLP = ds_subset.SLP.compute()
# Reorder static variables (selecting the first time index to remove the time dimension)
PHIS = ds_subset.PHIS.isel(time=0).compute()
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
atmos_levels = [int(level) for level in lev if level != 1000]
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
# Select time index
num_times = ds_subset.time.size
i = 6 # Adjust as needed (1 <= i < num_times)
if i >= num_times or i < 1:
raise IndexError("Time index i is out of bounds.")
time_values = ds_subset.time.values
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
# Prepare surface variables
surf_vars = {
"2t": _prepare(T_surface.values, i), # Two-meter temperature
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind
"msl": _prepare(SLP.values, i), # Mean sea-level pressure
}
# Prepare static variables (now 2D tensors)
static_vars = {
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
# Add 'lsm' and 'slt' if available and needed
}
# Prepare atmospheric variables
atmos_vars = {
"t": _prepare(T_atm.values, i), # Temperature at desired levels
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels
"v": _prepare(V_atm.values, i), # Southward wind at desired levels
}
# Define metadata
metadata = Metadata(
lat=torch.from_numpy(lat.copy()),
lon=torch.from_numpy(lon.copy()),
time=(current_time,),
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
)
# Create the Batch object
batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=metadata
) # Display the dataset or perform further processing
return batch
def initialize_aurora_model(device):
model = Aurora(use_lora=False)
# Load pretrained checkpoint if available
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model = model.to(device)
return model