fuxi_6h / model_predict_cpu.yml
djgagne's picture
Added model files
d6d123e
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# the FuXi architecture has been modified to reduce the overall model size
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
#
# Yingkai Sha
# [email protected]
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
seed: 1000
data:
# upper-air variables
variables: ['U','V','T','Q']
save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# surface variables
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# dynamic forcing variables
dynamic_forcing_variables: ['tsi']
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
# static variables
static_variables: ['Z_GDS4_SFC','LSM']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
# mean / std path
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
# train / validation split
train_years: [1979, 2018]
valid_years: [2018, 2019]
# data workflow
scaler_type: 'std_new'
# number of input states
# FuXi has 2 input states
history_len: 2
valid_history_len: 2
# number of forecast steps to compute loss
# 0 for single step training / validation
# larger than 0 for multi-step training / validation
forecast_len: 0
valid_forecast_len: 0
# 1 for hourly model
lead_time_periods: 6
# do not use skip_period
skip_periods: null
# compatible with the old 'std'
static_first: True
trainer:
type: standard
mode: none
model:
type: "fuxi"
frames: 2 # number of input states
image_height: 640 # number of latitude grids
image_width: 1280 # number of longitude grids
levels: 16 # number of upper-air variable levels
channels: 4 # upper-air variable channels
surface_channels: 7 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
output_only_channels: 0 # diagnostic variable channels
# patchify layer
patch_height: 4 # number of latitude grids in each 3D patch
patch_width: 4 # number of longitude grids in each 3D patch
frame_patch_size: 2 # number of input states in each 3D patch
# hidden layers
dim: 1024 # dimension (default: 1536)
num_groups: 32 # number of groups (default: 32)
num_heads: 8 # number of heads (default: 8)
window_size: 7 # window size (default: 7)
depth: 16 # number of swin transformers (default: 48)
# map boundary padding
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
# use spectral norm
use_spectral_norm: True
loss:
use_latitude_weights: True
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
predict:
forecasts:
type: "custom" # keep it as "custom"
start_year: 2020 # year of the first initialization (where rollout will start)
start_month: 1 # month of the first initialization
start_day: 1 # day of the first initialization
start_hours: [0,] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
duration: 1 # number of days to initialize, starting from the (year, mon, day) above
# duration should be divisible by the number of GPUs
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
days: 1 # forecast lead time as days (1 means 24-hour forecast)
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/RAW_OUTPUT/fuxi_6h_collins/'
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
# turn-off low-pass filter
use_laplace_filter: False