Added model files
Browse files- README.md +22 -3
- backup_checkpoint.pt +3 -0
- backup_final/best_checkpoint.pt +3 -0
- backup_final/best_model_checkpoint.pt +3 -0
- backup_final/best_optimizer_checkpoint.pt +3 -0
- backup_final/checkpoint.pt +3 -0
- backup_final/model_checkpoint.pt +3 -0
- backup_final/optimizer_checkpoint.pt +3 -0
- backup_final/training_log.csv +19 -0
- backup_model_checkpoint.pt +3 -0
- backup_optimizer_checkpoint.pt +3 -0
- best_checkpoint.pt +3 -0
- best_model_checkpoint.pt +3 -0
- best_optimizer_checkpoint.pt +3 -0
- checkpoint.pt +3 -0
- l.sh +44 -0
- launch_multi.sh +46 -0
- launch_predict.sh +45 -0
- launch_rollout.sh +45 -0
- launch_single.sh +46 -0
- model.yml +0 -0
- model_checkpoint.pt +3 -0
- model_multi.yml +224 -0
- model_predict.yml +136 -0
- model_predict_cpu.yml +116 -0
- model_single.yml +205 -0
- model_single_cached.yml +200 -0
- optimizer_checkpoint.pt +3 -0
README.md
CHANGED
@@ -1,3 +1,22 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
tags:
|
6 |
+
- weather
|
7 |
+
- climate
|
8 |
+
- global
|
9 |
+
---
|
10 |
+
|
11 |
+
# NSF NCAR Community Research Earth Digital Intelligence Twin (CREDIT) FuXi 6-Hour Model Weights and Configuration
|
12 |
+
This repository contains the PyTorch checkpoint weights and data/model configuration files for the CREDIT WXFormer 6-hour model.
|
13 |
+
More information about the training and verification of this model can be found in the Schreck et al. (2024) ArXiv [preprint](https://arxiv.org/abs/2411.07814).
|
14 |
+
|
15 |
+
## Data Access
|
16 |
+
Our model is trained on ERA5 Reanalysis Data from a subset of 16 of the 137 hybrid sigma-pressure vertical levels.
|
17 |
+
The raw data are available on the NSF NCAR [Research Data Archive](https://rda.ucar.edu/datasets/d633006/).
|
18 |
+
Processed data can be accessed on the [CREDIT ERA5 Zarr Files](https://app.globus.org/file-manager?origin_id=2fc90d8f-10b7-44e1-a6a5-cf844112822e&origin_path=%2F) globus collection.
|
19 |
+
|
20 |
+
## Running the Model
|
21 |
+
To run the model, first install the [CREDIT package](https://github.com/NCAR/miles-credit) from github.
|
22 |
+
Modify the paths in the `finetune_final/model_predict.yml` configuration file to point to the appropriate ERA5 and scaler directories.
|
backup_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f23a63c6843906694ae529948df4d4ac347b4a103eafe7bf939331257fbb921
|
3 |
+
size 1260
|
backup_final/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ce0b399e07427b2ad9bd4d44d03801866f58bc236229341cdf2bce1dfcdd2551
|
3 |
+
size 1132
|
backup_final/best_model_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3333a5bca6a8ca87285866c05998ffe00bd726156a70814b208dcfab0f9c86d8
|
3 |
+
size 1044714782
|
backup_final/best_optimizer_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a8fcbe3d191c3999d8652133c32ca2487924ff34a5a8238947b0e0ac795c1d78
|
3 |
+
size 1683978368
|
backup_final/checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b25c034739e99cc1ad34d56332f3f29a00a30086874e631ffaca50294c0b663
|
3 |
+
size 1132
|
backup_final/model_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af8ba6f9618c1a075b21c0fa6d741e511a933caac4466f74b7c5b88c3fdc1595
|
3 |
+
size 1044714782
|
backup_final/optimizer_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c304aa2c70e1b98e5c88c0a4ed76024e6b7ce0681d1fb94e5cfb40d264ddde49
|
3 |
+
size 1683978368
|
backup_final/training_log.csv
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
index,epoch,train_forecast_len,valid_forecast_len,train_loss,valid_loss,train_acc,valid_acc,train_mae,valid_mae,lr
|
2 |
+
0,0,8.0,,0.1387606938719278,0.1108131468296051,0.970562528166061,0.9694464921951294,0.1099035362854148,0.1114028289914131,1e-06
|
3 |
+
1,1,8.0,,0.1377466195342726,0.1100945279002189,0.9708313843792448,0.9694475412368776,0.1095289197089015,0.111414648592472,9.938441702975689e-07
|
4 |
+
2,2,8.0,,0.1385083928684629,0.1117167651653289,0.970662198676108,0.9694490909576416,0.1096920592413432,0.1113834738731384,9.755282581475769e-07
|
5 |
+
3,3,8.0,,0.1384558184501839,0.1127898842096328,0.9706702200948644,0.969444227218628,0.1096835084190795,0.1113786697387695,9.45503262094184e-07
|
6 |
+
4,4,8.0,,0.1384486823842145,0.1118623703718185,0.9707273427048534,0.9694671511650086,0.1096120545829551,0.1113623306155204,9.045084971874738e-07
|
7 |
+
5,5,8.0,,0.1381992981407168,0.1129628434777259,0.9707307068726762,0.9694624781608582,0.1096099303527311,0.1113620564341545,8.535533905932738e-07
|
8 |
+
6,6,8.0,,0.1391399535737018,0.1127516105771064,0.970510085502318,0.9694764852523804,0.1099141280661301,0.11133803576231,7.938926261462367e-07
|
9 |
+
7,7,8.0,,0.1381134882824537,0.1126476496458053,0.970679276152876,0.9694833755493164,0.1096669749230421,0.1113172695040702,7.269952498697736e-07
|
10 |
+
8,8,8.0,,0.1384780509630525,0.1114308580756187,0.970642456109973,0.9694918870925904,0.1097221110256607,0.1112886816263198,6.54508497187474e-07
|
11 |
+
9,9,8.0,,0.1387791881958643,0.109854482114315,0.9706856229088524,0.9695030689239502,0.1096888491503491,0.1112985372543335,5.782172325201157e-07
|
12 |
+
10,10,8.0,,0.1381753478910926,0.1100143820047378,0.9706785721269992,0.9695019960403444,0.1096319330461097,0.1112829759716987,5.000000000000002e-07
|
13 |
+
11,11,8.0,,0.1397529813965318,0.1093554720282554,0.9704045921917488,0.9695140957832336,0.1100279053032633,0.1112509205937385,4.217827674798849e-07
|
14 |
+
12,12,8.0,,0.1388648276822212,0.1110095202922821,0.970521897509478,0.969509732723236,0.1098963392895986,0.1112472981214523,3.454915028125265e-07
|
15 |
+
13,13,8.0,,0.139924709661833,0.1111257791519165,0.9704050614899796,0.9695259213447572,0.1100835978435119,0.111248242855072,2.730047501302268e-07
|
16 |
+
14,14,8.0,,0.1386912109745035,0.1115610241889953,0.9705360937809598,0.9695113182067872,0.1098389108229688,0.1112457856535911,2.061073738537636e-07
|
17 |
+
15,15,8.0,,0.1394790709411359,0.1099427804350852,0.9703982228975208,0.9695213079452516,0.1101159101182764,0.1112276285886764,1.4644660940672637e-07
|
18 |
+
16,16,8.0,,0.1392562538778043,0.1106211960315704,0.9705332976241984,0.9695265769958497,0.1098723398013548,0.1112342774868011,9.549150281252641e-08
|
19 |
+
17,17,8.0,,0.1382005311872648,0.11191920638084411,0.9706875172214232,0.9695238590240478,0.10971980983678531,0.11123556792736053,5.449673790581614e-08
|
backup_model_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63ce28c57eb54f0e5d4a4229cfe9dfc26327f7812a73c5ac0eb657fc617335c3
|
3 |
+
size 1044714782
|
backup_optimizer_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc8d079d97419fa481b7bf2cda4b24a33fb43c5f3b357940b3e017f9b95910a7
|
3 |
+
size 1683978368
|
best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:521fa1b735f45041505f3d1fccfae8b7d18625e5a59f52069452a9cab31f183e
|
3 |
+
size 1260
|
best_model_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1211699e85d387cf89e913a0b86d92e58986e8222322ef1cb1fb7dc30c23a790
|
3 |
+
size 1044714782
|
best_optimizer_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca87b2643346de107cd09e11fe130fbaa33b8669b139c0cc462b65dcc41c180e
|
3 |
+
size 1683978368
|
checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97de0e29c19c96993c73f700e695f0bf31ebdf7ba4297c3bda664e3e317a224a
|
3 |
+
size 1260
|
l.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#PBS -A NAML0001
|
3 |
+
#PBS -N wxformer_6h
|
4 |
+
#PBS -l walltime=12:00:00
|
5 |
+
#PBS -l select=8:ncpus=64:ngpus=4:mem=480GB
|
6 |
+
#PBS -q main
|
7 |
+
#PBS -j oe
|
8 |
+
#PBS -k eod
|
9 |
+
# Load modules
|
10 |
+
module purge
|
11 |
+
module load gcc craype cray-mpich cuda cudnn/8.8.1.3-12 conda
|
12 |
+
conda activate /glade/u/home/schreck/.conda/envs/credit-derecho
|
13 |
+
# Export environment variables
|
14 |
+
export LSCRATCH=/glade/derecho/scratch/schreck/
|
15 |
+
export LOGLEVEL=INFO
|
16 |
+
export NCCL_DEBUG=INFO
|
17 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
18 |
+
export NCCL_SOCKET_IFNAME=hsn
|
19 |
+
export MPICH_GPU_MANAGED_MEMORY_SUPPORT_ENABLED=1
|
20 |
+
export MPICH_OFI_NIC_POLICY=GPU
|
21 |
+
export MPICH_GPU_SUPPORT_ENABLED=1
|
22 |
+
export NCCL_IB_DISABLE=1
|
23 |
+
export NCCL_CROSS_NIC=1
|
24 |
+
export NCCL_NCHANNELS_PER_NET_PEER=4
|
25 |
+
export MPICH_RDMA_ENABLED_CUDA=1
|
26 |
+
export NCCL_NET="AWS Libfabric"
|
27 |
+
export NCCL_NET_GDR_LEVEL=PBH
|
28 |
+
export FI_CXI_DISABLE_HOST_REGISTER=1
|
29 |
+
export FI_CXI_OPTIMIZED_MRS=false
|
30 |
+
export FI_MR_CACHE_MONITOR=userfaultfd
|
31 |
+
export FI_CXI_DEFAULT_CQ_SIZE=131072
|
32 |
+
# logger.info the results
|
33 |
+
echo "Number of nodes: 8"
|
34 |
+
echo "Number of GPUs per node: 4"
|
35 |
+
echo "Total number of GPUs: 32"
|
36 |
+
# Log in to WandB if needed
|
37 |
+
# wandb login 02d2b1af00b5df901cb2bee071872de774781520
|
38 |
+
# Launch MPIs
|
39 |
+
nodes=( $( cat $PBS_NODEFILE ) )
|
40 |
+
echo nodes: $nodes
|
41 |
+
# Find headnode's IP:
|
42 |
+
head_node=${nodes[0]}
|
43 |
+
head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
|
44 |
+
MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -n 32 --ppn 4 --cpu-bind none python /glade/derecho/scratch/schreck/CREDIT_runs/test_ben_env/miles-credit/applications/train.py -c model.yml --backend nccl
|
launch_multi.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#PBS -A NCIS0010
|
3 |
+
#PBS -N fx6h_multi
|
4 |
+
#PBS -l walltime=12:00:00
|
5 |
+
#PBS -l select=8:ncpus=64:ngpus=4
|
6 |
+
#PBS -q main
|
7 |
+
#PBS -j oe
|
8 |
+
#PBS -k eod
|
9 |
+
#PBS -r n
|
10 |
+
# Load modules
|
11 |
+
module purge
|
12 |
+
module load gcc craype cray-mpich cuda cudnn/8.8.1.3-12 conda
|
13 |
+
conda activate /glade/work/ksha/miniconda3/envs/credit-derecho
|
14 |
+
# conda conda activate /glade/u/home/schreck/.conda/envs/credit-derecho
|
15 |
+
# Export environment variables
|
16 |
+
export LSCRATCH=/glade/derecho/scratch/ksha/
|
17 |
+
export LOGLEVEL=INFO
|
18 |
+
export NCCL_DEBUG=INFO
|
19 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
20 |
+
export NCCL_SOCKET_IFNAME=hsn
|
21 |
+
export MPICH_GPU_MANAGED_MEMORY_SUPPORT_ENABLED=1
|
22 |
+
export MPICH_OFI_NIC_POLICY=GPU
|
23 |
+
export MPICH_GPU_SUPPORT_ENABLED=1
|
24 |
+
export NCCL_IB_DISABLE=1
|
25 |
+
export NCCL_CROSS_NIC=1
|
26 |
+
export NCCL_NCHANNELS_PER_NET_PEER=4
|
27 |
+
export MPICH_RDMA_ENABLED_CUDA=1
|
28 |
+
export NCCL_NET="AWS Libfabric"
|
29 |
+
export NCCL_NET_GDR_LEVEL=PBH
|
30 |
+
export FI_CXI_DISABLE_HOST_REGISTER=1
|
31 |
+
export FI_CXI_OPTIMIZED_MRS=false
|
32 |
+
export FI_MR_CACHE_MONITOR=userfaultfd
|
33 |
+
export FI_CXI_DEFAULT_CQ_SIZE=131072
|
34 |
+
# logger.info the results
|
35 |
+
echo "Number of nodes: 8"
|
36 |
+
echo "Number of GPUs per node: 4"
|
37 |
+
echo "Total number of GPUs: 32"
|
38 |
+
# Log in to WandB if needed
|
39 |
+
# wandb login 02d2b1af00b5df901cb2bee071872de774781520
|
40 |
+
# Launch MPIs
|
41 |
+
nodes=( $( cat $PBS_NODEFILE ) )
|
42 |
+
echo nodes: $nodes
|
43 |
+
# Find headnode's IP:
|
44 |
+
head_node=${nodes[0]}
|
45 |
+
head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
|
46 |
+
MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -n 32 --ppn 4 --cpu-bind none python /glade/u/home/ksha/miles-credit/applications/train_multistep.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_multi.yml --backend nccl
|
launch_predict.sh
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#PBS -A NCIS0010
|
3 |
+
#PBS -N fx6h_pred
|
4 |
+
#PBS -l walltime=12:00:00
|
5 |
+
#PBS -l select=8:ncpus=64:ngpus=4
|
6 |
+
#PBS -q main
|
7 |
+
#PBS -j oe
|
8 |
+
#PBS -k eod
|
9 |
+
#PBS -r n
|
10 |
+
# Load modules
|
11 |
+
module purge
|
12 |
+
module load nvhpc cuda cray-mpich conda
|
13 |
+
conda activate /glade/work/ksha/miniconda3/envs/credit
|
14 |
+
# Get a list of allocated nodes
|
15 |
+
nodes=( $( cat $PBS_NODEFILE ) )
|
16 |
+
head_node=${nodes[0]}
|
17 |
+
head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
|
18 |
+
# Export environment variables
|
19 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
20 |
+
export LSCRATCH=/glade/derecho/scratch/ksha/
|
21 |
+
export LOGLEVEL=INFO
|
22 |
+
#export NCCL_DEBUG=INFO
|
23 |
+
|
24 |
+
export NCCL_SOCKET_IFNAME=hsn
|
25 |
+
export NCCL_HOME=/glade/u/home/dhoward/work/nccl-ofi-plugin/install
|
26 |
+
export LD_LIBRARY_PATH=$NCCL_HOME/lib:$NCCL_HOME/plugin/lib:$LD_LIBRARY_PATH
|
27 |
+
|
28 |
+
export NCCL_NCHANNELS_PER_NET_PEER=4
|
29 |
+
export MPICH_GPU_SUPPORT_ENABLED=1
|
30 |
+
export MPICH_OFI_NIC_POLICY=GPU
|
31 |
+
export MPICH_RDMA_ENABLED_CUDA=1
|
32 |
+
export NCCL_DISABLE_IB=1
|
33 |
+
export NCCL_CROSS_NIC=1
|
34 |
+
export FI_CXI_DISABLE_HOST_REGISTER=1
|
35 |
+
export FI_CXI_OPTIMIZED_MRS=false
|
36 |
+
|
37 |
+
# Print the results
|
38 |
+
echo "Number of nodes: 8"
|
39 |
+
echo "Number of GPUs per node: 4"
|
40 |
+
echo "Total number of GPUs: 32"
|
41 |
+
# Log in to WandB if needed
|
42 |
+
# wandb login 02d2b1af00b5df901cb2bee071872de774781520
|
43 |
+
|
44 |
+
# Launch MPIs
|
45 |
+
mpiexec -n 8 --ppn 1 --cpu-bind none torchrun --nnodes=8 --nproc-per-node=4 --rdzv-backend=c10d --rdzv-endpoint=$head_node_ip /glade/u/home/ksha/miles-credit/applications/rollout_to_netcdf.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_predict.yml
|
launch_rollout.sh
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#PBS -A NAML0001
|
3 |
+
#PBS -N fx6h_roll
|
4 |
+
#PBS -l walltime=12:00:00
|
5 |
+
#PBS -l select=8:ncpus=64:ngpus=4
|
6 |
+
#PBS -q main
|
7 |
+
#PBS -j oe
|
8 |
+
#PBS -k eod
|
9 |
+
#PBS -r n
|
10 |
+
# Load modules
|
11 |
+
module purge
|
12 |
+
module load nvhpc cuda cray-mpich conda
|
13 |
+
conda activate /glade/work/ksha/miniconda3/envs/credit
|
14 |
+
# Get a list of allocated nodes
|
15 |
+
nodes=( $( cat $PBS_NODEFILE ) )
|
16 |
+
head_node=${nodes[0]}
|
17 |
+
head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
|
18 |
+
# Export environment variables
|
19 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
20 |
+
export LSCRATCH=/glade/derecho/scratch/ksha/
|
21 |
+
export LOGLEVEL=INFO
|
22 |
+
#export NCCL_DEBUG=INFO
|
23 |
+
|
24 |
+
export NCCL_SOCKET_IFNAME=hsn
|
25 |
+
export NCCL_HOME=/glade/u/home/dhoward/work/nccl-ofi-plugin/install
|
26 |
+
export LD_LIBRARY_PATH=$NCCL_HOME/lib:$NCCL_HOME/plugin/lib:$LD_LIBRARY_PATH
|
27 |
+
|
28 |
+
export NCCL_NCHANNELS_PER_NET_PEER=4
|
29 |
+
export MPICH_GPU_SUPPORT_ENABLED=1
|
30 |
+
export MPICH_OFI_NIC_POLICY=GPU
|
31 |
+
export MPICH_RDMA_ENABLED_CUDA=1
|
32 |
+
export NCCL_DISABLE_IB=1
|
33 |
+
export NCCL_CROSS_NIC=1
|
34 |
+
export FI_CXI_DISABLE_HOST_REGISTER=1
|
35 |
+
export FI_CXI_OPTIMIZED_MRS=false
|
36 |
+
|
37 |
+
# Print the results
|
38 |
+
echo "Number of nodes: 8"
|
39 |
+
echo "Number of GPUs per node: 4"
|
40 |
+
echo "Total number of GPUs: 32"
|
41 |
+
# Log in to WandB if needed
|
42 |
+
# wandb login 02d2b1af00b5df901cb2bee071872de774781520
|
43 |
+
|
44 |
+
# Launch MPIs
|
45 |
+
mpiexec -n 8 --ppn 1 --cpu-bind none torchrun --nnodes=8 --nproc-per-node=4 --rdzv-backend=c10d --rdzv-endpoint=$head_node_ip /glade/u/home/ksha/miles-credit/applications/rollout_metrics.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_predict.yml
|
launch_single.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#PBS -A NCIS0010
|
3 |
+
#PBS -N fuxi_6h
|
4 |
+
#PBS -l walltime=12:00:00
|
5 |
+
#PBS -l select=8:ncpus=64:ngpus=4
|
6 |
+
#PBS -q main
|
7 |
+
#PBS -j oe
|
8 |
+
#PBS -k eod
|
9 |
+
#PBS -r n
|
10 |
+
# Load modules
|
11 |
+
module purge
|
12 |
+
module load gcc craype cray-mpich cuda cudnn/8.8.1.3-12 conda
|
13 |
+
conda activate /glade/work/ksha/miniconda3/envs/credit-derecho
|
14 |
+
# conda conda activate /glade/u/home/schreck/.conda/envs/credit-derecho
|
15 |
+
# Export environment variables
|
16 |
+
export LSCRATCH=/glade/derecho/scratch/ksha/
|
17 |
+
export LOGLEVEL=INFO
|
18 |
+
export NCCL_DEBUG=INFO
|
19 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
20 |
+
export NCCL_SOCKET_IFNAME=hsn
|
21 |
+
export MPICH_GPU_MANAGED_MEMORY_SUPPORT_ENABLED=1
|
22 |
+
export MPICH_OFI_NIC_POLICY=GPU
|
23 |
+
export MPICH_GPU_SUPPORT_ENABLED=1
|
24 |
+
export NCCL_IB_DISABLE=1
|
25 |
+
export NCCL_CROSS_NIC=1
|
26 |
+
export NCCL_NCHANNELS_PER_NET_PEER=4
|
27 |
+
export MPICH_RDMA_ENABLED_CUDA=1
|
28 |
+
export NCCL_NET="AWS Libfabric"
|
29 |
+
export NCCL_NET_GDR_LEVEL=PBH
|
30 |
+
export FI_CXI_DISABLE_HOST_REGISTER=1
|
31 |
+
export FI_CXI_OPTIMIZED_MRS=false
|
32 |
+
export FI_MR_CACHE_MONITOR=userfaultfd
|
33 |
+
export FI_CXI_DEFAULT_CQ_SIZE=131072
|
34 |
+
# logger.info the results
|
35 |
+
echo "Number of nodes: 8"
|
36 |
+
echo "Number of GPUs per node: 4"
|
37 |
+
echo "Total number of GPUs: 32"
|
38 |
+
# Log in to WandB if needed
|
39 |
+
# wandb login 02d2b1af00b5df901cb2bee071872de774781520
|
40 |
+
# Launch MPIs
|
41 |
+
nodes=( $( cat $PBS_NODEFILE ) )
|
42 |
+
echo nodes: $nodes
|
43 |
+
# Find headnode's IP:
|
44 |
+
head_node=${nodes[0]}
|
45 |
+
head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
|
46 |
+
MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -n 32 --ppn 4 --cpu-bind none python /glade/u/home/ksha/miles-credit/applications/train.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_single.yml --backend nccl
|
model.yml
ADDED
File without changes
|
model_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3137668291b82d20baf4b412769f6782a62db0a75cdcd213c7f03e02fba027ab
|
3 |
+
size 1044714782
|
model_multi.yml
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
2 |
+
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
|
3 |
+
# the FuXi architecture has been modified to reduce the overall model size
|
4 |
+
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
|
5 |
+
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
|
6 |
+
#
|
7 |
+
# Yingkai Sha
|
8 | |
9 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
10 |
+
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
|
11 |
+
seed: 1000
|
12 |
+
|
13 |
+
data:
|
14 |
+
# upper-air variables
|
15 |
+
variables: ['U','V','T','Q']
|
16 |
+
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/Sixiourly_y_TOTAL*'
|
17 |
+
|
18 |
+
# surface variables
|
19 |
+
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
20 |
+
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL*'
|
21 |
+
|
22 |
+
# dynamic forcing variables
|
23 |
+
dynamic_forcing_variables: ['tsi']
|
24 |
+
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
|
25 |
+
|
26 |
+
# diagnostic variables
|
27 |
+
# diagnostic_variables: ['V500','U500','T500','Z500','Q500']
|
28 |
+
# save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
|
29 |
+
|
30 |
+
# static variables
|
31 |
+
static_variables: ['Z_GDS4_SFC','LSM']
|
32 |
+
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
|
33 |
+
|
34 |
+
# mean / std path
|
35 |
+
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
|
36 |
+
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
|
37 |
+
|
38 |
+
# train / validation split
|
39 |
+
train_years: [1979, 2018]
|
40 |
+
valid_years: [2018, 2019]
|
41 |
+
|
42 |
+
# data workflow
|
43 |
+
scaler_type: 'std_new'
|
44 |
+
|
45 |
+
# number of input states
|
46 |
+
# FuXi has 2 input states
|
47 |
+
history_len: 2
|
48 |
+
valid_history_len: 2
|
49 |
+
|
50 |
+
# number of forecast steps to compute loss
|
51 |
+
# 0 for single step training / validation
|
52 |
+
# larger than 0 for multi-step training / validation
|
53 |
+
forecast_len: 7
|
54 |
+
valid_forecast_len: 7
|
55 |
+
|
56 |
+
# one_shot: True --> compute loss on the last forecast step only
|
57 |
+
# one_shot: False --> compute loss on all forecast steps
|
58 |
+
one_shot: True
|
59 |
+
|
60 |
+
# 1 for hourly model
|
61 |
+
lead_time_periods: 6
|
62 |
+
|
63 |
+
# do not use skip_period
|
64 |
+
skip_periods: null
|
65 |
+
|
66 |
+
# compatible with the old 'std'
|
67 |
+
static_first: True
|
68 |
+
|
69 |
+
trainer:
|
70 |
+
type: multi-step # <---------- change to your type
|
71 |
+
|
72 |
+
mode: fsdp
|
73 |
+
cpu_offload: False
|
74 |
+
activation_checkpoint: True
|
75 |
+
|
76 |
+
load_weights: True
|
77 |
+
load_optimizer: True
|
78 |
+
load_scaler: True
|
79 |
+
load_sheduler: True
|
80 |
+
|
81 |
+
skip_validation: False
|
82 |
+
update_learning_rate: False
|
83 |
+
|
84 |
+
save_backup_weights: True
|
85 |
+
save_best_weights: True
|
86 |
+
|
87 |
+
learning_rate: 1.0e-06 # <-- change to your lr
|
88 |
+
weight_decay: 0
|
89 |
+
|
90 |
+
train_batch_size: 1
|
91 |
+
valid_batch_size: 1
|
92 |
+
|
93 |
+
batches_per_epoch: 759 # full epoch: 1772
|
94 |
+
valid_batches_per_epoch: 0
|
95 |
+
stopping_patience: 50
|
96 |
+
|
97 |
+
start_epoch: 0
|
98 |
+
num_epoch: 1
|
99 |
+
# False when switching from single-step to multi-step
|
100 |
+
reload_epoch: True
|
101 |
+
epochs: &epochs 20
|
102 |
+
|
103 |
+
use_scheduler: True
|
104 |
+
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
|
105 |
+
|
106 |
+
# Automatic Mixed Precision: False
|
107 |
+
amp: False
|
108 |
+
|
109 |
+
# rescale loss as loss = loss / grad_accum_every
|
110 |
+
grad_accum_every: 1
|
111 |
+
# gradient clipping
|
112 |
+
grad_max_norm: 1.0
|
113 |
+
|
114 |
+
# number of workers
|
115 |
+
thread_workers: 4
|
116 |
+
valid_thread_workers: 0
|
117 |
+
|
118 |
+
model:
|
119 |
+
type: "fuxi"
|
120 |
+
|
121 |
+
frames: 2 # number of input states
|
122 |
+
image_height: 640 # number of latitude grids
|
123 |
+
image_width: 1280 # number of longitude grids
|
124 |
+
levels: 16 # number of upper-air variable levels
|
125 |
+
channels: 4 # upper-air variable channels
|
126 |
+
surface_channels: 7 # surface variable channels
|
127 |
+
input_only_channels: 3 # dynamic forcing, forcing, static channels
|
128 |
+
output_only_channels: 0 # diagnostic variable channels
|
129 |
+
|
130 |
+
# patchify layer
|
131 |
+
patch_height: 4 # number of latitude grids in each 3D patch
|
132 |
+
patch_width: 4 # number of longitude grids in each 3D patch
|
133 |
+
frame_patch_size: 2 # number of input states in each 3D patch
|
134 |
+
|
135 |
+
# hidden layers
|
136 |
+
dim: 1024 # dimension (default: 1536)
|
137 |
+
num_groups: 32 # number of groups (default: 32)
|
138 |
+
num_heads: 8 # number of heads (default: 8)
|
139 |
+
window_size: 7 # window size (default: 7)
|
140 |
+
depth: 16 # number of swin transformers (default: 48)
|
141 |
+
|
142 |
+
# use spectral norm
|
143 |
+
use_spectral_norm: True
|
144 |
+
|
145 |
+
# ============================================================== #
|
146 |
+
# New
|
147 |
+
|
148 |
+
# use interpolation to match the output size
|
149 |
+
interp: True
|
150 |
+
|
151 |
+
# map boundary padding
|
152 |
+
padding_conf:
|
153 |
+
activate: True
|
154 |
+
mode: earth
|
155 |
+
pad_lat: 80
|
156 |
+
pad_lon: 80
|
157 |
+
|
158 |
+
post_conf:
|
159 |
+
activate: True
|
160 |
+
|
161 |
+
tracer_fixer:
|
162 |
+
activate: True
|
163 |
+
denorm: True
|
164 |
+
tracer_name: ['Q', 'Q500']
|
165 |
+
tracer_thres: [1e-8, 1e-8]
|
166 |
+
|
167 |
+
loss:
|
168 |
+
# the main training loss
|
169 |
+
training_loss: "mse"
|
170 |
+
|
171 |
+
# power loss (x), spectral_loss (x)
|
172 |
+
use_power_loss: False
|
173 |
+
use_spectral_loss: False
|
174 |
+
|
175 |
+
# use latitude weighting
|
176 |
+
use_latitude_weights: True
|
177 |
+
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
|
178 |
+
|
179 |
+
# turn-off variable weighting
|
180 |
+
use_variable_weights: False
|
181 |
+
# variable_weights:
|
182 |
+
# U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
183 |
+
# V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
184 |
+
# T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
185 |
+
# Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
186 |
+
# SP: 0.1
|
187 |
+
# t2m: 1.0
|
188 |
+
# V500: 0.1
|
189 |
+
# U500: 0.1
|
190 |
+
# T500: 0.1
|
191 |
+
# Z500: 0.1
|
192 |
+
# Q500: 0.1
|
193 |
+
|
194 |
+
predict:
|
195 |
+
forecasts:
|
196 |
+
type: "custom" # keep it as "custom"
|
197 |
+
start_year: 2020 # year of the first initialization (where rollout will start)
|
198 |
+
start_month: 1 # month of the first initialization
|
199 |
+
start_day: 1 # day of the first initialization
|
200 |
+
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
|
201 |
+
duration: 30 # number of days to initialize, starting from the (year, mon, day) above
|
202 |
+
# duration should be divisible by the number of GPUs
|
203 |
+
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
|
204 |
+
days: 2 # forecast lead time as days (1 means 24-hour forecast)
|
205 |
+
|
206 |
+
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
|
207 |
+
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
208 |
+
|
209 |
+
# turn-off low-pass filter
|
210 |
+
use_laplace_filter: False
|
211 |
+
|
212 |
+
# deprecated
|
213 |
+
# save_format: "nc"
|
214 |
+
|
215 |
+
pbs: #derecho
|
216 |
+
conda: "/glade/work/ksha/miniconda3/envs/credit"
|
217 |
+
project: "NAML0001"
|
218 |
+
job_name: "fuxi_6h"
|
219 |
+
walltime: "12:00:00"
|
220 |
+
nodes: 8
|
221 |
+
ncpus: 64
|
222 |
+
ngpus: 4
|
223 |
+
mem: '480GB'
|
224 |
+
queue: 'main'
|
model_predict.yml
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
2 |
+
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
|
3 |
+
# the FuXi architecture has been modified to reduce the overall model size
|
4 |
+
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
|
5 |
+
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
|
6 |
+
#
|
7 |
+
# Yingkai Sha
|
8 | |
9 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
10 |
+
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
|
11 |
+
seed: 1000
|
12 |
+
|
13 |
+
data:
|
14 |
+
# upper-air variables
|
15 |
+
variables: ['U','V','T','Q']
|
16 |
+
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y*'
|
17 |
+
|
18 |
+
# surface variables
|
19 |
+
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
20 |
+
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y*'
|
21 |
+
|
22 |
+
# dynamic forcing variables
|
23 |
+
dynamic_forcing_variables: ['tsi']
|
24 |
+
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
|
25 |
+
|
26 |
+
# static variables
|
27 |
+
static_variables: ['Z_GDS4_SFC','LSM']
|
28 |
+
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
|
29 |
+
|
30 |
+
# mean / std path
|
31 |
+
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
|
32 |
+
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
|
33 |
+
|
34 |
+
# train / validation split
|
35 |
+
train_years: [1979, 2018]
|
36 |
+
valid_years: [2018, 2019]
|
37 |
+
|
38 |
+
# data workflow
|
39 |
+
scaler_type: 'std_new'
|
40 |
+
|
41 |
+
# number of input states
|
42 |
+
# FuXi has 2 input states
|
43 |
+
history_len: 2
|
44 |
+
valid_history_len: 2
|
45 |
+
|
46 |
+
# number of forecast steps to compute loss
|
47 |
+
# 0 for single step training / validation
|
48 |
+
# larger than 0 for multi-step training / validation
|
49 |
+
forecast_len: 0
|
50 |
+
valid_forecast_len: 0
|
51 |
+
|
52 |
+
# 1 for hourly model
|
53 |
+
lead_time_periods: 6
|
54 |
+
|
55 |
+
# do not use skip_period
|
56 |
+
skip_periods: null
|
57 |
+
|
58 |
+
# compatible with the old 'std'
|
59 |
+
static_first: True
|
60 |
+
|
61 |
+
trainer:
|
62 |
+
type: standard
|
63 |
+
mode: fsdp
|
64 |
+
|
65 |
+
model:
|
66 |
+
type: "fuxi"
|
67 |
+
|
68 |
+
frames: 2 # number of input states
|
69 |
+
image_height: 640 # number of latitude grids
|
70 |
+
image_width: 1280 # number of longitude grids
|
71 |
+
levels: 16 # number of upper-air variable levels
|
72 |
+
channels: 4 # upper-air variable channels
|
73 |
+
surface_channels: 7 # surface variable channels
|
74 |
+
input_only_channels: 3 # dynamic forcing, forcing, static channels
|
75 |
+
output_only_channels: 0 # diagnostic variable channels
|
76 |
+
|
77 |
+
# patchify layer
|
78 |
+
patch_height: 4 # number of latitude grids in each 3D patch
|
79 |
+
patch_width: 4 # number of longitude grids in each 3D patch
|
80 |
+
frame_patch_size: 2 # number of input states in each 3D patch
|
81 |
+
|
82 |
+
# hidden layers
|
83 |
+
dim: 1024 # dimension (default: 1536)
|
84 |
+
num_groups: 32 # number of groups (default: 32)
|
85 |
+
num_heads: 8 # number of heads (default: 8)
|
86 |
+
window_size: 7 # window size (default: 7)
|
87 |
+
depth: 16 # number of swin transformers (default: 48)
|
88 |
+
|
89 |
+
# use spectral norm
|
90 |
+
use_spectral_norm: True
|
91 |
+
|
92 |
+
# ============================================================== #
|
93 |
+
# New
|
94 |
+
|
95 |
+
# use interpolation to match the output size
|
96 |
+
interp: True
|
97 |
+
|
98 |
+
# map boundary padding
|
99 |
+
padding_conf:
|
100 |
+
activate: True
|
101 |
+
mode: earth
|
102 |
+
pad_lat: 80
|
103 |
+
pad_lon: 80
|
104 |
+
|
105 |
+
post_conf:
|
106 |
+
activate: True
|
107 |
+
|
108 |
+
tracer_fixer:
|
109 |
+
activate: True
|
110 |
+
denorm: True
|
111 |
+
tracer_name: ['Q', 'Q500']
|
112 |
+
tracer_thres: [1e-8, 1e-8]
|
113 |
+
|
114 |
+
|
115 |
+
loss:
|
116 |
+
use_latitude_weights: True
|
117 |
+
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
|
118 |
+
|
119 |
+
predict:
|
120 |
+
forecasts:
|
121 |
+
type: "custom" # keep it as "custom"
|
122 |
+
start_year: 2021 # year of the first initialization (where rollout will start)
|
123 |
+
start_month: 12 # month of the first initialization
|
124 |
+
start_day: 31 # day of the first initialization
|
125 |
+
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
|
126 |
+
duration: 384 # number of days to initialize, starting from the (year, mon, day) above
|
127 |
+
# duration should be divisible by the number of GPUs
|
128 |
+
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
|
129 |
+
days: 10 # forecast lead time as days (1 means 24-hour forecast)
|
130 |
+
|
131 |
+
metadata: '/glade/u/home/ksha/miles-credit/credit/metadata/era5.yaml'
|
132 |
+
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/RAW_OUTPUT/fuxi_6h_test/'
|
133 |
+
|
134 |
+
# turn-off low-pass filter
|
135 |
+
use_laplace_filter: False
|
136 |
+
|
model_predict_cpu.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
2 |
+
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
|
3 |
+
# the FuXi architecture has been modified to reduce the overall model size
|
4 |
+
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
|
5 |
+
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
|
6 |
+
#
|
7 |
+
# Yingkai Sha
|
8 | |
9 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
10 |
+
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
|
11 |
+
seed: 1000
|
12 |
+
|
13 |
+
data:
|
14 |
+
# upper-air variables
|
15 |
+
variables: ['U','V','T','Q']
|
16 |
+
save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
|
17 |
+
|
18 |
+
# surface variables
|
19 |
+
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
20 |
+
save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
|
21 |
+
|
22 |
+
# dynamic forcing variables
|
23 |
+
dynamic_forcing_variables: ['tsi']
|
24 |
+
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
|
25 |
+
|
26 |
+
# static variables
|
27 |
+
static_variables: ['Z_GDS4_SFC','LSM']
|
28 |
+
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
|
29 |
+
|
30 |
+
# mean / std path
|
31 |
+
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
|
32 |
+
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
|
33 |
+
|
34 |
+
# train / validation split
|
35 |
+
train_years: [1979, 2018]
|
36 |
+
valid_years: [2018, 2019]
|
37 |
+
|
38 |
+
# data workflow
|
39 |
+
scaler_type: 'std_new'
|
40 |
+
|
41 |
+
# number of input states
|
42 |
+
# FuXi has 2 input states
|
43 |
+
history_len: 2
|
44 |
+
valid_history_len: 2
|
45 |
+
|
46 |
+
# number of forecast steps to compute loss
|
47 |
+
# 0 for single step training / validation
|
48 |
+
# larger than 0 for multi-step training / validation
|
49 |
+
forecast_len: 0
|
50 |
+
valid_forecast_len: 0
|
51 |
+
|
52 |
+
# 1 for hourly model
|
53 |
+
lead_time_periods: 6
|
54 |
+
|
55 |
+
# do not use skip_period
|
56 |
+
skip_periods: null
|
57 |
+
|
58 |
+
# compatible with the old 'std'
|
59 |
+
static_first: True
|
60 |
+
|
61 |
+
trainer:
|
62 |
+
type: standard
|
63 |
+
mode: none
|
64 |
+
|
65 |
+
model:
|
66 |
+
type: "fuxi"
|
67 |
+
|
68 |
+
frames: 2 # number of input states
|
69 |
+
image_height: 640 # number of latitude grids
|
70 |
+
image_width: 1280 # number of longitude grids
|
71 |
+
levels: 16 # number of upper-air variable levels
|
72 |
+
channels: 4 # upper-air variable channels
|
73 |
+
surface_channels: 7 # surface variable channels
|
74 |
+
input_only_channels: 3 # dynamic forcing, forcing, static channels
|
75 |
+
output_only_channels: 0 # diagnostic variable channels
|
76 |
+
|
77 |
+
# patchify layer
|
78 |
+
patch_height: 4 # number of latitude grids in each 3D patch
|
79 |
+
patch_width: 4 # number of longitude grids in each 3D patch
|
80 |
+
frame_patch_size: 2 # number of input states in each 3D patch
|
81 |
+
|
82 |
+
# hidden layers
|
83 |
+
dim: 1024 # dimension (default: 1536)
|
84 |
+
num_groups: 32 # number of groups (default: 32)
|
85 |
+
num_heads: 8 # number of heads (default: 8)
|
86 |
+
window_size: 7 # window size (default: 7)
|
87 |
+
depth: 16 # number of swin transformers (default: 48)
|
88 |
+
|
89 |
+
# map boundary padding
|
90 |
+
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
|
91 |
+
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
|
92 |
+
|
93 |
+
# use spectral norm
|
94 |
+
use_spectral_norm: True
|
95 |
+
|
96 |
+
loss:
|
97 |
+
use_latitude_weights: True
|
98 |
+
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
|
99 |
+
|
100 |
+
predict:
|
101 |
+
forecasts:
|
102 |
+
type: "custom" # keep it as "custom"
|
103 |
+
start_year: 2020 # year of the first initialization (where rollout will start)
|
104 |
+
start_month: 1 # month of the first initialization
|
105 |
+
start_day: 1 # day of the first initialization
|
106 |
+
start_hours: [0,] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
|
107 |
+
duration: 1 # number of days to initialize, starting from the (year, mon, day) above
|
108 |
+
# duration should be divisible by the number of GPUs
|
109 |
+
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
|
110 |
+
days: 1 # forecast lead time as days (1 means 24-hour forecast)
|
111 |
+
|
112 |
+
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/RAW_OUTPUT/fuxi_6h_collins/'
|
113 |
+
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
114 |
+
|
115 |
+
# turn-off low-pass filter
|
116 |
+
use_laplace_filter: False
|
model_single.yml
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
2 |
+
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
|
3 |
+
# the FuXi architecture has been modified to reduce the overall model size
|
4 |
+
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
|
5 |
+
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
|
6 |
+
#
|
7 |
+
# Yingkai Sha
|
8 | |
9 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
10 |
+
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
|
11 |
+
seed: 1000
|
12 |
+
|
13 |
+
data:
|
14 |
+
# upper-air variables
|
15 |
+
variables: ['U','V','T','Q']
|
16 |
+
save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
|
17 |
+
|
18 |
+
# surface variables
|
19 |
+
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
20 |
+
save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
|
21 |
+
|
22 |
+
# dynamic forcing variables
|
23 |
+
dynamic_forcing_variables: ['tsi']
|
24 |
+
save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
|
25 |
+
|
26 |
+
# diagnostic variables
|
27 |
+
# diagnostic_variables: ['V500','U500','T500','Z500','Q500']
|
28 |
+
# save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
|
29 |
+
|
30 |
+
# static variables
|
31 |
+
static_variables: ['Z_GDS4_SFC','LSM']
|
32 |
+
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
|
33 |
+
|
34 |
+
# mean / std path
|
35 |
+
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
|
36 |
+
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
|
37 |
+
|
38 |
+
# train / validation split
|
39 |
+
train_years: [1979, 2018]
|
40 |
+
valid_years: [2018, 2019]
|
41 |
+
|
42 |
+
# data workflow
|
43 |
+
scaler_type: 'std_new'
|
44 |
+
|
45 |
+
# number of input states
|
46 |
+
# FuXi has 2 input states
|
47 |
+
history_len: 2
|
48 |
+
valid_history_len: 2
|
49 |
+
|
50 |
+
# number of forecast steps to compute loss
|
51 |
+
# 0 for single step training / validation
|
52 |
+
# larger than 0 for multi-step training / validation
|
53 |
+
forecast_len: 0
|
54 |
+
valid_forecast_len: 0
|
55 |
+
|
56 |
+
# one_shot: True --> compute loss on the last forecast step only
|
57 |
+
# one_shot: False --> compute loss on all forecast steps
|
58 |
+
one_shot: True
|
59 |
+
|
60 |
+
# 1 for hourly model
|
61 |
+
lead_time_periods: 6
|
62 |
+
|
63 |
+
# do not use skip_period
|
64 |
+
skip_periods: null
|
65 |
+
|
66 |
+
# compatible with the old 'std'
|
67 |
+
static_first: True
|
68 |
+
|
69 |
+
trainer:
|
70 |
+
type: standard # <---------- change to your type
|
71 |
+
|
72 |
+
mode: fsdp
|
73 |
+
cpu_offload: False
|
74 |
+
activation_checkpoint: True
|
75 |
+
|
76 |
+
load_weights: True
|
77 |
+
load_optimizer: True
|
78 |
+
load_scaler: True
|
79 |
+
load_sheduler: True
|
80 |
+
|
81 |
+
skip_validation: False
|
82 |
+
update_learning_rate: False
|
83 |
+
|
84 |
+
save_backup_weights: True
|
85 |
+
save_best_weights: True
|
86 |
+
|
87 |
+
learning_rate: 1.0e-03 # <-- change to your lr
|
88 |
+
weight_decay: 0
|
89 |
+
|
90 |
+
train_batch_size: 1
|
91 |
+
valid_batch_size: 1
|
92 |
+
|
93 |
+
batches_per_epoch: 0
|
94 |
+
valid_batches_per_epoch: 0
|
95 |
+
stopping_patience: 50
|
96 |
+
|
97 |
+
start_epoch: 0
|
98 |
+
num_epoch: 2
|
99 |
+
reload_epoch: True
|
100 |
+
epochs: &epochs 70
|
101 |
+
|
102 |
+
use_scheduler: True
|
103 |
+
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
|
104 |
+
|
105 |
+
# Automatic Mixed Precision: False
|
106 |
+
amp: False
|
107 |
+
|
108 |
+
# rescale loss as loss = loss / grad_accum_every
|
109 |
+
grad_accum_every: 1
|
110 |
+
# gradient clipping
|
111 |
+
grad_max_norm: 1.0
|
112 |
+
|
113 |
+
# number of workers
|
114 |
+
thread_workers: 4
|
115 |
+
valid_thread_workers: 0
|
116 |
+
|
117 |
+
model:
|
118 |
+
type: "fuxi"
|
119 |
+
|
120 |
+
frames: 2 # number of input states
|
121 |
+
image_height: 640 # number of latitude grids
|
122 |
+
image_width: 1280 # number of longitude grids
|
123 |
+
levels: 16 # number of upper-air variable levels
|
124 |
+
channels: 4 # upper-air variable channels
|
125 |
+
surface_channels: 7 # surface variable channels
|
126 |
+
input_only_channels: 3 # dynamic forcing, forcing, static channels
|
127 |
+
output_only_channels: 0 # diagnostic variable channels
|
128 |
+
|
129 |
+
# patchify layer
|
130 |
+
patch_height: 4 # number of latitude grids in each 3D patch
|
131 |
+
patch_width: 4 # number of longitude grids in each 3D patch
|
132 |
+
frame_patch_size: 2 # number of input states in each 3D patch
|
133 |
+
|
134 |
+
# hidden layers
|
135 |
+
dim: 1024 # dimension (default: 1536)
|
136 |
+
num_groups: 32 # number of groups (default: 32)
|
137 |
+
num_heads: 8 # number of heads (default: 8)
|
138 |
+
window_size: 7 # window size (default: 7)
|
139 |
+
depth: 16 # number of swin transformers (default: 48)
|
140 |
+
|
141 |
+
# map boundary padding
|
142 |
+
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
|
143 |
+
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
|
144 |
+
|
145 |
+
# use spectral norm
|
146 |
+
use_spectral_norm: True
|
147 |
+
|
148 |
+
loss:
|
149 |
+
# the main training loss
|
150 |
+
training_loss: "mse"
|
151 |
+
|
152 |
+
# power loss (x), spectral_loss (x)
|
153 |
+
use_power_loss: False
|
154 |
+
use_spectral_loss: False
|
155 |
+
|
156 |
+
# use latitude weighting
|
157 |
+
use_latitude_weights: True
|
158 |
+
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
|
159 |
+
|
160 |
+
# turn-off variable weighting
|
161 |
+
use_variable_weights: False
|
162 |
+
# variable_weights:
|
163 |
+
# U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
164 |
+
# V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
165 |
+
# T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
166 |
+
# Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
167 |
+
# SP: 0.1
|
168 |
+
# t2m: 1.0
|
169 |
+
# V500: 0.1
|
170 |
+
# U500: 0.1
|
171 |
+
# T500: 0.1
|
172 |
+
# Z500: 0.1
|
173 |
+
# Q500: 0.1
|
174 |
+
|
175 |
+
predict:
|
176 |
+
forecasts:
|
177 |
+
type: "custom" # keep it as "custom"
|
178 |
+
start_year: 2020 # year of the first initialization (where rollout will start)
|
179 |
+
start_month: 1 # month of the first initialization
|
180 |
+
start_day: 1 # day of the first initialization
|
181 |
+
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
|
182 |
+
duration: 30 # number of days to initialize, starting from the (year, mon, day) above
|
183 |
+
# duration should be divisible by the number of GPUs
|
184 |
+
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
|
185 |
+
days: 2 # forecast lead time as days (1 means 24-hour forecast)
|
186 |
+
|
187 |
+
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
|
188 |
+
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
189 |
+
|
190 |
+
# turn-off low-pass filter
|
191 |
+
use_laplace_filter: False
|
192 |
+
|
193 |
+
# deprecated
|
194 |
+
# save_format: "nc"
|
195 |
+
|
196 |
+
pbs: #derecho
|
197 |
+
conda: "/glade/work/ksha/miniconda3/envs/credit"
|
198 |
+
project: "NAML0001"
|
199 |
+
job_name: "fuxi_6h"
|
200 |
+
walltime: "12:00:00"
|
201 |
+
nodes: 8
|
202 |
+
ncpus: 64
|
203 |
+
ngpus: 4
|
204 |
+
mem: '480GB'
|
205 |
+
queue: 'main'
|
model_single_cached.yml
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
2 |
+
# This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
|
3 |
+
# the FuXi architecture has been modified to reduce the overall model size
|
4 |
+
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
|
5 |
+
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
|
6 |
+
#
|
7 |
+
# Yingkai Sha
|
8 | |
9 |
+
# --------------------------------------------------------------------------------------------------------------------- #
|
10 |
+
save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
|
11 |
+
seed: 1000
|
12 |
+
data:
|
13 |
+
# upper-air variables
|
14 |
+
variables: ['U','V','T','Q']
|
15 |
+
save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
|
16 |
+
|
17 |
+
# surface variables
|
18 |
+
surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
19 |
+
save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
|
20 |
+
|
21 |
+
# dynamic forcing variables
|
22 |
+
dynamic_forcing_variables: ['tsi']
|
23 |
+
save_loc_dynamic_forcing: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
|
24 |
+
|
25 |
+
# static variables
|
26 |
+
static_variables: ['Z_GDS4_SFC','LSM']
|
27 |
+
save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
|
28 |
+
|
29 |
+
# mean / std path
|
30 |
+
mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
|
31 |
+
std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
|
32 |
+
|
33 |
+
# train / validation split
|
34 |
+
train_years: [1979, 2018]
|
35 |
+
valid_years: [2018, 2019]
|
36 |
+
|
37 |
+
# data workflow
|
38 |
+
scaler_type: 'std_cached'
|
39 |
+
|
40 |
+
# number of input states
|
41 |
+
# FuXi has 2 input states
|
42 |
+
history_len: 2
|
43 |
+
valid_history_len: 2
|
44 |
+
|
45 |
+
# number of forecast steps to compute loss
|
46 |
+
# 0 for single step training / validation
|
47 |
+
# larger than 0 for multi-step training / validation
|
48 |
+
forecast_len: 0
|
49 |
+
valid_forecast_len: 0
|
50 |
+
|
51 |
+
# one_shot: True --> compute loss on the last forecast step only
|
52 |
+
# one_shot: False --> compute loss on all forecast steps
|
53 |
+
one_shot: True
|
54 |
+
|
55 |
+
# 1 for hourly model
|
56 |
+
lead_time_periods: 6
|
57 |
+
|
58 |
+
# do not use skip_period
|
59 |
+
skip_periods: null
|
60 |
+
|
61 |
+
# compatible with the old 'std'
|
62 |
+
static_first: True
|
63 |
+
|
64 |
+
trainer:
|
65 |
+
type: standard # <---------- change to your type
|
66 |
+
|
67 |
+
mode: fsdp
|
68 |
+
cpu_offload: False
|
69 |
+
activation_checkpoint: True
|
70 |
+
|
71 |
+
load_weights: True
|
72 |
+
load_optimizer: True
|
73 |
+
load_scaler: True
|
74 |
+
load_sheduler: True
|
75 |
+
|
76 |
+
skip_validation: False
|
77 |
+
update_learning_rate: False
|
78 |
+
|
79 |
+
save_backup_weights: True
|
80 |
+
save_best_weights: True
|
81 |
+
|
82 |
+
learning_rate: 1.0e-03 # <-- change to your lr
|
83 |
+
weight_decay: 0
|
84 |
+
|
85 |
+
train_batch_size: 1
|
86 |
+
valid_batch_size: 1
|
87 |
+
|
88 |
+
batches_per_epoch: 0
|
89 |
+
valid_batches_per_epoch: 0
|
90 |
+
stopping_patience: 50
|
91 |
+
|
92 |
+
start_epoch: 0
|
93 |
+
#num_epoch: 5
|
94 |
+
reload_epoch: True
|
95 |
+
epochs: &epochs 70
|
96 |
+
|
97 |
+
use_scheduler: True
|
98 |
+
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
|
99 |
+
|
100 |
+
# Automatic Mixed Precision: False
|
101 |
+
amp: False
|
102 |
+
|
103 |
+
# rescale loss as loss = loss / grad_accum_every
|
104 |
+
grad_accum_every: 1
|
105 |
+
# gradient clipping
|
106 |
+
grad_max_norm: 1.0
|
107 |
+
|
108 |
+
# number of workers
|
109 |
+
thread_workers: 4
|
110 |
+
valid_thread_workers: 0
|
111 |
+
|
112 |
+
model:
|
113 |
+
type: "fuxi"
|
114 |
+
|
115 |
+
frames: 2 # number of input states
|
116 |
+
image_height: 640 # number of latitude grids
|
117 |
+
image_width: 1280 # number of longitude grids
|
118 |
+
levels: 16 # number of upper-air variable levels
|
119 |
+
channels: 4 # upper-air variable channels
|
120 |
+
surface_channels: 7 # surface variable channels
|
121 |
+
input_only_channels: 3 # dynamic forcing, forcing, static channels
|
122 |
+
output_only_channels: 0 # diagnostic variable channels
|
123 |
+
|
124 |
+
# patchify layer
|
125 |
+
patch_height: 4 # number of latitude grids in each 3D patch
|
126 |
+
patch_width: 4 # number of longitude grids in each 3D patch
|
127 |
+
frame_patch_size: 2 # number of input states in each 3D patch
|
128 |
+
|
129 |
+
# hidden layers
|
130 |
+
dim: 1024 # dimension (default: 1536)
|
131 |
+
num_groups: 32 # number of groups (default: 32)
|
132 |
+
num_heads: 8 # number of heads (default: 8)
|
133 |
+
window_size: 7 # window size (default: 7)
|
134 |
+
depth: 16 # number of swin transformers (default: 48)
|
135 |
+
|
136 |
+
# map boundary padding
|
137 |
+
pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
|
138 |
+
pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
|
139 |
+
|
140 |
+
# use spectral norm
|
141 |
+
use_spectral_norm: True
|
142 |
+
|
143 |
+
loss:
|
144 |
+
# the main training loss
|
145 |
+
training_loss: "mse"
|
146 |
+
|
147 |
+
# power loss (x), spectral_loss (x)
|
148 |
+
use_power_loss: False
|
149 |
+
use_spectral_loss: False
|
150 |
+
|
151 |
+
# use latitude weighting
|
152 |
+
use_latitude_weights: True
|
153 |
+
latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
|
154 |
+
|
155 |
+
# turn-off variable weighting
|
156 |
+
use_variable_weights: False
|
157 |
+
# variable_weights:
|
158 |
+
# U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
159 |
+
# V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
160 |
+
# T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
161 |
+
# Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
|
162 |
+
# SP: 0.1
|
163 |
+
# t2m: 1.0
|
164 |
+
# V500: 0.1
|
165 |
+
# U500: 0.1
|
166 |
+
# T500: 0.1
|
167 |
+
# Z500: 0.1
|
168 |
+
# Q500: 0.1
|
169 |
+
|
170 |
+
predict:
|
171 |
+
forecasts:
|
172 |
+
type: "custom" # keep it as "custom"
|
173 |
+
start_year: 2020 # year of the first initialization (where rollout will start)
|
174 |
+
start_month: 1 # month of the first initialization
|
175 |
+
start_day: 1 # day of the first initialization
|
176 |
+
start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
|
177 |
+
duration: 30 # number of days to initialize, starting from the (year, mon, day) above
|
178 |
+
# duration should be divisible by the number of GPUs
|
179 |
+
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
|
180 |
+
days: 2 # forecast lead time as days (1 means 24-hour forecast)
|
181 |
+
|
182 |
+
save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
|
183 |
+
save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
|
184 |
+
|
185 |
+
# turn-off low-pass filter
|
186 |
+
use_laplace_filter: False
|
187 |
+
|
188 |
+
# deprecated
|
189 |
+
# save_format: "nc"
|
190 |
+
|
191 |
+
pbs: #derecho
|
192 |
+
conda: "/glade/work/ksha/miniconda3/envs/credit"
|
193 |
+
project: "NAML0001"
|
194 |
+
job_name: "fuxi_6h"
|
195 |
+
walltime: "12:00:00"
|
196 |
+
nodes: 8
|
197 |
+
ncpus: 64
|
198 |
+
ngpus: 4
|
199 |
+
mem: '480GB'
|
200 |
+
queue: 'main'
|
optimizer_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ae6c78578cc62ec39b838f86701f8e22d5238d97a3b2fd16daa2513fdfeebb7
|
3 |
+
size 1683978368
|