{ "cells": [ { "cell_type": "markdown", "id": "e12b9784-0a73-447c-bd95-5c4db12213ec", "metadata": {}, "source": [ "## Load " ] }, { "cell_type": "code", "execution_count": 11, "id": "94c34109-799b-4094-934b-85df33a3be99", "metadata": {}, "outputs": [], "source": [ "import transformers\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from transformers import BertTokenizer\n", "\n", "# Path of bert model\n", "path = '/home/colombo_phd/ItalianLaws/Data/BERT-Domains/'\n", "\n", "# label df to convert token to string\n", "label = pd.read_csv(path +'label_tokens.csv', sep = ';')\n", "\n", "# Load model\n", "if torch.cuda.is_available():\n", " model = torch.load('bert_model')\n", "else:\n", " model = torch.load(path +'bert_model', map_location=torch.device('cpu'))" ] }, { "cell_type": "markdown", "id": "f8df905b-9a7b-46ec-8aab-adb15b50aad5", "metadata": {}, "source": [ "## String to evaluate - title of the law" ] }, { "cell_type": "code", "execution_count": 19, "id": "5868b342-3161-4862-b269-1d4959359d48", "metadata": {}, "outputs": [], "source": [ "title = 'Regolamento per il commercio di prodotti agricoli in europa'" ] }, { "cell_type": "markdown", "id": "3c2d173a-4702-4fb3-93f8-c1ea366bdc41", "metadata": {}, "source": [ "## Run model" ] }, { "cell_type": "code", "execution_count": 20, "id": "d57af582-61bc-4be9-b305-63a40ede1311", "metadata": {}, "outputs": [], "source": [ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n", "encoded_dict = tokenizer.encode_plus(\n", " title,\n", " add_special_tokens = True,\n", " max_length = 389,\n", " truncation=True,\n", " pad_to_max_length = True,\n", " return_attention_mask = True,\n", " return_tensors = 'pt',\n", " )\n", "test_input_ids = torch.cat([encoded_dict['input_ids']], dim=0)\n", "test_attention_masks = torch.cat([encoded_dict['attention_mask']], dim=0)\n", "\n", "b_input_ids = test_input_ids.to(device)\n", "b_input_mask = test_attention_masks.to(device)\n", "with torch.no_grad():\n", " output= model(b_input_ids,\n", " token_type_ids=None,\n", " attention_mask=b_input_mask)\n", " logits = output.logits\n", " logits = logits.detach().cpu().numpy()\n", " pred_flat = np.argmax(logits, axis=1).flatten()" ] }, { "cell_type": "markdown", "id": "bbf3114a-c0ad-411c-b35b-1d7ec922035d", "metadata": {}, "source": [ "## Derive domain" ] }, { "cell_type": "code", "execution_count": 31, "id": "c2e3004b-d735-4462-bbda-6bdb02586102", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'economia'" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label[label['label']== pred_flat[0]]['Ministries'].iloc[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }