djgagne commited on
Commit
d6d123e
·
1 Parent(s): f59a072

Added model files

Browse files
README.md CHANGED
@@ -1,3 +1,22 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - weather
7
+ - climate
8
+ - global
9
+ ---
10
+
11
+ # NSF NCAR Community Research Earth Digital Intelligence Twin (CREDIT) FuXi 6-Hour Model Weights and Configuration
12
+ This repository contains the PyTorch checkpoint weights and data/model configuration files for the CREDIT WXFormer 6-hour model.
13
+ More information about the training and verification of this model can be found in the Schreck et al. (2024) ArXiv [preprint](https://arxiv.org/abs/2411.07814).
14
+
15
+ ## Data Access
16
+ Our model is trained on ERA5 Reanalysis Data from a subset of 16 of the 137 hybrid sigma-pressure vertical levels.
17
+ The raw data are available on the NSF NCAR [Research Data Archive](https://rda.ucar.edu/datasets/d633006/).
18
+ Processed data can be accessed on the [CREDIT ERA5 Zarr Files](https://app.globus.org/file-manager?origin_id=2fc90d8f-10b7-44e1-a6a5-cf844112822e&origin_path=%2F) globus collection.
19
+
20
+ ## Running the Model
21
+ To run the model, first install the [CREDIT package](https://github.com/NCAR/miles-credit) from github.
22
+ Modify the paths in the `finetune_final/model_predict.yml` configuration file to point to the appropriate ERA5 and scaler directories.
backup_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f23a63c6843906694ae529948df4d4ac347b4a103eafe7bf939331257fbb921
3
+ size 1260
backup_final/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce0b399e07427b2ad9bd4d44d03801866f58bc236229341cdf2bce1dfcdd2551
3
+ size 1132
backup_final/best_model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3333a5bca6a8ca87285866c05998ffe00bd726156a70814b208dcfab0f9c86d8
3
+ size 1044714782
backup_final/best_optimizer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8fcbe3d191c3999d8652133c32ca2487924ff34a5a8238947b0e0ac795c1d78
3
+ size 1683978368
backup_final/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b25c034739e99cc1ad34d56332f3f29a00a30086874e631ffaca50294c0b663
3
+ size 1132
backup_final/model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af8ba6f9618c1a075b21c0fa6d741e511a933caac4466f74b7c5b88c3fdc1595
3
+ size 1044714782
backup_final/optimizer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c304aa2c70e1b98e5c88c0a4ed76024e6b7ce0681d1fb94e5cfb40d264ddde49
3
+ size 1683978368
backup_final/training_log.csv ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ index,epoch,train_forecast_len,valid_forecast_len,train_loss,valid_loss,train_acc,valid_acc,train_mae,valid_mae,lr
2
+ 0,0,8.0,,0.1387606938719278,0.1108131468296051,0.970562528166061,0.9694464921951294,0.1099035362854148,0.1114028289914131,1e-06
3
+ 1,1,8.0,,0.1377466195342726,0.1100945279002189,0.9708313843792448,0.9694475412368776,0.1095289197089015,0.111414648592472,9.938441702975689e-07
4
+ 2,2,8.0,,0.1385083928684629,0.1117167651653289,0.970662198676108,0.9694490909576416,0.1096920592413432,0.1113834738731384,9.755282581475769e-07
5
+ 3,3,8.0,,0.1384558184501839,0.1127898842096328,0.9706702200948644,0.969444227218628,0.1096835084190795,0.1113786697387695,9.45503262094184e-07
6
+ 4,4,8.0,,0.1384486823842145,0.1118623703718185,0.9707273427048534,0.9694671511650086,0.1096120545829551,0.1113623306155204,9.045084971874738e-07
7
+ 5,5,8.0,,0.1381992981407168,0.1129628434777259,0.9707307068726762,0.9694624781608582,0.1096099303527311,0.1113620564341545,8.535533905932738e-07
8
+ 6,6,8.0,,0.1391399535737018,0.1127516105771064,0.970510085502318,0.9694764852523804,0.1099141280661301,0.11133803576231,7.938926261462367e-07
9
+ 7,7,8.0,,0.1381134882824537,0.1126476496458053,0.970679276152876,0.9694833755493164,0.1096669749230421,0.1113172695040702,7.269952498697736e-07
10
+ 8,8,8.0,,0.1384780509630525,0.1114308580756187,0.970642456109973,0.9694918870925904,0.1097221110256607,0.1112886816263198,6.54508497187474e-07
11
+ 9,9,8.0,,0.1387791881958643,0.109854482114315,0.9706856229088524,0.9695030689239502,0.1096888491503491,0.1112985372543335,5.782172325201157e-07
12
+ 10,10,8.0,,0.1381753478910926,0.1100143820047378,0.9706785721269992,0.9695019960403444,0.1096319330461097,0.1112829759716987,5.000000000000002e-07
13
+ 11,11,8.0,,0.1397529813965318,0.1093554720282554,0.9704045921917488,0.9695140957832336,0.1100279053032633,0.1112509205937385,4.217827674798849e-07
14
+ 12,12,8.0,,0.1388648276822212,0.1110095202922821,0.970521897509478,0.969509732723236,0.1098963392895986,0.1112472981214523,3.454915028125265e-07
15
+ 13,13,8.0,,0.139924709661833,0.1111257791519165,0.9704050614899796,0.9695259213447572,0.1100835978435119,0.111248242855072,2.730047501302268e-07
16
+ 14,14,8.0,,0.1386912109745035,0.1115610241889953,0.9705360937809598,0.9695113182067872,0.1098389108229688,0.1112457856535911,2.061073738537636e-07
17
+ 15,15,8.0,,0.1394790709411359,0.1099427804350852,0.9703982228975208,0.9695213079452516,0.1101159101182764,0.1112276285886764,1.4644660940672637e-07
18
+ 16,16,8.0,,0.1392562538778043,0.1106211960315704,0.9705332976241984,0.9695265769958497,0.1098723398013548,0.1112342774868011,9.549150281252641e-08
19
+ 17,17,8.0,,0.1382005311872648,0.11191920638084411,0.9706875172214232,0.9695238590240478,0.10971980983678531,0.11123556792736053,5.449673790581614e-08
backup_model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ce28c57eb54f0e5d4a4229cfe9dfc26327f7812a73c5ac0eb657fc617335c3
3
+ size 1044714782
backup_optimizer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc8d079d97419fa481b7bf2cda4b24a33fb43c5f3b357940b3e017f9b95910a7
3
+ size 1683978368
best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:521fa1b735f45041505f3d1fccfae8b7d18625e5a59f52069452a9cab31f183e
3
+ size 1260
best_model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1211699e85d387cf89e913a0b86d92e58986e8222322ef1cb1fb7dc30c23a790
3
+ size 1044714782
best_optimizer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca87b2643346de107cd09e11fe130fbaa33b8669b139c0cc462b65dcc41c180e
3
+ size 1683978368
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97de0e29c19c96993c73f700e695f0bf31ebdf7ba4297c3bda664e3e317a224a
3
+ size 1260
l.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #PBS -A NAML0001
3
+ #PBS -N wxformer_6h
4
+ #PBS -l walltime=12:00:00
5
+ #PBS -l select=8:ncpus=64:ngpus=4:mem=480GB
6
+ #PBS -q main
7
+ #PBS -j oe
8
+ #PBS -k eod
9
+ # Load modules
10
+ module purge
11
+ module load gcc craype cray-mpich cuda cudnn/8.8.1.3-12 conda
12
+ conda activate /glade/u/home/schreck/.conda/envs/credit-derecho
13
+ # Export environment variables
14
+ export LSCRATCH=/glade/derecho/scratch/schreck/
15
+ export LOGLEVEL=INFO
16
+ export NCCL_DEBUG=INFO
17
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
18
+ export NCCL_SOCKET_IFNAME=hsn
19
+ export MPICH_GPU_MANAGED_MEMORY_SUPPORT_ENABLED=1
20
+ export MPICH_OFI_NIC_POLICY=GPU
21
+ export MPICH_GPU_SUPPORT_ENABLED=1
22
+ export NCCL_IB_DISABLE=1
23
+ export NCCL_CROSS_NIC=1
24
+ export NCCL_NCHANNELS_PER_NET_PEER=4
25
+ export MPICH_RDMA_ENABLED_CUDA=1
26
+ export NCCL_NET="AWS Libfabric"
27
+ export NCCL_NET_GDR_LEVEL=PBH
28
+ export FI_CXI_DISABLE_HOST_REGISTER=1
29
+ export FI_CXI_OPTIMIZED_MRS=false
30
+ export FI_MR_CACHE_MONITOR=userfaultfd
31
+ export FI_CXI_DEFAULT_CQ_SIZE=131072
32
+ # logger.info the results
33
+ echo "Number of nodes: 8"
34
+ echo "Number of GPUs per node: 4"
35
+ echo "Total number of GPUs: 32"
36
+ # Log in to WandB if needed
37
+ # wandb login 02d2b1af00b5df901cb2bee071872de774781520
38
+ # Launch MPIs
39
+ nodes=( $( cat $PBS_NODEFILE ) )
40
+ echo nodes: $nodes
41
+ # Find headnode's IP:
42
+ head_node=${nodes[0]}
43
+ head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
44
+ MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -n 32 --ppn 4 --cpu-bind none python /glade/derecho/scratch/schreck/CREDIT_runs/test_ben_env/miles-credit/applications/train.py -c model.yml --backend nccl
launch_multi.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #PBS -A NCIS0010
3
+ #PBS -N fx6h_multi
4
+ #PBS -l walltime=12:00:00
5
+ #PBS -l select=8:ncpus=64:ngpus=4
6
+ #PBS -q main
7
+ #PBS -j oe
8
+ #PBS -k eod
9
+ #PBS -r n
10
+ # Load modules
11
+ module purge
12
+ module load gcc craype cray-mpich cuda cudnn/8.8.1.3-12 conda
13
+ conda activate /glade/work/ksha/miniconda3/envs/credit-derecho
14
+ # conda conda activate /glade/u/home/schreck/.conda/envs/credit-derecho
15
+ # Export environment variables
16
+ export LSCRATCH=/glade/derecho/scratch/ksha/
17
+ export LOGLEVEL=INFO
18
+ export NCCL_DEBUG=INFO
19
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
20
+ export NCCL_SOCKET_IFNAME=hsn
21
+ export MPICH_GPU_MANAGED_MEMORY_SUPPORT_ENABLED=1
22
+ export MPICH_OFI_NIC_POLICY=GPU
23
+ export MPICH_GPU_SUPPORT_ENABLED=1
24
+ export NCCL_IB_DISABLE=1
25
+ export NCCL_CROSS_NIC=1
26
+ export NCCL_NCHANNELS_PER_NET_PEER=4
27
+ export MPICH_RDMA_ENABLED_CUDA=1
28
+ export NCCL_NET="AWS Libfabric"
29
+ export NCCL_NET_GDR_LEVEL=PBH
30
+ export FI_CXI_DISABLE_HOST_REGISTER=1
31
+ export FI_CXI_OPTIMIZED_MRS=false
32
+ export FI_MR_CACHE_MONITOR=userfaultfd
33
+ export FI_CXI_DEFAULT_CQ_SIZE=131072
34
+ # logger.info the results
35
+ echo "Number of nodes: 8"
36
+ echo "Number of GPUs per node: 4"
37
+ echo "Total number of GPUs: 32"
38
+ # Log in to WandB if needed
39
+ # wandb login 02d2b1af00b5df901cb2bee071872de774781520
40
+ # Launch MPIs
41
+ nodes=( $( cat $PBS_NODEFILE ) )
42
+ echo nodes: $nodes
43
+ # Find headnode's IP:
44
+ head_node=${nodes[0]}
45
+ head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
46
+ MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -n 32 --ppn 4 --cpu-bind none python /glade/u/home/ksha/miles-credit/applications/train_multistep.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_multi.yml --backend nccl
launch_predict.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #PBS -A NCIS0010
3
+ #PBS -N fx6h_pred
4
+ #PBS -l walltime=12:00:00
5
+ #PBS -l select=8:ncpus=64:ngpus=4
6
+ #PBS -q main
7
+ #PBS -j oe
8
+ #PBS -k eod
9
+ #PBS -r n
10
+ # Load modules
11
+ module purge
12
+ module load nvhpc cuda cray-mpich conda
13
+ conda activate /glade/work/ksha/miniconda3/envs/credit
14
+ # Get a list of allocated nodes
15
+ nodes=( $( cat $PBS_NODEFILE ) )
16
+ head_node=${nodes[0]}
17
+ head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
18
+ # Export environment variables
19
+ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
20
+ export LSCRATCH=/glade/derecho/scratch/ksha/
21
+ export LOGLEVEL=INFO
22
+ #export NCCL_DEBUG=INFO
23
+
24
+ export NCCL_SOCKET_IFNAME=hsn
25
+ export NCCL_HOME=/glade/u/home/dhoward/work/nccl-ofi-plugin/install
26
+ export LD_LIBRARY_PATH=$NCCL_HOME/lib:$NCCL_HOME/plugin/lib:$LD_LIBRARY_PATH
27
+
28
+ export NCCL_NCHANNELS_PER_NET_PEER=4
29
+ export MPICH_GPU_SUPPORT_ENABLED=1
30
+ export MPICH_OFI_NIC_POLICY=GPU
31
+ export MPICH_RDMA_ENABLED_CUDA=1
32
+ export NCCL_DISABLE_IB=1
33
+ export NCCL_CROSS_NIC=1
34
+ export FI_CXI_DISABLE_HOST_REGISTER=1
35
+ export FI_CXI_OPTIMIZED_MRS=false
36
+
37
+ # Print the results
38
+ echo "Number of nodes: 8"
39
+ echo "Number of GPUs per node: 4"
40
+ echo "Total number of GPUs: 32"
41
+ # Log in to WandB if needed
42
+ # wandb login 02d2b1af00b5df901cb2bee071872de774781520
43
+
44
+ # Launch MPIs
45
+ mpiexec -n 8 --ppn 1 --cpu-bind none torchrun --nnodes=8 --nproc-per-node=4 --rdzv-backend=c10d --rdzv-endpoint=$head_node_ip /glade/u/home/ksha/miles-credit/applications/rollout_to_netcdf.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_predict.yml
launch_rollout.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #PBS -A NAML0001
3
+ #PBS -N fx6h_roll
4
+ #PBS -l walltime=12:00:00
5
+ #PBS -l select=8:ncpus=64:ngpus=4
6
+ #PBS -q main
7
+ #PBS -j oe
8
+ #PBS -k eod
9
+ #PBS -r n
10
+ # Load modules
11
+ module purge
12
+ module load nvhpc cuda cray-mpich conda
13
+ conda activate /glade/work/ksha/miniconda3/envs/credit
14
+ # Get a list of allocated nodes
15
+ nodes=( $( cat $PBS_NODEFILE ) )
16
+ head_node=${nodes[0]}
17
+ head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
18
+ # Export environment variables
19
+ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
20
+ export LSCRATCH=/glade/derecho/scratch/ksha/
21
+ export LOGLEVEL=INFO
22
+ #export NCCL_DEBUG=INFO
23
+
24
+ export NCCL_SOCKET_IFNAME=hsn
25
+ export NCCL_HOME=/glade/u/home/dhoward/work/nccl-ofi-plugin/install
26
+ export LD_LIBRARY_PATH=$NCCL_HOME/lib:$NCCL_HOME/plugin/lib:$LD_LIBRARY_PATH
27
+
28
+ export NCCL_NCHANNELS_PER_NET_PEER=4
29
+ export MPICH_GPU_SUPPORT_ENABLED=1
30
+ export MPICH_OFI_NIC_POLICY=GPU
31
+ export MPICH_RDMA_ENABLED_CUDA=1
32
+ export NCCL_DISABLE_IB=1
33
+ export NCCL_CROSS_NIC=1
34
+ export FI_CXI_DISABLE_HOST_REGISTER=1
35
+ export FI_CXI_OPTIMIZED_MRS=false
36
+
37
+ # Print the results
38
+ echo "Number of nodes: 8"
39
+ echo "Number of GPUs per node: 4"
40
+ echo "Total number of GPUs: 32"
41
+ # Log in to WandB if needed
42
+ # wandb login 02d2b1af00b5df901cb2bee071872de774781520
43
+
44
+ # Launch MPIs
45
+ mpiexec -n 8 --ppn 1 --cpu-bind none torchrun --nnodes=8 --nproc-per-node=4 --rdzv-backend=c10d --rdzv-endpoint=$head_node_ip /glade/u/home/ksha/miles-credit/applications/rollout_metrics.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_predict.yml
launch_single.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #PBS -A NCIS0010
3
+ #PBS -N fuxi_6h
4
+ #PBS -l walltime=12:00:00
5
+ #PBS -l select=8:ncpus=64:ngpus=4
6
+ #PBS -q main
7
+ #PBS -j oe
8
+ #PBS -k eod
9
+ #PBS -r n
10
+ # Load modules
11
+ module purge
12
+ module load gcc craype cray-mpich cuda cudnn/8.8.1.3-12 conda
13
+ conda activate /glade/work/ksha/miniconda3/envs/credit-derecho
14
+ # conda conda activate /glade/u/home/schreck/.conda/envs/credit-derecho
15
+ # Export environment variables
16
+ export LSCRATCH=/glade/derecho/scratch/ksha/
17
+ export LOGLEVEL=INFO
18
+ export NCCL_DEBUG=INFO
19
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
20
+ export NCCL_SOCKET_IFNAME=hsn
21
+ export MPICH_GPU_MANAGED_MEMORY_SUPPORT_ENABLED=1
22
+ export MPICH_OFI_NIC_POLICY=GPU
23
+ export MPICH_GPU_SUPPORT_ENABLED=1
24
+ export NCCL_IB_DISABLE=1
25
+ export NCCL_CROSS_NIC=1
26
+ export NCCL_NCHANNELS_PER_NET_PEER=4
27
+ export MPICH_RDMA_ENABLED_CUDA=1
28
+ export NCCL_NET="AWS Libfabric"
29
+ export NCCL_NET_GDR_LEVEL=PBH
30
+ export FI_CXI_DISABLE_HOST_REGISTER=1
31
+ export FI_CXI_OPTIMIZED_MRS=false
32
+ export FI_MR_CACHE_MONITOR=userfaultfd
33
+ export FI_CXI_DEFAULT_CQ_SIZE=131072
34
+ # logger.info the results
35
+ echo "Number of nodes: 8"
36
+ echo "Number of GPUs per node: 4"
37
+ echo "Total number of GPUs: 32"
38
+ # Log in to WandB if needed
39
+ # wandb login 02d2b1af00b5df901cb2bee071872de774781520
40
+ # Launch MPIs
41
+ nodes=( $( cat $PBS_NODEFILE ) )
42
+ echo nodes: $nodes
43
+ # Find headnode's IP:
44
+ head_node=${nodes[0]}
45
+ head_node_ip=$(ssh $head_node hostname -i | awk '{print $1}')
46
+ MASTER_ADDR=$head_node_ip MASTER_PORT=1234 mpiexec -n 32 --ppn 4 --cpu-bind none python /glade/u/home/ksha/miles-credit/applications/train.py -c /glade/work/ksha/CREDIT_runs/fuxi_6h/model_single.yml --backend nccl
model.yml ADDED
File without changes
model_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3137668291b82d20baf4b412769f6782a62db0a75cdcd213c7f03e02fba027ab
3
+ size 1044714782
model_multi.yml ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------------------------------------------- #
2
+ # This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
3
+ # the FuXi architecture has been modified to reduce the overall model size
4
+ # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
5
+ # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
6
+ #
7
+ # Yingkai Sha
8
9
+ # --------------------------------------------------------------------------------------------------------------------- #
10
+ save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
11
+ seed: 1000
12
+
13
+ data:
14
+ # upper-air variables
15
+ variables: ['U','V','T','Q']
16
+ save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/Sixiourly_y_TOTAL*'
17
+
18
+ # surface variables
19
+ surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
20
+ save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y_TOTAL*'
21
+
22
+ # dynamic forcing variables
23
+ dynamic_forcing_variables: ['tsi']
24
+ save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
25
+
26
+ # diagnostic variables
27
+ # diagnostic_variables: ['V500','U500','T500','Z500','Q500']
28
+ # save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
29
+
30
+ # static variables
31
+ static_variables: ['Z_GDS4_SFC','LSM']
32
+ save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
33
+
34
+ # mean / std path
35
+ mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
36
+ std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
37
+
38
+ # train / validation split
39
+ train_years: [1979, 2018]
40
+ valid_years: [2018, 2019]
41
+
42
+ # data workflow
43
+ scaler_type: 'std_new'
44
+
45
+ # number of input states
46
+ # FuXi has 2 input states
47
+ history_len: 2
48
+ valid_history_len: 2
49
+
50
+ # number of forecast steps to compute loss
51
+ # 0 for single step training / validation
52
+ # larger than 0 for multi-step training / validation
53
+ forecast_len: 7
54
+ valid_forecast_len: 7
55
+
56
+ # one_shot: True --> compute loss on the last forecast step only
57
+ # one_shot: False --> compute loss on all forecast steps
58
+ one_shot: True
59
+
60
+ # 1 for hourly model
61
+ lead_time_periods: 6
62
+
63
+ # do not use skip_period
64
+ skip_periods: null
65
+
66
+ # compatible with the old 'std'
67
+ static_first: True
68
+
69
+ trainer:
70
+ type: multi-step # <---------- change to your type
71
+
72
+ mode: fsdp
73
+ cpu_offload: False
74
+ activation_checkpoint: True
75
+
76
+ load_weights: True
77
+ load_optimizer: True
78
+ load_scaler: True
79
+ load_sheduler: True
80
+
81
+ skip_validation: False
82
+ update_learning_rate: False
83
+
84
+ save_backup_weights: True
85
+ save_best_weights: True
86
+
87
+ learning_rate: 1.0e-06 # <-- change to your lr
88
+ weight_decay: 0
89
+
90
+ train_batch_size: 1
91
+ valid_batch_size: 1
92
+
93
+ batches_per_epoch: 759 # full epoch: 1772
94
+ valid_batches_per_epoch: 0
95
+ stopping_patience: 50
96
+
97
+ start_epoch: 0
98
+ num_epoch: 1
99
+ # False when switching from single-step to multi-step
100
+ reload_epoch: True
101
+ epochs: &epochs 20
102
+
103
+ use_scheduler: True
104
+ scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
105
+
106
+ # Automatic Mixed Precision: False
107
+ amp: False
108
+
109
+ # rescale loss as loss = loss / grad_accum_every
110
+ grad_accum_every: 1
111
+ # gradient clipping
112
+ grad_max_norm: 1.0
113
+
114
+ # number of workers
115
+ thread_workers: 4
116
+ valid_thread_workers: 0
117
+
118
+ model:
119
+ type: "fuxi"
120
+
121
+ frames: 2 # number of input states
122
+ image_height: 640 # number of latitude grids
123
+ image_width: 1280 # number of longitude grids
124
+ levels: 16 # number of upper-air variable levels
125
+ channels: 4 # upper-air variable channels
126
+ surface_channels: 7 # surface variable channels
127
+ input_only_channels: 3 # dynamic forcing, forcing, static channels
128
+ output_only_channels: 0 # diagnostic variable channels
129
+
130
+ # patchify layer
131
+ patch_height: 4 # number of latitude grids in each 3D patch
132
+ patch_width: 4 # number of longitude grids in each 3D patch
133
+ frame_patch_size: 2 # number of input states in each 3D patch
134
+
135
+ # hidden layers
136
+ dim: 1024 # dimension (default: 1536)
137
+ num_groups: 32 # number of groups (default: 32)
138
+ num_heads: 8 # number of heads (default: 8)
139
+ window_size: 7 # window size (default: 7)
140
+ depth: 16 # number of swin transformers (default: 48)
141
+
142
+ # use spectral norm
143
+ use_spectral_norm: True
144
+
145
+ # ============================================================== #
146
+ # New
147
+
148
+ # use interpolation to match the output size
149
+ interp: True
150
+
151
+ # map boundary padding
152
+ padding_conf:
153
+ activate: True
154
+ mode: earth
155
+ pad_lat: 80
156
+ pad_lon: 80
157
+
158
+ post_conf:
159
+ activate: True
160
+
161
+ tracer_fixer:
162
+ activate: True
163
+ denorm: True
164
+ tracer_name: ['Q', 'Q500']
165
+ tracer_thres: [1e-8, 1e-8]
166
+
167
+ loss:
168
+ # the main training loss
169
+ training_loss: "mse"
170
+
171
+ # power loss (x), spectral_loss (x)
172
+ use_power_loss: False
173
+ use_spectral_loss: False
174
+
175
+ # use latitude weighting
176
+ use_latitude_weights: True
177
+ latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
178
+
179
+ # turn-off variable weighting
180
+ use_variable_weights: False
181
+ # variable_weights:
182
+ # U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
183
+ # V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
184
+ # T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
185
+ # Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
186
+ # SP: 0.1
187
+ # t2m: 1.0
188
+ # V500: 0.1
189
+ # U500: 0.1
190
+ # T500: 0.1
191
+ # Z500: 0.1
192
+ # Q500: 0.1
193
+
194
+ predict:
195
+ forecasts:
196
+ type: "custom" # keep it as "custom"
197
+ start_year: 2020 # year of the first initialization (where rollout will start)
198
+ start_month: 1 # month of the first initialization
199
+ start_day: 1 # day of the first initialization
200
+ start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
201
+ duration: 30 # number of days to initialize, starting from the (year, mon, day) above
202
+ # duration should be divisible by the number of GPUs
203
+ # (e.g., duration: 384 for 365-day rollout using 32 GPUs)
204
+ days: 2 # forecast lead time as days (1 means 24-hour forecast)
205
+
206
+ save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
207
+ save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
208
+
209
+ # turn-off low-pass filter
210
+ use_laplace_filter: False
211
+
212
+ # deprecated
213
+ # save_format: "nc"
214
+
215
+ pbs: #derecho
216
+ conda: "/glade/work/ksha/miniconda3/envs/credit"
217
+ project: "NAML0001"
218
+ job_name: "fuxi_6h"
219
+ walltime: "12:00:00"
220
+ nodes: 8
221
+ ncpus: 64
222
+ ngpus: 4
223
+ mem: '480GB'
224
+ queue: 'main'
model_predict.yml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------------------------------------------- #
2
+ # This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
3
+ # the FuXi architecture has been modified to reduce the overall model size
4
+ # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
5
+ # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
6
+ #
7
+ # Yingkai Sha
8
9
+ # --------------------------------------------------------------------------------------------------------------------- #
10
+ save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
11
+ seed: 1000
12
+
13
+ data:
14
+ # upper-air variables
15
+ variables: ['U','V','T','Q']
16
+ save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y*'
17
+
18
+ # surface variables
19
+ surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
20
+ save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/ERA5_mlevel_arXiv/SixHourly_y*'
21
+
22
+ # dynamic forcing variables
23
+ dynamic_forcing_variables: ['tsi']
24
+ save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
25
+
26
+ # static variables
27
+ static_variables: ['Z_GDS4_SFC','LSM']
28
+ save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
29
+
30
+ # mean / std path
31
+ mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
32
+ std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
33
+
34
+ # train / validation split
35
+ train_years: [1979, 2018]
36
+ valid_years: [2018, 2019]
37
+
38
+ # data workflow
39
+ scaler_type: 'std_new'
40
+
41
+ # number of input states
42
+ # FuXi has 2 input states
43
+ history_len: 2
44
+ valid_history_len: 2
45
+
46
+ # number of forecast steps to compute loss
47
+ # 0 for single step training / validation
48
+ # larger than 0 for multi-step training / validation
49
+ forecast_len: 0
50
+ valid_forecast_len: 0
51
+
52
+ # 1 for hourly model
53
+ lead_time_periods: 6
54
+
55
+ # do not use skip_period
56
+ skip_periods: null
57
+
58
+ # compatible with the old 'std'
59
+ static_first: True
60
+
61
+ trainer:
62
+ type: standard
63
+ mode: fsdp
64
+
65
+ model:
66
+ type: "fuxi"
67
+
68
+ frames: 2 # number of input states
69
+ image_height: 640 # number of latitude grids
70
+ image_width: 1280 # number of longitude grids
71
+ levels: 16 # number of upper-air variable levels
72
+ channels: 4 # upper-air variable channels
73
+ surface_channels: 7 # surface variable channels
74
+ input_only_channels: 3 # dynamic forcing, forcing, static channels
75
+ output_only_channels: 0 # diagnostic variable channels
76
+
77
+ # patchify layer
78
+ patch_height: 4 # number of latitude grids in each 3D patch
79
+ patch_width: 4 # number of longitude grids in each 3D patch
80
+ frame_patch_size: 2 # number of input states in each 3D patch
81
+
82
+ # hidden layers
83
+ dim: 1024 # dimension (default: 1536)
84
+ num_groups: 32 # number of groups (default: 32)
85
+ num_heads: 8 # number of heads (default: 8)
86
+ window_size: 7 # window size (default: 7)
87
+ depth: 16 # number of swin transformers (default: 48)
88
+
89
+ # use spectral norm
90
+ use_spectral_norm: True
91
+
92
+ # ============================================================== #
93
+ # New
94
+
95
+ # use interpolation to match the output size
96
+ interp: True
97
+
98
+ # map boundary padding
99
+ padding_conf:
100
+ activate: True
101
+ mode: earth
102
+ pad_lat: 80
103
+ pad_lon: 80
104
+
105
+ post_conf:
106
+ activate: True
107
+
108
+ tracer_fixer:
109
+ activate: True
110
+ denorm: True
111
+ tracer_name: ['Q', 'Q500']
112
+ tracer_thres: [1e-8, 1e-8]
113
+
114
+
115
+ loss:
116
+ use_latitude_weights: True
117
+ latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
118
+
119
+ predict:
120
+ forecasts:
121
+ type: "custom" # keep it as "custom"
122
+ start_year: 2021 # year of the first initialization (where rollout will start)
123
+ start_month: 12 # month of the first initialization
124
+ start_day: 31 # day of the first initialization
125
+ start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
126
+ duration: 384 # number of days to initialize, starting from the (year, mon, day) above
127
+ # duration should be divisible by the number of GPUs
128
+ # (e.g., duration: 384 for 365-day rollout using 32 GPUs)
129
+ days: 10 # forecast lead time as days (1 means 24-hour forecast)
130
+
131
+ metadata: '/glade/u/home/ksha/miles-credit/credit/metadata/era5.yaml'
132
+ save_forecast: '/glade/derecho/scratch/ksha/CREDIT/RAW_OUTPUT/fuxi_6h_test/'
133
+
134
+ # turn-off low-pass filter
135
+ use_laplace_filter: False
136
+
model_predict_cpu.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------------------------------------------- #
2
+ # This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
3
+ # the FuXi architecture has been modified to reduce the overall model size
4
+ # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
5
+ # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
6
+ #
7
+ # Yingkai Sha
8
9
+ # --------------------------------------------------------------------------------------------------------------------- #
10
+ save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
11
+ seed: 1000
12
+
13
+ data:
14
+ # upper-air variables
15
+ variables: ['U','V','T','Q']
16
+ save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
17
+
18
+ # surface variables
19
+ surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
20
+ save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
21
+
22
+ # dynamic forcing variables
23
+ dynamic_forcing_variables: ['tsi']
24
+ save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
25
+
26
+ # static variables
27
+ static_variables: ['Z_GDS4_SFC','LSM']
28
+ save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
29
+
30
+ # mean / std path
31
+ mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
32
+ std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
33
+
34
+ # train / validation split
35
+ train_years: [1979, 2018]
36
+ valid_years: [2018, 2019]
37
+
38
+ # data workflow
39
+ scaler_type: 'std_new'
40
+
41
+ # number of input states
42
+ # FuXi has 2 input states
43
+ history_len: 2
44
+ valid_history_len: 2
45
+
46
+ # number of forecast steps to compute loss
47
+ # 0 for single step training / validation
48
+ # larger than 0 for multi-step training / validation
49
+ forecast_len: 0
50
+ valid_forecast_len: 0
51
+
52
+ # 1 for hourly model
53
+ lead_time_periods: 6
54
+
55
+ # do not use skip_period
56
+ skip_periods: null
57
+
58
+ # compatible with the old 'std'
59
+ static_first: True
60
+
61
+ trainer:
62
+ type: standard
63
+ mode: none
64
+
65
+ model:
66
+ type: "fuxi"
67
+
68
+ frames: 2 # number of input states
69
+ image_height: 640 # number of latitude grids
70
+ image_width: 1280 # number of longitude grids
71
+ levels: 16 # number of upper-air variable levels
72
+ channels: 4 # upper-air variable channels
73
+ surface_channels: 7 # surface variable channels
74
+ input_only_channels: 3 # dynamic forcing, forcing, static channels
75
+ output_only_channels: 0 # diagnostic variable channels
76
+
77
+ # patchify layer
78
+ patch_height: 4 # number of latitude grids in each 3D patch
79
+ patch_width: 4 # number of longitude grids in each 3D patch
80
+ frame_patch_size: 2 # number of input states in each 3D patch
81
+
82
+ # hidden layers
83
+ dim: 1024 # dimension (default: 1536)
84
+ num_groups: 32 # number of groups (default: 32)
85
+ num_heads: 8 # number of heads (default: 8)
86
+ window_size: 7 # window size (default: 7)
87
+ depth: 16 # number of swin transformers (default: 48)
88
+
89
+ # map boundary padding
90
+ pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
91
+ pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
92
+
93
+ # use spectral norm
94
+ use_spectral_norm: True
95
+
96
+ loss:
97
+ use_latitude_weights: True
98
+ latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
99
+
100
+ predict:
101
+ forecasts:
102
+ type: "custom" # keep it as "custom"
103
+ start_year: 2020 # year of the first initialization (where rollout will start)
104
+ start_month: 1 # month of the first initialization
105
+ start_day: 1 # day of the first initialization
106
+ start_hours: [0,] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
107
+ duration: 1 # number of days to initialize, starting from the (year, mon, day) above
108
+ # duration should be divisible by the number of GPUs
109
+ # (e.g., duration: 384 for 365-day rollout using 32 GPUs)
110
+ days: 1 # forecast lead time as days (1 means 24-hour forecast)
111
+
112
+ save_forecast: '/glade/derecho/scratch/ksha/CREDIT/RAW_OUTPUT/fuxi_6h_collins/'
113
+ save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
114
+
115
+ # turn-off low-pass filter
116
+ use_laplace_filter: False
model_single.yml ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------------------------------------------- #
2
+ # This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
3
+ # the FuXi architecture has been modified to reduce the overall model size
4
+ # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
5
+ # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
6
+ #
7
+ # Yingkai Sha
8
9
+ # --------------------------------------------------------------------------------------------------------------------- #
10
+ save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
11
+ seed: 1000
12
+
13
+ data:
14
+ # upper-air variables
15
+ variables: ['U','V','T','Q']
16
+ save_loc: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
17
+
18
+ # surface variables
19
+ surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
20
+ save_loc_surface: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
21
+
22
+ # dynamic forcing variables
23
+ dynamic_forcing_variables: ['tsi']
24
+ save_loc_dynamic_forcing: '/glade/derecho/scratch/dgagne/credit_solar_6h_0.25deg/*.nc'
25
+
26
+ # diagnostic variables
27
+ # diagnostic_variables: ['V500','U500','T500','Z500','Q500']
28
+ # save_loc_diagnostic: '/glade/derecho/scratch/wchapman/SixHourly_y_TOTAL*'
29
+
30
+ # static variables
31
+ static_variables: ['Z_GDS4_SFC','LSM']
32
+ save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
33
+
34
+ # mean / std path
35
+ mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
36
+ std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
37
+
38
+ # train / validation split
39
+ train_years: [1979, 2018]
40
+ valid_years: [2018, 2019]
41
+
42
+ # data workflow
43
+ scaler_type: 'std_new'
44
+
45
+ # number of input states
46
+ # FuXi has 2 input states
47
+ history_len: 2
48
+ valid_history_len: 2
49
+
50
+ # number of forecast steps to compute loss
51
+ # 0 for single step training / validation
52
+ # larger than 0 for multi-step training / validation
53
+ forecast_len: 0
54
+ valid_forecast_len: 0
55
+
56
+ # one_shot: True --> compute loss on the last forecast step only
57
+ # one_shot: False --> compute loss on all forecast steps
58
+ one_shot: True
59
+
60
+ # 1 for hourly model
61
+ lead_time_periods: 6
62
+
63
+ # do not use skip_period
64
+ skip_periods: null
65
+
66
+ # compatible with the old 'std'
67
+ static_first: True
68
+
69
+ trainer:
70
+ type: standard # <---------- change to your type
71
+
72
+ mode: fsdp
73
+ cpu_offload: False
74
+ activation_checkpoint: True
75
+
76
+ load_weights: True
77
+ load_optimizer: True
78
+ load_scaler: True
79
+ load_sheduler: True
80
+
81
+ skip_validation: False
82
+ update_learning_rate: False
83
+
84
+ save_backup_weights: True
85
+ save_best_weights: True
86
+
87
+ learning_rate: 1.0e-03 # <-- change to your lr
88
+ weight_decay: 0
89
+
90
+ train_batch_size: 1
91
+ valid_batch_size: 1
92
+
93
+ batches_per_epoch: 0
94
+ valid_batches_per_epoch: 0
95
+ stopping_patience: 50
96
+
97
+ start_epoch: 0
98
+ num_epoch: 2
99
+ reload_epoch: True
100
+ epochs: &epochs 70
101
+
102
+ use_scheduler: True
103
+ scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
104
+
105
+ # Automatic Mixed Precision: False
106
+ amp: False
107
+
108
+ # rescale loss as loss = loss / grad_accum_every
109
+ grad_accum_every: 1
110
+ # gradient clipping
111
+ grad_max_norm: 1.0
112
+
113
+ # number of workers
114
+ thread_workers: 4
115
+ valid_thread_workers: 0
116
+
117
+ model:
118
+ type: "fuxi"
119
+
120
+ frames: 2 # number of input states
121
+ image_height: 640 # number of latitude grids
122
+ image_width: 1280 # number of longitude grids
123
+ levels: 16 # number of upper-air variable levels
124
+ channels: 4 # upper-air variable channels
125
+ surface_channels: 7 # surface variable channels
126
+ input_only_channels: 3 # dynamic forcing, forcing, static channels
127
+ output_only_channels: 0 # diagnostic variable channels
128
+
129
+ # patchify layer
130
+ patch_height: 4 # number of latitude grids in each 3D patch
131
+ patch_width: 4 # number of longitude grids in each 3D patch
132
+ frame_patch_size: 2 # number of input states in each 3D patch
133
+
134
+ # hidden layers
135
+ dim: 1024 # dimension (default: 1536)
136
+ num_groups: 32 # number of groups (default: 32)
137
+ num_heads: 8 # number of heads (default: 8)
138
+ window_size: 7 # window size (default: 7)
139
+ depth: 16 # number of swin transformers (default: 48)
140
+
141
+ # map boundary padding
142
+ pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
143
+ pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
144
+
145
+ # use spectral norm
146
+ use_spectral_norm: True
147
+
148
+ loss:
149
+ # the main training loss
150
+ training_loss: "mse"
151
+
152
+ # power loss (x), spectral_loss (x)
153
+ use_power_loss: False
154
+ use_spectral_loss: False
155
+
156
+ # use latitude weighting
157
+ use_latitude_weights: True
158
+ latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
159
+
160
+ # turn-off variable weighting
161
+ use_variable_weights: False
162
+ # variable_weights:
163
+ # U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
164
+ # V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
165
+ # T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
166
+ # Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
167
+ # SP: 0.1
168
+ # t2m: 1.0
169
+ # V500: 0.1
170
+ # U500: 0.1
171
+ # T500: 0.1
172
+ # Z500: 0.1
173
+ # Q500: 0.1
174
+
175
+ predict:
176
+ forecasts:
177
+ type: "custom" # keep it as "custom"
178
+ start_year: 2020 # year of the first initialization (where rollout will start)
179
+ start_month: 1 # month of the first initialization
180
+ start_day: 1 # day of the first initialization
181
+ start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
182
+ duration: 30 # number of days to initialize, starting from the (year, mon, day) above
183
+ # duration should be divisible by the number of GPUs
184
+ # (e.g., duration: 384 for 365-day rollout using 32 GPUs)
185
+ days: 2 # forecast lead time as days (1 means 24-hour forecast)
186
+
187
+ save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
188
+ save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
189
+
190
+ # turn-off low-pass filter
191
+ use_laplace_filter: False
192
+
193
+ # deprecated
194
+ # save_format: "nc"
195
+
196
+ pbs: #derecho
197
+ conda: "/glade/work/ksha/miniconda3/envs/credit"
198
+ project: "NAML0001"
199
+ job_name: "fuxi_6h"
200
+ walltime: "12:00:00"
201
+ nodes: 8
202
+ ncpus: 64
203
+ ngpus: 4
204
+ mem: '480GB'
205
+ queue: 'main'
model_single_cached.yml ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------------------------------------------------------------------- #
2
+ # This yaml file implements 6 hourly FuXi on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
3
+ # the FuXi architecture has been modified to reduce the overall model size
4
+ # The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask inputs
5
+ # Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
6
+ #
7
+ # Yingkai Sha
8
9
+ # --------------------------------------------------------------------------------------------------------------------- #
10
+ save_loc: '/glade/work/ksha/CREDIT_runs/fuxi_6h/'
11
+ seed: 1000
12
+ data:
13
+ # upper-air variables
14
+ variables: ['U','V','T','Q']
15
+ save_loc: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
16
+
17
+ # surface variables
18
+ surface_variables: ['SP','t2m','V500','U500','T500','Z500','Q500']
19
+ save_loc_surface: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
20
+
21
+ # dynamic forcing variables
22
+ dynamic_forcing_variables: ['tsi']
23
+ save_loc_dynamic_forcing: '/glade/derecho/scratch/ksha/CREDIT_data/arXiv_cached/cache_arXiv_6h_*'
24
+
25
+ # static variables
26
+ static_variables: ['Z_GDS4_SFC','LSM']
27
+ save_loc_static: '/glade/derecho/scratch/ksha/CREDIT_data/static_norm_old.nc'
28
+
29
+ # mean / std path
30
+ mean_path: '/glade/derecho/scratch/ksha/CREDIT_data/mean_6h_1979_2018_16lev_0.25deg.nc'
31
+ std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_residual_6h_1979_2018_16lev_0.25deg.nc'
32
+
33
+ # train / validation split
34
+ train_years: [1979, 2018]
35
+ valid_years: [2018, 2019]
36
+
37
+ # data workflow
38
+ scaler_type: 'std_cached'
39
+
40
+ # number of input states
41
+ # FuXi has 2 input states
42
+ history_len: 2
43
+ valid_history_len: 2
44
+
45
+ # number of forecast steps to compute loss
46
+ # 0 for single step training / validation
47
+ # larger than 0 for multi-step training / validation
48
+ forecast_len: 0
49
+ valid_forecast_len: 0
50
+
51
+ # one_shot: True --> compute loss on the last forecast step only
52
+ # one_shot: False --> compute loss on all forecast steps
53
+ one_shot: True
54
+
55
+ # 1 for hourly model
56
+ lead_time_periods: 6
57
+
58
+ # do not use skip_period
59
+ skip_periods: null
60
+
61
+ # compatible with the old 'std'
62
+ static_first: True
63
+
64
+ trainer:
65
+ type: standard # <---------- change to your type
66
+
67
+ mode: fsdp
68
+ cpu_offload: False
69
+ activation_checkpoint: True
70
+
71
+ load_weights: True
72
+ load_optimizer: True
73
+ load_scaler: True
74
+ load_sheduler: True
75
+
76
+ skip_validation: False
77
+ update_learning_rate: False
78
+
79
+ save_backup_weights: True
80
+ save_best_weights: True
81
+
82
+ learning_rate: 1.0e-03 # <-- change to your lr
83
+ weight_decay: 0
84
+
85
+ train_batch_size: 1
86
+ valid_batch_size: 1
87
+
88
+ batches_per_epoch: 0
89
+ valid_batches_per_epoch: 0
90
+ stopping_patience: 50
91
+
92
+ start_epoch: 0
93
+ #num_epoch: 5
94
+ reload_epoch: True
95
+ epochs: &epochs 70
96
+
97
+ use_scheduler: True
98
+ scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}
99
+
100
+ # Automatic Mixed Precision: False
101
+ amp: False
102
+
103
+ # rescale loss as loss = loss / grad_accum_every
104
+ grad_accum_every: 1
105
+ # gradient clipping
106
+ grad_max_norm: 1.0
107
+
108
+ # number of workers
109
+ thread_workers: 4
110
+ valid_thread_workers: 0
111
+
112
+ model:
113
+ type: "fuxi"
114
+
115
+ frames: 2 # number of input states
116
+ image_height: 640 # number of latitude grids
117
+ image_width: 1280 # number of longitude grids
118
+ levels: 16 # number of upper-air variable levels
119
+ channels: 4 # upper-air variable channels
120
+ surface_channels: 7 # surface variable channels
121
+ input_only_channels: 3 # dynamic forcing, forcing, static channels
122
+ output_only_channels: 0 # diagnostic variable channels
123
+
124
+ # patchify layer
125
+ patch_height: 4 # number of latitude grids in each 3D patch
126
+ patch_width: 4 # number of longitude grids in each 3D patch
127
+ frame_patch_size: 2 # number of input states in each 3D patch
128
+
129
+ # hidden layers
130
+ dim: 1024 # dimension (default: 1536)
131
+ num_groups: 32 # number of groups (default: 32)
132
+ num_heads: 8 # number of heads (default: 8)
133
+ window_size: 7 # window size (default: 7)
134
+ depth: 16 # number of swin transformers (default: 48)
135
+
136
+ # map boundary padding
137
+ pad_lon: 80 # number of grids to pad on 0 and 360 deg lon
138
+ pad_lat: 80 # number of grids to pad on -90 and 90 deg lat
139
+
140
+ # use spectral norm
141
+ use_spectral_norm: True
142
+
143
+ loss:
144
+ # the main training loss
145
+ training_loss: "mse"
146
+
147
+ # power loss (x), spectral_loss (x)
148
+ use_power_loss: False
149
+ use_spectral_loss: False
150
+
151
+ # use latitude weighting
152
+ use_latitude_weights: True
153
+ latitude_weights: "/glade/u/home/wchapman/MLWPS/DataLoader/LSM_static_variables_ERA5_zhght.nc"
154
+
155
+ # turn-off variable weighting
156
+ use_variable_weights: False
157
+ # variable_weights:
158
+ # U: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
159
+ # V: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
160
+ # T: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
161
+ # Q: [0.132, 0.123, 0.113, 0.104, 0.095, 0.085, 0.076, 0.067, 0.057, 0.048, 0.039, 0.029, 0.02 , 0.011, 0.005]
162
+ # SP: 0.1
163
+ # t2m: 1.0
164
+ # V500: 0.1
165
+ # U500: 0.1
166
+ # T500: 0.1
167
+ # Z500: 0.1
168
+ # Q500: 0.1
169
+
170
+ predict:
171
+ forecasts:
172
+ type: "custom" # keep it as "custom"
173
+ start_year: 2020 # year of the first initialization (where rollout will start)
174
+ start_month: 1 # month of the first initialization
175
+ start_day: 1 # day of the first initialization
176
+ start_hours: [0, 12] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
177
+ duration: 30 # number of days to initialize, starting from the (year, mon, day) above
178
+ # duration should be divisible by the number of GPUs
179
+ # (e.g., duration: 384 for 365-day rollout using 32 GPUs)
180
+ days: 2 # forecast lead time as days (1 means 24-hour forecast)
181
+
182
+ save_forecast: '/glade/derecho/scratch/ksha/CREDIT/fuxi_6h/'
183
+ save_vars: ['SP','t2m','V500','U500','T500','Z500','Q500']
184
+
185
+ # turn-off low-pass filter
186
+ use_laplace_filter: False
187
+
188
+ # deprecated
189
+ # save_format: "nc"
190
+
191
+ pbs: #derecho
192
+ conda: "/glade/work/ksha/miniconda3/envs/credit"
193
+ project: "NAML0001"
194
+ job_name: "fuxi_6h"
195
+ walltime: "12:00:00"
196
+ nodes: 8
197
+ ncpus: 64
198
+ ngpus: 4
199
+ mem: '480GB'
200
+ queue: 'main'
optimizer_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ae6c78578cc62ec39b838f86701f8e22d5238d97a3b2fd16daa2513fdfeebb7
3
+ size 1683978368