rider-provider-777's picture
Upload 5 files
e016a55 verified
raw
history blame
891 Bytes
import torch
from .base_optimizer import BaseOptimizer
class Backpropagation(BaseOptimizer):
def __init__(self, model, config):
super().__init__(model, config)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=float(config.get('learning_rate', 5e-5)),
weight_decay=float(config.get('weight_decay', 0.01))
)
def set_accelerator(self, accelerator):
super().set_accelerator(accelerator)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
def step(self, inputs, labels):
self.model.train()
outputs = self.model(**inputs, labels=labels)
loss = outputs.loss
self.accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
return float(loss.item())