update
Browse files- dist/index.html +0 -0
- dist/main.bundle.js +0 -0
- dist/main.bundle.js.map +0 -0
- src/index.html +0 -0
- src/memory.js +76 -40
dist/index.html
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dist/main.bundle.js
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dist/main.bundle.js.map
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/index.html
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/memory.js
CHANGED
|
@@ -11,48 +11,78 @@ export function activationMemory(
|
|
| 11 |
tp = 1, // tensor model parallelism
|
| 12 |
mixed = true,
|
| 13 |
recomputation = "none",
|
| 14 |
-
ff_activation = "relu"
|
|
|
|
| 15 |
) {
|
| 16 |
-
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation });
|
| 17 |
// https://arxiv.org/pdf/2205.05198
|
| 18 |
const bytesPerValue = mixed ? 2 : 4;
|
| 19 |
|
| 20 |
let oneLayerAttention;
|
| 21 |
if (recomputation === "none" || recomputation === "full") {
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
} else if (recomputation === "selective") {
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
} else {
|
| 26 |
throw new Error("Invalid recomputation value");
|
| 27 |
}
|
| 28 |
|
| 29 |
let oneLayerFeedforward;
|
| 30 |
if (ff_activation === "relu") {
|
| 31 |
-
|
| 32 |
-
+ s * b *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
} else if (ff_activation === "gelu") {
|
| 34 |
-
|
| 35 |
-
+ s * b * h_ff * bytesPerValue / tp // inputs of
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
} else if (ff_activation === "swiglu") {
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
}
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
const inputDropout = s * b * h; // section 4.3
|
| 46 |
-
const outputLayerNorm = s * b * h * bytesPerValue;
|
| 47 |
-
const outputLayerProjection = s * b * h * bytesPerValue;
|
| 48 |
-
const outputCrossEntropy = s * b * v * 4; // In FP32
|
| 49 |
|
| 50 |
|
| 51 |
-
let oneLayer;
|
| 52 |
let data
|
| 53 |
if (recomputation === "none" || recomputation === "selective") {
|
| 54 |
|
| 55 |
-
|
| 56 |
name: "activationMemory",
|
| 57 |
children: [
|
| 58 |
...Array.from({ length: L }, (_, index) => ({
|
|
@@ -70,22 +100,20 @@ export function activationMemory(
|
|
| 70 |
]
|
| 71 |
};
|
| 72 |
} else if (recomputation === "full") {
|
| 73 |
-
|
| 74 |
name: "activationMemory",
|
| 75 |
children: [
|
| 76 |
-
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L},
|
| 77 |
{ name: 'Dropout', value: inputDropout },
|
| 78 |
{ name: 'LayerNorm', value: outputLayerNorm },
|
| 79 |
{ name: 'Projection', value: outputLayerProjection },
|
| 80 |
{ name: 'Cross Entropy', value: outputCrossEntropy }
|
| 81 |
]
|
| 82 |
};
|
| 83 |
-
|
| 84 |
throw new Error("Invalid recomputation value");
|
| 85 |
}
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
return data;
|
| 90 |
}
|
| 91 |
|
|
@@ -138,10 +166,11 @@ export function updateGraph() {
|
|
| 138 |
const mixed = document.getElementById('mixed').checked;
|
| 139 |
const recomputation = document.getElementById('recomputation').value;
|
| 140 |
const ff_activation = document.getElementById('ff_activation').value;
|
|
|
|
| 141 |
|
| 142 |
-
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation });
|
| 143 |
|
| 144 |
-
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation);
|
| 145 |
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
|
| 146 |
|
| 147 |
const data = {
|
|
@@ -167,7 +196,7 @@ export function updateGraph() {
|
|
| 167 |
const svg = d3.select("#graph").select("svg");
|
| 168 |
svg.selectAll("*").remove();
|
| 169 |
svg.attr("width", width)
|
| 170 |
-
|
| 171 |
|
| 172 |
const treemap = d3.treemap()
|
| 173 |
.size([width, height])
|
|
@@ -178,10 +207,10 @@ export function updateGraph() {
|
|
| 178 |
|
| 179 |
const root = d3.hierarchy(data)
|
| 180 |
.sum(d => d.value);
|
| 181 |
-
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
// root.value = fixedSize100GB;
|
| 186 |
// root.children[0].value = fixedSize100GB;
|
| 187 |
// }
|
|
@@ -191,7 +220,7 @@ export function updateGraph() {
|
|
| 191 |
treemap(root);
|
| 192 |
|
| 193 |
const color = d => {
|
| 194 |
-
switch(d.data.name) {
|
| 195 |
case 'Parameters': return '#4e79a7'; // Blue
|
| 196 |
case 'Gradients': return '#f28e2c'; // Orange
|
| 197 |
case 'OptimizerAverages': return '#e15759'; // Green
|
|
@@ -227,14 +256,14 @@ export function updateGraph() {
|
|
| 227 |
cell.append("text")
|
| 228 |
.attr("font-size", `${fontSize}px`)
|
| 229 |
.attr("font-family", "sans-serif")
|
| 230 |
-
.each(function(d) {
|
| 231 |
if (d.depth === 0) return; // Skip root node
|
| 232 |
|
| 233 |
const node = d3.select(this);
|
| 234 |
-
|
| 235 |
const name = d.data.name;
|
| 236 |
const value = formatBytes(d.value);
|
| 237 |
-
|
| 238 |
if (d.depth === 1 || d.depth === 2) {
|
| 239 |
node.attr("transform", `translate(${padding},${fontSize + padding})`)
|
| 240 |
.attr("font-weight", "bold")
|
|
@@ -294,10 +323,10 @@ function formatBytes(bytes) {
|
|
| 294 |
}
|
| 295 |
|
| 296 |
const presets = {
|
| 297 |
-
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu" },
|
| 298 |
-
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu" },
|
| 299 |
-
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu" },
|
| 300 |
-
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu" }
|
| 301 |
};
|
| 302 |
|
| 303 |
function setPresetValues(preset) {
|
|
@@ -345,7 +374,7 @@ function syncSliderAndInput(sliderId, inputId) {
|
|
| 345 |
|
| 346 |
export const init_memory_plot = function () {
|
| 347 |
console.log('Initializing memory plot');
|
| 348 |
-
|
| 349 |
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
|
| 350 |
sliderIds.forEach(id => {
|
| 351 |
const slider = document.getElementById(id);
|
|
@@ -385,6 +414,13 @@ export const init_memory_plot = function () {
|
|
| 385 |
console.warn('Mixed checkbox not found');
|
| 386 |
}
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
const presetSelect = document.getElementById('presets');
|
| 389 |
if (presetSelect) {
|
| 390 |
presetSelect.addEventListener('change', (event) => {
|
|
@@ -398,7 +434,7 @@ export const init_memory_plot = function () {
|
|
| 398 |
sliderIds.forEach(id => {
|
| 399 |
const slider = document.getElementById(id);
|
| 400 |
if (slider) {
|
| 401 |
-
switch(id) {
|
| 402 |
case 'a': slider.max = '128'; break;
|
| 403 |
case 'b': slider.max = '53248'; break;
|
| 404 |
case 'h': slider.max = '16384'; break;
|
|
|
|
| 11 |
tp = 1, // tensor model parallelism
|
| 12 |
mixed = true,
|
| 13 |
recomputation = "none",
|
| 14 |
+
ff_activation = "relu",
|
| 15 |
+
seq_parallel = false
|
| 16 |
) {
|
| 17 |
+
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel });
|
| 18 |
// https://arxiv.org/pdf/2205.05198
|
| 19 |
const bytesPerValue = mixed ? 2 : 4;
|
| 20 |
|
| 21 |
let oneLayerAttention;
|
| 22 |
if (recomputation === "none" || recomputation === "full") {
|
| 23 |
+
if (seq_parallel) {
|
| 24 |
+
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
|
| 25 |
+
} else {
|
| 26 |
+
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); // eq (2)
|
| 27 |
+
}
|
| 28 |
} else if (recomputation === "selective") {
|
| 29 |
+
if (seq_parallel) {
|
| 30 |
+
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); // table 2
|
| 31 |
+
} else {
|
| 32 |
+
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2
|
| 33 |
+
}
|
| 34 |
} else {
|
| 35 |
throw new Error("Invalid recomputation value");
|
| 36 |
}
|
| 37 |
|
| 38 |
let oneLayerFeedforward;
|
| 39 |
if (ff_activation === "relu") {
|
| 40 |
+
if (seq_parallel) {
|
| 41 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
| 42 |
+
+ s * b * h / tp); // dropout
|
| 43 |
+
} else {
|
| 44 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
| 45 |
+
+ s * b * h); // dropout
|
| 46 |
+
}
|
| 47 |
} else if (ff_activation === "gelu") {
|
| 48 |
+
if (seq_parallel) {
|
| 49 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
| 50 |
+
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
|
| 51 |
+
+ s * b * h / tp); // dropout
|
| 52 |
+
} else {
|
| 53 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
|
| 54 |
+
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
|
| 55 |
+
+ s * b * h); // dropout
|
| 56 |
+
}
|
| 57 |
} else if (ff_activation === "swiglu") {
|
| 58 |
+
if (seq_parallel) {
|
| 59 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
|
| 60 |
+
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
|
| 61 |
+
+ s * b * h / tp); // dropout (note that dropout is lower-precision - boolean)
|
| 62 |
+
} else {
|
| 63 |
+
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
|
| 64 |
+
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
|
| 65 |
+
+ s * b * h); // dropout (note that dropout is lower-precision - boolean)
|
| 66 |
+
}
|
| 67 |
}
|
| 68 |
|
| 69 |
+
let layerNorm;
|
| 70 |
+
if (seq_parallel) {
|
| 71 |
+
layerNorm = s * b * h * bytesPerValue / tp;
|
| 72 |
+
} else {
|
| 73 |
+
layerNorm = s * b * h * bytesPerValue;
|
| 74 |
+
}
|
| 75 |
|
| 76 |
+
const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; // section 4.3
|
| 77 |
+
const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
|
| 78 |
+
const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
|
| 79 |
+
const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4; // In FP32
|
| 80 |
|
| 81 |
|
|
|
|
| 82 |
let data
|
| 83 |
if (recomputation === "none" || recomputation === "selective") {
|
| 84 |
|
| 85 |
+
data = {
|
| 86 |
name: "activationMemory",
|
| 87 |
children: [
|
| 88 |
...Array.from({ length: L }, (_, index) => ({
|
|
|
|
| 100 |
]
|
| 101 |
};
|
| 102 |
} else if (recomputation === "full") {
|
| 103 |
+
data = {
|
| 104 |
name: "activationMemory",
|
| 105 |
children: [
|
| 106 |
+
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L },
|
| 107 |
{ name: 'Dropout', value: inputDropout },
|
| 108 |
{ name: 'LayerNorm', value: outputLayerNorm },
|
| 109 |
{ name: 'Projection', value: outputLayerProjection },
|
| 110 |
{ name: 'Cross Entropy', value: outputCrossEntropy }
|
| 111 |
]
|
| 112 |
};
|
| 113 |
+
} else {
|
| 114 |
throw new Error("Invalid recomputation value");
|
| 115 |
}
|
| 116 |
|
|
|
|
|
|
|
| 117 |
return data;
|
| 118 |
}
|
| 119 |
|
|
|
|
| 166 |
const mixed = document.getElementById('mixed').checked;
|
| 167 |
const recomputation = document.getElementById('recomputation').value;
|
| 168 |
const ff_activation = document.getElementById('ff_activation').value;
|
| 169 |
+
const seq_parallel = document.getElementById('seq_parallel').checked;
|
| 170 |
|
| 171 |
+
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel });
|
| 172 |
|
| 173 |
+
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel);
|
| 174 |
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
|
| 175 |
|
| 176 |
const data = {
|
|
|
|
| 196 |
const svg = d3.select("#graph").select("svg");
|
| 197 |
svg.selectAll("*").remove();
|
| 198 |
svg.attr("width", width)
|
| 199 |
+
.attr("height", height + legendHeight);
|
| 200 |
|
| 201 |
const treemap = d3.treemap()
|
| 202 |
.size([width, height])
|
|
|
|
| 207 |
|
| 208 |
const root = d3.hierarchy(data)
|
| 209 |
.sum(d => d.value);
|
| 210 |
+
// .sort((a, b) => b.value - a.value);
|
| 211 |
|
| 212 |
+
// const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
|
| 213 |
+
// if (root.children[0].value < fixedSize100GB) {
|
| 214 |
// root.value = fixedSize100GB;
|
| 215 |
// root.children[0].value = fixedSize100GB;
|
| 216 |
// }
|
|
|
|
| 220 |
treemap(root);
|
| 221 |
|
| 222 |
const color = d => {
|
| 223 |
+
switch (d.data.name) {
|
| 224 |
case 'Parameters': return '#4e79a7'; // Blue
|
| 225 |
case 'Gradients': return '#f28e2c'; // Orange
|
| 226 |
case 'OptimizerAverages': return '#e15759'; // Green
|
|
|
|
| 256 |
cell.append("text")
|
| 257 |
.attr("font-size", `${fontSize}px`)
|
| 258 |
.attr("font-family", "sans-serif")
|
| 259 |
+
.each(function (d) {
|
| 260 |
if (d.depth === 0) return; // Skip root node
|
| 261 |
|
| 262 |
const node = d3.select(this);
|
| 263 |
+
|
| 264 |
const name = d.data.name;
|
| 265 |
const value = formatBytes(d.value);
|
| 266 |
+
|
| 267 |
if (d.depth === 1 || d.depth === 2) {
|
| 268 |
node.attr("transform", `translate(${padding},${fontSize + padding})`)
|
| 269 |
.attr("font-weight", "bold")
|
|
|
|
| 323 |
}
|
| 324 |
|
| 325 |
const presets = {
|
| 326 |
+
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false },
|
| 327 |
+
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
|
| 328 |
+
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
|
| 329 |
+
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }
|
| 330 |
};
|
| 331 |
|
| 332 |
function setPresetValues(preset) {
|
|
|
|
| 374 |
|
| 375 |
export const init_memory_plot = function () {
|
| 376 |
console.log('Initializing memory plot');
|
| 377 |
+
|
| 378 |
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
|
| 379 |
sliderIds.forEach(id => {
|
| 380 |
const slider = document.getElementById(id);
|
|
|
|
| 414 |
console.warn('Mixed checkbox not found');
|
| 415 |
}
|
| 416 |
|
| 417 |
+
const seqParallelCheckbox = document.getElementById('seq_parallel');
|
| 418 |
+
if (seqParallelCheckbox) {
|
| 419 |
+
seqParallelCheckbox.addEventListener('change', updateGraph);
|
| 420 |
+
} else {
|
| 421 |
+
console.warn('Seq Parallel checkbox not found');
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
const presetSelect = document.getElementById('presets');
|
| 425 |
if (presetSelect) {
|
| 426 |
presetSelect.addEventListener('change', (event) => {
|
|
|
|
| 434 |
sliderIds.forEach(id => {
|
| 435 |
const slider = document.getElementById(id);
|
| 436 |
if (slider) {
|
| 437 |
+
switch (id) {
|
| 438 |
case 'a': slider.max = '128'; break;
|
| 439 |
case 'b': slider.max = '53248'; break;
|
| 440 |
case 'h': slider.max = '16384'; break;
|