<pre> from accelerate import Accelerator accelerator = Accelerator() train_dataloader, model, optimizer, scheduler = accelerator.prepare( dataloader, model, optimizer, scheduler ) model.train() for batch in train_dataloader: inputs, targets = batch outputs = model(inputs) loss = loss_function(outputs, targets) accelerator.backward(loss) optimizer.step() scheduler.step() optimizer.zero_grad() </pre>