Spaces:
Sleeping
Sleeping
import torch | |
from .base_optimizer import BaseOptimizer | |
class Backpropagation(BaseOptimizer): | |
def __init__(self, model, config): | |
super().__init__(model, config) | |
self.optimizer = torch.optim.AdamW( | |
self.model.parameters(), | |
lr=float(config.get('learning_rate', 5e-5)), | |
weight_decay=float(config.get('weight_decay', 0.01)) | |
) | |
def set_accelerator(self, accelerator): | |
super().set_accelerator(accelerator) | |
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) | |
def step(self, inputs, labels): | |
self.model.train() | |
outputs = self.model(**inputs, labels=labels) | |
loss = outputs.loss | |
self.accelerator.backward(loss) | |
self.optimizer.step() | |
self.optimizer.zero_grad(set_to_none=True) | |
return float(loss.item()) | |