Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from .base_optimizer import BaseOptimizer | |
| class ForwardForward(BaseOptimizer): | |
| def __init__(self, model, config): | |
| super().__init__(model, config) | |
| self.threshold = float(config.get('threshold', 2.0)) | |
| self.lr = float(config.get('learning_rate', 1e-4)) | |
| self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h] | |
| self.sigmoid = nn.Sigmoid() | |
| def set_accelerator(self, accelerator): | |
| super().set_accelerator(accelerator) | |
| prepared = self.accelerator.prepare(self.model, *self.optimizers) | |
| self.model = prepared[0] | |
| self.optimizers = list(prepared[1:]) | |
| def step(self, inputs, labels): | |
| self.model.train() | |
| total = 0.0 | |
| input_ids = inputs['input_ids'] | |
| device = input_ids.device | |
| B, T = input_ids.shape | |
| pos = torch.arange(0, T, device=device).unsqueeze(0) | |
| x = self.model.wte(input_ids) + self.model.wpe(pos) | |
| x = self.model.drop(x) | |
| neg_ids = input_ids.clone() | |
| for b in range(B): | |
| idx = torch.randperm(T, device=device) | |
| neg_ids[b] = neg_ids[b, idx] | |
| x_neg = self.model.wte(neg_ids) + self.model.wpe(pos) | |
| x_neg = self.model.drop(x_neg) | |
| for i, block in enumerate(self.model.h): | |
| opt = self.optimizers[i] | |
| opt.zero_grad(set_to_none=True) | |
| xp = x.detach() | |
| xn = x_neg.detach() | |
| op = block(xp) | |
| on = block(xn) | |
| gp = (op.pow(2).mean(dim=-1)) | |
| gn = (on.pow(2).mean(dim=-1)) | |
| loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean() | |
| self.accelerator.backward(loss) | |
| opt.step() | |
| total += float(loss.item()) | |
| x = op.detach() | |
| x_neg = on.detach() | |
| with torch.no_grad(): | |
| logits = self.model.lm_head(self.model.ln_f(x)) | |
| V = logits.size(-1) | |
| ce = nn.CrossEntropyLoss() | |
| proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item() | |
| return total / max(1, len(self.model.h)) | |