{ "cells": [ { "cell_type": "markdown", "id": "33faae25-af36-4781-bf8f-2084ddc96a52", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": 17, "id": "73e72549-69f2-46b5-b0f5-655777139972", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:59:00.459773Z", "iopub.status.busy": "2025-01-24T18:59:00.458472Z", "iopub.status.idle": "2025-01-24T18:59:00.517418Z", "shell.execute_reply": "2025-01-24T18:59:00.517026Z", "shell.execute_reply.started": "2025-01-24T18:59:00.459726Z" } }, "outputs": [], "source": [ "from datetime import datetime\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "from transformers import BertTokenizer, BertModel\n", "from huggingface_hub import (\n", " PyTorchModelHubMixin,\n", " notebook_login,\n", " ModelCard,\n", " ModelCardData,\n", " EvalResult,\n", ")\n", "from datasets import DatasetDict, load_dataset\n", "from torch.utils.data import Dataset, DataLoader\n", "from statsmodels.stats.proportion import proportion_confint" ] }, { "cell_type": "code", "execution_count": 2, "id": "07e0787e-c72b-41f3-baba-43cef3f8d6f8", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:22:01.628023Z", "iopub.status.busy": "2025-01-24T18:22:01.627838Z", "iopub.status.idle": "2025-01-24T18:22:01.629825Z", "shell.execute_reply": "2025-01-24T18:22:01.629635Z", "shell.execute_reply.started": "2025-01-24T18:22:01.628013Z" } }, "outputs": [], "source": [ "notebook_login(new_session=False)" ] }, { "cell_type": "markdown", "id": "a919d72c-8d10-4275-a2ca-4ead295f41a8", "metadata": {}, "source": [ "# Functions" ] }, { "cell_type": "code", "execution_count": 12, "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:23:58.768682Z", "iopub.status.busy": "2025-01-24T18:23:58.768083Z", "iopub.status.idle": "2025-01-24T18:23:58.787548Z", "shell.execute_reply": "2025-01-24T18:23:58.786993Z", "shell.execute_reply.started": "2025-01-24T18:23:58.768631Z" } }, "outputs": [], "source": [ "def my_print(x):\n", " time_str = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n", " print(time_str, x)\n", "\n", "\n", "def model_metrics(model, dataloader):\n", " criterion = nn.CrossEntropyLoss()\n", " model.eval()\n", " with torch.no_grad():\n", " total_loss = 0\n", " total_correct = 0\n", " total_length = 0\n", " for batch in dataloader:\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " outputs = model(input_ids, attention_mask)\n", " loss = criterion(outputs, labels)\n", " predictions_cpu = torch.argmax(outputs, dim=1).cpu().numpy()\n", " labels_cpu = labels.cpu().numpy()\n", " correct_count = (predictions_cpu == labels_cpu).sum()\n", "\n", " total_loss += loss.item()\n", " total_correct += correct_count\n", " total_length += len(labels_cpu)\n", " avg_loss = total_loss / len(dataloader)\n", " avg_acc = total_correct / total_length\n", " model.train()\n", " return float(avg_loss), float(avg_acc)\n", "\n", "\n", "def print_model_status(epoch, num_epochs, model, train_dataloader, test_dataloader):\n", " train_loss, train_acc = model_metrics(model, train_dataloader)\n", " test_loss, test_acc = model_metrics(model, test_dataloader)\n", " loss_str = f\"Loss: Train {train_loss:0.3f}, Test {test_loss:0.3f}\"\n", " acc_str = f\"Acc: Train {train_acc:0.3f}, Test {test_acc:0.3f}\"\n", " my_print(f\"Epoch {epoch+1:2}/{num_epochs} done. {loss_str}; and {acc_str}\")\n", " metrics = dict(\n", " train_loss=train_loss,\n", " train_acc=train_acc,\n", " test_loss=test_loss,\n", " test_acc=test_acc,\n", " )\n", " return metrics\n", "\n", "\n", "class BertClassifier(nn.Module, PyTorchModelHubMixin):\n", " def __init__(self, num_labels=8, bert_variety=\"bert-base-uncased\"):\n", " super().__init__()\n", " self.bert = BertModel.from_pretrained(bert_variety)\n", " self.config = self.bert.config\n", " self.config.num_labels = num_labels\n", " self.dropout = nn.Dropout(0.05)\n", " self.classifier = nn.Linear(self.bert.pooler.dense.out_features, num_labels)\n", "\n", " def forward(self, input_ids, attention_mask):\n", " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", " pooled_output = outputs.pooler_output\n", " pooled_output = self.dropout(pooled_output)\n", " logits = self.classifier(pooled_output)\n", " return logits\n", "\n", "\n", "class TextDataset(Dataset):\n", " def __init__(self, texts, labels, tokenizer, max_length=256):\n", " self.texts = texts\n", " self.encodings = tokenizer(\n", " texts,\n", " truncation=True,\n", " padding=True,\n", " max_length=max_length,\n", " return_tensors=\"pt\",\n", " )\n", " self.labels = torch.tensor([int(l[0]) for l in labels])\n", "\n", " def __getitem__(self, idx):\n", " item = {key: val[idx] for key, val in self.encodings.items()}\n", " item[\"labels\"] = self.labels[idx]\n", " return item\n", "\n", " def __len__(self) -> int:\n", " return len(self.labels)\n", "\n", "\n", "def train_model(model, train_dataloader, test_dataloader, device, num_epochs):\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n", " criterion = nn.CrossEntropyLoss()\n", " model.train()\n", "\n", " _ = print_model_status(-1, num_epochs, model, train_dataloader, test_dataloader)\n", " for epoch in range(num_epochs):\n", " total_loss = 0\n", " for batch in train_dataloader:\n", " optimizer.zero_grad()\n", "\n", " input_ids = batch[\"input_ids\"].to(device)\n", " attention_mask = batch[\"attention_mask\"].to(device)\n", " labels = batch[\"labels\"].to(device)\n", "\n", " outputs = model(input_ids, attention_mask)\n", " loss = criterion(outputs, labels)\n", "\n", " loss.backward()\n", " optimizer.step()\n", "\n", " total_loss += loss.item()\n", " avg_loss = total_loss / len(train_dataloader)\n", " metrics = print_model_status(\n", " epoch, num_epochs, model, train_dataloader, test_dataloader\n", " )\n", " return metrics" ] }, { "cell_type": "code", "execution_count": 13, "id": "07131bce-23ad-4787-8622-cce401f3e5ce", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:23:59.127835Z", "iopub.status.busy": "2025-01-24T18:23:59.126787Z", "iopub.status.idle": "2025-01-24T18:23:59.136440Z", "shell.execute_reply": "2025-01-24T18:23:59.135267Z", "shell.execute_reply.started": "2025-01-24T18:23:59.127791Z" } }, "outputs": [], "source": [ "if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " torch.mps.empty_cache()\n", "elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "695bc080-bbd7-4937-af5b-50db1c936500", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:23:59.442432Z", "iopub.status.busy": "2025-01-24T18:23:59.441786Z", "iopub.status.idle": "2025-01-24T18:23:59.453218Z", "shell.execute_reply": "2025-01-24T18:23:59.452473Z", "shell.execute_reply.started": "2025-01-24T18:23:59.442367Z" } }, "outputs": [], "source": [ "def run_training(\n", " max_dataset_size=16 * 200,\n", " bert_variety=\"bert-base-uncased\",\n", " max_length=256,\n", " num_epochs=3,\n", " batch_size=32,\n", "):\n", " training_regime = dict(\n", " max_dataset_size=max_dataset_size,\n", " bert_variety=bert_variety,\n", " max_length=max_length,\n", " num_epochs=num_epochs,\n", " batch_size=batch_size,\n", " )\n", " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n", " test_size = 0.2\n", " test_seed = 42\n", " train_test = hf_dataset[\"train\"].train_test_split(\n", " test_size=test_size, seed=test_seed\n", " )\n", " train_dataset = train_test[\"train\"]\n", " test_dataset = train_test[\"test\"]\n", " if not max_dataset_size == \"full\" and max_dataset_size < len(hf_dataset[\"train\"]):\n", " train_dataset = train_dataset[:max_dataset_size]\n", " test_dataset = test_dataset[:max_dataset_size]\n", " else:\n", " train_dataset = train_dataset\n", " test_dataset = test_dataset\n", "\n", " tokenizer = BertTokenizer.from_pretrained(bert_variety, max_length=max_length)\n", " model = BertClassifier(bert_variety=bert_variety)\n", " if torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", " torch.mps.empty_cache()\n", " elif torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", " else:\n", " device = torch.device(\"cpu\")\n", " model.to(device)\n", "\n", " text_dataset_train = TextDataset(\n", " train_dataset[\"quote\"],\n", " train_dataset[\"label\"],\n", " tokenizer=tokenizer,\n", " max_length=max_length,\n", " )\n", " text_dataset_test = TextDataset(\n", " test_dataset[\"quote\"],\n", " test_dataset[\"label\"],\n", " tokenizer=tokenizer,\n", " max_length=max_length,\n", " )\n", " dataloader_train = DataLoader(\n", " text_dataset_train, batch_size=batch_size, shuffle=True\n", " )\n", " dataloader_test = DataLoader(\n", " text_dataset_test, batch_size=batch_size, shuffle=False\n", " )\n", "\n", " metrics = train_model(\n", " model, dataloader_train, dataloader_test, device, num_epochs=num_epochs\n", " )\n", " return model, tokenizer, training_regime, metrics" ] }, { "cell_type": "markdown", "id": "5af751f3-1fc4-4540-ae25-638db9d33c67", "metadata": {}, "source": [ "# Exploration" ] }, { "cell_type": "code", "execution_count": 15, "id": "11890d3b-8bcb-4a9b-b421-5431081cca39", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:24:00.153856Z", "iopub.status.busy": "2025-01-24T18:24:00.153044Z", "iopub.status.idle": "2025-01-24T18:24:00.158876Z", "shell.execute_reply": "2025-01-24T18:24:00.157762Z", "shell.execute_reply.started": "2025-01-24T18:24:00.153804Z" } }, "outputs": [], "source": [ "base_model_repo = \"google/bert_uncased_L-12_H-768_A-12\"\n", "model_and_repo_name = \"frugal-ai-text-bert-base\"" ] }, { "cell_type": "markdown", "id": "a847135f-ce86-46a1-9c61-3459a847cb29", "metadata": { "execution": { "iopub.execute_input": "2025-01-20T19:13:05.482383Z", "iopub.status.busy": "2025-01-20T19:13:05.481449Z", "iopub.status.idle": "2025-01-20T19:13:05.487546Z", "shell.execute_reply": "2025-01-20T19:13:05.486557Z", "shell.execute_reply.started": "2025-01-20T19:13:05.482339Z" } }, "source": [ "## Check if runs" ] }, { "cell_type": "code", "execution_count": 16, "id": "34a7c310-c486-4db1-b94d-4363c3d3df5b", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T18:24:00.721937Z", "iopub.status.busy": "2025-01-24T18:24:00.721190Z", "iopub.status.idle": "2025-01-24T18:24:06.157768Z", "shell.execute_reply": "2025-01-24T18:24:06.157299Z", "shell.execute_reply.started": "2025-01-24T18:24:00.721894Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4872 1219\n", "8 8\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[16], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model, tokenizer, regime, metrics \u001b[38;5;241m=\u001b[39m \u001b[43mrun_training\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_dataset_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mbert_variety\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbase_model_repo\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m16\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[14], line 62\u001b[0m, in \u001b[0;36mrun_training\u001b[0;34m(max_dataset_size, bert_variety, max_length, num_epochs, batch_size)\u001b[0m\n\u001b[1;32m 55\u001b[0m dataloader_train \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 56\u001b[0m text_dataset_train, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 57\u001b[0m )\n\u001b[1;32m 58\u001b[0m dataloader_test \u001b[38;5;241m=\u001b[39m DataLoader(\n\u001b[1;32m 59\u001b[0m text_dataset_test, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 60\u001b[0m )\n\u001b[0;32m---> 62\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_test\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_epochs\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model, tokenizer, training_regime, metrics\n", "Cell \u001b[0;32mIn[12], line 91\u001b[0m, in \u001b[0;36mtrain_model\u001b[0;34m(model, train_dataloader, test_dataloader, device, num_epochs)\u001b[0m\n\u001b[1;32m 88\u001b[0m criterion \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mCrossEntropyLoss()\n\u001b[1;32m 89\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m---> 91\u001b[0m _ \u001b[38;5;241m=\u001b[39m \u001b[43mprint_model_status\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_dataloader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_epochs):\n\u001b[1;32m 93\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n", "Cell \u001b[0;32mIn[12], line 34\u001b[0m, in \u001b[0;36mprint_model_status\u001b[0;34m(epoch, num_epochs, model, train_dataloader, test_dataloader)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprint_model_status\u001b[39m(epoch, num_epochs, model, train_dataloader, test_dataloader):\n\u001b[0;32m---> 34\u001b[0m train_loss, train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_metrics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m test_loss, test_acc \u001b[38;5;241m=\u001b[39m model_metrics(model, test_dataloader)\n\u001b[1;32m 36\u001b[0m loss_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoss: Train \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m0.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Test \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m0.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n", "Cell \u001b[0;32mIn[12], line 20\u001b[0m, in \u001b[0;36mmodel_metrics\u001b[0;34m(model, dataloader)\u001b[0m\n\u001b[1;32m 18\u001b[0m outputs \u001b[38;5;241m=\u001b[39m model(input_ids, attention_mask)\n\u001b[1;32m 19\u001b[0m loss \u001b[38;5;241m=\u001b[39m criterion(outputs, labels)\n\u001b[0;32m---> 20\u001b[0m predictions_cpu \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 21\u001b[0m labels_cpu \u001b[38;5;241m=\u001b[39m labels\u001b[38;5;241m.\u001b[39mcpu()\u001b[38;5;241m.\u001b[39mnumpy()\n\u001b[1;32m 22\u001b[0m correct_count \u001b[38;5;241m=\u001b[39m (predictions_cpu \u001b[38;5;241m==\u001b[39m labels_cpu)\u001b[38;5;241m.\u001b[39msum()\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "model, tokenizer, regime, metrics = run_training(\n", " max_dataset_size=16 * 100,\n", " bert_variety=base_model_repo,\n", " max_length=128,\n", " num_epochs=3,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc", "metadata": { "editable": true, "execution": { "iopub.status.busy": "2025-01-24T18:24:06.157956Z", "iopub.status.idle": "2025-01-24T18:24:06.158060Z", "shell.execute_reply": "2025-01-24T18:24:06.158008Z", "shell.execute_reply.started": "2025-01-24T18:24:06.158002Z" }, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "model.eval()\n", "test_text = [\n", " \"This was a great experience!\", # 0_not_relevant\n", " \"My favorite hike is Laguna de los Tres.\", # 0_not_relevant\n", " \"Crops will grow great in Finland if it's warmer there.\", # 3_not_bad\n", " \"Climate change is fake.\", # 1_not_happening\n", " \"The apparent warming is caused by solar cycles.\", # 2_not_human\n", " \"Solar panels emit bad vibes.\", # 4_solutions_harmful_unnecessary\n", " \"All those so-called scientists are Democrats.\", # 6_proponents_biased\n", "]\n", "test_encoding = tokenizer(\n", " test_text,\n", " truncation=True,\n", " padding=True,\n", " return_tensors=\"pt\",\n", " max_length=256,\n", ")\n", "\n", "with torch.no_grad():\n", " test_input_ids = test_encoding[\"input_ids\"].to(device)\n", " test_attention_mask = test_encoding[\"attention_mask\"].to(device)\n", " outputs = model(test_input_ids, test_attention_mask)\n", " predictions = torch.argmax(outputs, dim=1)\n", " my_print(f\"Predictions: {predictions}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "1201bf29-5040-4317-be30-77bec0bfe5b4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "0c3ea938-dd87-4673-b1d6-f06c70b19455", "metadata": {}, "source": [ "## Hyperparameters" ] }, { "cell_type": "markdown", "id": "6264418d-10ef-4eca-b188-2b6b7f487797", "metadata": {}, "source": [ "Overall top performance per model. Machine: bert-base is using an Nvidia 1xL40S, no inference time cleaverness attempted.\n", "\n", "[accidentally cheating bert-base by trainging on full dataset](https://huggingface.co/datasets/frugal-ai-challenge/public-leaderboard-text/blob/main/submissions/Nonnormalizable_20250117_220350.json):\\\n", "acc 0.954, energy 0.736 Wh\n", "\n", "[bert-base some hp tuning](https://huggingface.co/datasets/frugal-ai-challenge/public-leaderboard-text/blob/main/submissions/Nonnormalizable_20250120_231350.json):\\\n", "acc 0.707, energy 0.803 Wh\n", "\n", "Added normal data loader, batch size 32. Moved to Nvidia T4 small.\n", "\n", "bert-tiny\\\n", "acc 0.618, energy 0.079 Wh\n", "\n", "bert-mini\\\n", "acc 0.650, energy 0.129 Wh\n", "\n", "bert-small\\\n", "acc 0.656, energy 0.256 Wh\n", "\n", "bert-medium\\\n", "acc 0.645, energy 0.273 Wh\n", "\n", "bert-base\\\n", "acc 0.691, energy 1.053 Wh" ] }, { "cell_type": "code", "execution_count": 23, "id": "6c35f222-79d9-4166-8601-8a6240a49c91", "metadata": { "execution": { "iopub.execute_input": "2025-01-24T19:03:41.276772Z", "iopub.status.busy": "2025-01-24T19:03:41.276125Z", "iopub.status.idle": "2025-01-24T19:03:41.284530Z", "shell.execute_reply": "2025-01-24T19:03:41.283079Z", "shell.execute_reply.started": "2025-01-24T19:03:41.276731Z" } }, "outputs": [ { "data": { "text/plain": [ "(0.6284344081642794, 0.6817389605903139)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nobs = 1219\n", "acc = 0.656\n", "proportion_confint(\n", " count=int(nobs * acc),\n", " nobs=nobs,\n", " method=\"jeffreys\",\n", ")" ] }, { "cell_type": "markdown", "id": "df067c27-9d58-49fc-860d-ba79e5512013", "metadata": {}, "source": [ "Looking at bert-tiny.\n", "Scanning max_length and batch_size with num_epochs set to 3, looks like we want 256 and 16. That gets us\\\n", "`2025-01-21 10:18:56 Epoch 3/3 done. Loss: Train 1.368, Test 1.432; and Acc: Train 0.499, Test 0.477`.\n", "\n", "Then looking at num_epochs, we saturate test set performance at 15 (~3 minutes), giving e.g.\\\n", "`2025-01-21 10:38:30 Epoch 15/20 done. Loss: Train 0.553, Test 1.157; and Acc: Train 0.833, Test 0.595`\n", "\n", "For bert-mini, just looking at num_epochs, we choose 8\\\n", "`2025-01-22 10:56:12 Epoch 8/20 done. Loss: Train 0.305, Test 1.090; and Acc: Train 0.920, Test 0.646`\n", "\n", "For bert-small, 4\\\n", "`2025-01-22 11:39:41 Epoch 4/15 done. Loss: Train 0.301, Test 0.978; and Acc: Train 0.920, Test 0.664`\n", "\n", "For bert-medium, 4\\\n", "`2025-01-22 12:09:51 Epoch 4/10 done. Loss: Train 0.294, Test 1.020; and Acc: Train 0.922, Test 0.660`\n", "\n", "For bert-base, 3 does happen to be correct, just checking for completeness\\\n", "`2025-01-22 12:59:10 Epoch 3/7 done. Loss: Train 0.156, Test 0.930; and Acc: Train 0.964, Test 0.703`" ] }, { "cell_type": "code", "execution_count": 9, "id": "37794952-703c-466c-9d26-ee6cb2834246", "metadata": { "execution": { "iopub.execute_input": "2025-01-22T18:19:34.065427Z", "iopub.status.busy": "2025-01-22T18:19:34.065327Z", "iopub.status.idle": "2025-01-22T18:19:34.066925Z", "shell.execute_reply": "2025-01-22T18:19:34.066714Z", "shell.execute_reply.started": "2025-01-22T18:19:34.065418Z" } }, "outputs": [], "source": [ "static_hyperparams = dict(\n", " max_dataset_size=\"full\",\n", " bert_variety=base_model_repo,\n", " max_length=256,\n", " batch_size=16,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "id": "28354e8c-886a-4523-8968-8c688c13f6a3", "metadata": { "execution": { "iopub.execute_input": "2025-01-22T18:19:34.067286Z", "iopub.status.busy": "2025-01-22T18:19:34.067206Z", "iopub.status.idle": "2025-01-22T18:38:14.108104Z", "shell.execute_reply": "2025-01-22T18:38:14.107193Z", "shell.execute_reply.started": "2025-01-22T18:19:34.067278Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-22 13:21:10 Epoch 0/3 done. Loss: Train 2.088, Test 2.085; and Acc: Train 0.137, Test 0.135\n", "2025-01-22 13:26:50 Epoch 1/3 done. Loss: Train 0.780, Test 1.012; and Acc: Train 0.747, Test 0.648\n", "2025-01-22 13:32:30 Epoch 2/3 done. Loss: Train 0.346, Test 0.890; and Acc: Train 0.904, Test 0.689\n", "2025-01-22 13:38:14 Epoch 3/3 done. Loss: Train 0.167, Test 0.968; and Acc: Train 0.959, Test 0.691\n" ] } ], "source": [ "model, tokenizer, training_regime, testing_metrics = run_training(\n", " **static_hyperparams,\n", " num_epochs=3,\n", ")" ] }, { "cell_type": "markdown", "id": "982ba556-c589-4cbb-b392-614942a64ab3", "metadata": {}, "source": [ "# Model to upload" ] }, { "cell_type": "code", "execution_count": 11, "id": "ec2516f9-79f2-4ae1-ab9a-9a51a7a50587", "metadata": { "execution": { "iopub.execute_input": "2025-01-22T18:38:14.109094Z", "iopub.status.busy": "2025-01-22T18:38:14.108996Z", "iopub.status.idle": "2025-01-22T18:38:14.124982Z", "shell.execute_reply": "2025-01-22T18:38:14.124768Z", "shell.execute_reply.started": "2025-01-22T18:38:14.109081Z" }, "scrolled": true }, "outputs": [], "source": [ "card_data = ModelCardData(\n", " model_name=model_and_repo_name,\n", " base_model=static_hyperparams[\"bert_variety\"],\n", " license=\"apache-2.0\",\n", " language=[\"en\"],\n", " datasets=[\"QuotaClimat/frugalaichallenge-text-train\"],\n", " tags=[\"model_hub_mixin\", \"pytorch_model_hub_mixin\", \"climate\"],\n", " pipeline_tag=\"text-classification\",\n", ")\n", "card = ModelCard.from_template(\n", " card_data,\n", " model_summary=f\"Classify text into 8 categories of climate misinformation using {base_model_repo}.\",\n", " model_description=\"Fine trained BERT for classifying climate information as part of the Frugal AI Challenge, for submission to https://huggingface.co/frugal-ai-challenge and scoring on accuracy and efficiency. Trainied on only the non-evaluation 80% of the data, so it's (non-cheating) score will be lower.\",\n", " developers=\"Andre Bach\",\n", " funded_by=\"N/A\",\n", " shared_by=\"Andre Bach\",\n", " model_type=\"Text classification\",\n", " repo=model_and_repo_name,\n", " training_regime=training_regime,\n", " testing_metrics=testing_metrics,\n", ")\n", "# print(card_data.to_yaml())\n", "# print(card)" ] }, { "cell_type": "code", "execution_count": 12, "id": "29d3bbf9-ab2a-48e2-a550-e16da5025720", "metadata": { "execution": { "iopub.execute_input": "2025-01-22T18:38:14.125523Z", "iopub.status.busy": "2025-01-22T18:38:14.125395Z", "iopub.status.idle": "2025-01-22T18:38:14.126978Z", "shell.execute_reply": "2025-01-22T18:38:14.126771Z", "shell.execute_reply.started": "2025-01-22T18:38:14.125514Z" } }, "outputs": [], "source": [ "model_final = model\n", "tokenizer_final = tokenizer" ] }, { "cell_type": "code", "execution_count": 13, "id": "e3b099c6-6b98-473b-8797-5032213b9fcb", "metadata": { "execution": { "iopub.execute_input": "2025-01-22T18:38:14.127531Z", "iopub.status.busy": "2025-01-22T18:38:14.127415Z", "iopub.status.idle": "2025-01-22T18:38:14.157055Z", "shell.execute_reply": "2025-01-22T18:38:14.156821Z", "shell.execute_reply.started": "2025-01-22T18:38:14.127524Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2025-01-22 13:38:14 Predictions: tensor([0, 0, 3, 1, 2, 4, 6], device='mps:0')\n" ] } ], "source": [ "model_final.eval()\n", "test_text = [\n", " \"This was a great experience!\", # 0_not_relevant\n", " \"My favorite hike is Laguna de los Tres.\", # 0_not_relevant\n", " \"Crops will grow great in Finland if it's warmer there.\", # 3_not_bad\n", " \"Climate change is fake.\", # 1_not_happening\n", " \"The apparent warming is caused by solar cycles.\", # 2_not_human\n", " \"Solar panels emit bad vibes.\", # 4_solutions_harmful_unnecessary\n", " \"All those so-called scientists are Democrats.\", # 6_proponents_biased\n", "]\n", "test_encoding = tokenizer_final(\n", " test_text,\n", " truncation=True,\n", " padding=True,\n", " return_tensors=\"pt\",\n", " max_length=256,\n", ")\n", "\n", "with torch.no_grad():\n", " test_input_ids = test_encoding[\"input_ids\"].to(device)\n", " test_attention_mask = test_encoding[\"attention_mask\"].to(device)\n", " outputs = model_final(test_input_ids, test_attention_mask)\n", " predictions = torch.argmax(outputs, dim=1)\n", " my_print(f\"Predictions: {predictions}\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0", "metadata": { "execution": { "iopub.execute_input": "2025-01-22T18:38:14.157429Z", "iopub.status.busy": "2025-01-22T18:38:14.157356Z", "iopub.status.idle": "2025-01-22T18:38:53.948196Z", "shell.execute_reply": "2025-01-22T18:38:53.947738Z", "shell.execute_reply.started": "2025-01-22T18:38:14.157421Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "54e4f39d398f45ceb760107e5b57744a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/438M [00:00