File size: 2,293 Bytes
e016a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
from .base_optimizer import BaseOptimizer

class ForwardForward(BaseOptimizer):
    def __init__(self, model, config):
        super().__init__(model, config)
        self.threshold = float(config.get('threshold', 2.0))
        self.lr = float(config.get('learning_rate', 1e-4))
        self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h]
        self.sigmoid = nn.Sigmoid()

    def set_accelerator(self, accelerator):
        super().set_accelerator(accelerator)
        prepared = self.accelerator.prepare(self.model, *self.optimizers)
        self.model = prepared[0]
        self.optimizers = list(prepared[1:])

    def step(self, inputs, labels):
        self.model.train()
        total = 0.0

        input_ids = inputs['input_ids']
        device = input_ids.device
        B, T = input_ids.shape
        pos = torch.arange(0, T, device=device).unsqueeze(0)
        x = self.model.wte(input_ids) + self.model.wpe(pos)
        x = self.model.drop(x)

        neg_ids = input_ids.clone()
        for b in range(B):
            idx = torch.randperm(T, device=device)
            neg_ids[b] = neg_ids[b, idx]
        x_neg = self.model.wte(neg_ids) + self.model.wpe(pos)
        x_neg = self.model.drop(x_neg)

        for i, block in enumerate(self.model.h):
            opt = self.optimizers[i]
            opt.zero_grad(set_to_none=True)

            xp = x.detach()
            xn = x_neg.detach()

            op = block(xp)
            on = block(xn)

            gp = (op.pow(2).mean(dim=-1))
            gn = (on.pow(2).mean(dim=-1))

            loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean()
            self.accelerator.backward(loss)
            opt.step()
            total += float(loss.item())

            x = op.detach()
            x_neg = on.detach()

        with torch.no_grad():
            logits = self.model.lm_head(self.model.ln_f(x))
            V = logits.size(-1)
            ce = nn.CrossEntropyLoss()
            proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item()
        return total / max(1, len(self.model.h))