lvwerra's picture
lvwerra HF staff
memory-layout-widget (#12)
d4506c4 verified
raw
history blame
19.1 kB
import * as d3 from 'd3';
export function activationMemory(
a, // attention heads
b, // micro batch size
h, // hidden dimension size
h_ff, // feedforward dimension size (often h_ff = 4h)
L, // number of layers
s, // sequence length
v, // vocab size
tp = 1, // tensor model parallelism
mixed = true,
recomputation = "none",
ff_activation = "relu",
seq_parallel = false
) {
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel });
// https://arxiv.org/pdf/2205.05198
const bytesPerValue = mixed ? 2 : 4;
let oneLayerAttention;
if (recomputation === "none" || recomputation === "full") {
if (seq_parallel) {
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
} else {
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); // eq (2)
}
} else if (recomputation === "selective") {
if (seq_parallel) {
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); // table 2
} else {
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2
}
} else {
throw new Error("Invalid recomputation value");
}
let oneLayerFeedforward;
if (ff_activation === "relu") {
if (seq_parallel) {
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
+ s * b * h / tp); // dropout
} else {
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
+ s * b * h); // dropout
}
} else if (ff_activation === "gelu") {
if (seq_parallel) {
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
+ s * b * h / tp); // dropout
} else {
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
+ s * b * h); // dropout
}
} else if (ff_activation === "swiglu") {
if (seq_parallel) {
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
+ s * b * h / tp); // dropout (note that dropout is lower-precision - boolean)
} else {
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
+ s * b * h); // dropout (note that dropout is lower-precision - boolean)
}
}
let layerNorm;
if (seq_parallel) {
layerNorm = s * b * h * bytesPerValue / tp;
} else {
layerNorm = s * b * h * bytesPerValue;
}
const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; // section 4.3
const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4; // In FP32
let data
if (recomputation === "none" || recomputation === "selective") {
data = {
name: "Activation Memory",
children: [
...Array.from({ length: L }, (_, index) => ({
name: `Layer ${index + 1}`,
children: [
{ name: 'Attention', value: oneLayerAttention },
{ name: 'Feedforward', value: oneLayerFeedforward },
{ name: 'LayerNorm', value: 2 * layerNorm },
]
})),
{ name: 'Dropout', value: inputDropout },
{ name: 'LayerNorm', value: outputLayerNorm },
{ name: 'Projection', value: outputLayerProjection },
{ name: 'Cross Entropy', value: outputCrossEntropy }
]
};
} else if (recomputation === "full") {
data = {
name: "Activation Memory",
children: [
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L },
{ name: 'Dropout', value: inputDropout },
{ name: 'LayerNorm', value: outputLayerNorm },
{ name: 'Projection', value: outputLayerProjection },
{ name: 'Cross Entropy', value: outputCrossEntropy }
]
};
} else {
throw new Error("Invalid recomputation value");
}
return data;
}
export function paramGradsOpt(h, L, s, v, k = 8, dp = 1, zero = 0, mixed = true) {
// h, # hidden dimension size
// L, # number of layers
// s, # sequence length
// v, # vocab size
// k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance)
// dp=1, # data parallelism
// zero = 0, 1, 2, 3, # zero data parallelism
// mixed=True # mixed precision training
console.log('paramGradsOpt called with:', { h, L, s, v, k, dp, zero, mixed });
const emb = h * (v + s);
const oneLayer = 12 * h ** 2 + 13 * h;
const other = 2 * h;
const n = emb + L * oneLayer + other;
if (mixed) {
k += 4;
}
const bytesPerParameter = mixed ? 2 : 4;
const data = {
name: "Parameters / Gradients / Optimizer States",
children: [
{ name: 'Parameters', value: zero >= 3 ? bytesPerParameter * n / dp : bytesPerParameter * n },
{ name: 'Gradients', value: zero >= 2 ? bytesPerParameter * n / dp : bytesPerParameter * n },
{ name: 'OptimizerAverages', value: zero >= 1 ? k * n / dp : k * n }
]
};
console.log('paramGradsOpt result:', data);
return data;
}
export function updateGraph() {
console.log('updateGraph called');
const a = +document.getElementById('a').value;
const b = +document.getElementById('b').value;
const h = +document.getElementById('h').value;
const h_ff = +document.getElementById('h_ff').value;
const L = +document.getElementById('L').value;
const s = +document.getElementById('s').value;
const v = +document.getElementById('v').value;
const k = +document.getElementById('k').value;
const tp = +document.getElementById('tp').value; // New: t parameter
const zero = document.getElementById('zero').value;
const dp = document.getElementById('dp').value;
const mixed = document.getElementById('mixed').checked;
const recomputation = document.getElementById('recomputation').value;
const ff_activation = document.getElementById('ff_activation').value;
const seq_parallel = document.getElementById('seq_parallel').checked;
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel });
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel);
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);
const data = {
name: "root",
children: [
{
name: 'Total',
value: 0,
children: [
activationMemoryData,
paramGradsOptValue
]
}
]
};
console.log('Data for treemap:', data);
const width = 700;
const height = 450;
const legendHeight = 50;
const svg = d3.select("#graph").select("svg");
svg.selectAll("*").remove();
svg.attr("viewBox", [0, 0, width, height + legendHeight]);
const treemap = d3.treemap()
.size([width, height])
.paddingOuter(3)
.paddingTop(19)
.paddingInner(3)
.round(true);
const root = d3.hierarchy(data)
.sum(d => d.value);
// .sort((a, b) => b.value - a.value);
// const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
// if (root.children[0].value < fixedSize100GB) {
// root.value = fixedSize100GB;
// root.children[0].value = fixedSize100GB;
// }
console.log('Treemap root:', root);
treemap(root);
const color = d => {
switch (d.data.name) {
// Root and Total (container levels)
case 'root': return 'rgb(225, 225, 225)'; // Light Grey
case 'Total': return 'rgb(225, 225, 225)'; // Light Grey
// Give distinct colors to the main section containers
case 'Activation Memory': return 'rgb(78, 165, 183)'; // Orange
case 'Parameters / Gradients / Optimizer States': return 'rgb(232, 137, 171)'; // Teal Blue
// Parameters / Gradients / Optimizer States branch
case 'Parameters': return 'rgb(206, 192, 250)'; // Blue
case 'Gradients': return 'rgb(227, 138, 66)'; // Orange
case 'OptimizerAverages': return 'rgb(78, 165, 183)'; // Pink
// activationMemory branch - Layer components
case 'Attention': return 'rgb(206, 192, 250)'; // Purple
case 'Feedforward': return 'rgb(171, 232, 241)'; // Light Blue
case 'LayerNorm': return 'rgb(232, 137, 171)'; // Light Green
// activationMemory branch - other components
case 'Dropout': return 'rgb(67, 145, 108)'; // Dark Green
case 'Projection': return 'rgb(174, 214, 251)'; // Sky Blue
case 'Cross Entropy': return 'rgb(232, 137, 171)'; // Pink
// Default for any Layer nodes and unexpected cases
default: return 'rgb(227, 138, 66)'; // Light Grey
};
};
if (d3.select('#tooltip').empty()) {
d3.select('body')
.append('div')
.attr('id', 'tooltip')
.style('opacity', 0)
.style('position', 'absolute')
.style('background-color', 'white')
.style('padding', '4px')
.style('font-size', '12px')
.style('border-radius', '5px')
.style('box-shadow', '0px 0px 5px 0px rgba(0,0,0,0.3)');
}
const cell = svg.selectAll("g")
.data(root.descendants().filter(d => d.depth !== 0)) // Skip root node
.join("g")
.attr("transform", d => `translate(${d.x0},${d.y0})`)
.on('mouseover', (event, d) => {
const name = d.data.name;
const value = formatBytes(d.value);
d3.select('#tooltip').transition().duration(200).text(`${name}: ${value}`)
})
.on('mouseout', function() {
d3.select('#tooltip').style('opacity', 0)
})
.on('mousemove', function(event) {
d3.select('#tooltip').style('left', (event.pageX + 10) + 'px').style('top', (event.pageY + 10) + 'px').style('opacity', 1)
});
cell.append("rect")
.attr("width", d => d.x1 - d.x0)
.attr("height", d => d.y1 - d.y0)
.attr("fill", d => color(d))
.attr("stroke", d => d.depth === 1 ? color(d) : "white")
.attr("stroke-width", 1);
const fontSize = 10;
const padding = 2;
cell.append("text")
.attr("font-size", `${fontSize}px`)
.attr("font-family", "sans-serif")
.each(function (d) {
const node = d3.select(this);
const name = d.data.name;
const value = formatBytes(d.value);
if (d.depth === 1 || d.depth === 2) {
node.attr("transform", `translate(${padding},${fontSize + padding})`)
.attr("font-weight", "bold")
.attr("font-size", 12)
.text(`${name}: ${value}`);
} else {
// Child nodes
node.attr("transform", `translate(${padding},${fontSize + padding})`)
.text(name[0].toUpperCase()) // Display only the first letter
.attr("font-weight", "bold")
.append("title") // Add title for hover effect
.text(`${name}: ${value}`);
}
});
/*
// Adjust legend positioning
const legendData = root.children[0].children.concat(root.children[0]);
const legend = svg.append("g")
.attr("font-family", "sans-serif")
.attr("font-size", 10)
.attr("text-anchor", "start")
.attr("transform", `translate(0, ${height})`)
.selectAll("g")
.data(legendData)
.join("g")
.attr("transform", (d, i) => `translate(${i * 240}, 0)`);
legend.append("rect")
.attr("x", 0)
.attr("width", 19)
.attr("height", 19)
.attr("fill", d => color(d))
.attr("stroke", '#f3f3f3')
.attr("stroke-width", 0);
legend.append("text")
.attr("x", 24)
.attr("y", 9.5)
.attr("dy", "0.32em")
.text(d => `${d.data.name}: ${formatBytes(d.value)}`);
*/
console.log('Treemap nodes created');
}
function formatBytes(bytes) {
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB'];
if (bytes === 0) return '0 Bytes';
const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1024)), 10);
return `${(bytes / (1024 ** i)).toFixed(2)} ${sizes[i]}`;
}
const presets = {
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false },
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }
};
function setPresetValues(preset) {
if (preset === "custom") return;
const values = presets[preset];
Object.keys(values).forEach(key => {
const element = document.getElementById(key);
const inputElement = document.getElementById(`${key}_input`);
if (element) {
if (element.type === 'checkbox') {
element.checked = values[key];
} else {
element.value = values[key];
}
}
if (inputElement) {
inputElement.value = values[key];
}
});
updateGraph(); // Add this line to ensure the graph updates when a preset is selected
}
function syncSliderAndInput(sliderId, inputId) {
const slider = document.getElementById(sliderId);
const input = document.getElementById(inputId);
slider.addEventListener('input', () => {
input.value = slider.value;
updateGraph();
});
input.addEventListener('input', () => {
let value = parseInt(input.value);
if (isNaN(value)) {
value = parseInt(slider.min);
}
value = Math.max(parseInt(slider.min), Math.min(parseInt(slider.max), value));
slider.value = value;
input.value = value;
updateGraph();
});
}
export const init_memory_plot = function () {
console.log('Initializing memory plot');
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
sliderIds.forEach(id => {
const slider = document.getElementById(id);
const input = document.getElementById(`${id}_input`);
if (slider && input) {
syncSliderAndInput(id, `${id}_input`);
} else {
console.warn(`Elements for ${id} not found`);
}
});
const recomputationSelect = document.getElementById('recomputation');
if (recomputationSelect) {
recomputationSelect.addEventListener('change', updateGraph);
} else {
console.warn('Recomputation select not found');
}
const ffActivationSelect = document.getElementById('ff_activation');
if (ffActivationSelect) {
ffActivationSelect.addEventListener('change', updateGraph);
} else {
console.warn('FF Activation select not found');
}
const zeroSelect = document.getElementById('zero');
if (zeroSelect) {
zeroSelect.addEventListener('change', updateGraph);
} else {
console.warn('Zero select not found');
}
const mixedCheckbox = document.getElementById('mixed');
if (mixedCheckbox) {
mixedCheckbox.addEventListener('change', updateGraph);
} else {
console.warn('Mixed checkbox not found');
}
const seqParallelCheckbox = document.getElementById('seq_parallel');
if (seqParallelCheckbox) {
seqParallelCheckbox.addEventListener('change', updateGraph);
} else {
console.warn('Seq Parallel checkbox not found');
}
const presetSelect = document.getElementById('presets');
if (presetSelect) {
presetSelect.addEventListener('change', (event) => {
setPresetValues(event.target.value);
});
} else {
console.warn('Preset select not found');
}
// Set max values for sliders
sliderIds.forEach(id => {
const slider = document.getElementById(id);
if (slider) {
switch (id) {
case 'a': slider.max = '128'; break;
case 'b': slider.max = '53248'; break;
case 'h': slider.max = '16384'; break;
case 'h_ff': slider.max = '65536'; break;
case 'L': slider.max = '126'; break;
case 's': slider.max = '128000'; break;
case 'v': slider.max = '100000'; break;
case 'k': slider.max = '16'; break;
case 'tp': slider.max = '16'; break;
case 'dp': slider.max = '256'; break;
}
} else {
console.warn(`Slider ${id} not found`);
}
});
console.log('Adding svg');
const graphContainer = document.getElementById('graph');
if (graphContainer) {
const svg = d3.select("#graph")
.append("svg")
} else {
console.warn('Graph container not found');
}
updateGraph();
};