Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from .base_optimizer import BaseOptimizer | |
class GradientSynthesizer(nn.Module): | |
def __init__(self, hidden_size): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size), | |
nn.ReLU(), | |
nn.Linear(hidden_size, hidden_size) | |
) | |
def forward(self, x): | |
return self.net(x.detach()) | |
class SyntheticGradients(BaseOptimizer): | |
def __init__(self, model, config): | |
super().__init__(model, config) | |
self.hidden = model.config.n_embd | |
self.main_lr = float(config.get('main_learning_rate', 1e-5)) | |
self.synth_lr = float(config.get('synth_learning_rate', 1e-4)) | |
self.model_opt = torch.optim.AdamW(model.parameters(), lr=self.main_lr) | |
self.synths = nn.ModuleList([GradientSynthesizer(self.hidden) for _ in model.h]) | |
self.synth_opts = [torch.optim.AdamW(s.parameters(), lr=self.synth_lr) for s in self.synths] | |
self.ce = nn.CrossEntropyLoss() | |
def set_accelerator(self, accelerator): | |
super().set_accelerator(accelerator) | |
packs = [self.model, self.model_opt, self.synths] + self.synth_opts | |
prepped = self.accelerator.prepare(*packs) | |
self.model, self.model_opt, self.synths, *self.synth_opts = prepped | |
def step(self, inputs, labels): | |
self.model.train() | |
for opt in self.synth_opts: | |
opt.zero_grad(set_to_none=True) | |
self.model_opt.zero_grad(set_to_none=True) | |
logits, block_outs = self.model(inputs['input_ids'], return_activations=True) | |
for i, out in enumerate(block_outs): | |
pred_grad = self.synths[i](out) | |
out.backward(pred_grad, retain_graph=True) | |
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0) | |
self.model_opt.step() | |
self.model_opt.zero_grad(set_to_none=True) | |
logits, block_outs = self.model(inputs['input_ids'], return_activations=True) | |
B, T, V = logits.shape | |
task_loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)) | |
self.accelerator.backward(task_loss, retain_graph=True) | |
for i, out in enumerate(block_outs): | |
if out.grad is not None: | |
true_grad = out.grad.detach() | |
pred_grad = self.synths[i](out.detach()) | |
synth_loss = F.mse_loss(pred_grad, true_grad) | |
self.accelerator.backward(synth_loss) | |
self.synth_opts[i].step() | |
self.synth_opts[i].zero_grad(set_to_none=True) | |
return float(task_loss.item()) | |