training_bench / algorithms /feedback_alignment.py
rider-provider-777's picture
Upload 5 files
e016a55 verified
raw
history blame
2.19 kB
import torch
from torch import nn
from .base_optimizer import BaseOptimizer
class FeedbackAlignment(BaseOptimizer):
"""Direct Feedback Alignment on the custom ResearchTransformer."""
def __init__(self, model, config):
super().__init__(model, config)
self.lr = float(config.get('learning_rate', 1e-4))
self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr)
self.ce = nn.CrossEntropyLoss()
self.feedback = nn.ModuleList([
nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h
])
for fb in self.feedback:
for p in fb.parameters():
p.requires_grad_(False)
def set_accelerator(self, accelerator):
super().set_accelerator(accelerator)
self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback)
def step(self, inputs, labels):
self.model.train()
self.optimizer.zero_grad(set_to_none=True)
input_ids = inputs['input_ids']
device = input_ids.device
B, T = input_ids.shape
pos = torch.arange(0, T, device=device).unsqueeze(0)
x = self.model.wte(input_ids) + self.model.wpe(pos)
x = self.model.drop(x)
block_outs = []
for block in self.model.h:
x = block(x.detach())
block_outs.append(x)
x_final = self.model.ln_f(block_outs[-1])
logits = self.model.lm_head(x_final)
B, T, V = logits.shape
loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))
grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True)
for i in reversed(range(len(block_outs))):
pseudo_err = self.feedback[i](grad_final.detach())
block_outs[i].backward(pseudo_err, retain_graph=True)
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
return float(loss.item())