{"cells":[{"cell_type":"markdown","metadata":{"id":"28e4c4d1-a73f-437b-a1bd-c2cc3874924a"},"source":["# 강의 11주차: midm-food-order-understanding\n","\n","1. KT-AI/midm-bitext-S-7B-inst-v1 를 주문 문장 이해에 미세 튜닝\n","\n","- food-order-understanding-small-3200.json (학습)\n","- food-order-understanding-small-800.json (검증)\n","\n","\n","종속적인 필요 내용\n","- huggingface 계정 설정 및 llama-2 사용 승인\n","- 로깅을 위한 wandb"],"id":"28e4c4d1-a73f-437b-a1bd-c2cc3874924a"},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":28937,"status":"ok","timestamp":1702304112153,"user":{"displayName":"강소연","userId":"16433369612605492183"},"user_tz":-540},"id":"nDZe_wqKU6J3","outputId":"45e12280-6f1d-4fd9-9961-4cea5ec99fa2"},"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n","Collecting peft\n"," Downloading peft-0.7.0-py3-none-any.whl (168 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.3/168.3 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting accelerate\n"," Downloading accelerate-0.25.0-py3-none-any.whl (265 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m265.7/265.7 kB\u001b[0m \u001b[31m7.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting optimum\n"," Downloading optimum-1.15.0-py3-none-any.whl (400 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.9/400.9 kB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting bitsandbytes\n"," Downloading bitsandbytes-0.41.3.post1-py3-none-any.whl (92.6 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting trl\n"," Downloading trl-0.7.4-py3-none-any.whl (133 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.9/133.9 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting wandb\n"," Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m101.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting einops\n"," Downloading einops-0.7.0-py3-none-any.whl (44 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.6/44.6 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.4)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n","Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n","Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.1.0+cu118)\n","Collecting coloredlogs (from optimum)\n"," Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from optimum) (1.12)\n","Collecting datasets (from optimum)\n"," Downloading datasets-2.15.0-py3-none-any.whl (521 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m521.2/521.2 kB\u001b[0m \u001b[31m56.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting tyro>=0.5.11 (from trl)\n"," Downloading tyro-0.6.0-py3-none-any.whl (100 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.9/100.9 kB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.7)\n","Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n"," Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.6/190.6 kB\u001b[0m \u001b[31m26.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting sentry-sdk>=1.0.0 (from wandb)\n"," Downloading sentry_sdk-1.38.0-py2.py3-none-any.whl (252 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m252.8/252.8 kB\u001b[0m \u001b[31m33.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n"," Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n","Collecting setproctitle (from wandb)\n"," Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n","Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n","Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n","Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n","Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n","Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n"," Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (2023.6.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.16.4->transformers) (4.5.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.2)\n","Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.1.0)\n","Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers)\n"," Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m53.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl)\n"," Downloading docstring_parser-0.15-py3-none-any.whl (36 kB)\n","Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.0)\n","Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl)\n"," Downloading shtab-1.6.5-py3-none-any.whl (13 kB)\n","Collecting humanfriendly>=9.1 (from coloredlogs->optimum)\n"," Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (9.0.0)\n","Collecting pyarrow-hotfix (from datasets->optimum)\n"," Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)\n","Collecting dill<0.3.8,>=0.3.0 (from datasets->optimum)\n"," Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m16.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (1.5.3)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.4.1)\n","Collecting multiprocess (from datasets->optimum)\n"," Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->optimum) (3.9.1)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->optimum) (1.3.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (23.1.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (6.0.4)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.9.3)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.4.0)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (1.3.1)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->optimum) (4.0.3)\n","Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n"," Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.3)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->optimum) (2023.3.post1)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n","Installing collected packages: sentencepiece, bitsandbytes, smmap, shtab, setproctitle, sentry-sdk, pyarrow-hotfix, humanfriendly, einops, docstring-parser, docker-pycreds, dill, multiprocess, gitdb, coloredlogs, tyro, GitPython, accelerate, wandb, datasets, trl, peft, optimum\n","Successfully installed GitPython-3.1.40 accelerate-0.25.0 bitsandbytes-0.41.3.post1 coloredlogs-15.0.1 datasets-2.15.0 dill-0.3.7 docker-pycreds-0.4.0 docstring-parser-0.15 einops-0.7.0 gitdb-4.0.11 humanfriendly-10.0 multiprocess-0.70.15 optimum-1.15.0 peft-0.7.0 pyarrow-hotfix-0.6 sentencepiece-0.1.99 sentry-sdk-1.38.0 setproctitle-1.3.3 shtab-1.6.5 smmap-5.0.1 trl-0.7.4 tyro-0.6.0 wandb-0.16.1\n"]}],"source":["pip install transformers peft accelerate optimum bitsandbytes trl wandb einops"],"id":"nDZe_wqKU6J3"},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":16540,"status":"ok","timestamp":1702304128691,"user":{"displayName":"강소연","userId":"16433369612605492183"},"user_tz":-540},"id":"51eb00d7-2928-41ad-9ae9-7f0da7d64d6d","outputId":"1bb8ec1d-fcce-42fe-c50d-9a9eb4eab964"},"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n"," warnings.warn(\n"]}],"source":["import os\n","from dataclasses import dataclass, field\n","from typing import Optional\n","import re\n","\n","import torch\n","import tyro\n","from accelerate import Accelerator\n","from datasets import load_dataset, Dataset\n","from peft import AutoPeftModelForCausalLM, LoraConfig\n","from tqdm import tqdm\n","from transformers import (\n"," AutoModelForCausalLM,\n"," AutoTokenizer,\n"," BitsAndBytesConfig,\n"," TrainingArguments,\n",")\n","\n","from trl import SFTTrainer\n","\n","from trl.trainer import ConstantLengthDataset"],"id":"51eb00d7-2928-41ad-9ae9-7f0da7d64d6d"},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":145,"referenced_widgets":["2b28c142d3504b168c8ca93856e5f400","f8b42df8b3ec48c5adad8e45bd3cc0e5","6fcba02908d049c7b6d18a08039b354f","76082eb9f69c41f0a57504b47d706e70","57303db2d6194d6a8fec0e96bd7c130c","20531fc5c93344cf930bab835ab50e96","f960a2a70e2d48049eae8364cc9abedb","a819e4fcfb634c029ef7c003fdc135ff","10d04d442c1f4a6c9ba259afddd18d38","9523fe6a14974dfda3eae377a38347bf","82e9a4454761432f986881b658ed1611","662a04a5e49141448bc2740b16b1a1f3","bb476aa148fd41f0b4d9db3a6c3868ab","b261703915894d2185bfad000fdc749f","966845f63156483ba43dce0c76c7d755","6ea387fff5e545a0b5993d120bf19afc","56389cfeabdb4469839d273dc8385be5","df584a10f11a4d12988de9e1a4b41bab","452d8f4301134e49bb76dc4de68c6a20","f43ec23c00884a78964f259f2a759dfa","b1a6a6a0f2be42c089140e9ded9a1402","dee8e2d3da654053bad95371cd7da293","2a49ee7ef8e84e708101806c8ab5bcdc","f233f20aa0e84286a595583db146c997","9c7ecf7f9c91431787ac58d9daacc5a5","4677dcb3c9244357a72f2c6f114a4739","c44beec239d34c2a9b5e7c8447480c51","4032b77e3ec4459baff16a6f0625a9e6","c853335fe4c24286975d25bcdf062aac","bb353e431fc74ee5bfbb1fb6883c19da","b02c90c1b1274747ac1ce1dfa91e17fb","f933303dd9fa4db3a6d35445ca4408bc"]},"executionInfo":{"elapsed":5,"status":"ok","timestamp":1702304128691,"user":{"displayName":"강소연","userId":"16433369612605492183"},"user_tz":-540},"id":"tX7gYxZaVhYL","outputId":"87758fcb-9e73-4754-cd66-27ad14b5657f"},"outputs":[{"output_type":"display_data","data":{"text/plain":["VBox(children=(HTML(value='
/content/wandb/run-20231116_163415-mthnd6sk
"],"text/plain":["Step | \n","Training Loss | \n","
---|---|
50 | \n","1.043500 | \n","
100 | \n","0.547800 | \n","
150 | \n","0.505000 | \n","
200 | \n","0.495700 | \n","
250 | \n","0.518000 | \n","
300 | \n","0.497200 | \n","
"],"text/plain":["
Copy a token from your Hugging Face\ntokens page and paste it below.
Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file.