Spaces:
Running
Running
import * as d3 from 'd3'; | |
export function activationMemory( | |
a, // attention heads | |
b, // micro batch size | |
h, // hidden dimension size | |
h_ff, // feedforward dimension size (often h_ff = 4h) | |
L, // number of layers | |
s, // sequence length | |
v, // vocab size | |
tp = 1, // tensor model parallelism | |
mixed = true, | |
recomputation = "none", | |
ff_activation = "relu", | |
seq_parallel = false | |
) { | |
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel }); | |
// https://arxiv.org/pdf/2205.05198 | |
const bytesPerValue = mixed ? 2 : 4; | |
let oneLayerAttention; | |
if (recomputation === "none" || recomputation === "full") { | |
if (seq_parallel) { | |
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2) | |
} else { | |
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); // eq (2) | |
} | |
} else if (recomputation === "selective") { | |
if (seq_parallel) { | |
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); // table 2 | |
} else { | |
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2 | |
} | |
} else { | |
throw new Error("Invalid recomputation value"); | |
} | |
let oneLayerFeedforward; | |
if (ff_activation === "relu") { | |
if (seq_parallel) { | |
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers | |
+ s * b * h / tp); // dropout | |
} else { | |
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers | |
+ s * b * h); // dropout | |
} | |
} else if (ff_activation === "gelu") { | |
if (seq_parallel) { | |
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers | |
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu) | |
+ s * b * h / tp); // dropout | |
} else { | |
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers | |
+ s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu) | |
+ s * b * h); // dropout | |
} | |
} else if (ff_activation === "swiglu") { | |
if (seq_parallel) { | |
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers | |
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function | |
+ s * b * h / tp); // dropout (note that dropout is lower-precision - boolean) | |
} else { | |
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers | |
+ s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function | |
+ s * b * h); // dropout (note that dropout is lower-precision - boolean) | |
} | |
} | |
let layerNorm; | |
if (seq_parallel) { | |
layerNorm = s * b * h * bytesPerValue / tp; | |
} else { | |
layerNorm = s * b * h * bytesPerValue; | |
} | |
const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; // section 4.3 | |
const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue; | |
const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue; | |
const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4; // In FP32 | |
let data | |
if (recomputation === "none" || recomputation === "selective") { | |
data = { | |
name: "Activation Memory", | |
children: [ | |
...Array.from({ length: L }, (_, index) => ({ | |
name: `Layer ${index + 1}`, | |
children: [ | |
{ name: 'Attention', value: oneLayerAttention }, | |
{ name: 'Feedforward', value: oneLayerFeedforward }, | |
{ name: 'LayerNorm', value: 2 * layerNorm }, | |
] | |
})), | |
{ name: 'Dropout', value: inputDropout }, | |
{ name: 'LayerNorm', value: outputLayerNorm }, | |
{ name: 'Projection', value: outputLayerProjection }, | |
{ name: 'Cross Entropy', value: outputCrossEntropy } | |
] | |
}; | |
} else if (recomputation === "full") { | |
data = { | |
name: "Activation Memory", | |
children: [ | |
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L }, | |
{ name: 'Dropout', value: inputDropout }, | |
{ name: 'LayerNorm', value: outputLayerNorm }, | |
{ name: 'Projection', value: outputLayerProjection }, | |
{ name: 'Cross Entropy', value: outputCrossEntropy } | |
] | |
}; | |
} else { | |
throw new Error("Invalid recomputation value"); | |
} | |
return data; | |
} | |
export function paramGradsOpt(h, L, s, v, k = 8, dp = 1, zero = 0, mixed = true) { | |
// h, # hidden dimension size | |
// L, # number of layers | |
// s, # sequence length | |
// v, # vocab size | |
// k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance) | |
// dp=1, # data parallelism | |
// zero = 0, 1, 2, 3, # zero data parallelism | |
// mixed=True # mixed precision training | |
console.log('paramGradsOpt called with:', { h, L, s, v, k, dp, zero, mixed }); | |
const emb = h * (v + s); | |
const oneLayer = 12 * h ** 2 + 13 * h; | |
const other = 2 * h; | |
const n = emb + L * oneLayer + other; | |
if (mixed) { | |
k += 4; | |
} | |
const bytesPerParameter = mixed ? 2 : 4; | |
const data = { | |
name: "Parameters / Gradients / Optimizer States", | |
children: [ | |
{ name: 'Parameters', value: zero >= 3 ? bytesPerParameter * n / dp : bytesPerParameter * n }, | |
{ name: 'Gradients', value: zero >= 2 ? bytesPerParameter * n / dp : bytesPerParameter * n }, | |
{ name: 'OptimizerAverages', value: zero >= 1 ? k * n / dp : k * n } | |
] | |
}; | |
console.log('paramGradsOpt result:', data); | |
return data; | |
} | |
export function updateGraph() { | |
console.log('updateGraph called'); | |
const a = +document.getElementById('a').value; | |
const b = +document.getElementById('b').value; | |
const h = +document.getElementById('h').value; | |
const h_ff = +document.getElementById('h_ff').value; | |
const L = +document.getElementById('L').value; | |
const s = +document.getElementById('s').value; | |
const v = +document.getElementById('v').value; | |
const k = +document.getElementById('k').value; | |
const tp = +document.getElementById('tp').value; // New: t parameter | |
const zero = document.getElementById('zero').value; | |
const dp = document.getElementById('dp').value; | |
const mixed = document.getElementById('mixed').checked; | |
const recomputation = document.getElementById('recomputation').value; | |
const ff_activation = document.getElementById('ff_activation').value; | |
const seq_parallel = document.getElementById('seq_parallel').checked; | |
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel }); | |
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel); | |
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed); | |
const data = { | |
name: "root", | |
children: [ | |
{ | |
name: 'Total', | |
value: 0, | |
children: [ | |
activationMemoryData, | |
paramGradsOptValue | |
] | |
} | |
] | |
}; | |
console.log('Data for treemap:', data); | |
const width = 700; | |
const height = 450; | |
const legendHeight = 50; | |
const svg = d3.select("#graph").select("svg"); | |
svg.selectAll("*").remove(); | |
svg.attr("viewBox", [0, 0, width, height + legendHeight]); | |
const treemap = d3.treemap() | |
.size([width, height]) | |
.paddingOuter(3) | |
.paddingTop(19) | |
.paddingInner(3) | |
.round(true); | |
const root = d3.hierarchy(data) | |
.sum(d => d.value); | |
// .sort((a, b) => b.value - a.value); | |
// const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes | |
// if (root.children[0].value < fixedSize100GB) { | |
// root.value = fixedSize100GB; | |
// root.children[0].value = fixedSize100GB; | |
// } | |
console.log('Treemap root:', root); | |
treemap(root); | |
const color = d => { | |
switch (d.data.name) { | |
// Root and Total (container levels) | |
case 'root': return 'rgb(225, 225, 225)'; // Light Grey | |
case 'Total': return 'rgb(225, 225, 225)'; // Light Grey | |
// Give distinct colors to the main section containers | |
case 'Activation Memory': return 'rgb(78, 165, 183)'; // Orange | |
case 'Parameters / Gradients / Optimizer States': return 'rgb(232, 137, 171)'; // Teal Blue | |
// Parameters / Gradients / Optimizer States branch | |
case 'Parameters': return 'rgb(206, 192, 250)'; // Blue | |
case 'Gradients': return 'rgb(227, 138, 66)'; // Orange | |
case 'OptimizerAverages': return 'rgb(78, 165, 183)'; // Pink | |
// activationMemory branch - Layer components | |
case 'Attention': return 'rgb(206, 192, 250)'; // Purple | |
case 'Feedforward': return 'rgb(171, 232, 241)'; // Light Blue | |
case 'LayerNorm': return 'rgb(232, 137, 171)'; // Light Green | |
// activationMemory branch - other components | |
case 'Dropout': return 'rgb(67, 145, 108)'; // Dark Green | |
case 'Projection': return 'rgb(174, 214, 251)'; // Sky Blue | |
case 'Cross Entropy': return 'rgb(232, 137, 171)'; // Pink | |
// Default for any Layer nodes and unexpected cases | |
default: return 'rgb(227, 138, 66)'; // Light Grey | |
}; | |
}; | |
if (d3.select('#tooltip').empty()) { | |
d3.select('body') | |
.append('div') | |
.attr('id', 'tooltip') | |
.style('opacity', 0) | |
.style('position', 'absolute') | |
.style('background-color', 'white') | |
.style('padding', '4px') | |
.style('font-size', '12px') | |
.style('border-radius', '5px') | |
.style('box-shadow', '0px 0px 5px 0px rgba(0,0,0,0.3)'); | |
} | |
const cell = svg.selectAll("g") | |
.data(root.descendants().filter(d => d.depth !== 0)) // Skip root node | |
.join("g") | |
.attr("transform", d => `translate(${d.x0},${d.y0})`) | |
.on('mouseover', (event, d) => { | |
const name = d.data.name; | |
const value = formatBytes(d.value); | |
d3.select('#tooltip').transition().duration(200).text(`${name}: ${value}`) | |
}) | |
.on('mouseout', function() { | |
d3.select('#tooltip').style('opacity', 0) | |
}) | |
.on('mousemove', function(event) { | |
d3.select('#tooltip').style('left', (event.pageX + 10) + 'px').style('top', (event.pageY + 10) + 'px').style('opacity', 1) | |
}); | |
cell.append("rect") | |
.attr("width", d => d.x1 - d.x0) | |
.attr("height", d => d.y1 - d.y0) | |
.attr("fill", d => color(d)) | |
.attr("stroke", d => d.depth === 1 ? color(d) : "white") | |
.attr("stroke-width", 1); | |
const fontSize = 10; | |
const padding = 2; | |
cell.append("text") | |
.attr("font-size", `${fontSize}px`) | |
.attr("font-family", "sans-serif") | |
.each(function (d) { | |
const node = d3.select(this); | |
const name = d.data.name; | |
const value = formatBytes(d.value); | |
if (d.depth === 1 || d.depth === 2) { | |
node.attr("transform", `translate(${padding},${fontSize + padding})`) | |
.attr("font-weight", "bold") | |
.attr("font-size", 12) | |
.text(`${name}: ${value}`); | |
} else { | |
// Child nodes | |
node.attr("transform", `translate(${padding},${fontSize + padding})`) | |
.text(name[0].toUpperCase()) // Display only the first letter | |
.attr("font-weight", "bold") | |
.append("title") // Add title for hover effect | |
.text(`${name}: ${value}`); | |
} | |
}); | |
/* | |
// Adjust legend positioning | |
const legendData = root.children[0].children.concat(root.children[0]); | |
const legend = svg.append("g") | |
.attr("font-family", "sans-serif") | |
.attr("font-size", 10) | |
.attr("text-anchor", "start") | |
.attr("transform", `translate(0, ${height})`) | |
.selectAll("g") | |
.data(legendData) | |
.join("g") | |
.attr("transform", (d, i) => `translate(${i * 240}, 0)`); | |
legend.append("rect") | |
.attr("x", 0) | |
.attr("width", 19) | |
.attr("height", 19) | |
.attr("fill", d => color(d)) | |
.attr("stroke", '#f3f3f3') | |
.attr("stroke-width", 0); | |
legend.append("text") | |
.attr("x", 24) | |
.attr("y", 9.5) | |
.attr("dy", "0.32em") | |
.text(d => `${d.data.name}: ${formatBytes(d.value)}`); | |
*/ | |
console.log('Treemap nodes created'); | |
} | |
function formatBytes(bytes) { | |
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB']; | |
if (bytes === 0) return '0 Bytes'; | |
const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1024)), 10); | |
return `${(bytes / (1024 ** i)).toFixed(2)} ${sizes[i]}`; | |
} | |
const presets = { | |
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false }, | |
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }, | |
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }, | |
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false } | |
}; | |
function setPresetValues(preset) { | |
if (preset === "custom") return; | |
const values = presets[preset]; | |
Object.keys(values).forEach(key => { | |
const element = document.getElementById(key); | |
const inputElement = document.getElementById(`${key}_input`); | |
if (element) { | |
if (element.type === 'checkbox') { | |
element.checked = values[key]; | |
} else { | |
element.value = values[key]; | |
} | |
} | |
if (inputElement) { | |
inputElement.value = values[key]; | |
} | |
}); | |
updateGraph(); // Add this line to ensure the graph updates when a preset is selected | |
} | |
function syncSliderAndInput(sliderId, inputId) { | |
const slider = document.getElementById(sliderId); | |
const input = document.getElementById(inputId); | |
slider.addEventListener('input', () => { | |
input.value = slider.value; | |
updateGraph(); | |
}); | |
input.addEventListener('input', () => { | |
let value = parseInt(input.value); | |
if (isNaN(value)) { | |
value = parseInt(slider.min); | |
} | |
value = Math.max(parseInt(slider.min), Math.min(parseInt(slider.max), value)); | |
slider.value = value; | |
input.value = value; | |
updateGraph(); | |
}); | |
} | |
export const init_memory_plot = function () { | |
console.log('Initializing memory plot'); | |
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp']; | |
sliderIds.forEach(id => { | |
const slider = document.getElementById(id); | |
const input = document.getElementById(`${id}_input`); | |
if (slider && input) { | |
syncSliderAndInput(id, `${id}_input`); | |
} else { | |
console.warn(`Elements for ${id} not found`); | |
} | |
}); | |
const recomputationSelect = document.getElementById('recomputation'); | |
if (recomputationSelect) { | |
recomputationSelect.addEventListener('change', updateGraph); | |
} else { | |
console.warn('Recomputation select not found'); | |
} | |
const ffActivationSelect = document.getElementById('ff_activation'); | |
if (ffActivationSelect) { | |
ffActivationSelect.addEventListener('change', updateGraph); | |
} else { | |
console.warn('FF Activation select not found'); | |
} | |
const zeroSelect = document.getElementById('zero'); | |
if (zeroSelect) { | |
zeroSelect.addEventListener('change', updateGraph); | |
} else { | |
console.warn('Zero select not found'); | |
} | |
const mixedCheckbox = document.getElementById('mixed'); | |
if (mixedCheckbox) { | |
mixedCheckbox.addEventListener('change', updateGraph); | |
} else { | |
console.warn('Mixed checkbox not found'); | |
} | |
const seqParallelCheckbox = document.getElementById('seq_parallel'); | |
if (seqParallelCheckbox) { | |
seqParallelCheckbox.addEventListener('change', updateGraph); | |
} else { | |
console.warn('Seq Parallel checkbox not found'); | |
} | |
const presetSelect = document.getElementById('presets'); | |
if (presetSelect) { | |
presetSelect.addEventListener('change', (event) => { | |
setPresetValues(event.target.value); | |
}); | |
} else { | |
console.warn('Preset select not found'); | |
} | |
// Set max values for sliders | |
sliderIds.forEach(id => { | |
const slider = document.getElementById(id); | |
if (slider) { | |
switch (id) { | |
case 'a': slider.max = '128'; break; | |
case 'b': slider.max = '53248'; break; | |
case 'h': slider.max = '16384'; break; | |
case 'h_ff': slider.max = '65536'; break; | |
case 'L': slider.max = '126'; break; | |
case 's': slider.max = '128000'; break; | |
case 'v': slider.max = '100000'; break; | |
case 'k': slider.max = '16'; break; | |
case 'tp': slider.max = '16'; break; | |
case 'dp': slider.max = '256'; break; | |
} | |
} else { | |
console.warn(`Slider ${id} not found`); | |
} | |
}); | |
console.log('Adding svg'); | |
const graphContainer = document.getElementById('graph'); | |
if (graphContainer) { | |
const svg = d3.select("#graph") | |
.append("svg") | |
} else { | |
console.warn('Graph container not found'); | |
} | |
updateGraph(); | |
}; |