Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- algorithms/backprop.py +24 -0
- algorithms/base_optimizer.py +16 -0
- algorithms/feedback_alignment.py +54 -0
- algorithms/forward_forward.py +63 -0
- algorithms/synthetic_gradients.py +62 -0
algorithms/backprop.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .base_optimizer import BaseOptimizer
|
| 3 |
+
|
| 4 |
+
class Backpropagation(BaseOptimizer):
|
| 5 |
+
def __init__(self, model, config):
|
| 6 |
+
super().__init__(model, config)
|
| 7 |
+
self.optimizer = torch.optim.AdamW(
|
| 8 |
+
self.model.parameters(),
|
| 9 |
+
lr=float(config.get('learning_rate', 5e-5)),
|
| 10 |
+
weight_decay=float(config.get('weight_decay', 0.01))
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
def set_accelerator(self, accelerator):
|
| 14 |
+
super().set_accelerator(accelerator)
|
| 15 |
+
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
|
| 16 |
+
|
| 17 |
+
def step(self, inputs, labels):
|
| 18 |
+
self.model.train()
|
| 19 |
+
outputs = self.model(**inputs, labels=labels)
|
| 20 |
+
loss = outputs.loss
|
| 21 |
+
self.accelerator.backward(loss)
|
| 22 |
+
self.optimizer.step()
|
| 23 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 24 |
+
return float(loss.item())
|
algorithms/base_optimizer.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
class BaseOptimizer(ABC):
|
| 4 |
+
"""Abstract base class for all training algorithms."""
|
| 5 |
+
def __init__(self, model, config):
|
| 6 |
+
self.model = model
|
| 7 |
+
self.config = config
|
| 8 |
+
self.accelerator = None
|
| 9 |
+
|
| 10 |
+
def set_accelerator(self, accelerator):
|
| 11 |
+
self.accelerator = accelerator
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def step(self, inputs, labels):
|
| 15 |
+
"""Performs a single training step; must return a Python float loss."""
|
| 16 |
+
raise NotImplementedError
|
algorithms/feedback_alignment.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .base_optimizer import BaseOptimizer
|
| 4 |
+
|
| 5 |
+
class FeedbackAlignment(BaseOptimizer):
|
| 6 |
+
"""Direct Feedback Alignment on the custom ResearchTransformer."""
|
| 7 |
+
def __init__(self, model, config):
|
| 8 |
+
super().__init__(model, config)
|
| 9 |
+
self.lr = float(config.get('learning_rate', 1e-4))
|
| 10 |
+
self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr)
|
| 11 |
+
self.ce = nn.CrossEntropyLoss()
|
| 12 |
+
self.feedback = nn.ModuleList([
|
| 13 |
+
nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h
|
| 14 |
+
])
|
| 15 |
+
for fb in self.feedback:
|
| 16 |
+
for p in fb.parameters():
|
| 17 |
+
p.requires_grad_(False)
|
| 18 |
+
|
| 19 |
+
def set_accelerator(self, accelerator):
|
| 20 |
+
super().set_accelerator(accelerator)
|
| 21 |
+
self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback)
|
| 22 |
+
|
| 23 |
+
def step(self, inputs, labels):
|
| 24 |
+
self.model.train()
|
| 25 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 26 |
+
|
| 27 |
+
input_ids = inputs['input_ids']
|
| 28 |
+
device = input_ids.device
|
| 29 |
+
B, T = input_ids.shape
|
| 30 |
+
pos = torch.arange(0, T, device=device).unsqueeze(0)
|
| 31 |
+
x = self.model.wte(input_ids) + self.model.wpe(pos)
|
| 32 |
+
x = self.model.drop(x)
|
| 33 |
+
|
| 34 |
+
block_outs = []
|
| 35 |
+
for block in self.model.h:
|
| 36 |
+
x = block(x.detach())
|
| 37 |
+
block_outs.append(x)
|
| 38 |
+
|
| 39 |
+
x_final = self.model.ln_f(block_outs[-1])
|
| 40 |
+
logits = self.model.lm_head(x_final)
|
| 41 |
+
|
| 42 |
+
B, T, V = logits.shape
|
| 43 |
+
loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))
|
| 44 |
+
|
| 45 |
+
grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True)
|
| 46 |
+
|
| 47 |
+
for i in reversed(range(len(block_outs))):
|
| 48 |
+
pseudo_err = self.feedback[i](grad_final.detach())
|
| 49 |
+
block_outs[i].backward(pseudo_err, retain_graph=True)
|
| 50 |
+
|
| 51 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 52 |
+
self.optimizer.step()
|
| 53 |
+
|
| 54 |
+
return float(loss.item())
|
algorithms/forward_forward.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from .base_optimizer import BaseOptimizer
|
| 4 |
+
|
| 5 |
+
class ForwardForward(BaseOptimizer):
|
| 6 |
+
def __init__(self, model, config):
|
| 7 |
+
super().__init__(model, config)
|
| 8 |
+
self.threshold = float(config.get('threshold', 2.0))
|
| 9 |
+
self.lr = float(config.get('learning_rate', 1e-4))
|
| 10 |
+
self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h]
|
| 11 |
+
self.sigmoid = nn.Sigmoid()
|
| 12 |
+
|
| 13 |
+
def set_accelerator(self, accelerator):
|
| 14 |
+
super().set_accelerator(accelerator)
|
| 15 |
+
prepared = self.accelerator.prepare(self.model, *self.optimizers)
|
| 16 |
+
self.model = prepared[0]
|
| 17 |
+
self.optimizers = list(prepared[1:])
|
| 18 |
+
|
| 19 |
+
def step(self, inputs, labels):
|
| 20 |
+
self.model.train()
|
| 21 |
+
total = 0.0
|
| 22 |
+
|
| 23 |
+
input_ids = inputs['input_ids']
|
| 24 |
+
device = input_ids.device
|
| 25 |
+
B, T = input_ids.shape
|
| 26 |
+
pos = torch.arange(0, T, device=device).unsqueeze(0)
|
| 27 |
+
x = self.model.wte(input_ids) + self.model.wpe(pos)
|
| 28 |
+
x = self.model.drop(x)
|
| 29 |
+
|
| 30 |
+
neg_ids = input_ids.clone()
|
| 31 |
+
for b in range(B):
|
| 32 |
+
idx = torch.randperm(T, device=device)
|
| 33 |
+
neg_ids[b] = neg_ids[b, idx]
|
| 34 |
+
x_neg = self.model.wte(neg_ids) + self.model.wpe(pos)
|
| 35 |
+
x_neg = self.model.drop(x_neg)
|
| 36 |
+
|
| 37 |
+
for i, block in enumerate(self.model.h):
|
| 38 |
+
opt = self.optimizers[i]
|
| 39 |
+
opt.zero_grad(set_to_none=True)
|
| 40 |
+
|
| 41 |
+
xp = x.detach()
|
| 42 |
+
xn = x_neg.detach()
|
| 43 |
+
|
| 44 |
+
op = block(xp)
|
| 45 |
+
on = block(xn)
|
| 46 |
+
|
| 47 |
+
gp = (op.pow(2).mean(dim=-1))
|
| 48 |
+
gn = (on.pow(2).mean(dim=-1))
|
| 49 |
+
|
| 50 |
+
loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean()
|
| 51 |
+
self.accelerator.backward(loss)
|
| 52 |
+
opt.step()
|
| 53 |
+
total += float(loss.item())
|
| 54 |
+
|
| 55 |
+
x = op.detach()
|
| 56 |
+
x_neg = on.detach()
|
| 57 |
+
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
logits = self.model.lm_head(self.model.ln_f(x))
|
| 60 |
+
V = logits.size(-1)
|
| 61 |
+
ce = nn.CrossEntropyLoss()
|
| 62 |
+
proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item()
|
| 63 |
+
return total / max(1, len(self.model.h))
|
algorithms/synthetic_gradients.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from .base_optimizer import BaseOptimizer
|
| 5 |
+
|
| 6 |
+
class GradientSynthesizer(nn.Module):
|
| 7 |
+
def __init__(self, hidden_size):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.net = nn.Sequential(
|
| 10 |
+
nn.Linear(hidden_size, hidden_size),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Linear(hidden_size, hidden_size)
|
| 13 |
+
)
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return self.net(x.detach())
|
| 16 |
+
|
| 17 |
+
class SyntheticGradients(BaseOptimizer):
|
| 18 |
+
def __init__(self, model, config):
|
| 19 |
+
super().__init__(model, config)
|
| 20 |
+
self.hidden = model.config.n_embd
|
| 21 |
+
self.main_lr = float(config.get('main_learning_rate', 1e-5))
|
| 22 |
+
self.synth_lr = float(config.get('synth_learning_rate', 1e-4))
|
| 23 |
+
self.model_opt = torch.optim.AdamW(model.parameters(), lr=self.main_lr)
|
| 24 |
+
self.synths = nn.ModuleList([GradientSynthesizer(self.hidden) for _ in model.h])
|
| 25 |
+
self.synth_opts = [torch.optim.AdamW(s.parameters(), lr=self.synth_lr) for s in self.synths]
|
| 26 |
+
self.ce = nn.CrossEntropyLoss()
|
| 27 |
+
|
| 28 |
+
def set_accelerator(self, accelerator):
|
| 29 |
+
super().set_accelerator(accelerator)
|
| 30 |
+
packs = [self.model, self.model_opt, self.synths] + self.synth_opts
|
| 31 |
+
prepped = self.accelerator.prepare(*packs)
|
| 32 |
+
self.model, self.model_opt, self.synths, *self.synth_opts = prepped
|
| 33 |
+
|
| 34 |
+
def step(self, inputs, labels):
|
| 35 |
+
self.model.train()
|
| 36 |
+
for opt in self.synth_opts:
|
| 37 |
+
opt.zero_grad(set_to_none=True)
|
| 38 |
+
self.model_opt.zero_grad(set_to_none=True)
|
| 39 |
+
|
| 40 |
+
logits, block_outs = self.model(inputs['input_ids'], return_activations=True)
|
| 41 |
+
for i, out in enumerate(block_outs):
|
| 42 |
+
pred_grad = self.synths[i](out)
|
| 43 |
+
out.backward(pred_grad, retain_graph=True)
|
| 44 |
+
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 45 |
+
self.model_opt.step()
|
| 46 |
+
self.model_opt.zero_grad(set_to_none=True)
|
| 47 |
+
|
| 48 |
+
logits, block_outs = self.model(inputs['input_ids'], return_activations=True)
|
| 49 |
+
B, T, V = logits.shape
|
| 50 |
+
task_loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))
|
| 51 |
+
self.accelerator.backward(task_loss, retain_graph=True)
|
| 52 |
+
|
| 53 |
+
for i, out in enumerate(block_outs):
|
| 54 |
+
if out.grad is not None:
|
| 55 |
+
true_grad = out.grad.detach()
|
| 56 |
+
pred_grad = self.synths[i](out.detach())
|
| 57 |
+
synth_loss = F.mse_loss(pred_grad, true_grad)
|
| 58 |
+
self.accelerator.backward(synth_loss)
|
| 59 |
+
self.synth_opts[i].step()
|
| 60 |
+
self.synth_opts[i].zero_grad(set_to_none=True)
|
| 61 |
+
|
| 62 |
+
return float(task_loss.item())
|