{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "import torch\n",
    "import logging\n",
    "import torch.nn as nn\n",
    "from model import Seq2Seq\n",
    "from transformers import (\n",
    "    RobertaConfig, \n",
    "    RobertaModel, \n",
    "    RobertaTokenizer\n",
    ")\n",
    "\n",
    "import regex as re\n",
    "\n",
    "# disable warnings\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# base model is RoBERTa\n",
    "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}\n",
    "\n",
    "# initialize logging\n",
    "logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "                    datefmt = '%m/%d/%Y %H:%M:%S',\n",
    "                    level = logging.INFO)\n",
    "logger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CONFIG:\n",
    "    max_source_length = 256\n",
    "    max_target_length = 128\n",
    "    beam_size = 10\n",
    "    local_rank = -1\n",
    "    no_cuda = False\n",
    "\n",
    "    do_train = True\n",
    "    do_eval = True\n",
    "    do_test = True\n",
    "    train_batch_size = 12\n",
    "    eval_batch_size = 32\n",
    "\n",
    "    model_type = \"roberta\"\n",
    "    model_name_or_path = \"microsoft/codebert-base\"\n",
    "    output_dir = \"/content/drive/MyDrive/CodeSummarization\"\n",
    "    load_model_path = None\n",
    "    train_filename = \"dataset/python/train.jsonl\"\n",
    "    dev_filename = \"dataset/python/valid.jsonl\"\n",
    "    test_filename = \"dataset/python/test.jsonl\"\n",
    "    config_name = \"\"\n",
    "    tokenizer_name = \"\"\n",
    "    cache_dir = \"cache\"\n",
    "\n",
    "    save_every = 5000\n",
    "\n",
    "    gradient_accumulation_steps = 1\n",
    "    learning_rate = 5e-5\n",
    "    weight_decay = 1e-4\n",
    "    adam_epsilon = 1e-8\n",
    "    max_grad_norm = 1.0\n",
    "    num_train_epochs = 3.0\n",
    "    max_steps = -1\n",
    "    warmup_steps = 0\n",
    "    train_steps = 100000\n",
    "    eval_steps = 10000\n",
    "    n_gpu = torch.cuda.device_count()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ded94a2103074dc5b4413a2774888bca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1edf49c06d214de2ab403e4e6137f714",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "970cfab5b847490ea56f2fdc4e475393",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/150 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4df44ac11f74ec6b4460e40802ad890",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d2355af24624caabff2b7881799bc03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/498 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s> index: 0\n",
      "</s> index: 2\n",
      "<pad> index: 1\n",
      "<mask> index: 50264\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "from transformers import RobertaTokenizer\n",
    "logger = logging.getLogger(__name__)\n",
    "tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', cache_dir=CONFIG.cache_dir)\n",
    "\n",
    "print(f'{tokenizer.cls_token} index: {tokenizer.cls_token_id}')\n",
    "print(f'{tokenizer.sep_token} index: {tokenizer.sep_token_id}')\n",
    "print(f'{tokenizer.pad_token} index: {tokenizer.pad_token_id}')\n",
    "print(f'{tokenizer.mask_token} index: {tokenizer.mask_token_id}') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_str = \"def sina_xml_to_url_list(xml_data):\\n    \\\"\\\"\\\"str->list\\n    Convert XML to URL List.\\n    From Biligrab.\\n    \\\"\\\"\\\"\\n    rawurl = []\\n    dom = parseString(xml_data)\\n    for node in dom.getElementsByTagName('durl'):\\n        url = node.getElementsByTagName('url')[0]\\n        rawurl.append(url.childNodes[0].data)\\n    return rawurl\"\n",
    "input_tokens = tokenizer.tokenize(input_str)\n",
    "print(input_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['def',\n",
       " 'sina_xml_to_url_list',\n",
       " '(',\n",
       " 'xml_data',\n",
       " ')',\n",
       " ':',\n",
       " 'rawurl',\n",
       " '=',\n",
       " '[',\n",
       " ']',\n",
       " 'dom',\n",
       " '=',\n",
       " 'parseString',\n",
       " '(',\n",
       " 'xml_data',\n",
       " ')',\n",
       " 'for',\n",
       " 'node',\n",
       " 'in',\n",
       " 'dom',\n",
       " '.',\n",
       " 'getElementsByTagName',\n",
       " '(',\n",
       " \"'\",\n",
       " 'durl',\n",
       " \"'\",\n",
       " ')',\n",
       " ':',\n",
       " 'url',\n",
       " '=',\n",
       " 'node',\n",
       " '.',\n",
       " 'getElementsByTagName',\n",
       " '(',\n",
       " \"'\",\n",
       " 'url',\n",
       " \"'\",\n",
       " ')',\n",
       " '[',\n",
       " '0',\n",
       " ']',\n",
       " 'rawurl',\n",
       " '.',\n",
       " 'append',\n",
       " '(',\n",
       " 'url',\n",
       " '.',\n",
       " 'childNodes',\n",
       " '[',\n",
       " '0',\n",
       " ']',\n",
       " '.',\n",
       " 'data',\n",
       " ')',\n",
       " 'return',\n",
       " 'rawurl']"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def preprocessing(code_segment):\n",
    "    \n",
    "    # remove newlines\n",
    "    code_segment = re.sub(r'\\n', ' ', code_segment)\n",
    "    \n",
    "    # remove docstring\n",
    "    code_segment = re.sub(r'\"\"\".*?\"\"\"', '', code_segment, flags=re.DOTALL)\n",
    "    \n",
    "    # remove multiple spaces\n",
    "    code_segment = re.sub(r'\\s+', ' ', code_segment)\n",
    "    \n",
    "    # remove comments\n",
    "    code_segment = re.sub(r'#.*', '', code_segment)\n",
    "\n",
    "    # remove html tags\n",
    "    code_segment = re.sub(r'<.*?>', '', code_segment)\n",
    "\n",
    "    # remove urls\n",
    "    code_segment = re.sub(r'http\\S+', '', code_segment)\n",
    "    \n",
    "    # split special chars into different tokens\n",
    "    code_segment = re.sub(r'([^\\w\\s])', r' \\1 ', code_segment)\n",
    "    \n",
    "    return code_segment.split()\n",
    "\n",
    "preprocessing(input_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokens = ['def', 'get_data', '(', ')', ':', 'data', '=', '[', ']', 'for', 'i', 'in', 'range', '(', '10', ')', ':', 'data', '.', 'append', '(', 'i', ')', 'return', 'data']\n"
     ]
    }
   ],
   "source": [
    "input_str = \"def get_data():\\n    data = []\\n    for i in range(10):\\n        data.append(i)\\n    return data\"\n",
    "input_tokens = preprocessing(input_str)\n",
    "print(f'Tokens = {input_tokens}')\n",
    "# tokenizer.encode_plus(input_tokens, max_length=CONFIG.max_source_length, pad_to_max_length=True, truncation=True, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokens = ['def', 'sina_xml_to_url_list', '(', 'xml_data', ')', ':', 'rawurl', '=', '[', ']', 'dom', '=', 'parseString', '(', 'xml_data', ')', 'for', 'node', 'in', 'dom', '.', 'getElementsByTagName', '(', \"'\", 'durl', \"'\", ')', ':', 'url', '=', 'node', '.', 'getElementsByTagName', '(', \"'\", 'url', \"'\", ')', '[', '0', ']', 'rawurl', '.', 'append', '(', 'url', '.', 'childNodes', '[', '0', ']', '.', 'data', ')', 'return', 'rawurl']\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[    0,  9232,     3,  1640,     3,    43,    35,     3,  5214, 10975,\n",
       "           742, 12623,  5214,     3,  1640,     3,    43,  1990, 46840,   179,\n",
       "         12623,     4,     3,  1640,   108,     3,   108,    43,    35,  6423,\n",
       "          5214, 46840,     4,     3,  1640,   108,  6423,   108,    43, 10975,\n",
       "           288,   742,     3,     4, 48696,  1640,  6423,     4,     3, 10975,\n",
       "           288,   742,     4, 23687,    43, 30921,     3,     2,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
       "             1,     1,     1,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_str = \"def sina_xml_to_url_list(xml_data):\\n    \\\"\\\"\\\"str->list\\n    Convert XML to URL List.\\n    From Biligrab.\\n    \\\"\\\"\\\"\\n    rawurl = []\\n    dom = parseString(xml_data)\\n    for node in dom.getElementsByTagName('durl'):\\n        url = node.getElementsByTagName('url')[0]\\n        rawurl.append(url.childNodes[0].data)\\n    return rawurl\"\n",
    "input_tokens = preprocessing(input_str)\n",
    "print(f'Tokens = {input_tokens}')\n",
    "# tokenizer.encode_plus(input_tokens, max_length=CONFIG.max_source_length, pad_to_max_length=True, truncation=True, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[    0,  9232,     3,  1640,    43,    35, 23687,  5214, 10975,   742,\n",
      "          1990,   118,   179,  9435,  1640,   698,    43,    35, 23687,     4,\n",
      "         48696,  1640,   118,    43, 30921, 23687,     2,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
     ]
    }
   ],
   "source": [
    "encoded_input = tokenizer.encode_plus(\n",
    "    input_tokens, \n",
    "    max_length=CONFIG.max_source_length, \n",
    "    pad_to_max_length=True, \n",
    "    truncation=True, \n",
    "    return_tensors=\"pt\"\n",
    ")\n",
    "print(encoded_input)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Config model\n",
    "config_class, model_class, tokenizer_class = (RobertaConfig, RobertaModel, RobertaTokenizer)\n",
    "model_config = config_class.from_pretrained(CONFIG.config_name if CONFIG.config_name else CONFIG.model_name_or_path, cache_dir=CONFIG.cache_dir)\n",
    "model_config.save_pretrained('config')\n",
    "\n",
    "# load tokenizer\n",
    "tokenizer = tokenizer_class.from_pretrained(\n",
    "    CONFIG.tokenizer_name if CONFIG.tokenizer_name else CONFIG.model_name_or_path,\n",
    "    cache_dir=CONFIG.cache_dir,\n",
    "    # do_lower_case=args.do_lower_case\n",
    ")\n",
    "\n",
    "# load encoder from pretrained RoBERTa\n",
    "encoder = model_class.from_pretrained(CONFIG.model_name_or_path, config=model_config, cache_dir=CONFIG.cache_dir)    \n",
    "\n",
    "# build decoder \n",
    "decoder_layer = nn.TransformerDecoderLayer(d_model=model_config.hidden_size, nhead=model_config.num_attention_heads)\n",
    "decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
    "\n",
    "# build seq2seq model from pretrained encoder and from-scratch decoder\n",
    "model=Seq2Seq(\n",
    "    encoder=encoder,\n",
    "    decoder=decoder,\n",
    "    config=model_config,\n",
    "    beam_size=CONFIG.beam_size,\n",
    "    max_length=CONFIG.max_target_length,\n",
    "    sos_id=tokenizer.cls_token_id,\n",
    "    eos_id=tokenizer.sep_token_id\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state_dict = torch.load(\"./models/pytorch_model.bin\")\n",
    "model.load_state_dict(state_dict)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "# move model to GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() and not CONFIG.no_cuda else \"cpu\")\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[    0,  9232,     3,  1640,    43,    35, 23687,  5214, 10975,   742,\n",
      "          1990,   118,   179,  9435,  1640,   698,    43,    35, 23687,     4,\n",
      "         48696,  1640,   118,    43, 30921, 23687,     2,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,\n",
      "             1,     1,     1,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
      "         1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n"
     ]
    }
   ],
   "source": [
    "input_str = \"def get_data():\\n    data = []\\n    for i in range(10):\\n        data.append(i)\\n    return data\"\n",
    "input_tokens = preprocessing(input_str)\n",
    "encoded_input = tokenizer.encode_plus(\n",
    "    input_tokens, \n",
    "    max_length=CONFIG.max_source_length, \n",
    "    pad_to_max_length=True, \n",
    "    truncation=True, \n",
    "    return_tensors=\"pt\"\n",
    ")\n",
    "print(encoded_input)\n",
    "\n",
    "input_ids = encoded_input[\"input_ids\"].to(device)\n",
    "input_mask = encoded_input[\"attention_mask\"].to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Summary.shape = torch.Size([1, 10, 128])\n",
      "Summary = tensor([[[42555,    10,   889,  ...,     0,     0,     0],\n",
      "         [42555,    10,   889,  ...,     0,     0,     0],\n",
      "         [42555,    10,   889,  ...,     0,     0,     0],\n",
      "         ...,\n",
      "         [42555,    10,   889,  ...,     0,     0,     0],\n",
      "         [42555,    10,   889,  ...,     0,     0,     0],\n",
      "         [42555,    10,   889,  ...,     0,     0,     0]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "output = model(input_ids, input_mask)\n",
    "print(f'Summary.shape = {output.shape}')\n",
    "print(f'Summary = {output}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([128])\n",
      "Return a list of data.\n",
      "torch.Size([128])\n",
      "Return a list of int values.\n",
      "torch.Size([128])\n",
      "Return a list of ints.\n",
      "torch.Size([128])\n",
      "Return a list of ints\n",
      "torch.Size([128])\n",
      "Return a list of the number of integers.\n",
      "torch.Size([128])\n",
      "Return a list of the number of data.\n",
      "torch.Size([128])\n",
      "Return a list of the number of digits.\n",
      "torch.Size([128])\n",
      "Return a list of the number of numbers.\n",
      "torch.Size([128])\n",
      "Return a list of data in a list.\n",
      "torch.Size([128])\n",
      "Return a list of data in a list of data\n"
     ]
    }
   ],
   "source": [
    "# decode summary with tokenizer\n",
    "summary = output[0]\n",
    "for i in range(10):\n",
    "    print(f'{summary[i].shape}')\n",
    "    pred = tokenizer.decode(summary[i], skip_special_tokens=True)\n",
    "    print(pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aio",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c4b1d2403d5bedfc2b499b2d1212ae0437b5f8ebf43026ed45c1b9608ddeb20c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}