File size: 891 Bytes
e016a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from .base_optimizer import BaseOptimizer

class Backpropagation(BaseOptimizer):
    def __init__(self, model, config):
        super().__init__(model, config)
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=float(config.get('learning_rate', 5e-5)),
            weight_decay=float(config.get('weight_decay', 0.01))
        )

    def set_accelerator(self, accelerator):
        super().set_accelerator(accelerator)
        self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)

    def step(self, inputs, labels):
        self.model.train()
        outputs = self.model(**inputs, labels=labels)
        loss = outputs.loss
        self.accelerator.backward(loss)
        self.optimizer.step()
        self.optimizer.zero_grad(set_to_none=True)
        return float(loss.item())