import torch from torch import nn import torch.nn.functional as F from .base_optimizer import BaseOptimizer class GradientSynthesizer(nn.Module): def __init__(self, hidden_size): super().__init__() self.net = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size) ) def forward(self, x): return self.net(x.detach()) class SyntheticGradients(BaseOptimizer): def __init__(self, model, config): super().__init__(model, config) self.hidden = model.config.n_embd self.main_lr = float(config.get('main_learning_rate', 1e-5)) self.synth_lr = float(config.get('synth_learning_rate', 1e-4)) self.model_opt = torch.optim.AdamW(model.parameters(), lr=self.main_lr) self.synths = nn.ModuleList([GradientSynthesizer(self.hidden) for _ in model.h]) self.synth_opts = [torch.optim.AdamW(s.parameters(), lr=self.synth_lr) for s in self.synths] self.ce = nn.CrossEntropyLoss() def set_accelerator(self, accelerator): super().set_accelerator(accelerator) packs = [self.model, self.model_opt, self.synths] + self.synth_opts prepped = self.accelerator.prepare(*packs) self.model, self.model_opt, self.synths, *self.synth_opts = prepped def step(self, inputs, labels): self.model.train() for opt in self.synth_opts: opt.zero_grad(set_to_none=True) self.model_opt.zero_grad(set_to_none=True) logits, block_outs = self.model(inputs['input_ids'], return_activations=True) for i, out in enumerate(block_outs): pred_grad = self.synths[i](out) out.backward(pred_grad, retain_graph=True) self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) self.model_opt.step() self.model_opt.zero_grad(set_to_none=True) logits, block_outs = self.model(inputs['input_ids'], return_activations=True) B, T, V = logits.shape task_loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)) self.accelerator.backward(task_loss, retain_graph=True) for i, out in enumerate(block_outs): if out.grad is not None: true_grad = out.grad.detach() pred_grad = self.synths[i](out.detach()) synth_loss = F.mse_loss(pred_grad, true_grad) self.accelerator.backward(synth_loss) self.synth_opts[i].step() self.synth_opts[i].zero_grad(set_to_none=True) return float(task_loss.item())