thomwolf HF staff commited on
Commit
143e8bd
·
1 Parent(s): 9905b93
dist/index.html CHANGED
The diff for this file is too large to render. See raw diff
 
dist/main.bundle.js CHANGED
The diff for this file is too large to render. See raw diff
 
dist/main.bundle.js.map CHANGED
The diff for this file is too large to render. See raw diff
 
src/index.html CHANGED
The diff for this file is too large to render. See raw diff
 
src/memory.js CHANGED
@@ -11,48 +11,78 @@ export function activationMemory(
11
  tp = 1, // tensor model parallelism
12
  mixed = true,
13
  recomputation = "none",
14
- ff_activation = "relu"
 
15
  ) {
16
- console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation });
17
  // https://arxiv.org/pdf/2205.05198
18
  const bytesPerValue = mixed ? 2 : 4;
19
 
20
  let oneLayerAttention;
21
  if (recomputation === "none" || recomputation === "full") {
22
- oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
 
 
 
 
23
  } else if (recomputation === "selective") {
24
- oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2
 
 
 
 
25
  } else {
26
  throw new Error("Invalid recomputation value");
27
  }
28
 
29
  let oneLayerFeedforward;
30
  if (ff_activation === "relu") {
31
- oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
32
- + s * b * h); // dropout
 
 
 
 
 
33
  } else if (ff_activation === "gelu") {
34
- oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
35
- + s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
36
- + s * b * h); // dropout
 
 
 
 
 
 
37
  } else if (ff_activation === "swiglu") {
38
- oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
39
- + s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
40
- + s * b * h); // dropout (note that dropout is lower-precision - boolean)
 
 
 
 
 
 
41
  }
42
 
43
- const layerNorm = s * b * h * bytesPerValue;
 
 
 
 
 
44
 
45
- const inputDropout = s * b * h; // section 4.3
46
- const outputLayerNorm = s * b * h * bytesPerValue;
47
- const outputLayerProjection = s * b * h * bytesPerValue;
48
- const outputCrossEntropy = s * b * v * 4; // In FP32
49
 
50
 
51
- let oneLayer;
52
  let data
