rider-provider-777 commited on
Commit
e016a55
·
verified ·
1 Parent(s): 26a82c5

Upload 5 files

Browse files
algorithms/backprop.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .base_optimizer import BaseOptimizer
3
+
4
+ class Backpropagation(BaseOptimizer):
5
+ def __init__(self, model, config):
6
+ super().__init__(model, config)
7
+ self.optimizer = torch.optim.AdamW(
8
+ self.model.parameters(),
9
+ lr=float(config.get('learning_rate', 5e-5)),
10
+ weight_decay=float(config.get('weight_decay', 0.01))
11
+ )
12
+
13
+ def set_accelerator(self, accelerator):
14
+ super().set_accelerator(accelerator)
15
+ self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
16
+
17
+ def step(self, inputs, labels):
18
+ self.model.train()
19
+ outputs = self.model(**inputs, labels=labels)
20
+ loss = outputs.loss
21
+ self.accelerator.backward(loss)
22
+ self.optimizer.step()
23
+ self.optimizer.zero_grad(set_to_none=True)
24
+ return float(loss.item())
algorithms/base_optimizer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ class BaseOptimizer(ABC):
4
+ """Abstract base class for all training algorithms."""
5
+ def __init__(self, model, config):
6
+ self.model = model
7
+ self.config = config
8
+ self.accelerator = None
9
+
10
+ def set_accelerator(self, accelerator):
11
+ self.accelerator = accelerator
12
+
13
+ @abstractmethod
14
+ def step(self, inputs, labels):
15
+ """Performs a single training step; must return a Python float loss."""
16
+ raise NotImplementedError
algorithms/feedback_alignment.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .base_optimizer import BaseOptimizer
4
+
5
+ class FeedbackAlignment(BaseOptimizer):
6
+ """Direct Feedback Alignment on the custom ResearchTransformer."""
7
+ def __init__(self, model, config):
8
+ super().__init__(model, config)
9
+ self.lr = float(config.get('learning_rate', 1e-4))
10
+ self.optimizer = torch.optim.AdamW([p for b in model.h for p in b.parameters()] + list(model.ln_f.parameters()) + list(model.lm_head.parameters()), lr=self.lr)
11
+ self.ce = nn.CrossEntropyLoss()
12
+ self.feedback = nn.ModuleList([
13
+ nn.Linear(model.config.n_embd, model.config.n_embd, bias=False) for _ in model.h
14
+ ])
15
+ for fb in self.feedback:
16
+ for p in fb.parameters():
17
+ p.requires_grad_(False)
18
+
19
+ def set_accelerator(self, accelerator):
20
+ super().set_accelerator(accelerator)
21
+ self.model, self.optimizer, self.feedback = self.accelerator.prepare(self.model, self.optimizer, self.feedback)
22
+
23
+ def step(self, inputs, labels):
24
+ self.model.train()
25
+ self.optimizer.zero_grad(set_to_none=True)
26
+
27
+ input_ids = inputs['input_ids']
28
+ device = input_ids.device
29
+ B, T = input_ids.shape
30
+ pos = torch.arange(0, T, device=device).unsqueeze(0)
31
+ x = self.model.wte(input_ids) + self.model.wpe(pos)
32
+ x = self.model.drop(x)
33
+
34
+ block_outs = []
35
+ for block in self.model.h:
36
+ x = block(x.detach())
37
+ block_outs.append(x)
38
+
39
+ x_final = self.model.ln_f(block_outs[-1])
40
+ logits = self.model.lm_head(x_final)
41
+
42
+ B, T, V = logits.shape
43
+ loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))
44
+
45
+ grad_final, = torch.autograd.grad(loss, block_outs[-1], retain_graph=True)
46
+
47
+ for i in reversed(range(len(block_outs))):
48
+ pseudo_err = self.feedback[i](grad_final.detach())
49
+ block_outs[i].backward(pseudo_err, retain_graph=True)
50
+
51
+ self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
52
+ self.optimizer.step()
53
+
54
+ return float(loss.item())
algorithms/forward_forward.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .base_optimizer import BaseOptimizer
4
+
5
+ class ForwardForward(BaseOptimizer):
6
+ def __init__(self, model, config):
7
+ super().__init__(model, config)
8
+ self.threshold = float(config.get('threshold', 2.0))
9
+ self.lr = float(config.get('learning_rate', 1e-4))
10
+ self.optimizers = [torch.optim.AdamW(b.parameters(), lr=self.lr) for b in model.h]
11
+ self.sigmoid = nn.Sigmoid()
12
+
13
+ def set_accelerator(self, accelerator):
14
+ super().set_accelerator(accelerator)
15
+ prepared = self.accelerator.prepare(self.model, *self.optimizers)
16
+ self.model = prepared[0]
17
+ self.optimizers = list(prepared[1:])
18
+
19
+ def step(self, inputs, labels):
20
+ self.model.train()
21
+ total = 0.0
22
+
23
+ input_ids = inputs['input_ids']
24
+ device = input_ids.device
25
+ B, T = input_ids.shape
26
+ pos = torch.arange(0, T, device=device).unsqueeze(0)
27
+ x = self.model.wte(input_ids) + self.model.wpe(pos)
28
+ x = self.model.drop(x)
29
+
30
+ neg_ids = input_ids.clone()
31
+ for b in range(B):
32
+ idx = torch.randperm(T, device=device)
33
+ neg_ids[b] = neg_ids[b, idx]
34
+ x_neg = self.model.wte(neg_ids) + self.model.wpe(pos)
35
+ x_neg = self.model.drop(x_neg)
36
+
37
+ for i, block in enumerate(self.model.h):
38
+ opt = self.optimizers[i]
39
+ opt.zero_grad(set_to_none=True)
40
+
41
+ xp = x.detach()
42
+ xn = x_neg.detach()
43
+
44
+ op = block(xp)
45
+ on = block(xn)
46
+
47
+ gp = (op.pow(2).mean(dim=-1))
48
+ gn = (on.pow(2).mean(dim=-1))
49
+
50
+ loss = torch.log1p(torch.exp(-(gp - self.threshold))).mean() + torch.log1p(torch.exp(gn - self.threshold)).mean()
51
+ self.accelerator.backward(loss)
52
+ opt.step()
53
+ total += float(loss.item())
54
+
55
+ x = op.detach()
56
+ x_neg = on.detach()
57
+
58
+ with torch.no_grad():
59
+ logits = self.model.lm_head(self.model.ln_f(x))
60
+ V = logits.size(-1)
61
+ ce = nn.CrossEntropyLoss()
62
+ proxy = ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1)).item()
63
+ return total / max(1, len(self.model.h))
algorithms/synthetic_gradients.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from .base_optimizer import BaseOptimizer
5
+
6
+ class GradientSynthesizer(nn.Module):
7
+ def __init__(self, hidden_size):
8
+ super().__init__()
9
+ self.net = nn.Sequential(
10
+ nn.Linear(hidden_size, hidden_size),
11
+ nn.ReLU(),
12
+ nn.Linear(hidden_size, hidden_size)
13
+ )
14
+ def forward(self, x):
15
+ return self.net(x.detach())
16
+
17
+ class SyntheticGradients(BaseOptimizer):
18
+ def __init__(self, model, config):
19
+ super().__init__(model, config)
20
+ self.hidden = model.config.n_embd
21
+ self.main_lr = float(config.get('main_learning_rate', 1e-5))
22
+ self.synth_lr = float(config.get('synth_learning_rate', 1e-4))
23
+ self.model_opt = torch.optim.AdamW(model.parameters(), lr=self.main_lr)
24
+ self.synths = nn.ModuleList([GradientSynthesizer(self.hidden) for _ in model.h])
25
+ self.synth_opts = [torch.optim.AdamW(s.parameters(), lr=self.synth_lr) for s in self.synths]
26
+ self.ce = nn.CrossEntropyLoss()
27
+
28
+ def set_accelerator(self, accelerator):
29
+ super().set_accelerator(accelerator)
30
+ packs = [self.model, self.model_opt, self.synths] + self.synth_opts
31
+ prepped = self.accelerator.prepare(*packs)
32
+ self.model, self.model_opt, self.synths, *self.synth_opts = prepped
33
+
34
+ def step(self, inputs, labels):
35
+ self.model.train()
36
+ for opt in self.synth_opts:
37
+ opt.zero_grad(set_to_none=True)
38
+ self.model_opt.zero_grad(set_to_none=True)
39
+
40
+ logits, block_outs = self.model(inputs['input_ids'], return_activations=True)
41
+ for i, out in enumerate(block_outs):
42
+ pred_grad = self.synths[i](out)
43
+ out.backward(pred_grad, retain_graph=True)
44
+ self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
45
+ self.model_opt.step()
46
+ self.model_opt.zero_grad(set_to_none=True)
47
+
48
+ logits, block_outs = self.model(inputs['input_ids'], return_activations=True)
49
+ B, T, V = logits.shape
50
+ task_loss = self.ce(logits[:, :-1, :].contiguous().view(-1, V), labels[:, 1:].contiguous().view(-1))
51
+ self.accelerator.backward(task_loss, retain_graph=True)
52
+
53
+ for i, out in enumerate(block_outs):
54
+ if out.grad is not None:
55
+ true_grad = out.grad.detach()
56
+ pred_grad = self.synths[i](out.detach())
57
+ synth_loss = F.mse_loss(pred_grad, true_grad)
58
+ self.accelerator.backward(synth_loss)
59
+ self.synth_opts[i].step()
60
+ self.synth_opts[i].zero_grad(set_to_none=True)
61
+
62
+ return float(task_loss.item())