Spaces:
Sleeping
Sleeping
File size: 2,293 Bytes
e016a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
from torch import nn
from .base_optimizer import BaseOptimizer
class ForwardForward(BaseOptimizer):
def __init__(self, model, config):
super().__init__(model, config)
self.threshold = float(config.get('threshold', 2.0))
self.lr = float(config.get('learning_rate', 1e-4))
self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h]
self.sigmoid = nn.Sigmoid()
def set_accelerator(self, accelerator):
super().set_accelerator(accelerator)
prepared = self.accelerator.prepare(self.model, *self.optimizers)
self.model = prepared[0]
self.optimizers = list(prepared[1:])
def step(self, inputs, labels):
self.model.train()
total = 0.0
input_ids = inputs['input_ids']
device = input_ids.device
B, T = input_ids.shape
pos = torch.arange(0, T, device=device).unsqueeze(0)
x = self.model.wte(input_ids) + self.model.wpe(pos)
x = self.model.drop(x)
neg_ids = input_ids.clone()
for b in range(B):
idx = torch.randperm(T, device=device)
neg_ids[b] = neg_ids[b, idx]
x_neg = self.model.wte(neg_ids) + self.model.wpe(pos)
x_neg = self.model.drop(x_neg)
for i, block in enumerate(self.model.h):
opt = self.optimizers[i]
opt.zero_grad(set_to_none=True)
xp = x.detach()
xn = x_neg.detach()
op = block(xp)
on = block(xn)
gp = (op.pow(2).mean(dim=-1))
gn = (on.pow(2).mean(dim=-1))
loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean()
self.accelerator.backward(loss)
opt.step()
total += float(loss.item())
x = op.detach()
x_neg = on.detach()
with torch.no_grad():
logits = self.model.lm_head(self.model.ln_f(x))
V = logits.size(-1)
ce = nn.CrossEntropyLoss()
proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item()
return total / max(1, len(self.model.h))
|