Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from .base_optimizer import BaseOptimizer | |
| class FeedbackAlignment(BaseOptimizer): | |
| """Direct Feedback Alignment on the custom ResearchTransformer.""" | |
| def __init__(self, model, config): | |
| super().__init__(model, config) | |
| self.lr = float(config.get('learning_rate', 1e-4)) | |
| self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr) | |
| self.ce = nn.CrossEntropyLoss() | |
| self.feedback = nn.ModuleList([ | |
| nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h | |
| ]) | |
| for fb in self.feedback: | |
| for p in fb.parameters(): | |
| p.requires_grad_(False) | |
| def set_accelerator(self, accelerator): | |
| super().set_accelerator(accelerator) | |
| self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback) | |
| def step(self, inputs, labels): | |
| self.model.train() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| input_ids = inputs['input_ids'] | |
| device = input_ids.device | |
| B, T = input_ids.shape | |
| pos = torch.arange(0, T, device=device).unsqueeze(0) | |
| x = self.model.wte(input_ids) + self.model.wpe(pos) | |
| x = self.model.drop(x) | |
| block_outs = [] | |
| for block in self.model.h: | |
| x = block(x.detach()) | |
| block_outs.append(x) | |
| x_final = self.model.ln_f(block_outs[-1]) | |
| logits = self.model.lm_head(x_final) | |
| B, T, V = logits.shape | |
| loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)) | |
| grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True) | |
| for i in reversed(range(len(block_outs))): | |
| pseudo_err = self.feedback[i](grad_final.detach()) | |
| block_outs[i].backward(pseudo_err, retain_graph=True) | |
| self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) | |
| self.optimizer.step() | |
| return float(loss.item()) | |