File size: 59,077 Bytes
6753a7e |
1 |
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.7.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Import Packages and libraries","metadata":{}},{"cell_type":"code","source":"%%capture\n!pip install numpy==1.17.4\n!pip install nltk==3.4.5\n!pip install torchtext==0.4.0\n!pip install scikit_learn==0.23.2\n!pip install spacy==2.3.5\n!pip install textblob==0.15.3\n!pip install torch==1.6.0 \n!pip install torchvision==0.7.0\n!pip install tqdm\n!pip install underthesea==1.3.3\n!pip install rouge_score","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:54:40.390494Z","iopub.execute_input":"2023-02-12T04:54:40.390942Z","iopub.status.idle":"2023-02-12T04:56:23.604920Z","shell.execute_reply.started":"2023-02-12T04:54:40.390903Z","shell.execute_reply":"2023-02-12T04:56:23.603544Z"},"trusted":true},"execution_count":9,"outputs":[]},{"cell_type":"code","source":"import nltk\nnltk.download('wordnet')\n\nimport os\nimport math\nimport random\nimport argparse\nfrom pathlib import Path\nimport re\nimport numpy as np\nimport pandas as pd\nimport time\nimport gc\nimport sys\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torchtext.data import Field, Example, Dataset\nfrom torchtext.data import BucketIterator\nimport torch.nn.functional as F\n\nfrom sklearn.model_selection import train_test_split\n\nfrom tqdm import tqdm\nfrom tqdm.notebook import tqdm_notebook\n\n%rm -rf ./UIT-ViCoV19QA\n!git clone https://github.com/minhtriet2397/UIT-ViCoV19QA.git\n\nsys.path.insert(0, '..')\nsys.path.insert(0, '/kaggle/working/UIT-ViCoV19QA/models')\n%cd /kaggle/working/UIT-ViCoV19QA/models\n%pwd\n\nsys.argv=['']\ndel sys","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:56:23.607682Z","iopub.execute_input":"2023-02-12T04:56:23.608107Z","iopub.status.idle":"2023-02-12T04:56:27.462112Z","shell.execute_reply.started":"2023-02-12T04:56:23.608063Z","shell.execute_reply":"2023-02-12T04:56:27.460470Z"},"trusted":true},"execution_count":10,"outputs":[{"name":"stderr","text":"[nltk_data] Downloading package wordnet to /usr/share/nltk_data...\n[nltk_data] Package wordnet is already up-to-date!\n","output_type":"stream"},{"name":"stdout","text":"Cloning into 'UIT-ViCoV19QA'...\nremote: Enumerating objects: 62, done.\u001b[K\nremote: Counting objects: 100% (62/62), done.\u001b[K\nremote: Compressing objects: 100% (43/43), done.\u001b[K\nremote: Total 62 (delta 20), reused 51 (delta 15), pack-reused 0\u001b[K\nUnpacking objects: 100% (62/62), 1.72 MiB | 1.50 MiB/s, done.\n/kaggle/working/UIT-ViCoV19QA/models\n","output_type":"stream"}]},{"cell_type":"markdown","source":"# Main","metadata":{}},{"cell_type":"code","source":"\"\"\"Constants for the baseline models\"\"\"\nSEED = 42\nQUESTION = 'question'\n\nRNN_NAME = 'rnn'\nCNN_NAME = 'cnn'\nTRANSFORMER_NAME = 'transformer'\n\nATTENTION_1 = 'bahdanau'\nATTENTION_2 = 'luong'\n\nGPU = 'gpu'\nCPU = 'cpu'\nCUDA = 'cuda'\n\nCHECKPOINT_PATH = '/model/'\n\nANSWER_TOKEN = '<ans>'\nENTITY_TOKEN = '<ent>'\nEOS_TOKEN = '<eos>'\nSOS_TOKEN = '<sos>'\nPAD_TOKEN = '<pad>'\n\nSRC_NAME = 'src'\nTRG_NAME = 'trg'\n\npath = '/kaggle/working/'","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:56:27.464159Z","iopub.execute_input":"2023-02-12T04:56:27.464596Z","iopub.status.idle":"2023-02-12T04:56:27.471151Z","shell.execute_reply.started":"2023-02-12T04:56:27.464553Z","shell.execute_reply":"2023-02-12T04:56:27.470168Z"},"trusted":true},"execution_count":11,"outputs":[]},{"cell_type":"code","source":"import random\nimport numpy as np\n\ndef parse_args():\n \"\"\"Add arguments to parser\"\"\"\n parser = argparse.ArgumentParser(description='Verbalization dataset baseline models.')\n parser.add_argument('--model', default=RNN_NAME, type=str,\n choices=[RNN_NAME, CNN_NAME, TRANSFORMER_NAME], help='model to train the dataset')\n parser.add_argument('--input', default=QUESTION, type=str,\n choices=[QUESTION], help='use question as input')\n parser.add_argument('--attention', default=ATTENTION_2, type=str,\n choices=[ATTENTION_1, ATTENTION_2], help='attention layer for rnn model')\n parser.add_argument('--batch_size', default=8, type=int, help='batch size')\n parser.add_argument('--epochs_num', default=30, type=int, help='number of epochs')\n parser.add_argument('--answer_num', default=1, type=int, \n choices=[1,2,3,4], help='number of answer')\n args = parser.parse_args()\n return args\n\ndef set_SEED():\n SEED = 42\n random.seed(SEED)\n np.random.seed(SEED)\n torch.manual_seed(SEED)\n torch.cuda.manual_seed(SEED)\n torch.cuda.manual_seed_all(SEED)\n torch.backends.cudnn.enabled = False\n torch.backends.cudnn.benchmark = False\n torch.backends.cudnn.deterministic = True\n\nclass Checkpoint(object):\n \"\"\"Checkpoint class\"\"\"\n @staticmethod\n def save(model,cell, path):\n \"\"\"Save model using name\"\"\"\n name_tmp = model.name+\"_\"+ cell if model.name==RNN_NAME else model.name\n name = f'{name_tmp}.pt'\n torch.save(model.state_dict(), path+name)\n\n @staticmethod\n def load(model,path, name):\n \"\"\"Load model using name\"\"\"\n #name = f'{model.name}.pt'\n model.load_state_dict(torch.load(path+name))\n return model","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:56:27.474094Z","iopub.execute_input":"2023-02-12T04:56:27.474784Z","iopub.status.idle":"2023-02-12T04:56:27.487154Z","shell.execute_reply.started":"2023-02-12T04:56:27.474749Z","shell.execute_reply":"2023-02-12T04:56:27.486131Z"},"trusted":true},"execution_count":12,"outputs":[]},{"cell_type":"markdown","source":"## Import data and create torchtext dataset","metadata":{}},{"cell_type":"code","source":"from underthesea import word_tokenize\n\nclass VerbalDataset(object):\n \"\"\"VerbalDataset class\"\"\"\n \n def __init__(self,train,val,test):\n self.train = train\n self.val = val\n self.test = test\n self.train_data = None\n self.valid_data = None\n self.test_data = None\n self.src_field = None\n self.trg_field = None\n\n def _make_torchtext_dataset(self, data, fields):\n examples = [Example.fromlist(i, fields) for i in tqdm_notebook(data)]\n return Dataset(examples, fields)\n\n def load_data_and_fields(self, ):\n \"\"\"\n Load verbalization data\n Create source and target fields\n \"\"\"\n train, test, val = self.train, self.test, self.val\n \n train = train.melt(id_vars=['id',\"Question\"],value_name=\"Answer\")\n train = train[train['Answer'].astype(bool)].drop(['id','variable'],axis=1).values\n \n test = test.melt(id_vars=['id',\"Question\"],value_name=\"Answer\")\n test = test[test['Answer'].astype(bool)].drop(['id','variable'],axis=1).values\n \n val = val.melt(id_vars=['id',\"Question\"],value_name=\"Answer\")\n val = val[val['Answer'].astype(bool)].drop(['id','variable'],axis=1).values\n\n # create fields\n self.src_field = Field(tokenize=word_tokenize,\n init_token=SOS_TOKEN,\n eos_token=EOS_TOKEN,\n lower=True,\n include_lengths=True,\n batch_first=True)\n \n self.trg_field = Field(tokenize=word_tokenize,\n init_token=SOS_TOKEN,\n eos_token=EOS_TOKEN,\n lower=True,\n batch_first=True)\n\n fields_tuple = [(SRC_NAME, self.src_field), (TRG_NAME, self.trg_field)]\n\n # create toechtext datasets\n self.train_data = self._make_torchtext_dataset(train, fields_tuple)\n self.valid_data = self._make_torchtext_dataset(val, fields_tuple)\n self.test_data = self._make_torchtext_dataset(test, fields_tuple)\n\n # build vocabularies\n self.src_field.build_vocab(self.train_data, min_freq=1)\n self.trg_field.build_vocab(self.train_data, min_freq=1)\n print(\"i am field tuple\",fields_tuple)\n\n def get_data(self):\n \"\"\"Return train, validation and test data objects\"\"\"\n return self.train_data, self.valid_data, self.test_data\n\n def get_fields(self):\n \"\"\"Return source and target field objects\"\"\"\n return self.src_field, self.trg_field\n\n def get_vocabs(self):\n \"\"\"Return source and target vocabularies\"\"\"\n #print('self, trg field vocab: ', self.trg_field.vocab)\n return self.src_field.vocab, self.trg_field.vocab","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:56:27.488581Z","iopub.execute_input":"2023-02-12T04:56:27.488938Z","iopub.status.idle":"2023-02-12T04:56:27.506628Z","shell.execute_reply.started":"2023-02-12T04:56:27.488904Z","shell.execute_reply":"2023-02-12T04:56:27.505736Z"},"trusted":true},"execution_count":13,"outputs":[]},{"cell_type":"code","source":"set_SEED()\nargs = parse_args()\nanswer_num = args.answer_num\n\nset_ = ['train','val','test']\ntrain = pd.read_csv(f'{path}UIT-ViCoV19QA/dataset/{answer_num}_ans/UIT-ViCoV19QA_train.csv',na_filter=False,delimiter='|')\nval = pd.read_csv(f'{path}UIT-ViCoV19QA/dataset/{answer_num}_ans/UIT-ViCoV19QA_val.csv',na_filter=False,delimiter='|')\ntest = pd.read_csv(f'{path}UIT-ViCoV19QA/dataset/{answer_num}_ans/UIT-ViCoV19QA_test.csv',na_filter=False,delimiter='|')\n\ndataset = VerbalDataset(train,val,test)\ndataset.load_data_and_fields()\nsrc_vocab, trg_vocab = dataset.get_vocabs()\ntrain_data, valid_data, test_data = dataset.get_data()\n\nprint('--------------------------------')\nprint(f\"Training data: {len(train_data.examples)}\")\nprint(f\"Evaluation data: {len(valid_data.examples)}\")\nprint(f\"Testing data: {len(test_data.examples)}\")\nprint('--------------------------------')\nprint(f'Question example: {train_data.examples[2].src}\\n')\nprint(f'Answer example: {train_data.examples[2].trg}')\nprint('--------------------------------')\nprint(f\"Unique tokens in questions vocabulary: {len(src_vocab)}\")\nprint(f\"Unique tokens in answers vocabulary: {len(trg_vocab)}\")\nprint('--------------------------------')","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:56:27.508078Z","iopub.execute_input":"2023-02-12T04:56:27.508549Z","iopub.status.idle":"2023-02-12T04:57:40.844030Z","shell.execute_reply.started":"2023-02-12T04:56:27.508515Z","shell.execute_reply":"2023-02-12T04:57:40.842961Z"},"trusted":true},"execution_count":14,"outputs":[{"output_type":"display_data","data":{"text/plain":" 0%| | 0/3500 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"05a9e8b8ae224b9da9a92c5869ecf8b1"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/500 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5d4189ed992b40a9be50fe21799441c1"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/500 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b9457f3414b14bbfb685897ff5fa8aeb"}},"metadata":{}},{"name":"stdout","text":"i am field tuple [('src', <torchtext.data.field.Field object at 0x7f928e083990>), ('trg', <torchtext.data.field.Field object at 0x7f928e083050>)]\n--------------------------------\nTraining data: 3500\nEvaluation data: 500\nTesting data: 500\n--------------------------------\nQuestion example: ['tôi', 'đang', 'cho', 'con', 'bú', '(', '10', 'tháng', 'tuổi', ')', 'có', 'được', 'chủng', 'ngừa', 'vaccine', 'covid-19', 'không', '?', 'trẻ nhỏ', 'bao nhiêu', 'tháng', 'tuổi', 'mới', 'chủng', 'ngừa', 'được', 'vaccine', 'covid-19', 'ạ', '?', 'xin', 'cảm ơn', '!']\n\nAnswer example: ['chào', 'chị', ',', 'theo', 'hướng dẫn', 'của', 'bộ', 'y tế', ',', 'phụ nữ', 'đang', 'cho', 'con', 'bú', 'sẽ', 'hoãn', 'tiêm', 'vaccine', 'covid-19', 'trong', 'thời gian', 'này', '.', 'hiện nay', ',', 'mỗi', 'loại', 'vaccine', 'sẽ', 'chỉ định', 'ở', 'những', 'đối tượng', 'khác', 'nhau', 'như', 'vaccine covid-19', 'của', 'astrazeneca', 'chỉ định', 'tiêm chủng', 'cho', 'người', 'từ', '18', 'tuổi', 'trở lên', ',', 'vaccine', 'của', 'pfizer', '/', 'biontech', 'chỉ định', 'cho', 'trẻ', 'từ', '12', 'tuổi', 'trở lên', ',', 'chưa', 'có', 'vắc xin', 'nào', 'chỉ định', 'cho', 'trẻ', 'nhỏ', 'dưới', '12', 'tuổi', '.', 'cảm ơn', 'câu', 'hỏi', 'của', 'chị', '.', 'cảm ơn', 'chị', '.']\n--------------------------------\nUnique tokens in questions vocabulary: 4396\nUnique tokens in answers vocabulary: 8537\n--------------------------------\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## Define MODEL","metadata":{}},{"cell_type":"code","source":"class Seq2Seq(nn.Module):\n def __init__(self, encoder, decoder, name):\n super().__init__()\n self.encoder = encoder\n self.decoder = decoder\n self.name = name\n\n def forward(self, src_tokens, src_lengths, trg_tokens, teacher_forcing_ratio=0.5):\n encoder_out = self.encoder(src_tokens, \n src_lengths=src_lengths)\n \n decoder_out = self.decoder(trg_tokens, encoder_out,\n src_tokens=src_tokens,\n teacher_forcing_ratio=teacher_forcing_ratio)\n return decoder_out","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:57:40.845634Z","iopub.execute_input":"2023-02-12T04:57:40.846239Z","iopub.status.idle":"2023-02-12T04:57:40.853175Z","shell.execute_reply.started":"2023-02-12T04:57:40.846199Z","shell.execute_reply":"2023-02-12T04:57:40.852098Z"},"trusted":true},"execution_count":15,"outputs":[]},{"cell_type":"code","source":"# Choose model here\nargs.model = CNN_NAME # CNN and Transformers don't apply Attention_1, Attention_2\nargs.attention = ATTENTION_1\ncell_name = 'gru'\n\nif args.model == RNN_NAME and args.attention == ATTENTION_1:\n from rnn1 import Encoder, Decoder\nelif args.model == RNN_NAME and args.attention == ATTENTION_2:\n from rnn2 import Encoder, Decoder\nelif args.model == CNN_NAME:\n from cnn import Encoder, Decoder\nelif args.model == TRANSFORMER_NAME:\n from transformer import Encoder, Decoder, NoamOpt","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:57:40.854778Z","iopub.execute_input":"2023-02-12T04:57:40.855209Z","iopub.status.idle":"2023-02-12T04:57:40.870049Z","shell.execute_reply.started":"2023-02-12T04:57:40.855164Z","shell.execute_reply":"2023-02-12T04:57:40.869129Z"},"trusted":true},"execution_count":16,"outputs":[]},{"cell_type":"code","source":"set_SEED()\nDEVICE = torch.device(CUDA if torch.cuda.is_available() else CPU)\n\nif args.model == RNN_NAME and args.attention == ATTENTION_2:\n encoder = Encoder(src_vocab, DEVICE, cell_name)\n decoder = Decoder(trg_vocab, DEVICE, cell_name)\nelse:\n encoder = Encoder(src_vocab, DEVICE)\n decoder = Decoder(trg_vocab, DEVICE)\nmodel = Seq2Seq(encoder, decoder, args.model).to(DEVICE)\n\nparameters_num = sum(p.numel() for p in model.parameters() if p.requires_grad)\n\nprint('--------------------------------')\nprint(f'Model: {args.model}')\nprint(f'Model input: {args.input}')\nif args.model == RNN_NAME:\n print(f'Attention: {args.attention}')\n print('Cell name: ',cell_name)\nprint(f'The model has {parameters_num:,} trainable parameters')\nprint('--------------------------------')","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:57:40.871571Z","iopub.execute_input":"2023-02-12T04:57:40.872416Z","iopub.status.idle":"2023-02-12T04:57:43.496740Z","shell.execute_reply.started":"2023-02-12T04:57:40.872379Z","shell.execute_reply":"2023-02-12T04:57:43.495579Z"},"trusted":true},"execution_count":17,"outputs":[{"name":"stdout","text":"--------------------------------\nModel: cnn\nModel input: question\nThe model has 28,191,065 trainable parameters\n--------------------------------\n","output_type":"stream"}]},{"cell_type":"markdown","source":"## Train model","metadata":{}},{"cell_type":"code","source":"class Evaluator(object):\n \"\"\"Evaluator class\"\"\"\n def __init__(self, criterion):\n self.criterion = criterion\n\n def evaluate(self, model, iterator, teacher_ratio=1.0):\n model.eval()\n epoch_loss = 0\n with torch.no_grad():\n for _, batch in enumerate(iterator):\n src, src_len = batch.src\n trg = batch.trg\n input_trg = trg if model.name == RNN_NAME else trg[:, :-1]\n output = model(src, src_len, input_trg, teacher_ratio)\n trg = trg.t() if model.name == RNN_NAME else trg[:, 1:]\n output = output.contiguous().view(-1, output.shape[-1])\n trg = trg.contiguous().view(-1)\n # output: (batch_size * trg_len) x output_dim\n # trg: (batch_size * trg_len)\n loss = self.criterion(output, trg)\n epoch_loss += loss.item()\n return epoch_loss / len(iterator)","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:57:43.500533Z","iopub.execute_input":"2023-02-12T04:57:43.500824Z","iopub.status.idle":"2023-02-12T04:57:43.508605Z","shell.execute_reply.started":"2023-02-12T04:57:43.500797Z","shell.execute_reply":"2023-02-12T04:57:43.507609Z"},"trusted":true},"execution_count":18,"outputs":[]},{"cell_type":"code","source":"from torch.cuda.amp import autocast, GradScaler\n\nclass Trainer(object):\n \"\"\"Trainer Class\"\"\"\n def __init__(self, optimizer, criterion, batch_size, device):\n self.optimizer = optimizer\n self.criterion = criterion\n self.batch_size = batch_size\n self.device = device\n self.evaluator = Evaluator(criterion=self.criterion)\n\n def _train_batch(self, model, iterator, teacher_ratio, clip):\n model.train()\n epoch_loss = 0\n #scaler = GradScaler()\n for _, batch in enumerate(tqdm_notebook(iterator)):\n src, src_len = batch.src\n trg = batch.trg\n self.optimizer.zero_grad()\n input_trg = trg if model.name == RNN_NAME else trg[:, :-1]\n output = model(src, src_len, input_trg, teacher_ratio)\n trg = trg.t() if model.name == RNN_NAME else trg[:, 1:]\n output = output.contiguous().view(-1, output.shape[-1])\n trg = trg.contiguous().view(-1)\n # output: (batch_size * trg_len) x output_dim\n # trg: (batch_size * trg_len)\n torch.cuda.empty_cache()\n loss = self.criterion(output, trg)\n loss.backward()\n torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n self.optimizer.step()\n epoch_loss += loss.item()\n \n return epoch_loss / len(iterator)\n\n def _get_iterators(self, train_data, valid_data, model_name):\n return BucketIterator.splits((train_data, valid_data),\n batch_size=self.batch_size,\n sort_within_batch=True if model_name == RNN_NAME else \\\n False,\n sort_key=lambda x: len(x.src),\n device=self.device)\n\n def _epoch_time(self, start_time, end_time):\n elapsed_time = end_time - start_time\n elapsed_mins = int(elapsed_time / 60)\n elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n return elapsed_mins, elapsed_secs\n\n def _log_epoch(self, train_loss, valid_loss, epoch, start_time, end_time):\n minutes, seconds = self._epoch_time(start_time, end_time)\n print(f'Epoch: {epoch+1:02} | Time: {minutes}m {seconds}s')\n print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {np.exp(train_loss):7.3f}')\n print(f'\\t Val. Loss: {valid_loss:.3f} | Val. PPL: {np.exp(valid_loss):7.3f}')\n\n def _train_epoches(self, model, train_data, valid_data, path_, num_of_epochs, teacher_ratio, clip):\n best_valid_loss = float('inf')\n # pylint: disable=unbalanced-tuple-unpacking\n train_iterator, valid_iterator = self._get_iterators(train_data, valid_data, model.name)\n train_loss_list = []\n val_loss_list = []\n for epoch in range(num_of_epochs):\n start_time = time.time()\n train_loss = self._train_batch(model, train_iterator, teacher_ratio, clip)\n valid_loss = self.evaluator.evaluate(model, valid_iterator, teacher_ratio)\n \n train_loss_list.append(train_loss)\n val_loss_list.append(valid_loss)\n \n end_time = time.time()\n self._log_epoch(train_loss, valid_loss, epoch, start_time, end_time)\n if valid_loss < best_valid_loss:\n best_valid_loss = valid_loss\n Checkpoint.save(model,cell_name,path_)\n return train_loss_list, val_loss_list\n\n def train(self, model, train_data, valid_data, path_, num_of_epochs=20, teacher_ratio=1.0, clip=1):\n \"\"\"Train model\"\"\"\n return self._train_epoches(model, train_data, valid_data, path_, num_of_epochs, teacher_ratio, clip)","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:57:43.510381Z","iopub.execute_input":"2023-02-12T04:57:43.511005Z","iopub.status.idle":"2023-02-12T04:57:43.531297Z","shell.execute_reply.started":"2023-02-12T04:57:43.510969Z","shell.execute_reply":"2023-02-12T04:57:43.530280Z"},"trusted":true},"execution_count":19,"outputs":[]},{"cell_type":"code","source":"# create optimizer\nif args.model ==TRANSFORMER_NAME:\n for p in model.parameters():\n if p.dim() > 1:\n nn.init.xavier_uniform_(p)\n optimizer = NoamOpt(torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))\nelse:\n optimizer = optim.Adam(model.parameters(),lr=0.001)\n\nbatch_size = 8\nepochs=10\n\n# define criterion\ncriterion = nn.CrossEntropyLoss(ignore_index=trg_vocab.stoi[PAD_TOKEN])\n\ntrainer = Trainer(optimizer, criterion, batch_size, DEVICE)\ntrain_loss, val_loss = trainer.train(model, train_data, valid_data, path, num_of_epochs=epochs)","metadata":{"execution":{"iopub.status.busy":"2023-02-12T04:57:43.532550Z","iopub.execute_input":"2023-02-12T04:57:43.533240Z","iopub.status.idle":"2023-02-12T05:02:30.402067Z","shell.execute_reply.started":"2023-02-12T04:57:43.533204Z","shell.execute_reply":"2023-02-12T05:02:30.401321Z"},"trusted":true},"execution_count":20,"outputs":[{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"8ba4c23730ce4082adc7e68434ea1782"}},"metadata":{}},{"name":"stdout","text":"Epoch: 01 | Time: 0m 28s\n\tTrain Loss: 5.281 | Train PPL: 196.623\n\t Val. Loss: 4.216 | Val. PPL: 67.775\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5f4d0b32027b46568806a7d354a5c1d7"}},"metadata":{}},{"name":"stdout","text":"Epoch: 02 | Time: 0m 27s\n\tTrain Loss: 4.260 | Train PPL: 70.826\n\t Val. Loss: 3.829 | Val. PPL: 46.032\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"b31778e68d504eee8b47c1d222a06225"}},"metadata":{}},{"name":"stdout","text":"Epoch: 03 | Time: 0m 28s\n\tTrain Loss: 3.869 | Train PPL: 47.907\n\t Val. Loss: 3.614 | Val. PPL: 37.105\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e721a12b8c1044d592677509df1489f6"}},"metadata":{}},{"name":"stdout","text":"Epoch: 04 | Time: 0m 28s\n\tTrain Loss: 3.628 | Train PPL: 37.635\n\t Val. Loss: 3.497 | Val. PPL: 33.008\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"e2cb9b7f88cd45e18fce9790587d22bb"}},"metadata":{}},{"name":"stdout","text":"Epoch: 05 | Time: 0m 28s\n\tTrain Loss: 3.461 | Train PPL: 31.849\n\t Val. Loss: 3.420 | Val. PPL: 30.578\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5e508c73ac384cb4894825dfa6479f26"}},"metadata":{}},{"name":"stdout","text":"Epoch: 06 | Time: 0m 28s\n\tTrain Loss: 3.305 | Train PPL: 27.254\n\t Val. Loss: 3.359 | Val. PPL: 28.772\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"83c37918834041fb8beee9e57fb328ae"}},"metadata":{}},{"name":"stdout","text":"Epoch: 07 | Time: 0m 28s\n\tTrain Loss: 3.204 | Train PPL: 24.640\n\t Val. Loss: 3.330 | Val. PPL: 27.945\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"09ad137b3f034180a265b82800e676da"}},"metadata":{}},{"name":"stdout","text":"Epoch: 08 | Time: 0m 28s\n\tTrain Loss: 3.103 | Train PPL: 22.255\n\t Val. Loss: 3.307 | Val. PPL: 27.293\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"881a973adcce47d5b08d5cf765583076"}},"metadata":{}},{"name":"stdout","text":"Epoch: 09 | Time: 0m 28s\n\tTrain Loss: 3.017 | Train PPL: 20.422\n\t Val. Loss: 3.298 | Val. PPL: 27.063\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":" 0%| | 0/438 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"a7403b6050274dde9704403ef6ec5160"}},"metadata":{}},{"name":"stdout","text":"Epoch: 10 | Time: 0m 28s\n\tTrain Loss: 2.958 | Train PPL: 19.267\n\t Val. Loss: 3.283 | Val. PPL: 26.656\n","output_type":"stream"}]},{"cell_type":"code","source":"import matplotlib.pyplot as plt\nimport os \nos.chdir(r'/kaggle/working')\nfrom IPython.display import FileLink\n\nplt.plot(train_loss)\nplt.plot(val_loss)\nplt.title('model loss')\nplt.xlabel('epoch')\nplt.legend(['train', 'val'], loc='upper left')\n\n#plt.savefig('loss.png')\nplt.show()","metadata":{"_kg_hide-input":true,"execution":{"iopub.status.busy":"2023-02-12T05:02:30.403778Z","iopub.execute_input":"2023-02-12T05:02:30.404565Z","iopub.status.idle":"2023-02-12T05:02:30.626564Z","shell.execute_reply.started":"2023-02-12T05:02:30.404525Z","shell.execute_reply":"2023-02-12T05:02:30.625678Z"},"trusted":true},"execution_count":21,"outputs":[{"output_type":"display_data","data":{"text/plain":"<Figure size 432x288 with 1 Axes>","image/png":"iVBORw0KGgoAAAANSUhEUgAAAXQAAAEWCAYAAAB2X2wCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAuS0lEQVR4nO3dd3hc1Z3/8fdXvVvVtqxiuYHBxtjGNjYQQoAQukNwsOjhl8SwZDckv+xuwj6bJcmS3WQ3+9uENGICCd02YBJ6CKGFuDdcAVdZcpVkSVbXSHN+f9zBlo2LJEu6M6PP63nm0cy9d2a+msf+zNE5955jzjlERCTyxfhdgIiI9A4FuohIlFCgi4hECQW6iEiUUKCLiEQJBbqISJRQoMuAY2a/N7P7u3jsDjO79FRfR6Q/KNBFRKKEAl1EJEoo0CUshbo6/snM1ppZo5k9bGZDzOxVM6s3szfMLKvT8dea2QYzqzWzt83sjE77JpnZqtDz5gNJR73X1Wa2JvTcRWY2oYc1f9XMtpjZATN7wcyGhbabmf2vme03s7rQ7zQ+tO9KM9sYqm2Xmf1jjz4wERToEt6uBz4LnAZcA7wK/AuQi/dv9+sAZnYa8DTwDSAPeAV40cwSzCwB+APwOJANPBN6XULPnQw8AtwJ5AC/AV4ws8TuFGpmFwP/CdwA5ANlwLzQ7suAC0O/RyYwG6gO7XsYuNM5lw6MB97szvuKdKZAl3D2c+fcPufcLuCvwFLn3GrnXCvwPDApdNxs4GXn3J+dcwHgJ0AycB4wHYgHfuqcCzjnngWWd3qPrwK/cc4tdc51OOceBVpDz+uOm4FHnHOrQvXdC8wwsxIgAKQDYwFzzm1yzu0JPS8AnGlmGc65Gufcqm6+r8ghCnQJZ/s63W8+xuO00P1heC1iAJxzQaAcKAjt2+WOnIWurNP94cC3Qt0ttWZWCxSFntcdR9fQgNcKL3DOvQn8AvglsM/M5ppZRujQ64ErgTIze8fMZnTzfUUOUaBLNNiNF8yA12eNF8q7gD1AQWjbx4o73S8Hfuicy+x0S3HOPX2KNaTideHsAnDOPeCcOwcYh9f18k+h7cudczOBwXhdQwu6+b4ihyjQJRosAK4ys0vMLB74Fl63ySJgMdAOfN3M4szsC8C0Ts99CLjLzM4NDV6mmtlVZpbezRqeAu4ws4mh/vf/wOsi2mFmU0OvHw80Ai1AR6iP/2YzGxTqKjoIdJzC5yADnAJdIp5z7kPgFuDnQBXeAOo1zrk251wb8AXgS0ANXn/7wk7PXYHXj/6L0P4toWO7W8NfgO8Cz+H9VTAKKA3tzsD74qjB65apxuvnB7gV2GFmB4G7Qr+HSI+YFrgQEYkOaqGLiEQJBbqISJRQoIuIRAkFuohIlIjz641zc3NdSUmJX28vIhKRVq5cWeWcyzvWPt8CvaSkhBUrVvj19iIiEcnMyo63T10uIiJRQoEuIhIlFOgiIlHCtz70YwkEAlRUVNDS0uJ3KX0uKSmJwsJC4uPj/S5FRKJEWAV6RUUF6enplJSUcOTkeNHFOUd1dTUVFRWMGDHC73JEJEqEVZdLS0sLOTk5UR3mAGZGTk7OgPhLRET6T1gFOhD1Yf6xgfJ7ikj/CbtAP5mWQAe7a5sJapZIEZEjRFygt7UHqWpopb4l0OuvXVtby69+9atuP+/KK6+ktra21+sREemOiAv09KQ44mNjONDYf4He0XHiRWReeeUVMjMze70eEZHuCKuzXLrCzMhKSaCyvoW29iAJcb33nfSd73yHrVu3MnHiROLj40lLSyM/P581a9awceNGPv/5z1NeXk5LSwv33HMPc+bMAQ5PY9DQ0MAVV1zBBRdcwKJFiygoKOCPf/wjycnJvVajiMjxhG2gf//FDWzcffCY+5xzNLV1kBAXQ3xs1wP9zGEZ3HfNuOPu/9GPfsT69etZs2YNb7/9NldddRXr168/dGrhI488QnZ2Ns3NzUydOpXrr7+enJycI15j8+bNPP300zz00EPccMMNPPfcc9xyi1YVE5G+F7aBfiJmRmyMEehwxMf23ftMmzbtiPPEH3jgAZ5//nkAysvL2bx58ycCfcSIEUycOBGAc845hx07dvRdgSIinYRtoJ+oJQ1Q29TGzgNNjMhNJT2pb662TE1NPXT/7bff5o033mDx4sWkpKRw0UUXHfM88sTExEP3Y2NjaW5u7pPaRESOFnGDoh/LSIonNsaoaWzrtddMT0+nvr7+mPvq6urIysoiJSWFDz74gCVLlvTa+4qI9IawbaGfTEyMNzha3dhGe0eQuG70pR9PTk4O559/PuPHjyc5OZkhQ4Yc2nf55Zfz4IMPMmHCBE4//XSmT59+yu8nItKbzPl0gc6UKVPc0QtcbNq0iTPOOKPLr9Hc1sHm/fUMG5RMbnriyZ8QZrr7+4qImNlK59yUY+2L2C4XgOSEWFISYjnQ1IZfX0wiIuEiogMdICslgZZAB82BE1/8IyIS7SI+0DNT4okx40AvDo6KiESiiA/02JgYBiXHU9sUoCOobhcRGbgiPtABslMTCDpHXXPvz+8iIhIpoiLQUxJiSYyLVbeLiAxoURHoZkZ2ajxNbe209OPgaFpaWr+9l4jIyURFoANkpiRg9O6VoyIikSRirxQ9WnxsDBnJcdQ0BRgyKImYHizx9u1vf5vhw4dz9913A/C9730PM+Pdd9+lpqaGQCDA/fffz8yZM3u7fBGRUxa+gf7qd2Dvum49pSAYpCUQJBgfQ0zMMf74GHoWXPGj4z6/tLSUb3zjG4cCfcGCBbz22mt885vfJCMjg6qqKqZPn861116rNUFFJOyEb6D3QGyMYQaBDkdP1r2YNGkS+/fvZ/fu3VRWVpKVlUV+fj7f/OY3effdd4mJiWHXrl3s27ePoUOH9v4vICJyCsI30E/Qkj4eA+rrWthf38LYoRk9Ws1o1qxZPPvss+zdu5fS0lKefPJJKisrWblyJfHx8ZSUlBxz2lwREb91KfHMbIeZrTOzNWa24hj7zcweMLMtZrbWzCb3fqldk53qzY1e09SzwdHS0lLmzZvHs88+y6xZs6irq2Pw4MHEx8fz1ltvUVZW1pvlioj0mu600D/jnKs6zr4rgDGh27nAr0M/+11CXCxpiXHUNLYxOD2x233d48aNo76+noKCAvLz87n55pu55pprmDJlChMnTmTs2LF9VLmIyKnprS6XmcBjzpvycImZZZpZvnNuTy+9frdkpyaw80ATDa3tPVrNaN26w4Oxubm5LF68+JjHNTQ09LhGEZHe1tVOZge8bmYrzWzOMfYXAOWdHleEth3BzOaY2QozW1FZWdn9arsoI9lbzUhXjorIQNLVQD/fOTcZr2vla2Z24VH7j9Wv8YmZspxzc51zU5xzU/Ly8rpZatfFmLea0cGWdto7gn32PiIi4aRLge6c2x36uR94Hph21CEVQFGnx4XA7p4U1FsLVWSnJuCco6YpPCfs0oIcItLbThroZpZqZukf3wcuA9YfddgLwG2hs12mA3U96T9PSkqiurq6V8IuKT6WlIQ4asJwNSPnHNXV1SQlJfldiohEka4Mig4Bng+dLRIHPOWce83M7gJwzj0IvAJcCWwBmoA7elJMYWEhFRUV9Fb/emNrOzVNAZr3J/bonPS+lJSURGFhod9liEgUOWmgO+e2AWcfY/uDne474GunWkx8fDwjRow41Zc5pKG1nWk/fINrJgzjx7PG9drrioiEo/BqtvaytMQ4rp6Qz4trd9PQ2u53OSIifSqqAx1g9tRimto6eHltj8ZoRUQiRtQH+uTiTEYPTmPe8vKTHywiEsGiPtDNjNKpRazeWctH++r9LkdEpM9EfaADXDepgPhYY75a6SISxQZEoOekJfLZM4ewcFUFre39t+aoiEh/GhCBDt7gaE1TgDc27ve7FBGRPjFgAv2C0bkUZCYzb/lOv0sREekTAybQY2OMWecU8t6WKipqmvwuR0Sk1w2YQAf44hTvUvtnVlT4XImISO8bUIFemJXCBaNzeWZFOR3B8JqwS0TkVA2oQAconVrM7roW3ttyvNX0REQi04AL9EvPHEx2agLzNTgqIlFmwAV6Ylws100q4M8b91Hd0Op3OSIivWbABTrA7KlFBDocz6/e5XcpIiK9ZkAG+mlD0plcnMm85eVht5qRiEhPDchAB6+VvmV/A6t21vhdiohIrxiwgX71hGGkJsRqwi4RiRoDNtBTE+O45uxhvPj+HupbAn6XIyJyygZsoAPcMLWI5kAHL63d43cpIiKnbEAH+qSiTE4bkqZuFxGJCgM60M2M2VOLWVNeywd7D/pdjojIKRnQgQ5azUhEoseAD/Ts1AQuGzeU51fv0mpGIhLRBnygA5ROLaK2KcDrG/b5XYqISI8p0IHzR3mrGanbRUQimQIdiIkxbphSxHtbqig/oNWMRCQyKdBDvjilEDN4ZoVa6SISmRToIcMyk7lwTB7PrKzQakYiEpEU6J2UTi1iT10L726u9LsUEZFuU6B3cskZQ8hJTWD+MnW7iEjkUaB3khAXwxcmF/DGpn1UaTUjEYkwCvSjzJ5aRHvQsXBVhd+liIh0iwL9KKMHp3PO8CytZiQiEafLgW5msWa22sxeOsa+i8yszszWhG7/1rtl9q/ZU4vYVtnIyjKtZiQikaM7LfR7gE0n2P9X59zE0O0Hp1iXr646K5+0xDjm6cpREYkgXQp0MysErgJ+27flhAdvNaN8Xl6r1YxEJHJ0tYX+U+CfgeAJjplhZu+b2atmNu5YB5jZHDNbYWYrKivD+1zv2VOLaQ508OL7Ws1IRCLDSQPdzK4G9jvnVp7gsFXAcOfc2cDPgT8c6yDn3Fzn3BTn3JS8vLye1Ntvzi4cxNih6cxfvtPvUkREuqQrLfTzgWvNbAcwD7jYzJ7ofIBz7qBzriF0/xUg3sxye7vY/mTmTdj1fkUdm/ZoNSMRCX8nDXTn3L3OuULnXAlQCrzpnLul8zFmNtTMLHR/Wuh1q/ug3n513aQCEmJjNK2uiESEHp+HbmZ3mdldoYezgPVm9j7wAFDqouAk7qzUBD433lvNqCWg1YxEJLx1K9Cdc287564O3X/QOfdg6P4vnHPjnHNnO+emO+cW9UWxfpg9pYi65gB/2rDX71JERE5IV4qexHmjcijMSmaB5kkXkTCnQD+JmBhj9pQi/ralmp3VWs1IRMKXAr0LZk0pJMZQK11EwpoCvQvyByXz6dPyeHZlBe0dJ7q2SkTEPwr0Lpo9tZi9B7WakYiELwV6F11yxmBy0xJ0TrqIhC0FehfFx8Zw/eRC/rJpP/vrW/wuR0TkExTo3XDDodWMdvldiojIJyjQu2FUXhpTS7JYoNWMRCQMKdC7afbUYrZVNbJ8h1YzEpHwokDvpivPGhpazUjT6opIeFGgd1NKQhzXThzGK+v2cFCrGYlIGFGg90Dp1CJaAkFeWLPb71JERA5RoPfAWQXeakZPLd2pK0dFJGwo0HvAzLjz0yPZuOcg98xbQ0ChLiJhIM7vAiLVdZMKqapv44evbKI9GOTnN04mIU7fjyLin8hMoECz3xUA8NULR/JvV5/Jnzbs4+4nV9HarlWNRMQ/kRfoW96ABybD3nV+VwLA/7lgBD+YOY43Nu3j755YpaXqRMQ3kRfoWSPAYuD3V8OuVX5XA8BtM0r44XXjefOD/dz5+EqFuoj4IvICPWcU3PEKJGXAYzOhfLnfFQFw87nD+fH1Z/Hu5kq++tgKmtsU6iLSvyIv0AGyhsMdr0JqLjz+eSgLjzWpZ08t5r+un8B7W6r48qPLaWpr97skERlAIjPQAQYVwpdegYxh8MT1sO0dvysC4ItTivh/N5zNkm3VfOl3y2lsVaiLSP+I3EAHyMiHL70MWSXw1A3egGkYuG5SIf87eyIry2q4/ZFlNCjURaQfRHagA6QNhttfgtwx8PSN8OFrflcEwMyJBTxQOonV5bXc9vBSzfsiIn0u8gMdIDUHbnsBhoyD+bfAphf9rgiAqybk88ubJrG2oo5bH15GXbNCXUT6TnQEOkBKNtz2Rxg2CRbcDuuf87siAC4fn8+vbp7Mxt113PLbpdQ2tfldkohEqegJdICkQXDrQiieDs99Bd6f53dFAFw2bigP3nIOH+6t56aHllLTqFAXkd4XXYEOkJgONz8DJRfA83fBqsf9rgiAS84YwtzbzmFLZQM3PrSE6oZWv0sSkSgTfYEOkJAKNy2A0ZfAC38Py3/rd0UAXHT6YB6+fQrbqxq58aElVNYr1EWk90RnoAPEJ0PpU3DaFfDyt2DJr/2uCIBPjcnjd1+ays4DTdz40BL217f4XZKIRInoDXSAuES44TE441p47Tvw3k/9rgiA80bn8vs7prG7tpnSuUvYd1ChLiKnLroDHSAuAWb9DsZfD2/cB+/8t98VATB9ZA6/v2Ma++pamP2bxeypC48pgUUkckV/oAPExsEXHoKzb4S37oc3fwjO+V0V00Zk89iXp1HV0Mbs3yxhV61CXUR6rsuBbmaxZrbazF46xj4zswfMbIuZrTWzyb1bZi+IiYWZv4LJt8G7/+W11sMg1M8Zns3jX55GTVMbs3+zmPIDTX6XJCIRqjst9HuATcfZdwUwJnSbA4THCOTRYmLg6p/B1K/A334Gr90bFqE+qTiLJ79yLgebA5TOXUJZdaPfJYlIBOpSoJtZIXAVcLzz/2YCjznPEiDTzPJ7qcbeFRMDV/4Epn8Nlv7aOwMm6P8izxMKM3nqq9NpbGundO4Stlcp1EWke7raQv8p8M/A8ZKvACjv9LgitO0IZjbHzFaY2YrKysru1Nm7zOBzP4QLvgkrHoYXvw5B/xekGF8wiKe+Mp2WQAelcxeztbLB75JEJIKcNNDN7Gpgv3Nu5YkOO8a2T/RlOOfmOuemOOem5OXldaPMPmAGl9wHn/4OrH4c/nA3dPg/ze2ZwzJ4es502jscpXOXsGV/vd8liUiE6EoL/XzgWjPbAcwDLjazJ446pgIo6vS4ENjdKxX2JTP4zL1w8Xdh7TxY+FXo8H9GxLFDM5g3ZzrOQencJXy0T6EuIid30kB3zt3rnCt0zpUApcCbzrlbjjrsBeC20Nku04E659ye3i+3j1z4j3DZ/bBhITzzJWj3f/KsMUPSmTdnOjFmlM5dwqY9B/0uSUTCXI/PQzezu8zsrtDDV4BtwBbgIeDuXqitf533D3DFf8EHL8GCWyHg/9WbowenMf/OGSTExnDjQ0tYv6vO75JEJIyZ8+m0vSlTprgVK1b48t4ntOJ38NI3YNTF3lww8cl+V0RZdSM3zl1CY1sHT3z5XM4qHOR3SSLiEzNb6Zybcqx9A+NK0e6YcgfM/CVsfQue/CK0+X/64PCcVObfOYO0xDhu+u0S1pTX+l2SiIQhBfqxTLoFvjAXyv4GT8yCVv8HJYuyU5h/53QyU+K59bdLWVlW43dJIhJmFOjHM+EGmPUIVCyDx6+D5lq/K6IwK4X5c2aQnZbAbQ8vZfmOA36XJCJhRIF+IuOu86bf3b0GHpsJTf4H6LDMZObPmcHgjCRuf2QZS7dV+12SiIQJBfrJjL3KGxzdvwkevRYaq/yuiKGDkpg/Zzr5g5L40u+W89fNPl51KyJhQ4HeFaddBjfNg+ot8PuroX6f3xUxOCOJeXNmUJSdzK0PL+PehWu1+LTIAKdA76pRF3uLT9fuhN9fBQf9vxA2Lz2RhXefz1cuGMGCFRVc/D9vM3/5ToJB/2eQFJH+p0DvjhGfglsXQv1e+N2VUFt+8uf0sbTEOP716jN5+esXMCovjW8/t45ZDy5iw25dhCQy0CjQu6t4Otz2B2+A9OHPwpqnw2KmxrFDM1hw5wz+e9YEyqqbuObn7/G9FzZwsMX/uWlEpH8o0HuicArc8TKk58Mf7oLfXAib/+z7YhkxMcYXpxTx5rcu4qZzi3l08Q4u+Z93+OOaXfh1RbCI9B8Fek8NPQu++qa3AHVbIzw5Cx69Bnat8rsyBqXEc//nz+IPd59P/qAk7pm3hpseWqqpeEWinOZy6Q3tbbDy9/DOj6GpCsZ9AS75LmSP9LsyOoKOp5bt5L9f+4DmQAdf+dRI/uHi0aQkxPldmoj0wInmclGg96aWg7Do57D4F9686lP+D3z6nyE11+/KqGpo5T9f+YDnVlVQkJnMv11zJpedOQSzY61NIiLhSoHe3+r3wts/glWPQXwKnP91mPE1SEj1uzKWbT/Ad/+wng/31fOZ0/P4/rXjKc5J8bssEekiBbpfqjbDX74Pm16EtCFw0Xdg0m0Q6293R6AjyKOLdvC/f/6I9qDj7otGc+enR5IUH+trXSJycgp0v5Uvgz//G+xcDDlj4NL7YOzV3hJ4Ptpb18K/v7yRl9fuoSQnhe/PHM+nT/N5rVcROSHNh+63omlwx6tQ+jRYDMy/BR75HOxc4mtZQwcl8cubJvP4l6dhZtz+yDLufnIle+qafa1LRHpGLfT+1tEOa56Et/8T6vfA6VfCpd+DvNN9Lau1vYO572zjF29tITbG+MalY7jj/BHEx+o7XyScqMslHLU1wZJfwd9+Bm0N3qIaF/0LZOT7Wlb5gSa+98IG/vLBfk4bksa/zxzPuSNzfK1JRA5ToIezxmp4979h+W8hJg5m3A3n3wNJ/q4b+ueN+/jeCxvYVdvMFyYVcO+VZ5CXnuhrTSKiQI8MB7bDWz+Edc9AcjZc+E8w9csQ51+INrd18Iu3NjP33W0kxcfyT587nZvPHU5sjM5dF/GLAj2S7F4Db9wH296GzOFw8Xdh/PUQ419f9pb9Ddz3wnr+tqWa8QUZ3P/5s5hYlOlbPSIDmc5yiSTDJsJtf4RbFkJiBiz8Cjx0EWx9y7eSRg9O44kvn8vPb5zE/oOtXPerv3HvwnXUNmlBDZFwohZ6OAsGvS6YN++Hup3eIhuXfh/yJ/hWUn1LgJ++sZnfL9rBoOR4vnPFWGZNLiRG3TAi/UJdLpEu0OINmv71J9BcA2fdABf/K2QN962kTXsO8t0/rGdFWQ3nDM/i32eO58xhGb7VIzJQKNCjRXMtvPe/sPRBcEGY+lW48B8hJduXcoJBx3OrKvjPVz+grjnAFyYVcPt5JYwv8PcMHZFopkCPNnW74O3/gDVPQUK6dzbM2Ktg2GRfBk9rm9r46Rubmb+8nOZAB5OLM7n9vBKuGJ9PQpyGaUR6kwI9Wu3bCG/+O3z0mtdiT82D0Z+F0z4Hoz7T7+ey1zUHeG5lBY8vKWN7VSO5aQncOK2Ym84tJn9Qcr/WIhKtFOjRrukAbHkDPvqT97Ol1rtIqXgGnHa5F/A5o/ttMrBg0PHelioeW7yDv3ywnxgzLjtzCLfOGM6MkTmag13kFCjQB5KOdqhY5oX75tdh/0Zve9aIULhfBsPP77cLlsoPNPHE0jLmLy+ntinAmMFp3DZjONdNLiQtUasmiXSXAn0gq93phftHf4Lt70JHKySkwciLvJb7mMsgfWifl9ES6ODF93fz2OIy1u2qIy0xjusnF3DrjBJGD07r8/cXiRYKdPG0NXmh/tFrXuv94C5ve/7EULh/DoZN6tOBVecca8preWxxGS+v3UNbR5ALRudy64zhXDJ2MHGa3VHkhBTo8knOwb4Nh8O9YnloYHUwjPms13IfdTEk9d255VUNrcxfXs6TS8rYXddCQWYyN51bTOnUInLSNBGYyLEo0OXkGqu9AdXNHw+s1kFMPAwPDayO+Rzkju6Tt27vCPLGpv08vmQHf9tSTUJsDFdPyOe280o0Z4zIUU4p0M0sCXgXSATigGedc/cddcxFwB+B7aFNC51zPzjR6yrQw1hHO5Qv9cL9o9ehcpO3PXtkKNw/HlhN6PW33ryvnseXlPHcygoa2zqYUDiI22aUcPWEfK15KsKpB7oBqc65BjOLB94D7nHOLel0zEXAPzrnru5qUQr0CFJT5nXLfPQabP/r4YHVUZ/xWu5jLoP0Ib36lg2t7Ty/qoJHF5exZX8DWSnxzJ5azM3nFlOUndKr7yUSSXqty8XMUvAC/e+cc0s7bb8IBfrA0NZ4eGD1o9ehfre3fdgkr9WeP9GbPCxnNMSceovaOcfirdU8triM1zfuBeDisUO4bcZwLhidq0nBZMA55UA3s1hgJTAa+KVz7ttH7b8IeA6oAHbjhfuGY7zOHGAOQHFx8TllZWXd+kUkzDgH+9aHBlb/DLtXQ0doSt34FBgyDvLPhqETvJAffOYpnf++u7aZp5bu5OllO6lubGNkbiq3zhjO9ecUkpEU30u/lEh4680WeibwPPAPzrn1nbZnAMFQt8yVwM+cc2NO9FpqoUehjgBUfgh718Ke92HPWti7Dtrqvf0xcZB3hhfuQyeEwn48JKZ3621a2zt4dd1eHl28g9U7a0lJiOW6SQXcNqOE04d277VEIk2vnuViZvcBjc65n5zgmB3AFOdc1fGOUaAPEMEg1Gz3An7vWi/k97wPTR//0zBvsDV/QqfW/NmQmtull19XUcdji3fwwvu7aW0PMm1ENjefW8znxg3VIKpEpVMdFM0DAs65WjNLBl4Hfuyce6nTMUOBfc45Z2bTgGeB4e4EL65AH8Ccg/o9oRZ8p9Z83c7Dx2QUHO6q+TjkBxUedz6amsY2Fqwo54mlZZQfaCYzJZ7rJhVQOrVYrXaJKqca6BOAR4FYvCXrFjjnfmBmdwE45x40s78H/g5oB5qB/+ucW3Si11Wgyyc0HfC6aDqHfPVm74IngOSsw+H+cWs+Z9QRg6/BoGPR1mqeXr6T1zfsJdDhmFScSenUIq6eMIxUzR8jEU4XFknkamvyrmjd+/7hkN+/sdPga6rXD9+5NZ93OsQnc6CxjYWrKpi3vJwt+xtITYjl2onDmD21mLMLB2nWR4lICnSJLh8Pvnbul+88+IrBoCLvytacMbjcMWzuGMqC7Uk8uamd5kCQsUPTKZ1axOcnFZCZ0vsXSIn0FQW6RL+PB1/3roXKj7yumqrNUL0F2hoOHebiU6lJLmJd6xBWNeSyM6aAYaPO4sLpM5h2eqFa7RL2ThTo6lCU6BAT4/Wn54w6cvvHA7BVm6F6M1a1hezqzXy66iMubH0Xw8EOYAfss1zaMkeRM3wcKfljIXcM5IzxBmh9WNpPpLsU6BLdzCBjmHcb+ekjdwWa4cA22vZ9yOaNq6gu20BG9Q4yDzwNa5oPHxiX7F35GurCIfe00P3R3T6HXqQvqctFpJOtlQ0sWLaTt1auJ7uljEkplVw6uJ4zE/aRXLfVWzDk47NuANLzQ2F/2uEWfe5orw+/F6Y+EDma+tBFuqmtPcibH+zj6WXlvLu5EoALRudy8zlDuWRIA/E1Ww/304e6c2ipO/wCsQmQkgPJ2ZCS7Z1ymZwVut95W/bhbclZEKs/muXEFOgip2BXbTPPrChnwfJydte1kJOawPXnFHLDlKLDy+c5B41VoZD/CA5sg6ZqaKqB5gPQXOOdZ998AILtx3+zxIyTB39K1pHbEjP6bQFw8Z8CXaQXdAQdf91cybxl5byxaR/tQcfUkixKpxZz5Vn5JCd0oYvFOWitPyrkO4V95/udt3Vu/R8tJu7wXwDHCv7EDEga5N0SM7xVqD7+mZCuAd8Io0AX6WWV9a0sXFXB/OXlbKtqJD0xjpmThlE6tZjxBYN6/w072qGl9vAXQOfA/8S2Tl8U7c0neWH7ZMh/IvgHHXn/6G0JqfoLoR8p0EX6iHOOZdsPMH95OS+v20Nre5DxBRnMnlLEZ8YOpjDL58U4As3QchBaD4Z+1nk/W+o6bQs97ny/8zbXceL3sFjvbJ9DIT/ok18GCanekoaxCd44QWxC6PHHtwTvL41D94+1L8F7fPT9AfZlokAX6Qd1zQH+uGYXTy8rZ9OegwAUZSczfUQOM0Z5t/xByT5X2U3OQaDpGF8Cx/tiOM5x9GHOdP6iONmXRmyCN1d/fHLoZ1Kn+8neKaqdH8cnHXl8XOfHyb6cyaRAF+lHzjk2729g0ZYqFm+rZsm2A9Q1BwAoyUlhxqgcpo/0An5wepLP1faDYBDaWyAY8KZt6Ah4c/EE272fHYGT7Ot0/9CxbV43VEfbkc890b6ONu8WaPL+cgm0HL5/0q6p44hNOM4XQucvjKO/EJKheDqUXNCjt1Sgi/goGHRs2nuQxVurWbKtmqXbDlDf6p3pMiov1Wu9j8xl+shsctJ6vqKTnALnvC+dQHOnwG8+8nH70ds6fSEEmk7w/M77m7zrGC74v3DpfT0qVYEuEkY6go4Nu+tYvLWaxduqWb79AI1tXj/16UPSD7Xgp4/M1sRh0cY5768E57zWew8o0EXCWKAjyLpddYda8Mt3HKAlEMQMzhiaEWrB5zBtZLbWThUFukgkaWsP8n5FrdeC31rNyp01tLUHiTEYXzCIGSNzmD4qh6kl2aRpwY4BR4EuEsFaAh2s3lnrDbBurWZ1eQ2BDkdsjDGh0Av4GaNyOGd4FikJCvhop0AXiSLNbR2sLKth8bYqFm+tZm1FHe1BR3yscXZh5qFTJCcXZ2mh7CikQBeJYo2t7SzfceBQC37drjqCDhLiYphcnMmMkbmcNzqHswszSYjTZf6RToEuMoAcbAmwYscBFm+tZtHWajbuOYhzkBwfy5SSLM4blcuMUTmMH5ZBXKwCPtIo0EUGsNqmNpZsO8CSbdUs2lrFR/u8JfnSE+OYNiL7UBfNGUMziIkZWJfRRyItQScygGWmJHD5+KFcPn4o4E0stmSbdw784q3V/OWD/QBkpcQfuoL1vFE5jMpL0xqrEUYtdJEBbk9d86HumcVbq9lV610Gn5eeyIyRXrjPGJVDcXaKAj4MqMtFRLrEOUf5gWYWbfXmoVm0tZrK+lYACjKTmd4p4IdlRthEY1FCgS4iPeKcY2tlI4u3VrEodCVrTVPnicZyD13JmpeueWj6gwJdRHpFMOj4YG99qP+96oiJxsYMTgu13nM1D00fUqCLSJ9o7wiyYfdBr/89NNFYc6ADMzgzP8Prgx/tTVOQrnloeoUCXUT6RVt7kLUVtSza6p0iuWpnLW3t3kRjYwanMbEok4lFWUwqzuS0IenE6jTJblOgi4gvWgIdrCqrYfmOGtaU17C6vJbaUB98SkIsEwoHMbEoi4lFmUwuzmRwxgBY8OMU6Tx0EfFFUnws543O5bzRuYA3yFpW3cTq8hrW7KxlTXktD7+3jUCH17AcNiiJicWZTCrKYmJxJuOHDSI5QfPRdJUCXUT6jZlRkptKSW4q100qBLxW/IbdB1lTXsvqnTWsKa/llXV7AYiLMcbmpzOx6HDIj8hJ1RWtx6EuFxEJO5X1rawpr2VNuRfw75fX0RA6myYjKY6JxVmhkM9kYlEmWakD54wa9aGLSETrCDq2VjawZmctq8trWL2zlo/21RMMxVdJTooX8KGgPyM/I2pnljylQDezJOBdIBGvi+ZZ59x9Rx1jwM+AK4Em4EvOuVUnel0FuoicisbWdtZW1B3RVbM/dFVrQlwM44dleAOuxV5LvjArOSqmLjjVQdFW4GLnXIOZxQPvmdmrzrklnY65AhgTup0L/Dr0U0SkT6Qmxh2aKRK8Adc9dS2s3nm4q+bJpWU88rftAOSmJXBGfgYjQ334I0K3gszkqJlG+KSB7rwmfEPoYXzodnSzfibwWOjYJWaWaWb5zrk9vVqtiMhxmBnDMpMZlpnMVRPyAW8B7g/31rM61IrfvK+B51btOtQfDxAfaxRlp3hBn5PKiLxURoR+DklPiqgB2C6d5WJmscBKYDTwS+fc0qMOKQDKOz2uCG07ItDNbA4wB6C4uLiHJYuIdE18bAzjCwYxvmAQt04fDngt+aqGNrZXNbKjqpFtoZ/bqxr56+YqWtuDh56fFB/jhXyoNV+Sm3qohZ+TmhB2XThdCnTnXAcw0cwygefNbLxzbn2nQ471W32ic945NxeYC14fevfLFRE5NWZGXnoieemJTBuRfcS+YNCx92AL248K+g/31vPnjftoDx6OrfSkuMNBn5PKyDzvZ0luKoOS/ZnmoFvnoTvnas3sbeByoHOgVwBFnR4XArtPuToRkX4UE3O42+b80MVQH2vvCFJR08z2UMhvr2pkR3UjK3bU8ML7u+l8fklOasKhFv2ITreSnNQ+vVDqpIFuZnlAIBTmycClwI+POuwF4O/NbB7eYGid+s9FJJrExcYcuijqM0ftawl0sPNA0+GgD/1896NKnl1ZccSx+YOS+PIFI/jKp0b2fo1dOCYfeDTUjx4DLHDOvWRmdwE45x4EXsE7ZXEL3mmLd/R6pSIiYSopPpbThqRz2pD0T+xraG1nR6g1v72yke3VjX02d7wuLBIRiSAnOg89Ok6+FBERBbqISLRQoIuIRAkFuohIlFCgi4hECQW6iEiUUKCLiEQJBbqISJTw7cIiM6sEynr49FygqhfLiXT6PI6kz+MwfRZHiobPY7hzLu9YO3wL9FNhZiuOd6XUQKTP40j6PA7TZ3GkaP881OUiIhIlFOgiIlEiUgN9rt8FhBl9HkfS53GYPosjRfXnEZF96CIi8kmR2kIXEZGjKNBFRKJExAW6mV1uZh+a2RYz+47f9fjJzIrM7C0z22RmG8zsHr9r8puZxZrZajN7ye9a/GZmmWb2rJl9EPo3MsPvmvxiZt8M/R9Zb2ZPm1mS3zX1hYgK9NAyeL8ErgDOBG40szP9rcpX7cC3nHNnANOBrw3wzwPgHmCT30WEiZ8BrznnxgJnM0A/FzMrAL4OTHHOjQdigVJ/q+obERXowDRgi3Num3OuDZgHzPS5Jt845/Y451aF7tfj/Yct8Lcq/5hZIXAV8Fu/a/GbmWUAFwIPAzjn2pxztb4W5a84INnM4oAUYLfP9fSJSAv0AqC80+MKBnCAdWZmJcAkYKnPpfjpp8A/A0Gf6wgHI4FK4HehLqjfmlmq30X5wTm3C/gJsBPYA9Q55173t6q+EWmBbsfYNuDPuzSzNOA54BvOuYN+1+MHM7sa2O+cW+l3LWEiDpgM/No5NwloBAbkmJOZZeH9JT8CGAakmtkt/lbVNyIt0CuAok6PC4nSP526yszi8cL8SefcQr/r8dH5wLVmtgOvK+5iM3vC35J8VQFUOOc+/ovtWbyAH4guBbY75yqdcwFgIXCezzX1iUgL9OXAGDMbYWYJeAMbL/hck2/MzPD6SDc55/6f3/X4yTl3r3Ou0DlXgvfv4k3nXFS2wrrCObcXKDez00ObLgE2+liSn3YC080sJfR/5hKidIA4zu8CusM5125mfw/8CW+k+hHn3Aafy/LT+cCtwDozWxPa9i/OuVf8K0nCyD8AT4YaP9uAO3yuxxfOuaVm9iywCu/MsNVE6RQAuvRfRCRKRFqXi4iIHIcCXUQkSijQRUSihAJdRCRKKNBFRKKEAl2kB8zsIs3oKOFGgS4iEiUU6BLVzOwWM1tmZmvM7Deh+dIbzOx/zGyVmf3FzPJCx040syVmttbMng/NAYKZjTazN8zs/dBzRoVePq3TfONPhq5CFPGNAl2ilpmdAcwGznfOTQQ6gJuBVGCVc24y8A5wX+gpjwHfds5NANZ12v4k8Evn3Nl4c4DsCW2fBHwDb27+kXhX7or4JqIu/RfppkuAc4DlocZzMrAfb3rd+aFjngAWmtkgINM5905o+6PAM2aWDhQ4554HcM61AIReb5lzriL0eA1QArzX57+VyHEo0CWaGfCoc+7eIzaaffeo4040/8WJulFaO93vQP+fxGfqcpFo9hdglpkNBjCzbDMbjvfvflbomJuA95xzdUCNmX0qtP1W4J3Q/PIVZvb50GskmllKf/4SIl2lFoVELefcRjP7V+B1M4sBAsDX8BZ7GGdmK4E6vH52gNuBB0OB3Xl2wluB35jZD0Kv8cV+/DVEukyzLcqAY2YNzrk0v+sQ6W3qchERiRJqoYuIRAm10EVEooQCXUQkSijQRUSihAJdRCRKKNBFRKLE/wdyIpx2bznNvgAAAABJRU5ErkJggg==\n"},"metadata":{"needs_background":"light"}}]},{"cell_type":"markdown","source":"## Evaluation","metadata":{}},{"cell_type":"code","source":"val_ref = [list(filter(None, np.delete(i,[0,1]))) for i in val.values]\ntest_ref = [list(filter(None, np.delete(i,[0,1]))) for i in test.values]\n\nval_trg = []\ntest_trg = []\ntrg_ = [val_trg,test_trg]\nfor t in trg_:\n for i in val_ref:\n tmp=[]\n for j in i:\n s = word_tokenize(j)\n tmp.append(s)\n t.append(tmp)\n\nval_src = [i.src for i in valid_data.examples]\nnew_valid = [[val_src[i],val_trg[i]] for i in range(len(val_trg)) ]\ntest_src = [i.src for i in test_data.examples]\nnew_test = [[test_src[i],test_trg[i]] for i in range(len(test_trg))]","metadata":{"execution":{"iopub.status.busy":"2023-02-12T05:10:41.155405Z","iopub.execute_input":"2023-02-12T05:10:41.155774Z","iopub.status.idle":"2023-02-12T05:10:52.815140Z","shell.execute_reply.started":"2023-02-12T05:10:41.155744Z","shell.execute_reply":"2023-02-12T05:10:52.814150Z"},"trusted":true},"execution_count":28,"outputs":[]},{"cell_type":"code","source":"import nltk\nfrom nltk.translate.bleu_score import SmoothingFunction\nfrom rouge_score import rouge_scorer\n\nsmoothie = SmoothingFunction().method4\n\nclass BleuScorer(object):\n \"\"\"Blue scorer class\"\"\"\n def __init__(self):\n self.results = []\n self.results_meteor = []\n \n self.score = 0\n self.bleu_4 = 0\n self.meteor_score = 0\n self.rouge_score = 0\n \n self.instances = 0\n self.meteor_instances = 0\n\n def example_score(self, reference, hypothesis):\n \"\"\"Calculate blue score for one example\"\"\"\n bleu_1 = nltk.translate.bleu_score.sentence_bleu(reference, hypothesis,weights=(1,0,0,0),smoothing_function=SmoothingFunction().method4)\n bleu_4 = nltk.translate.bleu_score.sentence_bleu(reference, hypothesis,weights=(0.25,0.25,0.25,0.25),smoothing_function=SmoothingFunction().method4)\n return bleu_1, bleu_4\n \n def example_score_rouge(self, reference, hypothesis):\n scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)\n scores = []\n for i in reference:\n scores.append(scorer.score(i,hypothesis)['rougeL'][-1])\n return np.max(scores) #best\n \n \n def example_score_meteor(self, reference, hypothesis):\n \"\"\"Calculate blue score for one example\"\"\"\n return nltk.translate.meteor_score.meteor_score(reference,hypothesis)\n\n def data_score(self, data, predictor):\n \"\"\"Score complete list of data\"\"\"\n results_prelim = []\n for example in tqdm_notebook(data):\n #i = 1\n# src = [t.lower() for t in example.src]\n# reference = [t.lower() for t in example.trg]\n \n src = example[0]\n reference = [[string.lower() for string in sublist] for sublist in example[1]]\n\n #and calculate bleu score average of all hypothesis\n #hypothesis = predictor.predict(example.src)\n hypothesis = predictor.predict(src)\n bleu_1,bleu_4 = self.example_score(reference, hypothesis)\n meteor_score = self.example_score_meteor([' '.join(i) for i in reference], ' '.join(hypothesis))\n rouge_score = self.example_score_rouge([' '.join(i) for i in reference], ' '.join(hypothesis))\n \n f = open(\"result.txt\", \"a\")\n f.write('Question: '+\" \".join(src)+'\\n')\n for i in range(len(reference)):\n f.write('Reference_{}: '.format(i)+\" \".join(reference[i])+'\\n')\n f.write('Hypothesis: '+\" \".join(hypothesis)+'\\n')\n f.write('BLEU-1: '+ str(bleu_1*100)+'\\n')\n f.write('BLEU-4: '+str(bleu_4*100)+'\\n')\n f.write('METEOR: '+str(meteor_score*100)+'\\n')\n f.write('ROUGE-L: '+str(rouge_score*100)+'\\n\\n')\n \n f.close()\n \n \n results_prelim.append({\n 'question': '\"' + str(src) + '\"',\n 'reference': reference,\n 'hypothesis': hypothesis,\n 'bleu_1': bleu_1,\n 'bleu_4': bleu_4,\n 'meteor_score': meteor_score,\n 'rouge_score': rouge_score,\n \n })\n \n results = [max((v for v in results_prelim if v['question'] == x), key=lambda y:y['bleu_1']) for x in set(v['question'] for v in results_prelim)] \n\n with open(path+'result_output.txt', 'w') as f:\n for elem in results:\n f.write(\"%s\\n\" % elem)\n self.results.append(elem)\n self.score += elem['bleu_1']\n self.bleu_4 += elem['bleu_4']\n self.meteor_score += elem['meteor_score']\n self.rouge_score += elem['rouge_score']\n self.instances += 1\n return self.score / self.instances, self.bleu_4 / self.instances, self.meteor_score / self.instances, self.rouge_score / self.instances\n\n def average_score(self):\n \"\"\"Return bleu average score\"\"\"\n return self.score / self.instances, self.bleu_4 / self.instances\n \n def average_rouge_score(self):\n \"\"\"Return bleu average score\"\"\"\n return self.rouge_score / self.instances\n \n \n def data_meteor_score(self, data, predictor):\n \"\"\"Score complete list of data\"\"\"\n results_prelim = []\n for example in data:\n src = [t.lower() for t in example.src]\n reference = [t.lower() for t in example.trg]\n hypothesis = predictor.predict(example.src)\n meteor_score = self.example_score_meteor(' '.join(reference), ' '.join(hypothesis))\n results_prelim.append({\n 'question': '\"' + str(src) + '\"',\n 'reference': reference,\n 'hypothesis': hypothesis,\n 'meteor_score': meteor_score\n })\n results_meteor = [max((v for v in results_prelim if v['question'] == x), key=lambda y:y['meteor_score']) for x in set(v['question'] for v in results_prelim)] \n\n with open(path+'result_meteor_output.txt', 'w') as f:\n for elem in results_meteor:\n f.write(\"%s\\n\" % elem)\n self.results_meteor.append(elem)\n self.meteor_score += elem['meteor_score']\n self.meteor_instances += 1\n return self.meteor_score/self.meteor_instances\n \n def average_meteor_score(self):\n \"\"\"Return meteor average score\"\"\"\n return self.meteor_score/self.instances\n\n def reset(self):\n \"\"\"Reset object properties\"\"\"\n self.results = []\n self.results_meteor = []\n self.score = 0\n self.meteor_score = 0\n self.instances = 0\n self.meteor_instances = 0","metadata":{"execution":{"iopub.status.busy":"2023-02-12T05:05:50.377470Z","iopub.execute_input":"2023-02-12T05:05:50.378036Z","iopub.status.idle":"2023-02-12T05:05:50.427362Z","shell.execute_reply.started":"2023-02-12T05:05:50.378000Z","shell.execute_reply":"2023-02-12T05:05:50.426483Z"},"trusted":true},"execution_count":24,"outputs":[]},{"cell_type":"code","source":"class Predictor(object):\n \"\"\"Predictor class\"\"\"\n def __init__(self, model, src_vocab, trg_vocab, device):\n self.model = model\n self.src_vocab = src_vocab\n self.trg_vocab = trg_vocab\n self.device = device\n\n def _predict_step(self, tokens):\n self.model.eval()\n tokenized_sentence = [SOS_TOKEN] + [t.lower() for t in tokens] + [EOS_TOKEN]\n numericalized = [self.src_vocab.stoi[token] for token in tokenized_sentence]\n src_tensor = torch.LongTensor(numericalized).unsqueeze(0).to(self.device)\n\n with torch.no_grad():\n encoder_out = self.model.encoder(src_tensor)\n\n outputs = [self.trg_vocab.stoi[SOS_TOKEN]]\n\n # cnn positional embedding gives assertion error for tensor\n # of size > max_positions-1, we predict tokens for max_positions-2\n # to avoid the error\n for _ in range(self.model.decoder.max_positions-2):\n trg_tensor = torch.LongTensor(outputs).unsqueeze(0).to(self.device)\n\n with torch.no_grad():\n output = self.model.decoder(trg_tensor, encoder_out, src_tokens=src_tensor)\n\n prediction = output.argmax(2)[:, -1].item()\n\n if prediction == self.trg_vocab.stoi[EOS_TOKEN] or len(outputs)==500:\n break\n \n outputs.append(prediction)\n\n translation = [self.trg_vocab.itos[i] for i in outputs]\n\n return translation[1:] # , attention\n\n def _predict_rnn_step(self, tokens):\n self.model.eval()\n with torch.no_grad():\n tokenized_sentence = [SOS_TOKEN] + [t.lower() for t in tokens] + [EOS_TOKEN]\n numericalized = [self.src_vocab.stoi[t] for t in tokenized_sentence]\n\n src_len = torch.LongTensor([len(numericalized)]).to(self.device)\n tensor = torch.LongTensor(numericalized).unsqueeze(1).to(self.device)\n\n translation_tensor_logits = self.model(tensor.t(), src_len, None)\n\n translation_tensor = torch.argmax(translation_tensor_logits.squeeze(1), 1)\n translation = [self.trg_vocab.itos[t] for t in translation_tensor]\n\n return translation[1:] # , attention\n\n def predict(self, tokens):\n \"\"\"Perform prediction on given tokens\"\"\"\n return self._predict_rnn_step(tokens) if self.model.name == RNN_NAME else \\\n self._predict_step(tokens)","metadata":{"execution":{"iopub.status.busy":"2023-02-12T05:05:54.736292Z","iopub.execute_input":"2023-02-12T05:05:54.736655Z","iopub.status.idle":"2023-02-12T05:05:54.751475Z","shell.execute_reply.started":"2023-02-12T05:05:54.736624Z","shell.execute_reply":"2023-02-12T05:05:54.750546Z"},"trusted":true},"execution_count":25,"outputs":[]},{"cell_type":"code","source":"name = args.model+\"_\"+cell_name if args.model==RNN_NAME else args.model\nmodel = Checkpoint.load(model,path,'./{}.pt'.format(name))\n\nvalid_iterator, test_iterator = BucketIterator.splits(\n (valid_data, test_data),\n batch_size=8,\n sort_within_batch=True if args.model == RNN_NAME else False,\n sort_key=lambda x: len(x.src),\n device=DEVICE)\n\n# evaluate model\nvalid_loss = trainer.evaluator.evaluate(model, valid_iterator)\ntest_loss = trainer.evaluator.evaluate(model, test_iterator)\n\n# calculate blue score for valid and test data\npredictor = Predictor(model, src_vocab, trg_vocab, DEVICE)\n\n# # train_scorer = BleuScorer()\nvalid_scorer = BleuScorer()\ntest_scorer = BleuScorer()\n\nvalid_scorer.data_score(new_valid, predictor)\ntest_scorer.data_score(new_test, predictor)","metadata":{"execution":{"iopub.status.busy":"2023-02-12T05:05:59.035087Z","iopub.execute_input":"2023-02-12T05:05:59.035456Z","iopub.status.idle":"2023-02-12T05:07:51.964640Z","shell.execute_reply.started":"2023-02-12T05:05:59.035426Z","shell.execute_reply":"2023-02-12T05:07:51.963354Z"},"trusted":true},"execution_count":26,"outputs":[{"output_type":"display_data","data":{"text/plain":" 0%| | 0/500 [00:00<?, ?it/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5b1e766e58d04f1eb714a08c586e5c3a"}},"metadata":{}},{"name":"stdout","text":"| Test Loss: 3.169 | Test PPL: 23.787 |\n| Test Data Average BLEU score (0.13299530736768786, 0.08383683068952133) |\n| Test Data Average METEOR score 0.14519957374057366 |\n","output_type":"stream"}]},{"cell_type":"code","source":"print(f'| Val. Loss: {valid_loss:.3f} | Test PPL: {math.exp(valid_loss):7.3f} |')\nprint(f'| Val. Data Average BLEU1, BLEU4 score {valid_scorer.average_score()} |')\nprint(f'| Val. Data Average METEOR score {valid_scorer.average_meteor_score()} |')\nprint(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')\nprint(f'| Test Data Average BLEU1, BLEU4 score {test_scorer.average_score()} |')\nprint(f'| Test Data Average METEOR score {test_scorer.average_meteor_score()} |')","metadata":{"execution":{"iopub.status.busy":"2023-02-12T05:02:42.325507Z","iopub.status.idle":"2023-02-12T05:02:42.325985Z","shell.execute_reply.started":"2023-02-12T05:02:42.325740Z","shell.execute_reply":"2023-02-12T05:02:42.325764Z"},"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"r = {'ppl':[round(math.exp(test_loss),3)],\n 'BLEU-1':[test_scorer.average_score()[0]*100],\n 'BLEU-4':[test_scorer.average_score()[1]*100],\n 'METEOR':[test_scorer.average_meteor_score()*100],\n 'ROUGE-L':[test_scorer.average_rouge_score()*100]}\n\ndf_result = pd.DataFrame(data=r)\n\nhtml = df_result.style.set_table_styles([{'selector': 'th', 'props': [('font-size', '15pt')]}]).set_properties(**{'font-size': '15pt'})\nhtml","metadata":{"execution":{"iopub.status.busy":"2023-02-12T05:09:12.132953Z","iopub.execute_input":"2023-02-12T05:09:12.133326Z","iopub.status.idle":"2023-02-12T05:09:12.198299Z","shell.execute_reply.started":"2023-02-12T05:09:12.133295Z","shell.execute_reply":"2023-02-12T05:09:12.197237Z"},"trusted":true},"execution_count":27,"outputs":[{"execution_count":27,"output_type":"execute_result","data":{"text/plain":"<pandas.io.formats.style.Styler at 0x7f91cc4f0150>","text/html":"<style type=\"text/css\">\n#T_c7a0a_ th {\n font-size: 15pt;\n}\n#T_c7a0a_row0_col0, #T_c7a0a_row0_col1, #T_c7a0a_row0_col2, #T_c7a0a_row0_col3, #T_c7a0a_row0_col4 {\n font-size: 15pt;\n}\n</style>\n<table id=\"T_c7a0a_\">\n <thead>\n <tr>\n <th class=\"blank level0\" > </th>\n <th class=\"col_heading level0 col0\" >ppl</th>\n <th class=\"col_heading level0 col1\" >BLEU-1</th>\n <th class=\"col_heading level0 col2\" >BLEU-4</th>\n <th class=\"col_heading level0 col3\" >METEOR</th>\n <th class=\"col_heading level0 col4\" >ROUGE-L</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th id=\"T_c7a0a_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n <td id=\"T_c7a0a_row0_col0\" class=\"data row0 col0\" >23.787000</td>\n <td id=\"T_c7a0a_row0_col1\" class=\"data row0 col1\" >13.299531</td>\n <td id=\"T_c7a0a_row0_col2\" class=\"data row0 col2\" >8.383683</td>\n <td id=\"T_c7a0a_row0_col3\" class=\"data row0 col3\" >14.519957</td>\n <td id=\"T_c7a0a_row0_col4\" class=\"data row0 col4\" >27.549394</td>\n </tr>\n </tbody>\n</table>\n"},"metadata":{}}]}]} |