A newer version of the Gradio SDK is available:
5.21.0
Preparing the datasets
To download the wwPDB dataset and proprecessed training data, you need at least 1T disk space.
Use the following command to download the preprocessed wwpdb training databases:
wget -P /af3-dev/release_data/ https://af3-dev.tos-cn-beijing.volces.com/release_data.tar.gz
tar -xzvf /af3-dev/release_data/release_data.tar.gz -C /af3-dev/release_data/
rm /af3-dev/release_data/release_data.tar.gz
The data should be placed in the /af3-dev/release_data/
directory. You can also download it to a different directory, but remember to modify the DATA_ROOT_DIR
in configs/configs_data.py correspondingly. Data hierarchy after extraction is as follows:
βββ components.v20240608.cif [408M] # ccd source file
βββ components.v20240608.cif.rdkit_mol.pkl [121M] # rdkit Mol object generated by ccd source file
βββ indices [33M] # chain or interface entries
βββ mmcif [283G] # raw mmcif data
βββ mmcif_bioassembly [36G] # preprocessed wwPDB structural data
βββ mmcif_msa [450G] # msa files
βββ posebusters_bioassembly [42M] # preprocessed posebusters structural data
βββ posebusters_mmcif [361M] # raw mmcif data
βββ recentPDB_bioassembly [1.5G] # preprocessed recentPDB structural data
βββ seq_to_pdb_index.json [45M] # sequence to pdb id mapping file
Data processing scripts have also been released. you can refer to prepare_training_data.md for generating {dataset}_bioassembly
and indices
. And you can refer to msa_pipeline.md for pipelines to get mmcif_msa
and seq_to_pdb_index.json
.
Training demo
After the installation and data preparations, you can run the following command to train the model from scratch:
bash train_demo.sh
Key arguments in this scripts are explained as follows:
dtype
: data type used in training. Valid options include"bf16"
and"fp32"
.--dtype fp32
: the model will be trained in full FP32 precision.--dtype bf16
: the model will be trained in BF16 Mixed precision, by default, theSampleDiffusion
,ConfidenceHead
,Mini-rollout
andLoss
part will still be training in FP32 precision. if you want to train and infer the model in full BF16 Mixed precision, pass the following arguments to the train_demo.sh:--skip_amp.sample_diffusion_training false \ --skip_amp.confidence_head false \ --skip_amp.sample_diffusion false \ --skip_amp.loss false \
ema_decay
: the decay rate of the EMA, default is 0.999.sample_diffusion.N_step
: during evalutaion, the number of steps for the diffusion process is reduced to 20 to improve efficiency.data.train_sets/data.test_sets
: the datasets used for training and evaluation. If there are multiple datasets, separate them with commas.Some settings follow those in the AlphaFold 3 paper, The table in model_performance.md shows the training settings and memory usages for different training stages.
In this version, we do not use the template and RNA MSA feature for training. As the default settings in configs/configs_base.py and configs/configs_data.py:
--model.template_embedder.n_blocks 0 \ --data.msa.enable_rna_msa false \
This will be considered in our future work.
The model also supports distributed training with PyTorchβs
torchrun
. For example, if youβre running distributed training on a single node with 4 GPUs, you can use:torchrun --nproc_per_node=4 runner/train.py
You can also pass other arguments with
--<ARGS_KEY> <ARGS_VALUE>
as you want.
If you want to speed up training, see setting up kernels documentation .
Finetune demo
If you want to fine-tune the model on a specific subset, such as an antibody dataset, you only need to provide a PDB list file and load the pretrained weights as finetune_demo.sh shows:
# wget -P /af3-dev/release_model/ https://af3-dev.tos-cn-beijing.volces.com/release_model/model_v0.2.0.pt
checkpoint_path="/af3-dev/release_model/model_v0.2.0.pt"
...
--load_checkpoint_path ${checkpoint_path} \
--load_checkpoint_ema_path ${checkpoint_path} \
--data.weightedPDB_before2109_wopb_nometalc_0925.base_info.pdb_list examples/subset.txt \
, where the subset.txt
is a file containing the PDB IDs like:
6hvq
5mqc
5zin
3ew0
5akv