{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from utils import activation_memory, param_grads_opt" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "def activation_memory(\n", " a, # attention heads\n", " b, # micro batch size\n", " h, # hidden dimension size\n", " h_ff, # feedforward dimension size (often h_ff = 4h)\n", " L, # number of layers\n", " s, # sequence length\n", " mixed=True,\n", " recomputation=\"none\",\n", " ff_activation=\"relu\"\n", " ):\n", " \n", " # https://arxiv.org/pdf/2205.05198\n", " if mixed:\n", " bytes_per_value = 2 \n", " else:\n", " bytes_per_value = 4\n", "\n", " one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) # eq (2)\n", "\n", " if ff_activation == \"relu\":\n", " one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) # inputs of 1st/2nd linear layers\n", " + s * b * h) # dropout\n", " elif ff_activation == \"gelu\":\n", " one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) # inputs of 1st/2nd linear layers\n", " + s * b * h_ff * bytes_per_value # inputs of activation function (not really necessary for Relu)\n", " + s * b * h) # dropout\n", " elif ff_activation == \"swiglu\":\n", " one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value) # inputs of input/output linear layers\n", " + s * b * h_ff * bytes_per_value * 3 # inputs of activation function\n", " + s * b * h) # dropout (note that dropout is lower-precision - boolean)\n", "\n", "\n", " layer_norm = s * b * h * bytes_per_value\n", "\n", " if recomputation == \"none\":\n", " one_layer = one_layer_attention + one_layer_feedforward + 2 * layer_norm # eq (2)\n", " elif recomputation ==\"selective\":\n", " one_layer = s * b * h * 34 # eq (6)\n", " elif recomputation ==\"full\":\n", " one_layer = s * b * h * 2\n", " else:\n", " raise ValueError()\n", " \n", " input_dropout = s * b * h # section 4.3\n", "\n", " total = L * one_layer + input_dropout\n", " \n", " return total\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "a = 16\n", "b = 3\n", "h = 1024\n", "h_ff = 4 * h\n", "L = 1\n", "s = 7 # 128000\n", "recomputation = \"none\"\n", "mixed = True\n", "ff_activation = \"swiglu\"\n" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1086960" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "activation_memory(a=a, b=b, h=h, h_ff=h_ff, L=L, s=s, recomputation=recomputation, mixed=mixed, ff_activation=ff_activation)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from math import log\n", "\n", "def format_bytes(bytes):\n", " sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB']\n", " if bytes == 0:\n", " return '0 Bytes'\n", " i = int(log(bytes, 1024))\n", " print(i)\n", " p = 1024 ** i\n", " s = round(bytes / p, 2)\n", " return f\"{s} {sizes[i]}\"\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] }, { "data": { "text/plain": [ "'22.13 TB'" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "format_bytes(activation_memory(a, b, h, L, s, recomputation))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "jupyter", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }