### Training Some settings follow those in the [AlphaFold 3](https://www.nature.com/articles/s41586-024-07487-w) paper, The table below shows the training settings for different fine-tuning stages: | Arguments | Initial training | Fine tuning 1 | Fine tuning 2 | Fine tuning 3 | |-----------------------------------------|--------|---------|-------|-----| | `train_crop_size` | 384 | 640 | 768 | 768 | | `diffusion_batch_size` | 48 | 32 | 32 | 32 | | `loss.weight.alpha_pae` | 0 | 0 | 0 | 1.0 | | `loss.weight.alpha_bond` | 0 | 1.0 | 1.0 | 0 | | `loss.weight.smooth_lddt` | 1.0 | 0 | 0 | 0 | | `loss.weight.alpha_confidence` | 1e-4 | 1e-4 | 1e-4 | 1e-4| | `loss.weight.alpha_diffusion` | 4.0 | 4.0 | 4.0 | 0 | | `loss.weight.alpha_distogram` | 0.03 | 0.03 | 0.03 | 0 | | `train_confidence_only` | False | False | False | True| | full BF16-mixed speed(A100, s/step) | ~12 | ~30 | ~44 | ~13 | | full BF16-mixed peak memory (G) | ~34 | ~35 | ~48 | ~24 | We recommend carrying out the training on A100-80G or H20/H100 GPUs. If utilizing full BF16-Mixed precision training, the initial training stage can also be performed on A800-40G GPUs. GPUs with smaller memory, such as A30, you'll need to reduce the model size, such as decreasing `model.pairformer.nblocks` and `diffusion_batch_size`. ### Inference The model will be infered in BF16 Mixed precision, by **default**, the `SampleDiffusion`,`ConfidenceHead` part will still be infered in FP32 precision. Below are reference examples of cuda memory usage (G). | Ntoken | Natom | Default | Full BF16 Mixed | |--------|-------|-------|------------------| | 500 | 10000 | 5.6 | 5.1 | | 1500 | 30000 | 24.8 | 19.2 | | 2500 | 25000 | 52.2 | 34.8 | | 3500 | 35000 | 67.6 | 38.2 | | 4500 | 45000 | 77.0 | 59.2 | | 5000 | 50000 | OOM | 72.8 | The script in [runner/inference.py](../runner/inference.py) will automatically change the default precision to compute `SampleDiffusion`,`ConfidenceHead` to avoid OOM as follows: ```python def update_inference_configs(configs: Any, N_token: int): # Setting the default inference configs for different N_token and N_atom # when N_token is larger than 3000, the default config might OOM even on a # A100 80G GPUS, if N_token > 3840: configs.skip_amp.confidence_head = False configs.skip_amp.sample_diffusion = False elif N_token > 2560: configs.skip_amp.confidence_head = False configs.skip_amp.sample_diffusion = True else: configs.skip_amp.confidence_head = True configs.skip_amp.sample_diffusion = True return configs ```