fuxi_6h / model_single.yml
djgagne's picture
Added model files
d6d123e
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# the FuXi architecture has been modified to reduce the overall model size
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
#
# Yingkai Sha
# [email protected]
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
seed: 1000
data:
# upper-air variables
variables: ['U','V','T','Q']
save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# surface variables
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# dynamic forcing variables
dynamic_forcing_variables: ['tsi']
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
# diagnostic variables
# diagnostic_variables: ['V500','U500','T500','Z500','Q500']
# save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
# static variables
static_variables: ['Z_GDS4_SFC','LSM']
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
# mean / std path
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
# train / validation split
train_years: [1979, 2018]
valid_years: [2018, 2019]
# data workflow
scaler_type: 'std_new'
# number of input states
# FuXi has 2 input states
history_len: 2
valid_history_len: 2
# number of forecast steps to compute loss
# 0 for single step training / validation
# larger than 0 for multi-step training / validation
forecast_len: 0
valid_forecast_len: 0
# one_shot: True --> compute loss on the last forecast step only
# one_shot: False --> compute loss on all forecast steps
one_shot: True
# 1 for hourly model
lead_time_periods: 6
# do not use skip_period
skip_periods: null
# compatible with the old 'std'
static_first: True
trainer:
type: standard # <---------- change to your type
mode: fsdp
cpu_offload: False
activation_checkpoint: True
load_weights: True
load_optimizer: True
load_scaler: True
load_sheduler: True
skip_validation: False
update_learning_rate: False
save_backup_weights: True
save_best_weights: True
learning_rate: 1.0e-03 # <-- change to your lr
weight_decay: 0
train_batch_size: 1
valid_batch_size: 1
batches_per_epoch: 0
valid_batches_per_epoch: 0
stopping_patience: 50
start_epoch: 0
num_epoch: 2
reload_epoch: True
epochs: &epochs 70
use_scheduler: True
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
# Automatic Mixed Precision: False
amp: False
# rescale loss as loss = loss / grad_accum_every
grad_accum_every: 1
# gradient clipping
grad_max_norm: 1.0
# number of workers
thread_workers: 4
valid_thread_workers: 0
model:
type: "fuxi"
frames: 2 # number of input states
image_height: 640 # number of latitude grids
image_width: 1280 # number of longitude grids
levels: 16 # number of upper-air variable levels
channels: 4 # upper-air variable channels
surface_channels: 7 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
output_only_channels: 0 # diagnostic variable channels
# patchify layer
patch_height: 4 # number of latitude grids in each 3D patch
patch_width: 4 # number of longitude grids in each 3D patch
frame_patch_size: 2 # number of input states in each 3D patch
# hidden layers
dim: 1024 # dimension (default: 1536)
num_groups: 32 # number of groups (default: 32)
num_heads: 8 # number of heads (default: 8)
window_size: 7 # window size (default: 7)
depth: 16 # number of swin transformers (default: 48)
# map boundary padding
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
# use spectral norm
use_spectral_norm: True
loss:
# the main training loss
training_loss: "mse"
# power loss (x), spectral_loss (x)
use_power_loss: False
use_spectral_loss: False
# use latitude weighting
use_latitude_weights: True
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
# turn-off variable weighting
use_variable_weights: False
# variable_weights:
# U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
# SP: 0.1
# t2m: 1.0
# V500: 0.1
# U500: 0.1
# T500: 0.1
# Z500: 0.1
# Q500: 0.1
predict:
forecasts:
type: "custom" # keep it as "custom"
start_year: 2020 # year of the first initialization (where rollout will start)
start_month: 1 # month of the first initialization
start_day: 1 # day of the first initialization
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
duration: 30 # number of days to initialize, starting from the (year, mon, day) above
# duration should be divisible by the number of GPUs
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
days: 2 # forecast lead time as days (1 means 24-hour forecast)
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
# turn-off low-pass filter
use_laplace_filter: False
# deprecated
# save_format: "nc"
pbs: #derecho
conda: "/glade/work/ksha/miniconda3/envs/credit"
project: "NAML0001"
job_name: "fuxi_6h"
walltime: "12:00:00"
nodes: 8
ncpus: 64
ngpus: 4
mem: '480GB'
queue: 'main'