File size: 791 Bytes
ed920f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pytorch_lightning as L
from pytorch_lightning.strategies import DDPStrategy
from configs.config import Config
from utils.data_loader import get_dataloaders
from models.diffusion import Diffusion

# Get dataloaders
train_loader, val_loader, _ = get_dataloaders(Config)

# Initialize model
latent_diffusion_model = Diffusion(Config, latent_dim=Config.latent_dim)

# Initialize trainer
trainer = L.Trainer(
    max_epochs=Config.training["epochs"],
    gpus=Config.training["gpus"],
    precision=Config.training["precision"],
    strategy=DDPStrategy(find_unused_parameters=False),
    accumulate_grad_batches=Config.training["accumulate_grad_batches"],
    default_root_dir=Config.training["save_dir"]
)

# Train the model
trainer.fit(latent_diffusion_model, train_loader, val_loader)