Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from .base_optimizer import BaseOptimizer | |
class ForwardForward(BaseOptimizer): | |
def __init__(self, model, config): | |
super().__init__(model, config) | |
self.threshold = float(config.get('threshold', 2.0)) | |
self.lr = float(config.get('learning_rate', 1e-4)) | |
self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h] | |
self.sigmoid = nn.Sigmoid() | |
def set_accelerator(self, accelerator): | |
super().set_accelerator(accelerator) | |
prepared = self.accelerator.prepare(self.model, *self.optimizers) | |
self.model = prepared[0] | |
self.optimizers = list(prepared[1:]) | |
def step(self, inputs, labels): | |
self.model.train() | |
total = 0.0 | |
input_ids = inputs['input_ids'] | |
device = input_ids.device | |
B, T = input_ids.shape | |
pos = torch.arange(0, T, device=device).unsqueeze(0) | |
x = self.model.wte(input_ids) + self.model.wpe(pos) | |
x = self.model.drop(x) | |
neg_ids = input_ids.clone() | |
for b in range(B): | |
idx = torch.randperm(T, device=device) | |
neg_ids[b] = neg_ids[b, idx] | |
x_neg = self.model.wte(neg_ids) + self.model.wpe(pos) | |
x_neg = self.model.drop(x_neg) | |
for i, block in enumerate(self.model.h): | |
opt = self.optimizers[i] | |
opt.zero_grad(set_to_none=True) | |
xp = x.detach() | |
xn = x_neg.detach() | |
op = block(xp) | |
on = block(xn) | |
gp = (op.pow(2).mean(dim=-1)) | |
gn = (on.pow(2).mean(dim=-1)) | |
loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean() | |
self.accelerator.backward(loss) | |
opt.step() | |
total += float(loss.item()) | |
x = op.detach() | |
x_neg = on.detach() | |
with torch.no_grad(): | |
logits = self.model.lm_head(self.model.ln_f(x)) | |
V = logits.size(-1) | |
ce = nn.CrossEntropyLoss() | |
proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item() | |
return total / max(1, len(self.model.h)) | |