import torch from torch import nn from .base_optimizer import BaseOptimizer class ForwardForward(BaseOptimizer): def __init__(self, model, config): super().__init__(model, config) self.threshold = float(config.get('threshold', 2.0)) self.lr = float(config.get('learning_rate', 1e-4)) self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h] self.sigmoid = nn.Sigmoid() def set_accelerator(self, accelerator): super().set_accelerator(accelerator) prepared = self.accelerator.prepare(self.model, *self.optimizers) self.model = prepared[0] self.optimizers = list(prepared[1:]) def step(self, inputs, labels): self.model.train() total = 0.0 input_ids = inputs['input_ids'] device = input_ids.device B, T = input_ids.shape pos = torch.arange(0, T, device=device).unsqueeze(0) x = self.model.wte(input_ids) + self.model.wpe(pos) x = self.model.drop(x) neg_ids = input_ids.clone() for b in range(B): idx = torch.randperm(T, device=device) neg_ids[b] = neg_ids[b, idx] x_neg = self.model.wte(neg_ids) + self.model.wpe(pos) x_neg = self.model.drop(x_neg) for i, block in enumerate(self.model.h): opt = self.optimizers[i] opt.zero_grad(set_to_none=True) xp = x.detach() xn = x_neg.detach() op = block(xp) on = block(xn) gp = (op.pow(2).mean(dim=-1)) gn = (on.pow(2).mean(dim=-1)) loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean() self.accelerator.backward(loss) opt.step() total += float(loss.item()) x = op.detach() x_neg = on.detach() with torch.no_grad(): logits = self.model.lm_head(self.model.ln_f(x)) V = logits.size(-1) ce = nn.CrossEntropyLoss() proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item() return total / max(1, len(self.model.h))