training_bench / algorithms /forward_forward.py
rider-provider-777's picture
Upload 5 files
e016a55 verified
import torch
from torch import nn
from .base_optimizer import BaseOptimizer
class ForwardForward(BaseOptimizer):
def __init__(self, model, config):
super().__init__(model, config)
self.threshold = float(config.get('threshold', 2.0))
self.lr = float(config.get('learning_rate', 1e-4))
self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h]
self.sigmoid = nn.Sigmoid()
def set_accelerator(self, accelerator):
super().set_accelerator(accelerator)
prepared = self.accelerator.prepare(self.model, *self.optimizers)
self.model = prepared[0]
self.optimizers = list(prepared[1:])
def step(self, inputs, labels):
self.model.train()
total = 0.0
input_ids = inputs['input_ids']
device = input_ids.device
B, T = input_ids.shape
pos = torch.arange(0, T, device=device).unsqueeze(0)
x = self.model.wte(input_ids) + self.model.wpe(pos)
x = self.model.drop(x)
neg_ids = input_ids.clone()
for b in range(B):
idx = torch.randperm(T, device=device)
neg_ids[b] = neg_ids[b, idx]
x_neg = self.model.wte(neg_ids) + self.model.wpe(pos)
x_neg = self.model.drop(x_neg)
for i, block in enumerate(self.model.h):
opt = self.optimizers[i]
opt.zero_grad(set_to_none=True)
xp = x.detach()
xn = x_neg.detach()
op = block(xp)
on = block(xn)
gp = (op.pow(2).mean(dim=-1))
gn = (on.pow(2).mean(dim=-1))
loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean()
self.accelerator.backward(loss)
opt.step()
total += float(loss.item())
x = op.detach()
x_neg = on.detach()
with torch.no_grad():
logits = self.model.lm_head(self.model.ln_f(x))
V = logits.size(-1)
ce = nn.CrossEntropyLoss()
proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item()
return total / max(1, len(self.model.h))