training_bench / algorithms /synthetic_gradients.py
rider-provider-777's picture
Upload 5 files
e016a55 verified
import torch
from torch import nn
import torch.nn.functional as F
from .base_optimizer import BaseOptimizer
class GradientSynthesizer(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size)
)
def forward(self, x):
return self.net(x.detach())
class SyntheticGradients(BaseOptimizer):
def __init__(self, model, config):
super().__init__(model, config)
self.hidden = model.config.n_embd
self.main_lr = float(config.get('main_learning_rate', 1e-5))
self.synth_lr = float(config.get('synth_learning_rate', 1e-4))
self.model_opt = torch.optim.AdamW(model.parameters(), lr=self.main_lr)
self.synths = nn.ModuleList([GradientSynthesizer(self.hidden) for _ in model.h])
self.synth_opts = [torch.optim.AdamW(s.parameters(), lr=self.synth_lr) for s in self.synths]
self.ce = nn.CrossEntropyLoss()
def set_accelerator(self, accelerator):
super().set_accelerator(accelerator)
packs = [self.model, self.model_opt, self.synths] + self.synth_opts
prepped = self.accelerator.prepare(*packs)
self.model, self.model_opt, self.synths, *self.synth_opts = prepped
def step(self, inputs, labels):
self.model.train()
for opt in self.synth_opts:
opt.zero_grad(set_to_none=True)
self.model_opt.zero_grad(set_to_none=True)
logits, block_outs = self.model(inputs['input_ids'], return_activations=True)
for i, out in enumerate(block_outs):
pred_grad = self.synths[i](out)
out.backward(pred_grad, retain_graph=True)
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
self.model_opt.step()
self.model_opt.zero_grad(set_to_none=True)
logits, block_outs = self.model(inputs['input_ids'], return_activations=True)
B, T, V = logits.shape
task_loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))
self.accelerator.backward(task_loss, retain_graph=True)
for i, out in enumerate(block_outs):
if out.grad is not None:
true_grad = out.grad.detach()
pred_grad = self.synths[i](out.detach())
synth_loss = F.mse_loss(pred_grad, true_grad)
self.accelerator.backward(synth_loss)
self.synth_opts[i].step()
self.synth_opts[i].zero_grad(set_to_none=True)
return float(task_loss.item())