Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from .base_optimizer import BaseOptimizer | |
class FeedbackAlignment(BaseOptimizer): | |
"""Direct Feedback Alignment on the custom ResearchTransformer.""" | |
def __init__(self, model, config): | |
super().__init__(model, config) | |
self.lr = float(config.get('learning_rate', 1e-4)) | |
self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr) | |
self.ce = nn.CrossEntropyLoss() | |
self.feedback = nn.ModuleList([ | |
nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h | |
]) | |
for fb in self.feedback: | |
for p in fb.parameters(): | |
p.requires_grad_(False) | |
def set_accelerator(self, accelerator): | |
super().set_accelerator(accelerator) | |
self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback) | |
def step(self, inputs, labels): | |
self.model.train() | |
self.optimizer.zero_grad(set_to_none=True) | |
input_ids = inputs['input_ids'] | |
device = input_ids.device | |
B, T = input_ids.shape | |
pos = torch.arange(0, T, device=device).unsqueeze(0) | |
x = self.model.wte(input_ids) + self.model.wpe(pos) | |
x = self.model.drop(x) | |
block_outs = [] | |
for block in self.model.h: | |
x = block(x.detach()) | |
block_outs.append(x) | |
x_final = self.model.ln_f(block_outs[-1]) | |
logits = self.model.lm_head(x_final) | |
B, T, V = logits.shape | |
loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)) | |
grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True) | |
for i in reversed(range(len(block_outs))): | |
pseudo_err = self.feedback[i](grad_final.detach()) | |
block_outs[i].backward(pseudo_err, retain_graph=True) | |
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) | |
self.optimizer.step() | |
return float(loss.item()) | |