import torch from torch import nn from .base_optimizer import BaseOptimizer class FeedbackAlignment(BaseOptimizer): """Direct Feedback Alignment on the custom ResearchTransformer.""" def __init__(self, model, config): super().__init__(model, config) self.lr = float(config.get('learning_rate', 1e-4)) self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr) self.ce = nn.CrossEntropyLoss() self.feedback = nn.ModuleList([ nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h ]) for fb in self.feedback: for p in fb.parameters(): p.requires_grad_(False) def set_accelerator(self, accelerator): super().set_accelerator(accelerator) self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback) def step(self, inputs, labels): self.model.train() self.optimizer.zero_grad(set_to_none=True) input_ids = inputs['input_ids'] device = input_ids.device B, T = input_ids.shape pos = torch.arange(0, T, device=device).unsqueeze(0) x = self.model.wte(input_ids) + self.model.wpe(pos) x = self.model.drop(x) block_outs = [] for block in self.model.h: x = block(x.detach()) block_outs.append(x) x_final = self.model.ln_f(block_outs[-1]) logits = self.model.lm_head(x_final) B, T, V = logits.shape loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)) grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True) for i in reversed(range(len(block_outs))): pseudo_err = self.feedback[i](grad_final.detach()) block_outs[i].backward(pseudo_err, retain_graph=True) self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() return float(loss.item())