{ "cells": [ { "cell_type": "markdown", "id": "19817716", "metadata": {}, "source": [ "# Question Answering\n", "The following notebook contains different question answering models. We will start by introducing a representation for the dataset and corresponding DataLoader and then evaluate different models." ] }, { "cell_type": "code", "execution_count": 50, "id": "49bf46c6", "metadata": {}, "outputs": [], "source": [ "from transformers import DistilBertModel, DistilBertForMaskedLM, DistilBertConfig, \\\n", " DistilBertTokenizerFast, AutoTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast, BertConfig\n", "from torch import nn\n", "from pathlib import Path\n", "import torch\n", "import pandas as pd\n", "from typing import Optional \n", "from tqdm.auto import tqdm\n", "from util import eval_test_set, count_parameters\n", "from torch.optim import AdamW, RMSprop\n", "\n", "\n", "from qa_model import QuestionDistilBERT, SimpleQuestionDistilBERT, ReuseQuestionDistilBERT, Dataset, test_model" ] }, { "cell_type": "markdown", "id": "3ea47820", "metadata": {}, "source": [ "## Data\n", "Processing the data correctly is partly based on the Huggingface Tutorial (https://huggingface.co/course/chapter7/7?fw=pt)" ] }, { "cell_type": "code", "execution_count": 51, "id": "7b1b2b3e", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')" ] }, { "cell_type": "code", "execution_count": 52, "id": "f276eba7", "metadata": { "scrolled": false }, "outputs": [], "source": [ " \n", "# create datasets and loaders for training and test set\n", "squad_paths = [str(x) for x in Path('data/training_squad/').glob('**/*.txt')]\n", "nat_paths = [str(x) for x in Path('data/natural_questions_train/').glob('**/*.txt')]\n", "hotpotqa_paths = [str(x) for x in Path('data/hotpotqa_training/').glob('**/*.txt')]" ] }, { "cell_type": "markdown", "id": "ad8d532a", "metadata": {}, "source": [ "## POC Model\n", "* Works very well:\n", " * Dropout 0.1 is too small (overfitting after first epoch) - changed to 0.15\n", " * Difference between AdamW and RMSprop minimal\n", " \n", "### Results:\n", "Dropout = 0.15\n", "* Mean EM: 0.5374\n", "* Mean F-1: 0.6826317532406944\n", "\n", "Dropout = 0.2 (overfitting realtively similar to first, but seems to be too high)\n", "* Mean EM: 0.5044\n", "* Mean F-1: 0.6437359169276439" ] }, { "cell_type": "code", "execution_count": 54, "id": "703e7f38", "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n", "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n", "\n", "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n", " natural_question_paths=None, \n", " hotpotqa_paths = None, tokenizer=tokenizer)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)" ] }, { "cell_type": "code", "execution_count": 55, "id": "6672f614", "metadata": {}, "outputs": [], "source": [ "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n", "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n", "mod = model.distilbert" ] }, { "cell_type": "code", "execution_count": 56, "id": "dec15198", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SimpleQuestionDistilBERT(\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (2): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (3): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (4): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (5): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=2, bias=True)\n", ")" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model = SimpleQuestionDistilBERT(mod)\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 57, "id": "9def3c83", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------------------------------------------------------+------------+\n", "| Modules | Parameters |\n", "+---------------------------------------------------------+------------+\n", "| distilbert.embeddings.word_embeddings.weight | 23440896 |\n", "| distilbert.embeddings.position_embeddings.weight | 393216 |\n", "| distilbert.embeddings.LayerNorm.weight | 768 |\n", "| distilbert.embeddings.LayerNorm.bias | 768 |\n", "| distilbert.transformer.layer.0.attention.q_lin.weight | 589824 |\n", "| distilbert.transformer.layer.0.attention.q_lin.bias | 768 |\n", "| distilbert.transformer.layer.0.attention.k_lin.weight | 589824 |\n", "| distilbert.transformer.layer.0.attention.k_lin.bias | 768 |\n", "| distilbert.transformer.layer.0.attention.v_lin.weight | 589824 |\n", "| distilbert.transformer.layer.0.attention.v_lin.bias | 768 |\n", "| distilbert.transformer.layer.0.attention.out_lin.weight | 589824 |\n", "| distilbert.transformer.layer.0.attention.out_lin.bias | 768 |\n", "| distilbert.transformer.layer.0.sa_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.0.sa_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.0.ffn.lin1.weight | 2359296 |\n", "| distilbert.transformer.layer.0.ffn.lin1.bias | 3072 |\n", "| distilbert.transformer.layer.0.ffn.lin2.weight | 2359296 |\n", "| distilbert.transformer.layer.0.ffn.lin2.bias | 768 |\n", "| distilbert.transformer.layer.0.output_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.0.output_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.1.attention.q_lin.weight | 589824 |\n", "| distilbert.transformer.layer.1.attention.q_lin.bias | 768 |\n", "| distilbert.transformer.layer.1.attention.k_lin.weight | 589824 |\n", "| distilbert.transformer.layer.1.attention.k_lin.bias | 768 |\n", "| distilbert.transformer.layer.1.attention.v_lin.weight | 589824 |\n", "| distilbert.transformer.layer.1.attention.v_lin.bias | 768 |\n", "| distilbert.transformer.layer.1.attention.out_lin.weight | 589824 |\n", "| distilbert.transformer.layer.1.attention.out_lin.bias | 768 |\n", "| distilbert.transformer.layer.1.sa_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.1.sa_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.1.ffn.lin1.weight | 2359296 |\n", "| distilbert.transformer.layer.1.ffn.lin1.bias | 3072 |\n", "| distilbert.transformer.layer.1.ffn.lin2.weight | 2359296 |\n", "| distilbert.transformer.layer.1.ffn.lin2.bias | 768 |\n", "| distilbert.transformer.layer.1.output_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.1.output_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.2.attention.q_lin.weight | 589824 |\n", "| distilbert.transformer.layer.2.attention.q_lin.bias | 768 |\n", "| distilbert.transformer.layer.2.attention.k_lin.weight | 589824 |\n", "| distilbert.transformer.layer.2.attention.k_lin.bias | 768 |\n", "| distilbert.transformer.layer.2.attention.v_lin.weight | 589824 |\n", "| distilbert.transformer.layer.2.attention.v_lin.bias | 768 |\n", "| distilbert.transformer.layer.2.attention.out_lin.weight | 589824 |\n", "| distilbert.transformer.layer.2.attention.out_lin.bias | 768 |\n", "| distilbert.transformer.layer.2.sa_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.2.sa_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.2.ffn.lin1.weight | 2359296 |\n", "| distilbert.transformer.layer.2.ffn.lin1.bias | 3072 |\n", "| distilbert.transformer.layer.2.ffn.lin2.weight | 2359296 |\n", "| distilbert.transformer.layer.2.ffn.lin2.bias | 768 |\n", "| distilbert.transformer.layer.2.output_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.2.output_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.3.attention.q_lin.weight | 589824 |\n", "| distilbert.transformer.layer.3.attention.q_lin.bias | 768 |\n", "| distilbert.transformer.layer.3.attention.k_lin.weight | 589824 |\n", "| distilbert.transformer.layer.3.attention.k_lin.bias | 768 |\n", "| distilbert.transformer.layer.3.attention.v_lin.weight | 589824 |\n", "| distilbert.transformer.layer.3.attention.v_lin.bias | 768 |\n", "| distilbert.transformer.layer.3.attention.out_lin.weight | 589824 |\n", "| distilbert.transformer.layer.3.attention.out_lin.bias | 768 |\n", "| distilbert.transformer.layer.3.sa_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.3.sa_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.3.ffn.lin1.weight | 2359296 |\n", "| distilbert.transformer.layer.3.ffn.lin1.bias | 3072 |\n", "| distilbert.transformer.layer.3.ffn.lin2.weight | 2359296 |\n", "| distilbert.transformer.layer.3.ffn.lin2.bias | 768 |\n", "| distilbert.transformer.layer.3.output_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.3.output_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.4.attention.q_lin.weight | 589824 |\n", "| distilbert.transformer.layer.4.attention.q_lin.bias | 768 |\n", "| distilbert.transformer.layer.4.attention.k_lin.weight | 589824 |\n", "| distilbert.transformer.layer.4.attention.k_lin.bias | 768 |\n", "| distilbert.transformer.layer.4.attention.v_lin.weight | 589824 |\n", "| distilbert.transformer.layer.4.attention.v_lin.bias | 768 |\n", "| distilbert.transformer.layer.4.attention.out_lin.weight | 589824 |\n", "| distilbert.transformer.layer.4.attention.out_lin.bias | 768 |\n", "| distilbert.transformer.layer.4.sa_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.4.sa_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.4.ffn.lin1.weight | 2359296 |\n", "| distilbert.transformer.layer.4.ffn.lin1.bias | 3072 |\n", "| distilbert.transformer.layer.4.ffn.lin2.weight | 2359296 |\n", "| distilbert.transformer.layer.4.ffn.lin2.bias | 768 |\n", "| distilbert.transformer.layer.4.output_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.4.output_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.5.attention.q_lin.weight | 589824 |\n", "| distilbert.transformer.layer.5.attention.q_lin.bias | 768 |\n", "| distilbert.transformer.layer.5.attention.k_lin.weight | 589824 |\n", "| distilbert.transformer.layer.5.attention.k_lin.bias | 768 |\n", "| distilbert.transformer.layer.5.attention.v_lin.weight | 589824 |\n", "| distilbert.transformer.layer.5.attention.v_lin.bias | 768 |\n", "| distilbert.transformer.layer.5.attention.out_lin.weight | 589824 |\n", "| distilbert.transformer.layer.5.attention.out_lin.bias | 768 |\n", "| distilbert.transformer.layer.5.sa_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.5.sa_layer_norm.bias | 768 |\n", "| distilbert.transformer.layer.5.ffn.lin1.weight | 2359296 |\n", "| distilbert.transformer.layer.5.ffn.lin1.bias | 3072 |\n", "| distilbert.transformer.layer.5.ffn.lin2.weight | 2359296 |\n", "| distilbert.transformer.layer.5.ffn.lin2.bias | 768 |\n", "| distilbert.transformer.layer.5.output_layer_norm.weight | 768 |\n", "| distilbert.transformer.layer.5.output_layer_norm.bias | 768 |\n", "| classifier.weight | 1536 |\n", "| classifier.bias | 2 |\n", "+---------------------------------------------------------+------------+\n", "Total Trainable Params: 66364418\n" ] }, { "data": { "text/plain": [ "66364418" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_parameters(model)" ] }, { "cell_type": "markdown", "id": "426a6311", "metadata": {}, "source": [ "### Testing the model" ] }, { "cell_type": "code", "execution_count": 58, "id": "6151c201", "metadata": {}, "outputs": [], "source": [ "# get smaller dataset\n", "batch_size = 8\n", "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n", "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n", "optim = RMSprop(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": 59, "id": "aeae0c56", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Passed\n" ] } ], "source": [ "test_model(model, optim, test_ds_loader, device)" ] }, { "cell_type": "markdown", "id": "59928d34", "metadata": {}, "source": [ "### Model Training" ] }, { "cell_type": "code", "execution_count": 60, "id": "a8017b8c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SimpleQuestionDistilBERT(\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (2): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (3): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (4): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (5): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=2, bias=True)\n", ")" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model = SimpleQuestionDistilBERT(mod)\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 61, "id": "f13c12dc", "metadata": {}, "outputs": [], "source": [ "model.train()\n", "optim = RMSprop(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "id": "e4fa54d9", "metadata": {}, "outputs": [], "source": [ "epochs = 5\n", "\n", "for epoch in range(epochs):\n", " loop = tqdm(loader, leave=True)\n", " model.train()\n", " mean_training_error = []\n", " for batch in loop:\n", " optim.zero_grad()\n", " \n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " start = batch['start_positions'].to(device)\n", " end = batch['end_positions'].to(device)\n", " \n", " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n", " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n", " loss = outputs['loss']\n", " loss.backward()\n", " # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n", " optim.step()\n", " mean_training_error.append(loss.item())\n", " loop.set_description(f'Epoch {epoch}')\n", " loop.set_postfix(loss=loss.item())\n", " print(\"Mean Training Error\", np.mean(mean_training_error))\n", " \n", " \n", " loop = tqdm(test_loader, leave=True)\n", " model.eval()\n", " mean_test_error = []\n", " for batch in loop:\n", " \n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " start = batch['start_positions'].to(device)\n", " end = batch['end_positions'].to(device)\n", " \n", " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n", " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n", " loss = outputs['loss']\n", " \n", " mean_test_error.append(loss.item())\n", " loop.set_description(f'Epoch {epoch} Testset')\n", " loop.set_postfix(loss=loss.item())\n", " print(\"Mean Test Error\", np.mean(mean_test_error))" ] }, { "cell_type": "code", "execution_count": 19, "id": "6ff26fb4", "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"simple_distilbert_qa.model\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "a5e7abeb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = SimpleQuestionDistilBERT(mod)\n", "model.load_state_dict(torch.load(\"simple_distilbert_qa.model\"))" ] }, { "cell_type": "code", "execution_count": 18, "id": "f5ad7bee", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2500/2500 [02:09<00:00, 19.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mean EM: 0.5374\n", "Mean F-1: 0.6826317532406944\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "eval_test_set(model, tokenizer, test_loader, device)" ] }, { "cell_type": "markdown", "id": "fa6017a8", "metadata": {}, "source": [ "## Freeze baseline and train new head\n", "This was my initial idea, to freeze the layers and add a completely new head, which we train from scratch. I tried a lot of different configurations, but nothing really worked, I usually stayed at a CrossEntropyLoss of about 3 the whole time. Below, you can see the different heads I have tried.\n", "\n", "Furthermore, I experimented with different data, because I though it might not be enough data all in all. I would conclude that this didn't work because (1) Transformers are very data-hungry and I probably still used too little data (one epoch took about 1h though, so it wasn't possible to use even more). (2) We train the layers completely new, which means they contain absolutely no structure about the problem and task beforehand. I do not think that this way of training leads to better results / less energy used all in all, because it would be too resource intense.\n", "\n", "The following setup is partly based on the HuggingFace implementation of the question answering model (https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/models/distilbert/modeling_distilbert.py#L805)" ] }, { "cell_type": "code", "execution_count": 62, "id": "92b21967", "metadata": {}, "outputs": [], "source": [ "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": 63, "id": "1d7b3a8c", "metadata": {}, "outputs": [], "source": [ "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": 64, "id": "91444894", "metadata": {}, "outputs": [], "source": [ "# only take base model, we do not need the classification head\n", "mod = model.distilbert" ] }, { "cell_type": "code", "execution_count": 65, "id": "74ca6c07", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "QuestionDistilBERT(\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (2): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (3): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (4): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (5): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (relu): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (te): TransformerEncoder(\n", " (layers): ModuleList(\n", " (0): TransformerEncoderLayer(\n", " (self_attn): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n", " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (dropout1): Dropout(p=0.1, inplace=False)\n", " (dropout2): Dropout(p=0.1, inplace=False)\n", " )\n", " (1): TransformerEncoderLayer(\n", " (self_attn): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n", " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (dropout1): Dropout(p=0.1, inplace=False)\n", " (dropout2): Dropout(p=0.1, inplace=False)\n", " )\n", " (2): TransformerEncoderLayer(\n", " (self_attn): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", " )\n", " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n", " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (dropout1): Dropout(p=0.1, inplace=False)\n", " (dropout2): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (classifier): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): ReLU()\n", " (2): Linear(in_features=768, out_features=512, bias=True)\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): ReLU()\n", " (5): Linear(in_features=512, out_features=256, bias=True)\n", " (6): Dropout(p=0.1, inplace=False)\n", " (7): ReLU()\n", " (8): Linear(in_features=256, out_features=128, bias=True)\n", " (9): Dropout(p=0.1, inplace=False)\n", " (10): ReLU()\n", " (11): Linear(in_features=128, out_features=64, bias=True)\n", " (12): Dropout(p=0.1, inplace=False)\n", " (13): ReLU()\n", " (14): Linear(in_features=64, out_features=2, bias=True)\n", " )\n", ")" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model = QuestionDistilBERT(mod)\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 66, "id": "340857f9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+---------------------------------------+------------+\n", "| Modules | Parameters |\n", "+---------------------------------------+------------+\n", "| te.layers.0.self_attn.in_proj_weight | 1769472 |\n", "| te.layers.0.self_attn.in_proj_bias | 2304 |\n", "| te.layers.0.self_attn.out_proj.weight | 589824 |\n", "| te.layers.0.self_attn.out_proj.bias | 768 |\n", "| te.layers.0.linear1.weight | 1572864 |\n", "| te.layers.0.linear1.bias | 2048 |\n", "| te.layers.0.linear2.weight | 1572864 |\n", "| te.layers.0.linear2.bias | 768 |\n", "| te.layers.0.norm1.weight | 768 |\n", "| te.layers.0.norm1.bias | 768 |\n", "| te.layers.0.norm2.weight | 768 |\n", "| te.layers.0.norm2.bias | 768 |\n", "| te.layers.1.self_attn.in_proj_weight | 1769472 |\n", "| te.layers.1.self_attn.in_proj_bias | 2304 |\n", "| te.layers.1.self_attn.out_proj.weight | 589824 |\n", "| te.layers.1.self_attn.out_proj.bias | 768 |\n", "| te.layers.1.linear1.weight | 1572864 |\n", "| te.layers.1.linear1.bias | 2048 |\n", "| te.layers.1.linear2.weight | 1572864 |\n", "| te.layers.1.linear2.bias | 768 |\n", "| te.layers.1.norm1.weight | 768 |\n", "| te.layers.1.norm1.bias | 768 |\n", "| te.layers.1.norm2.weight | 768 |\n", "| te.layers.1.norm2.bias | 768 |\n", "| te.layers.2.self_attn.in_proj_weight | 1769472 |\n", "| te.layers.2.self_attn.in_proj_bias | 2304 |\n", "| te.layers.2.self_attn.out_proj.weight | 589824 |\n", "| te.layers.2.self_attn.out_proj.bias | 768 |\n", "| te.layers.2.linear1.weight | 1572864 |\n", "| te.layers.2.linear1.bias | 2048 |\n", "| te.layers.2.linear2.weight | 1572864 |\n", "| te.layers.2.linear2.bias | 768 |\n", "| te.layers.2.norm1.weight | 768 |\n", "| te.layers.2.norm1.bias | 768 |\n", "| te.layers.2.norm2.weight | 768 |\n", "| te.layers.2.norm2.bias | 768 |\n", "| classifier.2.weight | 393216 |\n", "| classifier.2.bias | 512 |\n", "| classifier.5.weight | 131072 |\n", "| classifier.5.bias | 256 |\n", "| classifier.8.weight | 32768 |\n", "| classifier.8.bias | 128 |\n", "| classifier.11.weight | 8192 |\n", "| classifier.11.bias | 64 |\n", "| classifier.14.weight | 128 |\n", "| classifier.14.bias | 2 |\n", "+---------------------------------------+------------+\n", "Total Trainable Params: 17108290\n" ] }, { "data": { "text/plain": [ "17108290" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_parameters(model)" ] }, { "cell_type": "markdown", "id": "9babd013", "metadata": {}, "source": [ "### Testing the model\n", "This is the same procedure as in `distilbert.ipynb`. " ] }, { "cell_type": "code", "execution_count": 67, "id": "694c828b", "metadata": {}, "outputs": [], "source": [ "# get smaller dataset\n", "batch_size = 8\n", "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n", "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n", "optim=torch.optim.Adam(model.parameters())" ] }, { "cell_type": "code", "execution_count": 68, "id": "a76587df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Passed\n" ] } ], "source": [ "test_model(model, optim, test_ds_loader, device)" ] }, { "cell_type": "markdown", "id": "7c326e8e", "metadata": {}, "source": [ "### Training the model\n", "* Parameter Tuning:\n", " * Learning Rate: I experimented with several values, 1e-4 seemed to work best for me. 1e-3 was very unstable and 1e-5 was too small.\n", " * Gradient Clipping: I experimented with this, but the difference was only minimal\n", "\n", "Data:\n", "* I first used only the SQuAD dataset, but generalisation is a problem\n", " * The dataset is realtively small and we often have entries with the same context but different questions\n", " * I believe, the diversity is not big enough to train a fully functional model\n", "* Hence, I included the Natural Questions dataset too\n", " * It is however a lot more messy - I elaborated a bit more on this in `load_data.ipynb`\n", "* Also the hotpotqa data was used\n", "\n", "Tested with: \n", "* 3 Linear Layers\n", " * Training Error high - needed more layers\n", " * Already expected - this was mostly a Proof of Concept\n", "* 1 TransformerEncoder with 4 attention heads + 1 Linear Layer:\n", " * Training Error was high, still too simple\n", "* 1 TransformerEncoder with 8 heads + 1 Linear Layer:\n", " * Training Error gets lower, however stagnates at some point\n", " * Probably still too simple, it doesn't generalise either\n", "* 2 TransformerEncoder with 8 and 4 heads + 1 Linear Layer:\n", " * Loss gets down but doesn't go further after some time\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2e9f4bd3", "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=nat_paths, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n", "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n", "\n", "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n", " natural_question_paths=None, \n", " hotpotqa_paths = None, tokenizer=tokenizer)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)" ] }, { "cell_type": "code", "execution_count": 26, "id": "03a6de37", "metadata": {}, "outputs": [], "source": [ "model = QuestionDistilBERT(mod)" ] }, { "cell_type": "code", "execution_count": 41, "id": "ed854b73", "metadata": {}, "outputs": [], "source": [ "from torch.optim import AdamW, RMSprop\n", "\n", "model.train()\n", "optim = RMSprop(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": 42, "id": "79fdfcc9", "metadata": {}, "outputs": [], "source": [ "from torch.utils.tensorboard import SummaryWriter\n", "writer = SummaryWriter()" ] }, { "cell_type": "code", "execution_count": null, "id": "f7bddb43", "metadata": {}, "outputs": [], "source": [ "epochs = 20\n", "\n", "for epoch in range(epochs):\n", " loop = tqdm(loader, leave=True)\n", " model.train()\n", " mean_training_error = []\n", " for batch in loop:\n", " optim.zero_grad()\n", " \n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " start = batch['start_positions'].to(device)\n", " end = batch['end_positions'].to(device)\n", " \n", " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n", " \n", " loss = outputs['loss']\n", " loss.backward()\n", " \n", " optim.step()\n", " mean_training_error.append(loss.item())\n", " loop.set_description(f'Epoch {epoch}')\n", " loop.set_postfix(loss=loss.item())\n", " print(\"Mean Training Error\", np.mean(mean_training_error))\n", " writer.add_scalar(\"Loss/train\", np.mean(mean_training_error), epoch)\n", " \n", " loop = tqdm(test_loader, leave=True)\n", " model.eval()\n", " mean_test_error = []\n", " for batch in loop:\n", " \n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " start = batch['start_positions'].to(device)\n", " end = batch['end_positions'].to(device)\n", " \n", " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n", " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n", " loss = outputs['loss']\n", " \n", " mean_test_error.append(loss.item())\n", " loop.set_description(f'Epoch {epoch} Testset')\n", " loop.set_postfix(loss=loss.item())\n", " print(\"Mean Test Error\", np.mean(mean_test_error))\n", " writer.add_scalar(\"Loss/test\", np.mean(mean_test_error), epoch)" ] }, { "cell_type": "code", "execution_count": 238, "id": "a9d6af2e", "metadata": {}, "outputs": [], "source": [ "writer.close()" ] }, { "cell_type": "code", "execution_count": 33, "id": "ba43447e", "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"distilbert_qa.model\")" ] }, { "cell_type": "code", "execution_count": 34, "id": "ffc49aca", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = QuestionDistilBERT(mod)\n", "model.load_state_dict(torch.load(\"distilbert_qa.model\"))" ] }, { "cell_type": "code", "execution_count": 35, "id": "730a86c1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2500/2500 [02:57<00:00, 14.09it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mean EM: 0.0479\n", "Mean F-1: 0.08989175857485086\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "eval_test_set(model, tokenizer, test_loader, device)" ] }, { "cell_type": "markdown", "id": "bd1c7076", "metadata": {}, "source": [ "## Reuse Layer\n", "This was inspired by how well the original model with just one classification head worked. I felt like the main problem with the previous model was the lack of structure which was already in the layers, combined with the massive amount of resources needed for a Transformer.\n", "\n", "Hence, I tried cloning the last (and then last two) layers of the DistilBERT model, putting a classifier on top and using this as the head. The base DistilBERT model is completely frozen. This worked extremely well, while we only fine-tune about 21% of the parameters (14 Mio as opposed to 66 Mio!) we did before. Below you can see the results.\n", "\n", "### Last DistilBERT layer\n", "\n", "Dropout 0.1 and RMSprop 1e-4:\n", "* Mean EM: 0.3888\n", "* Mean F-1: 0.5122932744694068\n", "\n", "Dropout 0.25: very early stagnating\n", "* Mean EM: 0.3552\n", "* Mean F-1: 0.4711235721312687\n", "\n", "Dropout 0.15: seems to work well - training and test error stagnate around 1.7 and 1.8 but good generalisation (need to add more layers)\n", "* Mean EM: 0.4119\n", "* Mean F-1: 0.5296387232893214\n", "\n", "### Last DitilBERT layer + more Dense layers\n", "Dropout 0.15 + 4 dense layers((786-512)-(512-256)-(256-128)-(128-2)) & ReLU: doesn't work too well - stagnates at around 2.4\n", "\n", "### Last two DistilBERT layers\n", "Dropout 0.1 but last 2 DistilBERT layers: works very well, but early overfitting - maybe use more data\n", "* Mean EM: 0.458\n", "* Mean F-1: 0.6003368353673634\n", "\n", "Dropout 0.1 - last 2 distilbert layers: all data\n", "* Mean EM: 0.484\n", "* Mean F-1: 0.6344960035215299\n", "\n", "Dropout 0.15 - **BEST**\n", "* Mean EM: 0.5178\n", "* Mean F-1: 0.6671140689626448\n", "\n", "Dropout 0.2 - doesn't work too well\n", "* Mean EM: 0.4353\n", "* Mean F-1: 0.5776847879304647\n" ] }, { "cell_type": "code", "execution_count": 69, "id": "654e09e8", "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n", "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n", "\n", "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n", " natural_question_paths=None, \n", " hotpotqa_paths = None, tokenizer=tokenizer)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)" ] }, { "cell_type": "code", "execution_count": 70, "id": "707c0cb5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ReuseQuestionDistilBERT(\n", " (te): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (2): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (3): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (4): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (5): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (relu): ReLU()\n", " (dropout): Dropout(p=0.15, inplace=False)\n", " (classifier): Linear(in_features=768, out_features=2, bias=True)\n", ")" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n", "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n", "mod = model.distilbert\n", "\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model = ReuseQuestionDistilBERT(mod)\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 71, "id": "d2c6bff5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+-------------------------------+------------+\n", "| Modules | Parameters |\n", "+-------------------------------+------------+\n", "| te.0.attention.q_lin.weight | 589824 |\n", "| te.0.attention.q_lin.bias | 768 |\n", "| te.0.attention.k_lin.weight | 589824 |\n", "| te.0.attention.k_lin.bias | 768 |\n", "| te.0.attention.v_lin.weight | 589824 |\n", "| te.0.attention.v_lin.bias | 768 |\n", "| te.0.attention.out_lin.weight | 589824 |\n", "| te.0.attention.out_lin.bias | 768 |\n", "| te.0.sa_layer_norm.weight | 768 |\n", "| te.0.sa_layer_norm.bias | 768 |\n", "| te.0.ffn.lin1.weight | 2359296 |\n", "| te.0.ffn.lin1.bias | 3072 |\n", "| te.0.ffn.lin2.weight | 2359296 |\n", "| te.0.ffn.lin2.bias | 768 |\n", "| te.0.output_layer_norm.weight | 768 |\n", "| te.0.output_layer_norm.bias | 768 |\n", "| te.1.attention.q_lin.weight | 589824 |\n", "| te.1.attention.q_lin.bias | 768 |\n", "| te.1.attention.k_lin.weight | 589824 |\n", "| te.1.attention.k_lin.bias | 768 |\n", "| te.1.attention.v_lin.weight | 589824 |\n", "| te.1.attention.v_lin.bias | 768 |\n", "| te.1.attention.out_lin.weight | 589824 |\n", "| te.1.attention.out_lin.bias | 768 |\n", "| te.1.sa_layer_norm.weight | 768 |\n", "| te.1.sa_layer_norm.bias | 768 |\n", "| te.1.ffn.lin1.weight | 2359296 |\n", "| te.1.ffn.lin1.bias | 3072 |\n", "| te.1.ffn.lin2.weight | 2359296 |\n", "| te.1.ffn.lin2.bias | 768 |\n", "| te.1.output_layer_norm.weight | 768 |\n", "| te.1.output_layer_norm.bias | 768 |\n", "| classifier.weight | 1536 |\n", "| classifier.bias | 2 |\n", "+-------------------------------+------------+\n", "Total Trainable Params: 14177282\n" ] }, { "data": { "text/plain": [ "14177282" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_parameters(model)" ] }, { "cell_type": "markdown", "id": "c386c2eb", "metadata": {}, "source": [ "### Testing the Model" ] }, { "cell_type": "code", "execution_count": 72, "id": "818deed3", "metadata": {}, "outputs": [], "source": [ "# get smaller dataset\n", "batch_size = 8\n", "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n", "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n", "optim=torch.optim.Adam(model.parameters())" ] }, { "cell_type": "code", "execution_count": 73, "id": "9da40760", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Passed\n" ] } ], "source": [ "test_model(model, optim, test_ds_loader, device)" ] }, { "cell_type": "markdown", "id": "c3f80248", "metadata": {}, "source": [ "### Model Training" ] }, { "cell_type": "code", "execution_count": 24, "id": "e1adabe6", "metadata": {}, "outputs": [], "source": [ "from torch.optim import AdamW, RMSprop\n", "\n", "model.train()\n", "optim = AdamW(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "id": "efe1cbd5", "metadata": {}, "outputs": [], "source": [ "epochs = 16\n", "\n", "for epoch in range(epochs):\n", " loop = tqdm(loader, leave=True)\n", " model.train()\n", " mean_training_error = []\n", " for batch in loop:\n", " optim.zero_grad()\n", " \n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " start = batch['start_positions'].to(device)\n", " end = batch['end_positions'].to(device)\n", " \n", " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n", " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n", " loss = outputs['loss']\n", " loss.backward()\n", " # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n", " optim.step()\n", " mean_training_error.append(loss.item())\n", " loop.set_description(f'Epoch {epoch}')\n", " loop.set_postfix(loss=loss.item())\n", " print(\"Mean Training Error\", np.mean(mean_training_error))\n", " \n", " loop = tqdm(test_loader, leave=True)\n", " model.eval()\n", " mean_test_error = []\n", " for batch in loop:\n", " \n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " start = batch['start_positions'].to(device)\n", " end = batch['end_positions'].to(device)\n", " \n", " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n", " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n", " loss = outputs['loss']\n", " \n", " mean_test_error.append(loss.item())\n", " loop.set_description(f'Epoch {epoch} Testset')\n", " loop.set_postfix(loss=loss.item())\n", " print(\"Mean Test Error\", np.mean(mean_test_error))\n", " torch.save(model.state_dict(), \"distilbert_reuse_{}\".format(epoch))" ] }, { "cell_type": "code", "execution_count": 48, "id": "fdf37d18", "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), \"distilbert_reuse.model\")" ] }, { "cell_type": "code", "execution_count": 49, "id": "d1cfded4", "metadata": {}, "outputs": [], "source": [ "m = ReuseQuestionDistilBERT(mod)\n", "m.load_state_dict(torch.load(\"distilbert_reuse.model\"))\n", "model = m" ] }, { "cell_type": "code", "execution_count": 47, "id": "233bdc18", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2500/2500 [02:51<00:00, 14.59it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Mean EM: 0.5178\n", "Mean F-1: 0.6671140689626448\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "eval_test_set(model, tokenizer, test_loader, device)" ] }, { "cell_type": "code", "execution_count": null, "id": "0fb1ce9e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.8 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false }, "vscode": { "interpreter": { "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa" } } }, "nbformat": 4, "nbformat_minor": 5 }