File size: 2,900 Bytes
89c0b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
### Training
Some settings follow those in the [AlphaFold 3](https://www.nature.com/articles/s41586-024-07487-w) paper, The table below shows the training settings for different fine-tuning stages:
| Arguments | Initial training | Fine tuning 1 | Fine tuning 2 | Fine tuning 3 |
|-----------------------------------------|--------|---------|-------|-----|
| `train_crop_size` | 384 | 640 | 768 | 768 |
| `diffusion_batch_size` | 48 | 32 | 32 | 32 |
| `loss.weight.alpha_pae` | 0 | 0 | 0 | 1.0 |
| `loss.weight.alpha_bond` | 0 | 1.0 | 1.0 | 0 |
| `loss.weight.smooth_lddt` | 1.0 | 0 | 0 | 0 |
| `loss.weight.alpha_confidence` | 1e-4 | 1e-4 | 1e-4 | 1e-4|
| `loss.weight.alpha_diffusion` | 4.0 | 4.0 | 4.0 | 0 |
| `loss.weight.alpha_distogram` | 0.03 | 0.03 | 0.03 | 0 |
| `train_confidence_only` | False | False | False | True|
| full BF16-mixed speed(A100, s/step) | ~12 | ~30 | ~44 | ~13 |
| full BF16-mixed peak memory (G) | ~34 | ~35 | ~48 | ~24 |
We recommend carrying out the training on A100-80G or H20/H100 GPUs. If utilizing full BF16-Mixed precision training, the initial training stage can also be performed on A800-40G GPUs. GPUs with smaller memory, such as A30, you'll need to reduce the model size, such as decreasing `model.pairformer.nblocks` and `diffusion_batch_size`.
### Inference
The model will be infered in BF16 Mixed precision, by **default**, the `SampleDiffusion`,`ConfidenceHead` part will still be infered in FP32 precision.
Below are reference examples of cuda memory usage (G).
| Ntoken | Natom | Default | Full BF16 Mixed |
|--------|-------|-------|------------------|
| 500 | 10000 | 5.6 | 5.1 |
| 1500 | 30000 | 24.8 | 19.2 |
| 2500 | 25000 | 52.2 | 34.8 |
| 3500 | 35000 | 67.6 | 38.2 |
| 4500 | 45000 | 77.0 | 59.2 |
| 5000 | 50000 | OOM | 72.8 |
The script in [runner/inference.py](../runner/inference.py) will automatically change the default precision to compute `SampleDiffusion`,`ConfidenceHead` to avoid OOM as follows:
```python
def update_inference_configs(configs: Any, N_token: int):
# Setting the default inference configs for different N_token and N_atom
# when N_token is larger than 3000, the default config might OOM even on a
# A100 80G GPUS,
if N_token > 3840:
configs.skip_amp.confidence_head = False
configs.skip_amp.sample_diffusion = False
elif N_token > 2560:
configs.skip_amp.confidence_head = False
configs.skip_amp.sample_diffusion = True
else:
configs.skip_amp.confidence_head = True
configs.skip_amp.sample_diffusion = True
return configs
``` |