# --------------------------------------------------------------------------------------------------------------------- # # This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu) # the FuXi architecture has been modified to reduce the overall model size # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q] # # Yingkai Sha # ksha@ucar.edu # --------------------------------------------------------------------------------------------------------------------- # save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/' seed: 1000 data: # upper-air variables variables: ['U','V','T','Q'] save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/Sixiourly_y_TOTAL*' # surface variables surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500'] save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL*' # dynamic forcing variables dynamic_forcing_variables: ['tsi'] save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc' # diagnostic variables # diagnostic_variables: ['V500','U500','T500','Z500','Q500'] # save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*' # static variables static_variables: ['Z_GDS4_SFC','LSM'] save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc' # mean / std path mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc' std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc' # train / validation split train_years: [1979, 2018] valid_years: [2018, 2019] # data workflow scaler_type: 'std_new' # number of input states # FuXi has 2 input states history_len: 2 valid_history_len: 2 # number of forecast steps to compute loss # 0 for single step training / validation # larger than 0 for multi-step training / validation forecast_len: 7 valid_forecast_len: 7 # one_shot: True --> compute loss on the last forecast step only # one_shot: False --> compute loss on all forecast steps one_shot: True # 1 for hourly model lead_time_periods: 6 # do not use skip_period skip_periods: null # compatible with the old 'std' static_first: True trainer: type: multi-step # <---------- change to your type mode: fsdp cpu_offload: False activation_checkpoint: True load_weights: True load_optimizer: True load_scaler: True load_sheduler: True skip_validation: False update_learning_rate: False save_backup_weights: True save_best_weights: True learning_rate: 1.0e-06 # <-- change to your lr weight_decay: 0 train_batch_size: 1 valid_batch_size: 1 batches_per_epoch: 759 # full epoch: 1772 valid_batches_per_epoch: 0 stopping_patience: 50 start_epoch: 0 num_epoch: 1 # False when switching from single-step to multi-step reload_epoch: True epochs: &epochs 20 use_scheduler: True scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1} # Automatic Mixed Precision: False amp: False # rescale loss as loss = loss / grad_accum_every grad_accum_every: 1 # gradient clipping grad_max_norm: 1.0 # number of workers thread_workers: 4 valid_thread_workers: 0 model: type: "fuxi" frames: 2 # number of input states image_height: 640 # number of latitude grids image_width: 1280 # number of longitude grids levels: 16 # number of upper-air variable levels channels: 4 # upper-air variable channels surface_channels: 7 # surface variable channels input_only_channels: 3 # dynamic forcing, forcing, static channels output_only_channels: 0 # diagnostic variable channels # patchify layer patch_height: 4 # number of latitude grids in each 3D patch patch_width: 4 # number of longitude grids in each 3D patch frame_patch_size: 2 # number of input states in each 3D patch # hidden layers dim: 1024 # dimension (default: 1536) num_groups: 32 # number of groups (default: 32) num_heads: 8 # number of heads (default: 8) window_size: 7 # window size (default: 7) depth: 16 # number of swin transformers (default: 48) # use spectral norm use_spectral_norm: True # ============================================================== # # New # use interpolation to match the output size interp: True # map boundary padding padding_conf: activate: True mode: earth pad_lat: 80 pad_lon: 80 post_conf: activate: True tracer_fixer: activate: True denorm: True tracer_name: ['Q', 'Q500'] tracer_thres: [1e-8, 1e-8] loss: # the main training loss training_loss: "mse" # power loss (x), spectral_loss (x) use_power_loss: False use_spectral_loss: False # use latitude weighting use_latitude_weights: True latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc" # turn-off variable weighting use_variable_weights: False # variable_weights: # U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] # V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] # T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] # Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005] # SP: 0.1 # t2m: 1.0 # V500: 0.1 # U500: 0.1 # T500: 0.1 # Z500: 0.1 # Q500: 0.1 predict: forecasts: type: "custom" # keep it as "custom" start_year: 2020 # year of the first initialization (where rollout will start) start_month: 1 # month of the first initialization start_day: 1 # day of the first initialization start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z duration: 30 # number of days to initialize, starting from the (year, mon, day) above # duration should be divisible by the number of GPUs # (e.g., duration: 384 for 365-day rollout using 32 GPUs) days: 2 # forecast lead time as days (1 means 24-hour forecast) save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/' save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500'] # turn-off low-pass filter use_laplace_filter: False # deprecated # save_format: "nc" pbs: #derecho conda: "/glade/work/ksha/miniconda3/envs/credit" project: "NAML0001" job_name: "fuxi_6h" walltime: "12:00:00" nodes: 8 ncpus: 64 ngpus: 4 mem: '480GB' queue: 'main'