{ "cells": [ { "cell_type": "markdown", "id": "47700837", "metadata": {}, "source": [ "# DistilBERT Base Model\n", "The following contains the code to create and train a DistilBERT model using the Huggingface library. It works quite well for a moderate amount of data, but the runtime increases quite drastically with data.\n", "\n", "I decided to take the pretrained model after all, still, creating the model myself was quite interesting!" ] }, { "cell_type": "code", "execution_count": 4, "id": "c09fa906", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import torch\n", "import time\n", "from pathlib import Path\n", "from transformers import DistilBertTokenizerFast\n", "import os\n", "from transformers import DistilBertConfig\n", "from transformers import DistilBertForMaskedLM\n", "from tokenizers import BertWordPieceTokenizer\n", "from tqdm.auto import tqdm\n", "from torch.optim import AdamW\n", "import torchtest\n", "from transformers import pipeline\n", "\n", "\n", "from distilbert import test_model\n", "from distilbert import Dataset\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "id": "3b773fac", "metadata": {}, "source": [ "## Tokeniser\n", "We need a way to convert the strings we get as the input to numerical tokens, that we can give to the neual network. Hence, we take a BertWorkPieceTokenizer (works for DistilBERT too) and create tokens from our words." ] }, { "cell_type": "code", "execution_count": 5, "id": "24277c5b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tokeniser created\n" ] } ], "source": [ "fit_new_tokenizer = True\n", "\n", "if fit_new_tokenizer:\n", " paths = [str(x) for x in Path('data/original').glob('**/*.txt')]\n", "\n", " tokenizer = BertWordPieceTokenizer(\n", " clean_text=True,\n", " handle_chinese_chars=False,\n", " strip_accents=False,\n", " lowercase=True\n", " )\n", " print(\"Tokeniser created\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "beacf3e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\n" ] } ], "source": [ "# fit the tokenizer\n", "if fit_new_tokenizer:\n", " tokenizer.train(files=paths[:10], vocab_size=30_000, min_frequency=2,\n", " limit_alphabet=1000, wordpieces_prefix='##',\n", " special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])" ] }, { "cell_type": "code", "execution_count": null, "id": "0d462cc5", "metadata": {}, "outputs": [], "source": [ "if fit_new_tokenizer:\n", " os.mkdir('./tokeniser')\n", " tokenizer.save_model('tokeniser')\n", " print(\"Tokeniser saved\")" ] }, { "cell_type": "markdown", "id": "7eaa1667", "metadata": {}, "source": [ "After having created a basic tokeniser, we use the model to initialise a DistilBert tokenizer, that we need for the model architecture later on. We save the tokeniser separately." ] }, { "cell_type": "code", "execution_count": 8, "id": "f4dd0684", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('distilbert_tokenizer/tokenizer_config.json',\n", " 'distilbert_tokenizer/special_tokens_map.json',\n", " 'distilbert_tokenizer/vocab.txt',\n", " 'distilbert_tokenizer/added_tokens.json',\n", " 'distilbert_tokenizer/tokenizer.json')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer = DistilBertTokenizerFast.from_pretrained('tokeniser', max_len=512)\n", "tokenizer.save_pretrained(\"distilbert_tokenizer\")" ] }, { "cell_type": "markdown", "id": "bfcafcde", "metadata": {}, "source": [ "### Testing\n", "We now test the created tokenizer. We take a simple example and tokenise the input. It can be seen that we add a special token in the beginning and end ('CLS' and 'SEP'), which is how the BERT model was defined.\n", "\n", "When we translate the input back, we can see that we get the same, except for the first and last token. Also, we can see that questionmarks and commas are encoded separately." ] }, { "cell_type": "code", "execution_count": 9, "id": "37e7f6a8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'input_ids': [2, 10958, 16, 2175, 1993, 1965, 35, 3], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}\n" ] } ], "source": [ "tokens = tokenizer('Hello, how are you?')\n", "print(tokens)" ] }, { "cell_type": "code", "execution_count": 10, "id": "bbd0c4b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'[CLS] hello, how are you? [SEP]'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(tokens['input_ids'])" ] }, { "cell_type": "code", "execution_count": 11, "id": "4ab6e506", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[CLS]\n", "hello\n", ",\n", "how\n", "are\n", "you\n", "?\n", "[SEP]\n" ] } ], "source": [ "for tok in tokens['input_ids']:\n", " print(tokenizer.decode(tok))" ] }, { "cell_type": "code", "execution_count": 12, "id": "c75d3255", "metadata": {}, "outputs": [], "source": [ "assert len(tokenizer.vocab) == 30_000" ] }, { "cell_type": "markdown", "id": "dd114355", "metadata": {}, "source": [ "## Dataset\n", "We now define a function to mask some of the tokens. In particular, we create a Dataset class, that automates loading the data and tokenising it for us. Lastly, we use a DataLoader to load the data step by step into memory.\n", "\n", "The big problem with the limited resources we have is memory. In particular, I am loading the data sequentially, file by file, keeping track how many samples have been read. Shuffling wouldn't work here (it would also not make a lot of sense for this dataset)." ] }, { "cell_type": "code", "execution_count": 10, "id": "bff9ea54", "metadata": {}, "outputs": [], "source": [ "# create dataset and dataloader \n", "dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][50:70], tokenizer=tokenizer)\n", "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n", "\n", "test_dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][10:12], tokenizer=tokenizer)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)" ] }, { "cell_type": "markdown", "id": "6bbe6e63", "metadata": {}, "source": [ "### Testing\n", "The randomisation makes it a bit difficult to test. But altogether, we see that the input ids, masks and labels have the same shape. Also, as we mask 15% of the samples, when decoding a given sample, we can see that some samples are now '[MASK]'." ] }, { "cell_type": "code", "execution_count": 11, "id": "436ab745", "metadata": {}, "outputs": [], "source": [ "i = iter(dataset)" ] }, { "cell_type": "code", "execution_count": 12, "id": "330e599d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Passed\n" ] } ], "source": [ "for j in range(10):\n", " sample = next(i)\n", " \n", " input_ids = sample['input_ids']\n", " attention_masks = sample['attention_mask']\n", " labels = sample['labels']\n", " \n", " # check if the dimensions are right\n", " assert input_ids.shape[0] == (512)\n", " assert attention_masks.shape[0] == (512)\n", " assert labels.shape[0] == (512)\n", " \n", " # if the input ids are not masked, the labels are the same as the input ids\n", " assert np.array_equal(input_ids[input_ids != 4].numpy(),labels[input_ids != 4].numpy())\n", " # input ids are zero if the attention masks are zero\n", " assert np.all(input_ids[attention_masks == 0].numpy()==0)\n", " # check if input contains masked tokens (we can't guarantee this 100% but this will apply) most likely\n", " assert np.any(input_ids.numpy() == 4)\n", "print(\"Passed\")" ] }, { "cell_type": "markdown", "id": "08db6d22", "metadata": {}, "source": [ "## Model\n", "In the following section, we intialise and train a model." ] }, { "cell_type": "code", "execution_count": 13, "id": "7803bda6", "metadata": {}, "outputs": [], "source": [ "config = DistilBertConfig(\n", " vocab_size=30000,\n", " max_position_embeddings=514\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "8ca03f6a", "metadata": {}, "outputs": [], "source": [ "model = DistilBertForMaskedLM(config)" ] }, { "cell_type": "code", "execution_count": 15, "id": "4da22bff", "metadata": { "scrolled": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/sanju/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/cuda/__init__.py:83: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n", " return torch._C._cuda_getDeviceCount() > 0\n" ] }, { "data": { "text/plain": [ "DistilBertForMaskedLM(\n", " (activation): GELUActivation()\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30000, 768, padding_idx=0)\n", " (position_embeddings): Embedding(514, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (2): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (3): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (4): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (5): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (vocab_transform): Linear(in_features=768, out_features=768, bias=True)\n", " (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (vocab_projector): Linear(in_features=768, out_features=30000, bias=True)\n", " (mlm_loss_fct): CrossEntropyLoss()\n", ")" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# if we have a GPU - train on gpu\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model.to(device)" ] }, { "cell_type": "markdown", "id": "6fb8c2e2", "metadata": {}, "source": [ "### Testing the model\n", "I stumbled across some Medium articles on how to test DeepLearning models beforehand \n", "* https://thenerdstation.medium.com/how-to-unit-test-machine-learning-code-57cf6fd81765: the package is however deprecated\n", "* https://towardsdatascience.com/testing-your-pytorch-models-with-torcheck-cb689ecbc08c: released a package (torcheck)\n", "* https://github.com/suriyadeepan/torchtest: I found this package, which is the PyTorch version of the first one and is still maintained.\n", "\n", "Essentially, testing a model is inherently difficult, because we do not know the result in the beginning. Still, the following four conditions should be satisfied in every model (see second reference above):\n", "1. The parameters should change during training (if they are not frozen).\n", "2. The parameters should not change if they are frozen.\n", "3. The range of the ouput should be in a predefined range.\n", "4. The parameters should never contain NaN. The same goes for the outputs too.\n", "\n", "I tried using the packages, but they do not trivially apply for models with multiple inputs (we have input ids and attention masks). The following is partly adapted from the torchtest package (https://github.com/suriyadeepan/torchtest/blob/master/torchtest/torchtest.py)." ] }, { "cell_type": "code", "execution_count": 16, "id": "cfd33fa1", "metadata": {}, "outputs": [], "source": [ "# get smaller dataset\n", "test_ds = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][:2], tokenizer=tokenizer)\n", "test_ds_loader = torch.utils.data.DataLoader(test_ds, batch_size=2)\n", "optim=torch.optim.Adam(model.parameters())" ] }, { "cell_type": "code", "execution_count": 17, "id": "907db815", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Passed\n" ] } ], "source": [ "from distilbert import test_model\n", "\n", "test_model(model, optim, test_ds_loader, device)" ] }, { "cell_type": "markdown", "id": "c02c9c4b", "metadata": {}, "source": [ "### Training the model\n", "We use AdamW as the optimiser and train for 10 epochs." ] }, { "cell_type": "code", "execution_count": 18, "id": "178914f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DistilBertForMaskedLM(\n", " (activation): GELUActivation()\n", " (distilbert): DistilBertModel(\n", " (embeddings): Embeddings(\n", " (word_embeddings): Embedding(30000, 768, padding_idx=0)\n", " (position_embeddings): Embedding(514, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (transformer): Transformer(\n", " (layer): ModuleList(\n", " (0): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (1): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (2): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (3): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (4): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (5): TransformerBlock(\n", " (attention): MultiHeadSelfAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n", " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n", " )\n", " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (ffn): FFN(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n", " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n", " (activation): GELUActivation()\n", " )\n", " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " )\n", " (vocab_transform): Linear(in_features=768, out_features=768, bias=True)\n", " (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (vocab_projector): Linear(in_features=768, out_features=30000, bias=True)\n", " (mlm_loss_fct): CrossEntropyLoss()\n", ")" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = DistilBertForMaskedLM(config)\n", "# if we have a GPU - train on gpu\n", "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 19, "id": "bb6532be", "metadata": {}, "outputs": [], "source": [ "# we use AdamW as the optimiser\n", "optim = AdamW(model.parameters(), lr=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "id": "2fd5d609", "metadata": {}, "outputs": [], "source": [ "epochs = 10\n", "\n", "for epoch in range(epochs):\n", " loop = tqdm(loader, leave=True)\n", " \n", " # set model to training mode\n", " model.train()\n", " losses = []\n", " \n", " # iterate over dataset\n", " for batch in loop:\n", " optim.zero_grad()\n", " \n", " # copy input to device\n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " labels = batch['labels'].to(device)\n", " \n", " # predict\n", " outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n", " \n", " # update weights\n", " loss = outputs.loss\n", " loss.backward()\n", " \n", " optim.step()\n", " \n", " # output current loss\n", " loop.set_description(f'Epoch {epoch}')\n", " loop.set_postfix(loss=loss.item())\n", " losses.append(loss.item())\n", " \n", " del input_ids\n", " del attention_mask\n", " del labels\n", " \n", " print(\"Mean Training Loss\", np.mean(losses))\n", " losses = []\n", " loop = tqdm(test_loader, leave=True)\n", " \n", " # set model to evaluation mode\n", " model.eval()\n", " \n", " # iterate over dataset\n", " for batch in loop:\n", " # copy input to device\n", " input_ids = batch['input_ids'].to(device)\n", " attention_mask = batch['attention_mask'].to(device)\n", " labels = batch['labels'].to(device)\n", " \n", " # predict\n", " outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n", " \n", " # update weights\n", " loss = outputs.loss\n", " \n", " # output current loss\n", " loop.set_description(f'Epoch {epoch}')\n", " loop.set_postfix(loss=loss.item())\n", " losses.append(loss.item())\n", " \n", " del input_ids\n", " del attention_mask\n", " del labels\n", " print(\"Mean Test Loss\", np.mean(losses))" ] }, { "cell_type": "code", "execution_count": 22, "id": "03c23c3e", "metadata": {}, "outputs": [], "source": [ "# save the pretrained model\n", "torch.save(model, \"distilbert.model\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "9b18d3e3", "metadata": {}, "outputs": [], "source": [ "model = torch.load(\"distilbert.model\")" ] }, { "cell_type": "markdown", "id": "e6ad94db", "metadata": {}, "source": [ "### Testing\n", "Huggingface provides a library to quickly be able to see what word the model would predict for our masked token." ] }, { "cell_type": "code", "execution_count": 27, "id": "7c8582d2", "metadata": {}, "outputs": [], "source": [ "fill = pipeline(\"fill-mask\", model='distilbert', config=config, tokenizer='distilbert_tokenizer')" ] }, { "cell_type": "code", "execution_count": 28, "id": "d309e57f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'score': 0.19730663299560547,\n", " 'token': 2965,\n", " 'token_str': 'change',\n", " 'sequence': 'it seems important to tackle the climate change.'},\n", " {'score': 0.12946806848049164,\n", " 'token': 5215,\n", " 'token_str': 'crisis',\n", " 'sequence': 'it seems important to tackle the climate crisis.'},\n", " {'score': 0.05868387222290039,\n", " 'token': 3688,\n", " 'token_str': 'issues',\n", " 'sequence': 'it seems important to tackle the climate issues.'},\n", " {'score': 0.047418754547834396,\n", " 'token': 3406,\n", " 'token_str': 'issue',\n", " 'sequence': 'it seems important to tackle the climate issue.'},\n", " {'score': 0.027855267748236656,\n", " 'token': 2629,\n", " 'token_str': 'here',\n", " 'sequence': 'it seems important to tackle the climate here.'}]" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fill(f'It seems important to tackle the climate {fill.tokenizer.mask_token}.')" ] }, { "cell_type": "code", "execution_count": null, "id": "94e3e623", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.8 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false }, "vscode": { "interpreter": { "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa" } } }, "nbformat": 4, "nbformat_minor": 5 }