{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "73e72549-69f2-46b5-b0f5-655777139972", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:17:50.964659Z", "iopub.status.busy": "2025-01-17T18:17:50.964450Z", "iopub.status.idle": "2025-01-17T18:17:53.646932Z", "shell.execute_reply": "2025-01-17T18:17:53.646697Z", "shell.execute_reply.started": "2025-01-17T18:17:50.964637Z" } }, "outputs": [], "source": [ "from datetime import datetime\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from transformers import BertTokenizer, BertModel\n", "from huggingface_hub import PyTorchModelHubMixin, notebook_login\n", "from torch.utils.data import Dataset, DataLoader\n", "from datasets import load_dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:17:53.648499Z", "iopub.status.busy": "2025-01-17T18:17:53.648417Z", "iopub.status.idle": "2025-01-17T18:17:53.650284Z", "shell.execute_reply": "2025-01-17T18:17:53.650113Z", "shell.execute_reply.started": "2025-01-17T18:17:53.648489Z" } }, "outputs": [], "source": [ "notebook_login(new_session=False)" ] }, { "cell_type": "code", "execution_count": 11, "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:35:15.421761Z", "iopub.status.busy": "2025-01-17T18:35:15.421353Z", "iopub.status.idle": "2025-01-17T18:35:15.433782Z", "shell.execute_reply": "2025-01-17T18:35:15.433001Z", "shell.execute_reply.started": "2025-01-17T18:35:15.421734Z" } }, "outputs": [], "source": [ "def my_print(x):\n", " time_str = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " print(time_str, x)\n", "\n", "\n", "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n", " def __init__(self, num_labels=8, bert_variety=\"bert-base-uncased\"):\n", " super().__init__()\n", " self.bert = BertModel.from_pretrained(bert_variety)\n", " self.dropout = nn.Dropout(0.05)\n", " self.classifier = nn.Linear(self.bert.pooler.dense.out_features, num_labels)\n", "\n", " def forward(self, input_ids, attention_mask):\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " pooled_output = outputs.pooler_output\n", " pooled_output = self.dropout(pooled_output)\n", " logits = self.classifier(pooled_output)\n", " return logits\n", "\n", "\n", "class TextDataset(Dataset):\n", " def __init__(self, texts, labels, tokenizer, max_length=512):\n", " self.encodings = tokenizer(\n", " texts,\n", " truncation=True,\n", " padding=True,\n", " max_length=max_length,\n", " return_tensors=\"pt\",\n", " )\n", " self.labels = torch.tensor([int(l[0]) for l in labels])\n", "\n", " def __getitem__(self, idx):\n", " item = {key: val[idx] for key, val in self.encodings.items()}\n", " item[\"labels\"] = self.labels[idx]\n", " return item\n", "\n", " def __len__(self) -> int:\n", " return len(self.labels)\n", "\n", "\n", "def train_model(model, train_dataloader, device, num_epochs):\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n", " criterion = nn.CrossEntropyLoss()\n", " model.train()\n", "\n", " my_print(\"Starting epoch 1.\")\n", " for epoch in range(num_epochs):\n", " total_loss = 0\n", " for batch in train_dataloader:\n", " optimizer.zero_grad()\n", "\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " outputs = model(input_ids, attention_mask)\n", " loss = criterion(outputs, labels)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_loss += loss.item()\n", " avg_loss = total_loss / len(train_dataloader)\n", " my_print(f\"Epoch {epoch+1}/{num_epochs} done, Average Loss: {avg_loss:0.4f}\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "07131bce-23ad-4787-8622-cce401f3e5ce", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:17:57.885732Z", "iopub.status.busy": "2025-01-17T18:17:57.884455Z", "iopub.status.idle": "2025-01-17T18:17:57.919509Z", "shell.execute_reply": "2025-01-17T18:17:57.919081Z", "shell.execute_reply.started": "2025-01-17T18:17:57.885667Z" } }, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " torch.mps.empty_cache()\n", "elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "695bc080-bbd7-4937-af5b-50db1c936500", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:17:58.556031Z", "iopub.status.busy": "2025-01-17T18:17:58.555349Z", "iopub.status.idle": "2025-01-17T18:17:58.564519Z", "shell.execute_reply": "2025-01-17T18:17:58.563640Z", "shell.execute_reply.started": "2025-01-17T18:17:58.555979Z" } }, "outputs": [], "source": [ "def run_training(\n", " max_dataset_size=16 * 200,\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=200,\n", " num_epochs=3,\n", " batch_size=32,\n", "):\n", " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n", " if not max_dataset_size == \"full\" and max_dataset_size < len(hf_dataset[\"train\"]):\n", " train_dataset = hf_dataset[\"train\"][:max_dataset_size]\n", " else:\n", " train_dataset = hf_dataset[\"train\"]\n", "\n", " tokenizer = BertTokenizer.from_pretrained(bert_variety, max_length=max_length)\n", " model = BertClassifier(bert_variety=bert_variety)\n", " if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " torch.mps.empty_cache()\n", " elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", " else:\n", " device = torch.device(\"cpu\")\n", " model.to(device)\n", "\n", " dataset = TextDataset(\n", " train_dataset[\"quote\"],\n", " train_dataset[\"label\"],\n", " tokenizer=tokenizer,\n", " max_length=max_length,\n", " )\n", " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", "\n", " train_model(model, dataloader, device, num_epochs=num_epochs)\n", " return model, tokenizer" ] }, { "cell_type": "code", "execution_count": 19, "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T15:22:41.286449Z", "iopub.status.busy": "2025-01-17T15:22:41.285811Z", "iopub.status.idle": "2025-01-17T15:24:35.507909Z", "shell.execute_reply": "2025-01-17T15:24:35.506587Z", "shell.execute_reply.started": "2025-01-17T15:22:41.286404Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-17 07:22:44 Starting epoch 1.\n", "2025-01-17 07:23:21 Epoch 1/3 done, Average Loss: 1.8129\n", "2025-01-17 07:23:58 Epoch 2/3 done, Average Loss: 1.3089\n", "2025-01-17 07:24:35 Epoch 3/3 done, Average Loss: 0.8916\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=16 * 100,\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=32,\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T15:24:46.754460Z", "iopub.status.busy": "2025-01-17T15:24:46.753753Z", "iopub.status.idle": "2025-01-17T15:24:47.249458Z", "shell.execute_reply": "2025-01-17T15:24:47.249207Z", "shell.execute_reply.started": "2025-01-17T15:24:46.754391Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-17 07:24:47 Predictions: tensor([0, 1, 3, 6, 2, 3, 6], device='mps:0')\n" ] } ], "source": [ "model.eval()\n", "test_text = [\n", " \"This was a great experience!\", # 0_not_relevant\n", " \"My favorite hike is Laguna de los Tres.\", # 0_not_relevant\n", " \"Crops will grow great in Finland if it's warmer there.\", # 3_not_bad\n", " \"Climate change is fake.\", # 1_not_happening\n", " \"The apparent warming is caused by solar cycles.\", # 2_not_human\n", " \"Solar panels emit bad vibes.\", # 4_solutions_harmful_unnecessary\n", " \"All those so-called scientists are Democrats.\", # 6_proponents_biased\n", "]\n", "test_encoding = tokenizer(\n", " test_text,\n", " truncation=True,\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "with torch.no_grad():\n", " test_input_ids = test_encoding[\"input_ids\"].to(device)\n", " test_attention_mask = test_encoding[\"attention_mask\"].to(device)\n", " outputs = model(test_input_ids, test_attention_mask)\n", " predictions = torch.argmax(outputs, dim=1)\n", " my_print(f\"Predictions: {predictions}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "881b738e-2392-4b7e-a0de-a0bad572ddfa", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T04:47:17.334399Z", "iopub.status.busy": "2025-01-17T04:47:17.334287Z", "iopub.status.idle": "2025-01-17T04:50:59.116389Z", "shell.execute_reply": "2025-01-17T04:50:59.115528Z", "shell.execute_reply.started": "2025-01-17T04:47:17.334390Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-16 20:47:23 Starting epoch 1.\n", "2025-01-16 20:48:35 Epoch 1/3 done, Average Loss: 1.4272\n", "2025-01-16 20:49:46 Epoch 2/3 done, Average Loss: 0.8694\n", "2025-01-16 20:50:59 Epoch 3/3 done, Average Loss: 0.5774\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=64,\n", " num_epochs=3,\n", " batch_size=32,\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "id": "1d29336e-7f88-4127-afdf-2fe043e310e1", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T04:50:59.118025Z", "iopub.status.busy": "2025-01-17T04:50:59.117838Z", "iopub.status.idle": "2025-01-17T04:58:02.423121Z", "shell.execute_reply": "2025-01-17T04:58:02.421532Z", "shell.execute_reply.started": "2025-01-17T04:50:59.118005Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-16 20:51:04 Starting epoch 1.\n", "2025-01-16 20:53:20 Epoch 1/3 done, Average Loss: 1.4107\n", "2025-01-16 20:55:41 Epoch 2/3 done, Average Loss: 0.8491\n", "2025-01-16 20:58:02 Epoch 3/3 done, Average Loss: 0.5359\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=32,\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "id": "461b8f57-0c52-403a-bb69-3bc192b323bf", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T04:58:02.426159Z", "iopub.status.busy": "2025-01-17T04:58:02.425896Z", "iopub.status.idle": "2025-01-17T05:05:36.903446Z", "shell.execute_reply": "2025-01-17T05:05:36.901961Z", "shell.execute_reply.started": "2025-01-17T04:58:02.426132Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-16 20:58:08 Starting epoch 1.\n", "2025-01-16 21:00:38 Epoch 1/3 done, Average Loss: 1.2946\n", "2025-01-16 21:03:07 Epoch 2/3 done, Average Loss: 0.7425\n", "2025-01-16 21:05:36 Epoch 3/3 done, Average Loss: 0.4126\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "28354e8c-886a-4523-8968-8c688c13f6a3", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:35:15.434902Z", "iopub.status.busy": "2025-01-17T18:35:15.434668Z", "iopub.status.idle": "2025-01-17T18:50:43.167167Z", "shell.execute_reply": "2025-01-17T18:50:43.166720Z", "shell.execute_reply.started": "2025-01-17T18:35:15.434880Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-17 10:35:20 Starting epoch 1.\n", "2025-01-17 10:40:29 Epoch 1/3 done, Average Loss: 1.2876\n", "2025-01-17 10:45:37 Epoch 2/3 done, Average Loss: 0.7289\n", "2025-01-17 10:50:43 Epoch 3/3 done, Average Loss: 0.3990\n" ] } ], "source": [ "model, tokenizer = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=256,\n", " num_epochs=3,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "markdown", "id": "982ba556-c589-4cbb-b392-614942a64ab3", "metadata": {}, "source": [ "# Model to upload" ] }, { "cell_type": "code", "execution_count": 6, "id": "ac5f412c-a745-4327-9303-acf4c5b1efcd", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:19:11.590514Z", "iopub.status.busy": "2025-01-17T18:19:11.589753Z", "iopub.status.idle": "2025-01-17T18:26:45.645104Z", "shell.execute_reply": "2025-01-17T18:26:45.644631Z", "shell.execute_reply.started": "2025-01-17T18:19:11.590428Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-17 10:19:17 Starting epoch 1.\n", "2025-01-17 10:21:47 Epoch 1/3 done, Average Loss: 1.2608\n", "2025-01-17 10:24:16 Epoch 2/3 done, Average Loss: 0.7134\n", "2025-01-17 10:26:45 Epoch 3/3 done, Average Loss: 0.3931\n" ] } ], "source": [ "model_final, tokenizer_final = run_training(\n", " max_dataset_size=\"full\",\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "e3b099c6-6b98-473b-8797-5032213b9fcb", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:26:45.646178Z", "iopub.status.busy": "2025-01-17T18:26:45.646081Z", "iopub.status.idle": "2025-01-17T18:26:45.722052Z", "shell.execute_reply": "2025-01-17T18:26:45.721803Z", "shell.execute_reply.started": "2025-01-17T18:26:45.646168Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-17 10:26:45 Predictions: tensor([0, 0, 3, 1, 2, 4, 6], device='mps:0')\n" ] } ], "source": [ "model_final.eval()\n", "test_text = [\n", " \"This was a great experience!\", # 0_not_relevant\n", " \"My favorite hike is Laguna de los Tres.\", # 0_not_relevant\n", " \"Crops will grow great in Finland if it's warmer there.\", # 3_not_bad\n", " \"Climate change is fake.\", # 1_not_happening\n", " \"The apparent warming is caused by solar cycles.\", # 2_not_human\n", " \"Solar panels emit bad vibes.\", # 4_solutions_harmful_unnecessary\n", " \"All those so-called scientists are Democrats.\", # 6_proponents_biased\n", "]\n", "test_encoding = tokenizer_final(\n", " test_text,\n", " truncation=True,\n", " padding=True,\n", " return_tensors=\"pt\",\n", ")\n", "\n", "with torch.no_grad():\n", " test_input_ids = test_encoding[\"input_ids\"].to(device)\n", " test_attention_mask = test_encoding[\"attention_mask\"].to(device)\n", " outputs = model_final(test_input_ids, test_attention_mask)\n", " predictions = torch.argmax(outputs, dim=1)\n", " my_print(f\"Predictions: {predictions}\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0", "metadata": { "execution": { "iopub.execute_input": "2025-01-17T18:32:40.094019Z", "iopub.status.busy": "2025-01-17T18:32:40.093429Z", "iopub.status.idle": "2025-01-17T18:35:15.419578Z", "shell.execute_reply": "2025-01-17T18:35:15.418848Z", "shell.execute_reply.started": "2025-01-17T18:32:40.093970Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7dd2d0eb08624920b345ca85712f0169", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/438M [00:00