53
  if (recomputation === "none" || recomputation === "selective") {
54
 
55
- data = {
56
  name: "activationMemory",
57
  children: [
58
  ...Array.from({ length: L }, (_, index) => ({
@@ -70,22 +100,20 @@ export function activationMemory(
70
  ]
71
  };
72
  } else if (recomputation === "full") {
73
- data = {
74
  name: "activationMemory",
75
  children: [
76
- { name: 'LayerInput', value: s * b * h * bytesPerValue * L},
77
  { name: 'Dropout', value: inputDropout },
78
  { name: 'LayerNorm', value: outputLayerNorm },
79
  { name: 'Projection', value: outputLayerProjection },
80
  { name: 'Cross Entropy', value: outputCrossEntropy }
81
  ]
82
  };
83
- } else {
84
  throw new Error("Invalid recomputation value");
85
  }
86
 
87
-
88
-
89
  return data;
90
  }
91
 
@@ -138,10 +166,11 @@ export function updateGraph() {
138
  const mixed = document.getElementById('mixed').checked;
139
  const recomputation = document.getElementById('recomputation').value;
140
  const ff_activation = document.getElementById('ff_activation').value;
 
141
 
142
- console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation });
143
 
144
- const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation);
145
  const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
146
 
147
  const data = {
@@ -167,7 +196,7 @@ export function updateGraph() {
167
  const svg = d3.select("#graph").select("svg");
168
  svg.selectAll("*").remove();
169
  svg.attr("width", width)
170
- .attr("height", height + legendHeight);
171
 
172
  const treemap = d3.treemap()
173
  .size([width, height])
@@ -178,10 +207,10 @@ export function updateGraph() {
178
 
179
  const root = d3.hierarchy(data)
180
  .sum(d => d.value);
181
- // .sort((a, b) => b.value - a.value);
182
 
183
- // const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
184
- // if (root.children[0].value < fixedSize100GB) {
185
  // root.value = fixedSize100GB;
186
  // root.children[0].value = fixedSize100GB;
187
  // }
@@ -191,7 +220,7 @@ export function updateGraph() {
191
  treemap(root);
192
 
193
  const color = d => {
194
- switch(d.data.name) {
195
  case 'Parameters': return '#4e79a7'; // Blue
196
  case 'Gradients': return '#f28e2c'; // Orange
197
  case 'OptimizerAverages': return '#e15759'; // Green
@@ -227,14 +256,14 @@ export function updateGraph() {
227
  cell.append("text")
228
  .attr("font-size", `${fontSize}px`)
229
  .attr("font-family", "sans-serif")
230
- .each(function(d) {
231
  if (d.depth === 0) return; // Skip root node
232
 
233
  const node = d3.select(this);
234
-
235
  const name = d.data.name;
236
  const value = formatBytes(d.value);
237
-
238
  if (d.depth === 1 || d.depth === 2) {
239
  node.attr("transform", `translate(${padding},${fontSize + padding})`)
240
  .attr("font-weight", "bold")
@@ -294,10 +323,10 @@ function formatBytes(bytes) {
294
  }
295
 
296
  const presets = {
297
- "Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu" },
298
- "Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu" },
299
- "Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu" },
300
- "Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu" }
301
  };
302
 
303
  function setPresetValues(preset) {
@@ -345,7 +374,7 @@ function syncSliderAndInput(sliderId, inputId) {
345
 
346
  export const init_memory_plot = function () {
347
  console.log('Initializing memory plot');
348
-
349
  const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
350
  sliderIds.forEach(id => {
351
  const slider = document.getElementById(id);
@@ -385,6 +414,13 @@ export const init_memory_plot = function () {
385
  console.warn('Mixed checkbox not found');
386
  }
387
 
 
 
 
 
 
 
 
388
  const presetSelect = document.getElementById('presets');
389
  if (presetSelect) {
390
  presetSelect.addEventListener('change', (event) => {
@@ -398,7 +434,7 @@ export const init_memory_plot = function () {
398
  sliderIds.forEach(id => {
399
  const slider = document.getElementById(id);
400
  if (slider) {
401
- switch(id) {
402
  case 'a': slider.max = '128'; break;
403
  case 'b': slider.max = '53248'; break;
404
  case 'h': slider.max = '16384'; break;
 
11
  tp = 1, // tensor model parallelism
12
  mixed = true,
13
  recomputation = "none",
14
+ ff_activation = "relu",
15
+ seq_parallel = false
16
  ) {
17
+ console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel });
18
  // https://arxiv.org/pdf/2205.05198
19
  const bytesPerValue = mixed ? 2 : 4;
20
 
21
  let oneLayerAttention;
22
  if (recomputation === "none" || recomputation === "full") {
23
+ if (seq_parallel) {
24
+ oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
25
+ } else {
26
+ oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); // eq (2)
27
+ }
28
  } else if (recomputation === "selective") {
29
+ if (seq_parallel) {
30
+ oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); // table 2
31
+ } else {
32
+ oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2
33
+ }
34
  } else {
35
  throw new Error("Invalid recomputation value");
36
  }
37
 
38
  let oneLayerFeedforward;
39
  if (ff_activation === "relu") {
40
+ if (seq_parallel) {
41
+ oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
42
+ + s * b * h / tp); // dropout
43
+ } else {
44
+ oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
45
+ + s * b * h); // dropout
46
+ }
47
  } else if (ff_activation === "gelu") {
48
+ if (seq_parallel) {
49
+ oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
50
+ + s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
51
+ + s * b * h / tp); // dropout
52
+ } else {
53
+ oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
54
+ + s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
55
+ + s * b * h); // dropout
56
+ }
57
  } else if (ff_activation === "swiglu") {
58
+ if (seq_parallel) {
59
+ oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
60
+ + s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
61
+ + s * b * h / tp); // dropout (note that dropout is lower-precision - boolean)
62
+ } else {
63
+ oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
64
+ + s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
65
+ + s * b * h); // dropout (note that dropout is lower-precision - boolean)
66
+ }
67
  }
68
 
69
+ let layerNorm;
70
+ if (seq_parallel) {
71
+ layerNorm = s * b * h * bytesPerValue / tp;
72
+ } else {
73
+ layerNorm = s * b * h * bytesPerValue;
74
+ }
75
 
76
+ const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; // section 4.3
77
+ const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
78
+ const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
79
+ const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4; // In FP32
80
 
81
 
 
82
  let data
83
  if (recomputation === "none" || recomputation === "selective") {
84
 
85
+ data = {
86
  name: "activationMemory",
87
  children: [
88
  ...Array.from({ length: L }, (_, index) => ({
 
100
  ]
101
  };
102
  } else if (recomputation === "full") {
103
+ data = {
104
  name: "activationMemory",
105
  children: [
106
+ { name: 'LayerInput', value: s * b * h * bytesPerValue * L },
107
  { name: 'Dropout', value: inputDropout },
108
  { name: 'LayerNorm', value: outputLayerNorm },
109
  { name: 'Projection', value: outputLayerProjection },
110
  { name: 'Cross Entropy', value: outputCrossEntropy }
111
  ]
112
  };
113
+ } else {
114
  throw new Error("Invalid recomputation value");
115
  }
116
 
 
 
117
  return data;
118
  }
119
 
 
166
  const mixed = document.getElementById('mixed').checked;
167
  const recomputation = document.getElementById('recomputation').value;
168
  const ff_activation = document.getElementById('ff_activation').value;
169
+ const seq_parallel = document.getElementById('seq_parallel').checked;
170
 
171
+ console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel });
172
 
173
+ const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel);
174
  const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
175
 
176
  const data = {
 
196
  const svg = d3.select("#graph").select("svg");
197
  svg.selectAll("*").remove();
198
  svg.attr("width", width)
199
+ .attr("height", height + legendHeight);
200
 
201
  const treemap = d3.treemap()
202
  .size([width, height])
 
207
 
208
  const root = d3.hierarchy(data)
209
  .sum(d => d.value);
210
+ // .sort((a, b) => b.value - a.value);
211
 
212
+ // const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
213
+ // if (root.children[0].value < fixedSize100GB) {
214
  // root.value = fixedSize100GB;
215
  // root.children[0].value = fixedSize100GB;
216
  // }
 
220
  treemap(root);
221
 
222
  const color = d => {
223
+ switch (d.data.name) {
224
  case 'Parameters': return '#4e79a7'; // Blue
225
  case 'Gradients': return '#f28e2c'; // Orange
226
  case 'OptimizerAverages': return '#e15759'; // Green
 
256
  cell.append("text")
257
  .attr("font-size", `${fontSize}px`)
258
  .attr("font-family", "sans-serif")
259
+ .each(function (d) {
260
  if (d.depth === 0) return; // Skip root node
261
 
262
  const node = d3.select(this);
263
+
264
  const name = d.data.name;
265
  const value = formatBytes(d.value);
266
+
267
  if (d.depth === 1 || d.depth === 2) {
268
  node.attr("transform", `translate(${padding},${fontSize + padding})`)
269
  .attr("font-weight", "bold")
 
323
  }
324
 
325
  const presets = {
326
+ "Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false },
327
+ "Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
328
+ "Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
329
+ "Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }
330
  };
331
 
332
  function setPresetValues(preset) {
 
374
 
375
  export const init_memory_plot = function () {
376
  console.log('Initializing memory plot');
377
+
378
  const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
379
  sliderIds.forEach(id => {
380
  const slider = document.getElementById(id);
 
414
  console.warn('Mixed checkbox not found');
415
  }
416
 
417
+ const seqParallelCheckbox = document.getElementById('seq_parallel');
418
+ if (seqParallelCheckbox) {
419
+ seqParallelCheckbox.addEventListener('change', updateGraph);
420
+ } else {
421
+ console.warn('Seq Parallel checkbox not found');
422
+ }
423
+
424
  const presetSelect = document.getElementById('presets');
425
  if (presetSelect) {
426
  presetSelect.addEventListener('change', (event) => {
 
434
  sliderIds.forEach(id => {
435
  const slider = document.getElementById(id);
436
  if (slider) {
437
+ switch (id) {
438
  case 'a': slider.max = '128'; break;
439
  case 'b': slider.max = '53248'; break;
440
  case 'h': slider.max = '16384'; break;