{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ID8Vmc0vF4Ay", "outputId": "76f7bedf-7711-483b-9d5e-e340d8cfbc36" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.47.1)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.16.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.27.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.0)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2024.9.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.12.14)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.11/dist-packages (3.2.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets) (3.16.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from datasets) (1.26.4)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (17.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets) (2.2.2)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.11/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.11/dist-packages (from datasets) (4.67.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets) (3.5.0)\n", "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets) (3.11.11)\n", "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.27.1)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from datasets) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (2.4.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.3.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (24.3.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (0.2.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.18.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.32.2->datasets) (2024.12.14)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\n", "Requirement already satisfied: evaluate in /usr/local/lib/python3.11/dist-packages (0.4.3)\n", "Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.2.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from evaluate) (1.26.4)\n", "Requirement already satisfied: dill in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.2.2)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.11/dist-packages (from evaluate) (4.67.1)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from evaluate) (3.5.0)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.70.16)\n", "Requirement already satisfied: fsspec>=2021.05.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]>=2021.05.0->evaluate) (2024.9.0)\n", "Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from evaluate) (0.27.1)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from evaluate) (24.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (3.16.1)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (17.0.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (3.11.11)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from datasets>=2.0.0->evaluate) (6.0.2)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->evaluate) (2024.12.14)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->evaluate) (2024.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.4.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.2)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (24.3.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.1.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (0.2.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.18.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.17.0)\n", "Requirement already satisfied: rouge_score in /usr/local/lib/python3.11/dist-packages (0.1.2)\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.11/dist-packages (from rouge_score) (1.4.0)\n", "Requirement already satisfied: nltk in /usr/local/lib/python3.11/dist-packages (from rouge_score) (3.9.1)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from rouge_score) (1.26.4)\n", "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.11/dist-packages (from rouge_score) (1.17.0)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk->rouge_score) (8.1.8)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.11/dist-packages (from nltk->rouge_score) (1.4.2)\n", "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.11/dist-packages (from nltk->rouge_score) (2024.11.6)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from nltk->rouge_score) (4.67.1)\n", "Requirement already satisfied: nltk in /usr/local/lib/python3.11/dist-packages (3.9.1)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk) (8.1.8)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.11/dist-packages (from nltk) (1.4.2)\n", "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.11/dist-packages (from nltk) (2024.11.6)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from nltk) (4.67.1)\n", "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (0.2.0)\n" ] } ], "source": [ "!pip install transformers\n", "!pip install datasets\n", "!pip install evaluate\n", "!pip install rouge_score\n", "!pip install nltk\n", "!pip install sentencepiece" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Vhj7zRU5F2tU" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer\n", "import torch\n", "import time\n", "import evaluate\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_rMVgsurF2tW", "outputId": "03ba2bdb-f8d0-42dc-f2e7-94ccb3f44287" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA available, using CUDA\n" ] } ], "source": [ "if torch.cuda.is_available():\n", " print(\"CUDA available, using CUDA\")\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " print(\"MLX available, using MLX\")\n", " device = torch.device(\"mps\")\n", "else:\n", " print(\"Using CPU\")\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "78YtJIJeF2tW", "outputId": "5fddb779-0c9d-4752-af4e-50a894a25c01" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'dialogue', 'summary', 'topic'],\n", " num_rows: 12460\n", " })\n", " validation: Dataset({\n", " features: ['id', 'dialogue', 'summary', 'topic'],\n", " num_rows: 500\n", " })\n", " test: Dataset({\n", " features: ['id', 'dialogue', 'summary', 'topic'],\n", " num_rows: 1500\n", " })\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_dataset_name = \"knkarthick/dialogsum\"\n", "dataset = load_dataset(training_dataset_name)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "4AjBmSxVF2tW" }, "outputs": [], "source": [ "model_name = \"google/flan-t5-base\"\n", "base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "gwWyK9aEF2tW", "outputId": "c7e9a9fe-fae6-4739-a6dd-38e54e6b83d3" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'All parameters: 247577856 \\n Trainable parameters: 247577856 \\n Percentage Trainable: 100.0'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def print_number_of_trainable_model_parameters(model):\n", " trainable_model_params = 0\n", " all_model_params = 0\n", " for _, param in model.named_parameters():\n", " all_model_params += param.numel()\n", " if param.requires_grad:\n", " trainable_model_params += param.numel()\n", " return f\"All parameters: {all_model_params} \\n Trainable parameters: {trainable_model_params} \\n Percentage Trainable: {trainable_model_params/all_model_params * 100}\"\n", "\n", "print_number_of_trainable_model_parameters(base_model)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3ZGNban5F2tX", "outputId": "0587a92b-853e-44cd-bcec-3da20d073ab3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "\n", "Summarize the following conversation.\n", "#Person1#: Taxi!\n", "#Person2#: Where will you go, sir?\n", "#Person1#: Friendship Hotel.\n", "#Person2#: OK, it's not far from here.\n", "#Person1#: I have something important to do, can you fast the speed?\n", "#Person2#: Sure, I'll try my best. Here we are.\n", "#Person1#: It's fast! How much should I pay you?\n", "#Person2#: The reading on the meter is 15 yuan.\n", "#Person1#: Here's 20 yuan, keep the change.\n", "#Person2#: Thank you very much.\n", "Summary:\n", "\n", "---------------------------------------------------------------------------------------------------\n", "BASELINE SUMMARY:\n", " #Person1# takes a taxi to the Friendship Hotel for something important.\n", "\n", "---------------------------------------------------------------------------------------------------\n", "MODEL GENERATED SUMMARY:\n", " The taxi will pick you up at the Friendship Hotel.\n" ] } ], "source": [ "index = 150\n", "\n", "dialogue = dataset['test'][index]['dialogue']\n", "summary = dataset['test'][index]['summary']\n", "prompt = f\"\"\"\n", "Summarize the following conversation.\n", "{dialogue}\n", "Summary:\n", "\"\"\"\n", "inputs = tokenizer(prompt, return_tensors='pt')\n", "output = tokenizer.decode(\n", " base_model.generate(\n", " inputs[\"input_ids\"],\n", " max_new_tokens = 200,\n", " )[0],\n", " skip_special_tokens=True\n", ")\n", "dash_line = '-'.join('' for x in range(100))\n", "print(dash_line)\n", "print(f'INPUT PROMPT:\\n{prompt}')\n", "print(dash_line)\n", "print(f'BASELINE SUMMARY:\\n {summary}\\n')\n", "print(dash_line)\n", "print(f'MODEL GENERATED SUMMARY:\\n {output}')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "4ZRPbFyuF2tX" }, "outputs": [], "source": [ "def tokenize_function(example):\n", " start_prompt = \"Summarize the following conversation.\\n\\n\"\n", " end_prompt = '\\n\\nSummary: '\n", " prompt = [start_prompt + dialogue + end_prompt for dialogue in example['dialogue']]\n", " example['input_ids'] = tokenizer(prompt, padding='max_length', truncation=True, return_tensors='pt').input_ids\n", " example['labels'] = tokenizer(example['summary'], padding='max_length', truncation=True, return_tensors='pt').input_ids\n", " return example" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "191e83ce3953499c9be33286ec762c0a", "1fee7f795e714d028ff8af71b9ae56de", "ba8f14998569469b95e7153e89a86098", "f692c5360afe44839ed23e80a35e7f62", "3d15497caf044cc7bc4034ef8bc48238", "733007b2c33a42ebb42d6ce8392b514b", "5cfb7f92123f4884aa4442e227277378", "2c4d28068f1d4db5987ddd166c71477e", "775b61d5bde74ac7a8f74262ee208f0a", "89cd27d582994cfb8db85a8be2df76e3", "20b63bfeec184873a1246086ececfc0c" ] }, "id": "PXto6CyVF2tX", "outputId": "876baf5f-1b0c-4b5a-d119-364b956519d4" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "191e83ce3953499c9be33286ec762c0a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/500 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /content/wandb/run-20250120_072049-urybv4uy" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run ./training-1737357642 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/dhrumeen-umass/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/dhrumeen-umass/huggingface/runs/urybv4uy" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [100/100 04:43, Epoch 6/7]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
149.000000
223.875000
314.937500
46.031200
54.562500
64.312500
74.000000
83.640600
93.109400
102.843800
112.406200
122.343800
131.851600
141.687500
151.460900
161.203100
171.015600
181.085900
190.761700
200.632800
210.558600
220.494100
230.425800
240.347700
250.320300
260.243200
270.259800
280.222700
290.219700
300.359400
310.207000
320.202100
330.191400
340.204100
350.156200
360.118700
370.136700
380.127900
390.248000
400.112800
410.135700
420.148400
430.150400
440.094200
450.127900
460.119600
470.165000
480.218800
490.119100
500.123500
510.124500
520.130900
530.102500
540.079600
550.128900
560.074200
570.084500
580.099600
590.111800
600.102100
610.090300
620.096200
630.120600
640.158200
650.061500
660.082500
670.067900
680.096200
690.098100
700.073700
710.079100
720.106900
730.155300
740.096200
750.075200
760.096200
770.092800
780.084000
790.075200
800.064500
810.058600
820.067900
830.080600
840.094200
850.066400
860.089400
870.088400
880.095700
890.098600
900.057600
910.084500
920.056900
930.072800
940.147500
950.070800
960.053200
970.069300
980.064900
990.067400
1000.088900

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=100, training_loss=1.41803466796875, metrics={'train_runtime': 292.467, 'train_samples_per_second': 2.735, 'train_steps_per_second': 0.342, 'total_flos': 535480249614336.0, 'train_loss': 1.41803466796875, 'epoch': 6.25})" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "XWoSZE_zF2tY" }, "outputs": [], "source": [ "trained_model_dir = './trained_model'\n", "trainer.save_model(trained_model_dir)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "0DgzurAoF2tY" }, "outputs": [], "source": [ "trained_model_dir = './trained_model'\n", "trained_model = AutoModelForSeq2SeqLM.from_pretrained(trained_model_dir)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wPOZmfIoF2tY", "outputId": "7d31d962-b30e-420d-cf19-b7822624f2f4" }, "outputs": [ { "data": { "text/plain": [ "T5ForConditionalGeneration(\n", " (shared): Embedding(32128, 768)\n", " (encoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 768)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " (relative_attention_bias): Embedding(32, 12)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseGatedActDense(\n", " (wi_0): Linear(in_features=768, out_features=2048, bias=False)\n", " (wi_1): Linear(in_features=768, out_features=2048, bias=False)\n", " (wo): Linear(in_features=2048, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): NewGELUActivation()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (1-11): 11 x T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerFF(\n", " (DenseReluDense): T5DenseGatedActDense(\n", " (wi_0): Linear(in_features=768, out_features=2048, bias=False)\n", " (wi_1): Linear(in_features=768, out_features=2048, bias=False)\n", " (wo): Linear(in_features=2048, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): NewGELUActivation()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (final_layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (decoder): T5Stack(\n", " (embed_tokens): Embedding(32128, 768)\n", " (block): ModuleList(\n", " (0): T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " (relative_attention_bias): Embedding(32, 12)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseGatedActDense(\n", " (wi_0): Linear(in_features=768, out_features=2048, bias=False)\n", " (wi_1): Linear(in_features=768, out_features=2048, bias=False)\n", " (wo): Linear(in_features=2048, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): NewGELUActivation()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (1-11): 11 x T5Block(\n", " (layer): ModuleList(\n", " (0): T5LayerSelfAttention(\n", " (SelfAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): T5LayerCrossAttention(\n", " (EncDecAttention): T5Attention(\n", " (q): Linear(in_features=768, out_features=768, bias=False)\n", " (k): Linear(in_features=768, out_features=768, bias=False)\n", " (v): Linear(in_features=768, out_features=768, bias=False)\n", " (o): Linear(in_features=768, out_features=768, bias=False)\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): T5LayerFF(\n", " (DenseReluDense): T5DenseGatedActDense(\n", " (wi_0): Linear(in_features=768, out_features=2048, bias=False)\n", " (wi_1): Linear(in_features=768, out_features=2048, bias=False)\n", " (wo): Linear(in_features=2048, out_features=768, bias=False)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (act): NewGELUActivation()\n", " )\n", " (layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (final_layer_norm): T5LayerNorm()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (lm_head): Linear(in_features=768, out_features=32128, bias=False)\n", ")" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "input_ids = input_ids.to(device)\n", "base_model.to(device)\n", "trained_model.to(device)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "H9f5QCZzF2tY" }, "outputs": [], "source": [ "generation_config = GenerationConfig(max_new_tokens=200, num_beams=1)\n", "original_model_outputs = base_model.generate(input_ids=input_ids, generation_config=generation_config)\n", "original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "Wmeo0oqGF2tY" }, "outputs": [], "source": [ "trained_model_outputs = trained_model.generate(input_ids=input_ids, generation_config=generation_config)\n", "trained_model_text_output = tokenizer.decode(trained_model_outputs[0], skip_special_tokens=True)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "So_8WpbbF2tY", "outputId": "d1585112-5af9-4c8b-f0d6-70fdec6186fa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n", "BASELINE HUMAN SUMMARY: \n", "#Person1# takes a taxi to the Friendship Hotel for something important.\n", "--------------------------------------------------\n", "ORIGINAL MODEL: \n", "#Person1# is a taxi to Friendship Hotel. #Person2# will try his best to get the meter, but #Person1# will try his best.\n", "--------------------------------------------------\n", "TRAINED MODEL: \n", "#Person1# will go to Friendship Hotel. #Person1# will try his best to pay #Person1#.\n" ] } ], "source": [ "human_baseline_summary = summary\n", "dash_line = '-' * 50\n", "print(dash_line)\n", "print(f'BASELINE HUMAN SUMMARY: \\n{human_baseline_summary}')\n", "print(dash_line)\n", "print(f'ORIGINAL MODEL: \\n{original_model_text_output}')\n", "print(dash_line)\n", "print(f'TRAINED MODEL: \\n{trained_model_text_output}')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "18201ec200b143d0b982926365167688": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "191e83ce3953499c9be33286ec762c0a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_1fee7f795e714d028ff8af71b9ae56de", "IPY_MODEL_ba8f14998569469b95e7153e89a86098", "IPY_MODEL_f692c5360afe44839ed23e80a35e7f62" ], "layout": "IPY_MODEL_3d15497caf044cc7bc4034ef8bc48238" } }, "1fee7f795e714d028ff8af71b9ae56de": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_733007b2c33a42ebb42d6ce8392b514b", "placeholder": "​", "style": "IPY_MODEL_5cfb7f92123f4884aa4442e227277378", "value": "Map: 100%" } }, "20b63bfeec184873a1246086ececfc0c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2c4d28068f1d4db5987ddd166c71477e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2e038389704c4f40b8ab9cead672e725": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_e85554e5cd22434e892693edbcc83a4b", "IPY_MODEL_b9eb2c31392945a1bc3bacdd782addac", "IPY_MODEL_5a5a96fba0334b50814fd03b86c418fc" ], "layout": "IPY_MODEL_e6466602735343eabb67125d74258634" } }, "3d15497caf044cc7bc4034ef8bc48238": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4885cbf3e7594b0ab825721187fe0c27": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "5a5a96fba0334b50814fd03b86c418fc": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_9ce011dc17e64018ba76b7f408f99c78", "placeholder": "​", "style": "IPY_MODEL_5ee5e919d1e8426aa9aa9108ac46d26d", "value": " 500/500 [00:00<00:00, 1056.20 examples/s]" } }, "5cfb7f92123f4884aa4442e227277378": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5ee5e919d1e8426aa9aa9108ac46d26d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "5f181ed93540409eabf5481cfcca8912": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "733007b2c33a42ebb42d6ce8392b514b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "775b61d5bde74ac7a8f74262ee208f0a": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "89cd27d582994cfb8db85a8be2df76e3": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "9ce011dc17e64018ba76b7f408f99c78": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "b9eb2c31392945a1bc3bacdd782addac": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_5f181ed93540409eabf5481cfcca8912", "max": 500, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_4885cbf3e7594b0ab825721187fe0c27", "value": 500 } }, "ba8f14998569469b95e7153e89a86098": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2c4d28068f1d4db5987ddd166c71477e", "max": 500, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_775b61d5bde74ac7a8f74262ee208f0a", "value": 500 } }, "e6466602735343eabb67125d74258634": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e85554e5cd22434e892693edbcc83a4b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_18201ec200b143d0b982926365167688", "placeholder": "​", "style": "IPY_MODEL_f716d287b7e44e1f9399fa3d42a3a360", "value": "Filter: 100%" } }, "f692c5360afe44839ed23e80a35e7f62": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_89cd27d582994cfb8db85a8be2df76e3", "placeholder": "​", "style": "IPY_MODEL_20b63bfeec184873a1246086ececfc0c", "value": " 500/500 [00:01<00:00, 331.09 examples/s]" } }, "f716d287b7e44e1f9399fa3d42a3a360": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "nbformat": 4, "nbformat_minor": 0 }