Spaces:
Running
Running
update
Browse files- dist/index.html +0 -0
- dist/main.bundle.js +0 -0
- dist/main.bundle.js.map +0 -0
- src/index.html +0 -0
- src/memory.js +76 -40
dist/index.html
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
dist/main.bundle.js
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
dist/main.bundle.js.map
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/index.html
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/memory.js
CHANGED
@@ -11,48 +11,78 @@ export function activationMemory(
|
|
11 |
tp = 1, // tensor model parallelism
|
12 |
mixed = true,
|
13 |
recomputation = "none",
|
14 |
-
ff_activation = "relu"
|
|
|
15 |
) {
|
16 |
-
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation });
|
17 |
// https://arxiv.org/pdf/2205.05198
|
18 |
const bytesPerValue = mixed ? 2 : 4;
|
19 |
|
20 |
let oneLayerAttention;
|
21 |
if (recomputation === "none" || recomputation === "full") {
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
} else if (recomputation === "selective") {
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
} else {
|
26 |
throw new Error("Invalid recomputation value");
|
27 |
}
|
28 |
|
29 |
let oneLayerFeedforward;
|
30 |
if (ff_activation === "relu") {
|
31 |
-
|
32 |
-
+ s * b *
|
|
|
|
|
|
|
|
|
|
|
33 |
} else if (ff_activation === "gelu") {
|
34 |
-
|
35 |
-
+ s * b * h_ff * bytesPerValue / tp // inputs of
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
} else if (ff_activation === "swiglu") {
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
}
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
const inputDropout = s * b * h; // section 4.3
|
46 |
-
const outputLayerNorm = s * b * h * bytesPerValue;
|
47 |
-
const outputLayerProjection = s * b * h * bytesPerValue;
|
48 |
-
const outputCrossEntropy = s * b * v * 4; // In FP32
|
49 |
|
50 |
|
51 |
-
let oneLayer;
|
52 |
let data
|
53 |
if (recomputation === "none" || recomputation === "selective") {
|
54 |
|
55 |
-
|
56 |
name: "activationMemory",
|
57 |
children: [
|
58 |
...Array.from({ length: L }, (_, index) => ({
|
@@ -70,22 +100,20 @@ export function activationMemory(
|
|
70 |
]
|
71 |
};
|
72 |
} else if (recomputation === "full") {
|
73 |
-
|
74 |
name: "activationMemory",
|
75 |
children: [
|
76 |
-
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L},
|
77 |
{ name: 'Dropout', value: inputDropout },
|
78 |
{ name: 'LayerNorm', value: outputLayerNorm },
|
79 |
{ name: 'Projection', value: outputLayerProjection },
|
80 |
{ name: 'Cross Entropy', value: outputCrossEntropy }
|
81 |
]
|
82 |
};
|
83 |
-
|
84 |
throw new Error("Invalid recomputation value");
|
85 |
}
|
86 |
|
87 |
-
|
88 |
-
|
89 |
return data;
|
90 |
}
|
91 |
|
@@ -138,10 +166,11 @@ export function updateGraph() {
|
|
138 |
const mixed = document.getElementById('mixed').checked;
|
139 |
const recomputation = document.getElementById('recomputation').value;
|
140 |
const ff_activation = document.getElementById('ff_activation').value;
|
|
|
141 |
|
142 |
-
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation });
|
143 |
|
144 |
-
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation);
|
145 |
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
|
146 |
|
147 |
const data = {
|
@@ -167,7 +196,7 @@ export function updateGraph() {
|
|
167 |
const svg = d3.select("#graph").select("svg");
|
168 |
svg.selectAll("*").remove();
|
169 |
svg.attr("width", width)
|
170 |
-
|
171 |
|
172 |
const treemap = d3.treemap()
|
173 |
.size([width, height])
|
@@ -178,10 +207,10 @@ export function updateGraph() {
|
|
178 |
|
179 |
const root = d3.hierarchy(data)
|
180 |
.sum(d => d.value);
|
181 |
-
|
182 |
|
183 |
-
|
184 |
-
|
185 |
// root.value = fixedSize100GB;
|
186 |
// root.children[0].value = fixedSize100GB;
|
187 |
// }
|
@@ -191,7 +220,7 @@ export function updateGraph() {
|
|
191 |
treemap(root);
|
192 |
|
193 |
const color = d => {
|
194 |
-
switch(d.data.name) {
|
195 |
case 'Parameters': return '#4e79a7'; // Blue
|
196 |
case 'Gradients': return '#f28e2c'; // Orange
|
197 |
case 'OptimizerAverages': return '#e15759'; // Green
|
@@ -227,14 +256,14 @@ export function updateGraph() {
|
|
227 |
cell.append("text")
|
228 |
.attr("font-size", `${fontSize}px`)
|
229 |
.attr("font-family", "sans-serif")
|
230 |
-
.each(function(d) {
|
231 |
if (d.depth === 0) return; // Skip root node
|
232 |
|
233 |
const node = d3.select(this);
|
234 |
-
|
235 |
const name = d.data.name;
|
236 |
const value = formatBytes(d.value);
|
237 |
-
|
238 |
if (d.depth === 1 || d.depth === 2) {
|
239 |
node.attr("transform", `translate(${padding},${fontSize + padding})`)
|
240 |
.attr("font-weight", "bold")
|
@@ -294,10 +323,10 @@ function formatBytes(bytes) {
|
|
294 |
}
|
295 |
|
296 |
const presets = {
|
297 |
-
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu" },
|
298 |
-
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu" },
|
299 |
-
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu" },
|
300 |
-
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu" }
|
301 |
};
|
302 |
|
303 |
function setPresetValues(preset) {
|
@@ -345,7 +374,7 @@ function syncSliderAndInput(sliderId, inputId) {
|
|
345 |
|
346 |
export const init_memory_plot = function () {
|
347 |
console.log('Initializing memory plot');
|
348 |
-
|
349 |
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
|
350 |
sliderIds.forEach(id => {
|
351 |
const slider = document.getElementById(id);
|
@@ -385,6 +414,13 @@ export const init_memory_plot = function () {
|
|
385 |
console.warn('Mixed checkbox not found');
|
386 |
}
|
387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
const presetSelect = document.getElementById('presets');
|
389 |
if (presetSelect) {
|
390 |
presetSelect.addEventListener('change', (event) => {
|
@@ -398,7 +434,7 @@ export const init_memory_plot = function () {
|
|
398 |
sliderIds.forEach(id => {
|
399 |
const slider = document.getElementById(id);
|
400 |
if (slider) {
|
401 |
-
switch(id) {
|
402 |
case 'a': slider.max = '128'; break;
|
403 |
case 'b': slider.max = '53248'; break;
|
404 |
case 'h': slider.max = '16384'; break;
|
|
|
11 |
tp = 1, // tensor model parallelism
|
12 |
mixed = true,
|
13 |
recomputation = "none",
|
14 |
+
ff_activation = "relu",
|
15 |
+
seq_parallel = false
|
16 |
) {
|
17 |
+
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel });
|
18 |
// https://arxiv.org/pdf/2205.05198
|
19 |
const bytesPerValue = mixed ? 2 : 4;
|
20 |
|
21 |
let oneLayerAttention;
|
22 |
if (recomputation === "none" || recomputation === "full") {
|
23 |
+
if (seq_parallel) {
|
24 |
+
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
|
25 |
+
} else {
|
26 |
+
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); // eq (2)
|
27 |
+
}
|
28 |
} else if (recomputation === "selective") {
|
29 |
+
if (seq_parallel) {
|
30 |
+
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); // table 2
|
31 |
+
} else {
|
32 |
+
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2
|
33 |
+
}
|
34 |
} else {
|
35 |
throw new Error("Invalid recomputation value");
|
36 |
}
|
37 |
|
38 |
let oneLayerFeedforward;
|
39 |
if (ff_activation === "relu") {
|
40 |
+
if (seq_parallel) {
|
41 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
42 |
+
+ s * b * h / tp); // dropout
|
43 |
+
} else {
|
44 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
45 |
+
+ s * b * h); // dropout
|
46 |
+
}
|
47 |
} else if (ff_activation === "gelu") {
|
48 |
+
if (seq_parallel) {
|
49 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
50 |
+
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
|
51 |
+
+ s * b * h / tp); // dropout
|
52 |
+
} else {
|
53 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
54 |
+
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
|
55 |
+
+ s * b * h); // dropout
|
56 |
+
}
|
57 |
} else if (ff_activation === "swiglu") {
|
58 |
+
if (seq_parallel) {
|
59 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
|
60 |
+
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
|
61 |
+
+ s * b * h / tp); // dropout (note that dropout is lower-precision - boolean)
|
62 |
+
} else {
|
63 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
|
64 |
+
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
|
65 |
+
+ s * b * h); // dropout (note that dropout is lower-precision - boolean)
|
66 |
+
}
|
67 |
}
|
68 |
|
69 |
+
let layerNorm;
|
70 |
+
if (seq_parallel) {
|
71 |
+
layerNorm = s * b * h * bytesPerValue / tp;
|
72 |
+
} else {
|
73 |
+
layerNorm = s * b * h * bytesPerValue;
|
74 |
+
}
|
75 |
|
76 |
+
const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; // section 4.3
|
77 |
+
const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
|
78 |
+
const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
|
79 |
+
const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4; // In FP32
|
80 |
|
81 |
|
|
|
82 |
let data
|
83 |
if (recomputation === "none" || recomputation === "selective") {
|
84 |
|
85 |
+
data = {
|
86 |
name: "activationMemory",
|
87 |
children: [
|
88 |
...Array.from({ length: L }, (_, index) => ({
|
|
|
100 |
]
|
101 |
};
|
102 |
} else if (recomputation === "full") {
|
103 |
+
data = {
|
104 |
name: "activationMemory",
|
105 |
children: [
|
106 |
+
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L },
|
107 |
{ name: 'Dropout', value: inputDropout },
|
108 |
{ name: 'LayerNorm', value: outputLayerNorm },
|
109 |
{ name: 'Projection', value: outputLayerProjection },
|
110 |
{ name: 'Cross Entropy', value: outputCrossEntropy }
|
111 |
]
|
112 |
};
|
113 |
+
} else {
|
114 |
throw new Error("Invalid recomputation value");
|
115 |
}
|
116 |
|
|
|
|
|
117 |
return data;
|
118 |
}
|
119 |
|
|
|
166 |
const mixed = document.getElementById('mixed').checked;
|
167 |
const recomputation = document.getElementById('recomputation').value;
|
168 |
const ff_activation = document.getElementById('ff_activation').value;
|
169 |
+
const seq_parallel = document.getElementById('seq_parallel').checked;
|
170 |
|
171 |
+
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel });
|
172 |
|
173 |
+
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel);
|
174 |
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
|
175 |
|
176 |
const data = {
|
|
|
196 |
const svg = d3.select("#graph").select("svg");
|
197 |
svg.selectAll("*").remove();
|
198 |
svg.attr("width", width)
|
199 |
+
.attr("height", height + legendHeight);
|
200 |
|
201 |
const treemap = d3.treemap()
|
202 |
.size([width, height])
|
|
|
207 |
|
208 |
const root = d3.hierarchy(data)
|
209 |
.sum(d => d.value);
|
210 |
+
// .sort((a, b) => b.value - a.value);
|
211 |
|
212 |
+
// const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
|
213 |
+
// if (root.children[0].value < fixedSize100GB) {
|
214 |
// root.value = fixedSize100GB;
|
215 |
// root.children[0].value = fixedSize100GB;
|
216 |
// }
|
|
|
220 |
treemap(root);
|
221 |
|
222 |
const color = d => {
|
223 |
+
switch (d.data.name) {
|
224 |
case 'Parameters': return '#4e79a7'; // Blue
|
225 |
case 'Gradients': return '#f28e2c'; // Orange
|
226 |
case 'OptimizerAverages': return '#e15759'; // Green
|
|
|
256 |
cell.append("text")
|
257 |
.attr("font-size", `${fontSize}px`)
|
258 |
.attr("font-family", "sans-serif")
|
259 |
+
.each(function (d) {
|
260 |
if (d.depth === 0) return; // Skip root node
|
261 |
|
262 |
const node = d3.select(this);
|
263 |
+
|
264 |
const name = d.data.name;
|
265 |
const value = formatBytes(d.value);
|
266 |
+
|
267 |
if (d.depth === 1 || d.depth === 2) {
|
268 |
node.attr("transform", `translate(${padding},${fontSize + padding})`)
|
269 |
.attr("font-weight", "bold")
|
|
|
323 |
}
|
324 |
|
325 |
const presets = {
|
326 |
+
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false },
|
327 |
+
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
|
328 |
+
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
|
329 |
+
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }
|
330 |
};
|
331 |
|
332 |
function setPresetValues(preset) {
|
|
|
374 |
|
375 |
export const init_memory_plot = function () {
|
376 |
console.log('Initializing memory plot');
|
377 |
+
|
378 |
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
|
379 |
sliderIds.forEach(id => {
|
380 |
const slider = document.getElementById(id);
|
|
|
414 |
console.warn('Mixed checkbox not found');
|
415 |
}
|
416 |
|
417 |
+
const seqParallelCheckbox = document.getElementById('seq_parallel');
|
418 |
+
if (seqParallelCheckbox) {
|
419 |
+
seqParallelCheckbox.addEventListener('change', updateGraph);
|
420 |
+
} else {
|
421 |
+
console.warn('Seq Parallel checkbox not found');
|
422 |
+
}
|
423 |
+
|
424 |
const presetSelect = document.getElementById('presets');
|
425 |
if (presetSelect) {
|
426 |
presetSelect.addEventListener('change', (event) => {
|
|
|
434 |
sliderIds.forEach(id => {
|
435 |
const slider = document.getElementById(id);
|
436 |
if (slider) {
|
437 |
+
switch (id) {
|
438 |
case 'a': slider.max = '128'; break;
|
439 |
case 'b': slider.max = '53248'; break;
|
440 |
case 'h': slider.max = '16384'; break;
|