{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.19.1)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.1)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.40.2)\n", "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.10.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.2)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.0.0)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.2)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2023.10.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.0b0)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.0)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.8.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.18.1)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.3.101)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.4.28)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.6)\n", "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.30.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.1.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.11.17)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "pip install datasets torch transformers peft" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from tqdm.notebook import tqdm\n", "\n", "from datasets import load_dataset\n", "import torch\n", "from torch.utils.data import DataLoader\n", "\n", "from peft import (\n", " get_peft_model,\n", " LoraConfig,\n", " TaskType,\n", ")\n", "from transformers import default_data_collator, Trainer, TrainingArguments\n", "\n", "from short_hf import ShortHFModel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# data = load_dataset(\"pg19\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n", "# dataloader = DataLoader(\n", "# data,\n", "# batch_size=2,\n", "# shuffle=True,\n", "# )" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "data = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n", "dataloader = DataLoader(\n", " data,\n", " batch_size=1,\n", " shuffle=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# !huggingface-cli login\n", "# pip install huggingface_hub\n", "!python3 -c \"from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#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" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "asifahmed\n" ] } ], "source": [ "!huggingface-cli whoami" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# pip install git+https://github.com/tri-ml/linear_open_lm.git\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9fcf366ecc414808b39285438599f5b9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "short_model.model.parameters()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7241732096" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n", "pytorch_total_params" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ " # Save the model state to the specified path.\n", "# model_dir='ShortModelSaved/'\n", "# short_model.model.save_pretrained(\n", "# save_directory=model_dir,\n", "# safe_serialization=True,\n", "# )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MistralDecoderLayer(\n", " (self_attn): MistralSdpaAttention(\n", " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " (rotary_emb): MistralRotaryEmbedding()\n", " )\n", " (mlp): MistralMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", " (act_fn): SiLU()\n", " )\n", " (input_layernorm): MistralRMSNorm()\n", " (post_attention_layernorm): MistralRMSNorm()\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "short_model.layers[0]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" ] }, { "data": { "text/plain": [ "['I am an avid fan of 3D printing. I have been using 3D printers for over 10 years and have been involved in the development of several 3D printers. I have also been involved in the development of several 3D printing software packages.\\n\\nI have been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages.']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# sample generationThe evolution of AI has lead to \n", "gen = short_model.model.generate(\n", " short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n", " max_new_tokens=256\n", ")\n", "short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# # sample generation\n", "# gen = short_model.model.generate(\n", "# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n", "# max_new_tokens=256\n", "# )\n", "# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute Importances" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "# for i, batch in enumerate(tqdm(dataloader)):\n", "# prompts = batch['text']\n", "\n", "# short_model.eval_importance(\n", "# prompts=prompts,\n", "# max_seq_len=MAX_SEQ_LEN,\n", "# stride=256,\n", "# max_gen_len=0\n", "# )" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "# short_model.importances" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Remove unimportant layers\n", "\n", "Layers removed when using subset of pg19 val set: [25, 26, 24, 27, 22, 23, 28, 21, 29]\n", "\n", "Authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance." ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "# short_model.remove_layers()" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "# short_model.remove_layers()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "# short_model.layers" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "# # reassign layer_idx to attentions for caching\n", "# for layer_idx, module in enumerate(short_model.layers):\n", "# module.self_attn.layer_idx = layer_idx" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# short_model.model.parameters()" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7241732096" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n", "# pytorch_total_params" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the paper states: \\\n", " - \"Our experiments reveal that the effect of layer removal is significantly more pronounced on generative\n", " tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and\n", " HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance\n", " drop, with scores approaching zero.\"" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "# gen = short_model.model.generate(\n", "# short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n", "# max_new_tokens=20,\n", "# use_cache=True\n", "# )\n", "# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "# gen = short_model.model.generate(I am an avid fan of \n", "# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n", "# max_new_tokens=20,\n", "# use_cache=True\n", "# )\n", "# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compute Angular Importances" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a6fd2bf4360b4aba801085bab0755a06", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3760 [00:00 8192). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (10738 > 8192). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (12860 > 8192). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (23091 > 8192). Running this sequence through the model will result in indexing errors\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "121ffe72baf143f4aeea4616bee88405", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=4): 0%| | 0/1000 [00:00 8192). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (15886 > 8192). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (28727 > 8192). Running this sequence through the model will result in indexing errors\n", "Token indices sequence length is longer than the specified maximum sequence length for this model (8257 > 8192). Running this sequence through the model will result in indexing errors\n" ] } ], "source": [ "def preprocess_function(examples):\n", " return tokenizer([\" \".join(x) for x in examples[\"Text\"]])\n", "\n", "\n", "\n", "tokenized_Falcon = Falcon.map(\n", " preprocess_function,\n", " batched=True,\n", " num_proc=4,\n", " remove_columns=Falcon[\"train\"].column_names,\n", ")" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d7b13436ae54624bd96973987373482", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map (num_proc=4): 0%| | 0/10000 [00:00= block_size:\n", " total_length = (total_length // block_size) * block_size\n", " # Split by chunks of block_size.\n", " result = {\n", " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n", "\n", "\"\"\"Apply the `group_texts` function over the entire dataset:\"\"\"\n", "\n", "lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)\n" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorForLanguageModeling\n", "\n", "# tokenizer.pad_token = tokenizer.eos_token\n", "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# from transformers import AutoModelForCausalLM, TrainingArguments, Trainer\n", "# import torch\n", "# model = AutoModelForCausalLM.from_pretrained(\"tensorplex-labs/pretraining-sn9-7B-5\", torch_dtype=torch.bfloat16)\n", "\n", "# print('Model Loaded!')\n" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MistralForCausalLM(\n", " (model): MistralModel(\n", " (embed_tokens): Embedding(32000, 4096)\n", " (layers): ModuleList(\n", " (0-29): 30 x MistralDecoderLayer(\n", " (self_attn): MistralSdpaAttention(\n", " (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", " (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", " (rotary_emb): MistralRotaryEmbedding()\n", " )\n", " (mlp): MistralMLP(\n", " (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", " (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n", " (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n", " (act_fn): SiLU()\n", " )\n", " (input_layernorm): MistralRMSNorm()\n", " (post_attention_layernorm): MistralRMSNorm()\n", " )\n", " )\n", " (norm): MistralRMSNorm()\n", " )\n", " (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n", ")" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.to('cuda')" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6805508096" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pytorch_total_params = sum(p.numel() for p in model.parameters())\n", "pytorch_total_params" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"Fine-Tuned-S9-2\",\n", " overwrite_output_dir=True,\n", " bf16=True,\n", " # evaluation_strategy=\"epoch\",\n", " evaluation_strategy=\"steps\",\n", " learning_rate=2e-5,\n", " weight_decay=0.01,\n", " num_train_epochs=1,\n", " per_device_train_batch_size=2,\n", " per_device_eval_batch_size=2,\n", " lr_scheduler_type = 'cosine',\n", " push_to_hub=False,\n", " save_total_limit = 2,\n", " # save_strategy = “no”\n", " load_best_model_at_end=False,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=lm_dataset[\"train\"],\n", " eval_dataset=lm_dataset[\"validation\"],\n", " # eval_dataset=lm_dataset[\"test\"],\n", " data_collator=data_collator,\n", ")" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Started Training!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mthatmlguy\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.17.0" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /workspace/ShortGPT/short_gpt/wandb/run-20240516_090043-ni1hktjg" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run misty-serenity-4 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/thatmlguy/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [ 2/6459 : < :, Epoch 0.00/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation Loss

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "ename": "OutOfMemoryError", "evalue": "CUDA out of memory. Tried to allocate 112.00 MiB. GPU ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[49], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# trainer.train()\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mStarted Training!\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2206\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2207\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2208\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2209\u001b[0m ):\n\u001b[1;32m 2210\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2211\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3138\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3138\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 3141\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3161\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 3159\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3160\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3161\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 3164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:822\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32..forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 822\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m convert_to_fp32(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:16\u001b[0m, in \u001b[0;36mautocast_decorator..decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_autocast\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:1158\u001b[0m, in \u001b[0;36mMistralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1155\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1157\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1158\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1159\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1160\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1166\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1167\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1168\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1170\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1171\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:1043\u001b[0m, in \u001b[0;36mMistralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1033\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1034\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1035\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1040\u001b[0m use_cache,\n\u001b[1;32m 1041\u001b[0m )\n\u001b[1;32m 1042\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1052\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:770\u001b[0m, in \u001b[0;36mMistralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m 768\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 769\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 770\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 771\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 773\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:179\u001b[0m, in \u001b[0;36mMistralMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fn(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mup_proj(x))\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 112.00 MiB. GPU " ] } ], "source": [ "# trainer.train()\n", "print('Started Training!')\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import math\n", "\n", "eval_results = trainer.evaluate()\n", "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# # referencing https://github.com/meta-llama/llama-recipes/blob/main/recipes/finetuning/huggingface_trainer/peft_finetuning.ipynb\n", "# eval_prompt = \"\"\"\n", "# Summarize this dialog:\n", "# A: Hi Tom, are you busy tomorrow's afternoon?\n", "# B: I'm pretty sure I am. What's up?\n", "# A: Can you go with me to the animal shelter?.\n", "# B: What do you want to do?\n", "# A: I want to get a puppy for my son.\n", "# B: That will make him so happy.\n", "# A: Yeah, we've discussed it many times. I think he's ready now.\n", "# B: That's good. Raising a dog is a tough issue. Like having a baby ;-) \n", "# A: I'll get him one of those little dogs.\n", "# B: One that won't grow up too big;-)\n", "# A: And eat too much;-))\n", "# B: Do you know which one he would like?\n", "# A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n", "# B: I bet you had to drag him away.\n", "# A: He wanted to take it home right away ;-).\n", "# B: I wonder what he'll name it.\n", "# A: He said he'd name it after his dead hamster - Lemmy - he's a great Motorhead fan :-)))\n", "# ---\n", "# Summary:\n", "# \"\"\"\n", "\n", "# model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n", "\n", "# model.eval()\n", "# with torch.no_grad():\n", "# print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100, use_cache=True)[0], skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# def get_preprocessed_samsum():\n", "# dataset = load_dataset(\"samsum\", split=\"train\")\n", "\n", "# prompt = (\n", "# f\"Summarize this dialog:\\n{{dialog}}\\n---\\nSummary:\\n\"\n", "# )\n", "\n", "# def apply_prompt_template(sample):\n", "# return {\n", "# \"prompt\": prompt.format(dialog=sample[\"dialogue\"]),\n", "# \"summary\": sample[\"summary\"],\n", "# }\n", "\n", "# dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))\n", "\n", "# def tokenize_add_label(sample):\n", "# prompt = tokenizer.encode(tokenizer.bos_token + sample[\"prompt\"], add_special_tokens=False)\n", "# summary = tokenizer.encode(sample[\"summary\"] + tokenizer.eos_token, add_special_tokens=False)\n", "# sample = {\n", "# \"input_ids\": prompt + summary,\n", "# \"attention_mask\" : [1] * (len(prompt) + len(summary)),\n", "# \"labels\": [-100] * len(prompt) + summary,\n", "# }\n", "\n", "# return sample\n", "\n", "# dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))\n", "\n", "# return dataset" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# model.train()\n", "\n", "# def create_peft_config(model):\n", "# peft_config = LoraConfig(\n", "# task_type=TaskType.CAUSAL_LM,\n", "# inference_mode=False,\n", "# r=8,\n", "# lora_alpha=32,\n", "# lora_dropout=0.05,\n", "# target_modules = [\"q_proj\", \"v_proj\"]\n", "# )\n", "\n", "# model = get_peft_model(model, peft_config)\n", "# model.print_trainable_parameters()\n", "# return model, peft_config\n", "\n", "# # create peft config\n", "# model, lora_config = create_peft_config(model)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# output_dir = \"tmp/\"\n", "\n", "# config = {\n", "# 'lora_config': lora_config,\n", "# 'learning_rate': 1e-6,\n", "# 'num_train_epochs': 1,\n", "# 'per_device_train_batch_size': 1,\n", "# 'gradient_checkpointing': False,\n", "# }\n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# training_args = TrainingArguments(\n", "# output_dir=output_dir,\n", "# overwrite_output_dir=True,\n", "# # logging strategies\n", "# logging_strategy=\"steps\",\n", "# logging_steps=10,\n", "# save_strategy=\"no\",\n", "# optim=\"adamw_torch_fused\",\n", "# **{k:v for k,v in config.items() if k != 'lora_config'}\n", "# )\n", "\n", "# # Create Trainer instance\n", "# trainer = Trainer(\n", "# model=model,\n", "# args=training_args,\n", "# train_dataset=get_preprocessed_samsum(),\n", "# data_collator=default_data_collator,\n", "# callbacks=[],\n", "# )\n", "\n", "# # Start training\n", "# trainer.train()" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# model.eval()\n", "# with torch.no_grad():\n", "# print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 4 }