File size: 2,187 Bytes
e016a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch import nn
from .base_optimizer import BaseOptimizer

class FeedbackAlignment(BaseOptimizer):
    """Direct Feedback Alignment on the custom ResearchTransformer."""
    def __init__(self, model, config):
        super().__init__(model, config)
        self.lr = float(config.get('learning_rate', 1e-4))
        self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr)
        self.ce = nn.CrossEntropyLoss()
        self.feedback = nn.ModuleList([
            nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h
        ])
        for fb in self.feedback:
            for p in fb.parameters():
                p.requires_grad_(False)

    def set_accelerator(self, accelerator):
        super().set_accelerator(accelerator)
        self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback)

    def step(self, inputs, labels):
        self.model.train()
        self.optimizer.zero_grad(set_to_none=True)

        input_ids = inputs['input_ids']
        device = input_ids.device
        B, T = input_ids.shape
        pos = torch.arange(0, T, device=device).unsqueeze(0)
        x = self.model.wte(input_ids) + self.model.wpe(pos)
        x = self.model.drop(x)

        block_outs = []
        for block in self.model.h:
            x = block(x.detach())
            block_outs.append(x)

        x_final = self.model.ln_f(block_outs[-1])
        logits = self.model.lm_head(x_final)

        B, T, V = logits.shape
        loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))

        grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True)

        for i in reversed(range(len(block_outs))):
            pseudo_err = self.feedback[i](grad_final.detach())
            block_outs[i].backward(pseudo_err, retain_graph=True)

        self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        return float(loss.item())