Spaces:
Sleeping
Sleeping
| import torch | |
| from .base_optimizer import BaseOptimizer | |
| class Backpropagation(BaseOptimizer): | |
| def __init__(self, model, config): | |
| super().__init__(model, config) | |
| self.optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=float(config.get('learning_rate', 5e-5)), | |
| weight_decay=float(config.get('weight_decay', 0.01)) | |
| ) | |
| def set_accelerator(self, accelerator): | |
| super().set_accelerator(accelerator) | |
| self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) | |
| def step(self, inputs, labels): | |
| self.model.train() | |
| outputs = self.model(**inputs, labels=labels) | |
| loss = outputs.loss | |
| self.accelerator.backward(loss) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| return float(loss.item()) | |