import torch from .base_optimizer import BaseOptimizer class Backpropagation(BaseOptimizer): def __init__(self, model, config): super().__init__(model, config) self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=float(config.get('learning_rate', 5e-5)), weight_decay=float(config.get('weight_decay', 0.01)) ) def set_accelerator(self, accelerator): super().set_accelerator(accelerator) self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) def step(self, inputs, labels): self.model.train() outputs = self.model(**inputs, labels=labels) loss = outputs.loss self.accelerator.backward(loss) self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) return float(loss.item())