{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import _jsonnet\n", "import os" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from seq2struct.commands.infer import Inferer\n", "from seq2struct.datasets.spider import SpiderItem\n", "from seq2struct.utils import registry" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "exp_config = json.loads(\n", " _jsonnet.evaluate_file(\n", " \"experiments/spider-configs/spider-mBART50MtoM-large-en-pt-es-fr-train_en-pt-es-fr-eval.jsonnet\"))\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "model_config_path = exp_config[\"model_config\"]\n", "model_config_args = exp_config.get(\"model_config_args\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "infer_config = json.loads(\n", " _jsonnet.evaluate_file(\n", " model_config_path, \n", " tla_codes={'args': json.dumps(model_config_args)}))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "infer_config[\"model\"][\"encoder_preproc\"][\"db_path\"] = \"data/sqlite_files/\"" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING : superfluous {'name': 'EncDec'}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SpiderEncoderBartPreproc Model: facebook/mbart-large-50-many-to-many-mmt\n", "SpiderEncoderBartPreproc Pretrained Checkpoint: models/mBART50MtoM-large/pretrained_checkpoint/pytorch_model.bin\n", "BART load Model: facebook/mbart-large-50-many-to-many-mmt\n" ] } ], "source": [ "inferer = Inferer(infer_config)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "model_dir = exp_config[\"logdir\"] + \"/bs=12,lr=1.0e-04,bert_lr=1.0e-05,end_lr=0e0,att=1\"\n", "#checkpoint_step = exp_config[\"eval_steps\"][0]\n", "checkpoint_step = 42100" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/mnt/Databases/nl2sql/gap-text2sql/mrat-sql-gap/logdir/mBART50MtoM-large-en-pt-es-fr-train/bs=12,lr=1.0e-04,bert_lr=1.0e-05,end_lr=0e0,att=1'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_dir" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING : superfluous {'decoder_preproc': {'grammar': {'clause_order': None, 'end_with_from': True, 'factorize_sketch': 2, 'include_literals': False, 'infer_from_conditions': True, 'name': 'spider', 'output_from': True, 'use_table_pointer': True}, 'save_path': 'data/spider-en-pt-es-fr/mBART50MtoM-large-nl2code-1115,output_from=true,fs=2,emb=bart,cvlink', 'use_seq_elem_rules': True}, 'encoder_preproc': {'bart_version': 'facebook/mbart-large-50-many-to-many-mmt', 'compute_cv_link': True, 'compute_sc_link': True, 'db_path': 'data/sqlite_files/', 'fix_issue_16_primary_keys': True, 'include_table_name_in_column': False, 'pretrained_checkpoint': 'models/mBART50MtoM-large/pretrained_checkpoint/pytorch_model.bin', 'save_path': 'data/spider-en-pt-es-fr/mBART50MtoM-large-nl2code-1115,output_from=true,fs=2,emb=bart,cvlink'}}\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SpiderEncoderBart Model: facebook/mbart-large-50-many-to-many-mmt\n", "SpiderEncoderBart Pretrained Checkpoint: models/mBART50MtoM-large/pretrained_checkpoint/pytorch_model.bin\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at facebook/mbart-large-50-many-to-many-mmt were not used when initializing MBartModel: ['final_logits_bias', 'lm_head.weight']\n", "- This IS expected if you are initializing MBartModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing MBartModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Parameter containing:\n", "tensor([[-0.0428, -0.0166, 0.0040, ..., -0.0377, 0.0135, 0.0064],\n", " [-0.0486, 0.0342, 0.0125, ..., -0.0298, 0.0023, -0.0117],\n", " [-0.0438, -0.0331, -0.0250, ..., -0.0430, 0.0212, -0.0113],\n", " ...,\n", " [-0.0331, -0.0959, -0.0164, ..., -0.0314, 0.0189, -0.0045],\n", " [-0.0296, -0.0487, 0.0050, ..., 0.0017, -0.0259, 0.0182],\n", " [-0.0449, -0.0099, -0.0020, ..., -0.0484, 0.0027, -0.0029]],\n", " requires_grad=True)\n", "Updated the model with models/mBART50MtoM-large/pretrained_checkpoint/pytorch_model.bin\n", "Parameter containing:\n", "tensor([[-0.0428, -0.0166, 0.0040, ..., -0.0377, 0.0135, 0.0064],\n", " [-0.0486, 0.0342, 0.0125, ..., -0.0298, 0.0023, -0.0117],\n", " [-0.0438, -0.0331, -0.0250, ..., -0.0430, 0.0212, -0.0113],\n", " ...,\n", " [ 0.0103, 0.0193, -0.0229, ..., 0.0224, 0.0545, -0.0211],\n", " [-0.0350, -0.0251, 0.0002, ..., -0.0259, 0.0119, 0.0058],\n", " [ 0.0126, -0.0247, -0.0331, ..., 0.0142, -0.0418, 0.0507]],\n", " requires_grad=True)\n", "Loading model from /mnt/Databases/nl2sql/gap-text2sql/mrat-sql-gap/logdir/mBART50MtoM-large-en-pt-es-fr-train/bs=12,lr=1.0e-04,bert_lr=1.0e-05,end_lr=0e0,att=1/model_checkpoint-00042100\n" ] } ], "source": [ "model = inferer.load_model(model_dir, checkpoint_step)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from seq2struct.datasets.spider_lib.preprocess.get_tables import dump_db_json_schema\n", "from seq2struct.datasets.spider import load_tables_from_schema_dict" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "db_id = \"singer\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "my_schema = dump_db_json_schema(\"data/sqlite_files/{db_id}/{db_id}.sqlite\".format(db_id=db_id), db_id)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from seq2struct.utils.api_utils import refine_schema_names" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'db_id': 'singer',\n", " 'table_names_original': ['singer', 'song'],\n", " 'table_names': ['singer', 'song'],\n", " 'column_names_original': [(-1, '*'),\n", " (0, 'Singer_ID'),\n", " (0, 'Name'),\n", " (0, 'Birth_Year'),\n", " (0, 'Net_Worth_Millions'),\n", " (0, 'Citizenship'),\n", " (1, 'Song_ID'),\n", " (1, 'Title'),\n", " (1, 'Singer_ID'),\n", " (1, 'Sales'),\n", " (1, 'Highest_Position')],\n", " 'column_names': [(-1, '*'),\n", " (0, 'singer id'),\n", " (0, 'name'),\n", " (0, 'birth year'),\n", " (0, 'net worth millions'),\n", " (0, 'citizenship'),\n", " (1, 'song id'),\n", " (1, 'title'),\n", " (1, 'singer id'),\n", " (1, 'sales'),\n", " (1, 'highest position')],\n", " 'column_types': ['text',\n", " 'number',\n", " 'text',\n", " 'number',\n", " 'number',\n", " 'text',\n", " 'number',\n", " 'text',\n", " 'number',\n", " 'number',\n", " 'number'],\n", " 'primary_keys': [1, 6],\n", " 'foreign_keys': [[8, 1]]}" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "my_schema" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# If you want to change your schema name, then run this; Otherwise you can skip this.\n", "#refine_schema_names(my_schema)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "schema, eval_foreign_key_maps = load_tables_from_schema_dict(my_schema)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['singer'])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "schema.keys()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "dataset = registry.construct('dataset_infer', {\n", " \"name\": \"spider\", \"schemas\": schema, \"eval_foreign_key_maps\": eval_foreign_key_maps, \n", " \"db_path\": \"data/sqlite_files/\"\n", "})" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "for _, schema in dataset.schemas.items():\n", " model.preproc.enc_preproc._preprocess_schema(schema)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "spider_schema = dataset.schemas[db_id]" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def infer(question):\n", " data_item = SpiderItem(\n", " text=None, # intentionally None -- should be ignored when the tokenizer is set correctly\n", " code=None,\n", " schema=spider_schema,\n", " orig_schema=spider_schema.orig,\n", " orig={\"question\": question}\n", " )\n", " model.preproc.clear_items()\n", " enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)\n", " preproc_data = enc_input, None\n", " with torch.no_grad():\n", " output = inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)\n", " return output[0][\"inferred_code\"]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Código do T5, mas na execução BART\n", "Question: ['', '▁How', '▁many', '▁sing', 'ers', '▁are', '▁there', '▁?', '']\n", "Columns: [['▁*', ''], ['▁sing', 'er', '▁id', ''], ['▁name', ''], ['▁birth', '▁year', ''], ['▁net', '▁worth', '▁millions', ''], ['▁citizen', 'ship', ''], ['▁song', '▁id', ''], ['▁title', ''], ['▁sing', 'er', '▁id', ''], ['▁sales', ''], ['▁highest', '▁position', '']]\n", "Tables: [['▁sing', 'er', ''], ['▁song', '']]\n", "\n", "column_indexes: [9, 11, 15, 17, 20, 24, 27, 30, 32, 36, 38]\n", "column_indexes_2: [9, 13, 15, 18, 22, 25, 28, 30, 34, 36, 39]\n", "table_indexes: [41, 44]\n", "table_indexes_2: [42, 44]\n", "\n", "SELECT Count(*) FROM singer\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/marchanjo/miniconda3/envs/gap-text2sql/lib/python3.7/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead\n", " warnings.warn(\"pickle support for Storage will be removed in 1.5. Use `torch.save` instead\", FutureWarning)\n" ] } ], "source": [ "code = infer(\"How many singers are there?\")\n", "print(code)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Código do T5, mas na execução BART\n", "Question: ['', '▁Quanto', 's', '▁can', 'tores', '▁nós', '▁temos', '▁?', '']\n", "Columns: [['▁*', ''], ['▁sing', 'er', '▁id', ''], ['▁name', ''], ['▁birth', '▁year', ''], ['▁net', '▁worth', '▁millions', ''], ['▁citizen', 'ship', ''], ['▁song', '▁id', ''], ['▁title', ''], ['▁sing', 'er', '▁id', ''], ['▁sales', ''], ['▁highest', '▁position', '']]\n", "Tables: [['▁sing', 'er', ''], ['▁song', '']]\n", "\n", "column_indexes: [9, 11, 15, 17, 20, 24, 27, 30, 32, 36, 38]\n", "column_indexes_2: [9, 13, 15, 18, 22, 25, 28, 30, 34, 36, 39]\n", "table_indexes: [41, 44]\n", "table_indexes_2: [42, 44]\n", "\n", "SELECT Count(*) FROM singer\n" ] } ], "source": [ "code = infer(\"Quantos cantores nós temos?\")\n", "print(code)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Código do T5, mas na execução BART\n", "Question: ['', '▁¿', 'Cu', 'án', 'tos', '▁cantante', 's', '▁tenemos', '▁?', '']\n", "Columns: [['▁*', ''], ['▁sing', 'er', '▁id', ''], ['▁name', ''], ['▁birth', '▁year', ''], ['▁net', '▁worth', '▁millions', ''], ['▁citizen', 'ship', ''], ['▁song', '▁id', ''], ['▁title', ''], ['▁sing', 'er', '▁id', ''], ['▁sales', ''], ['▁highest', '▁position', '']]\n", "Tables: [['▁sing', 'er', ''], ['▁song', '']]\n", "\n", "column_indexes: [10, 12, 16, 18, 21, 25, 28, 31, 33, 37, 39]\n", "column_indexes_2: [10, 14, 16, 19, 23, 26, 29, 31, 35, 37, 40]\n", "table_indexes: [42, 45]\n", "table_indexes_2: [43, 45]\n", "\n", "SELECT Count(*) FROM singer\n" ] } ], "source": [ "code = infer(\"¿Cuántos cantantes tenemos?\")\n", "print(code)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Código do T5, mas na execução BART\n", "Question: ['', '▁Combi', 'en', '▁de', '▁chan', 'teurs', '▁avons', '-', 'nous', '▁?', '']\n", "Columns: [['▁*', ''], ['▁sing', 'er', '▁id', ''], ['▁name', ''], ['▁birth', '▁year', ''], ['▁net', '▁worth', '▁millions', ''], ['▁citizen', 'ship', ''], ['▁song', '▁id', ''], ['▁title', ''], ['▁sing', 'er', '▁id', ''], ['▁sales', ''], ['▁highest', '▁position', '']]\n", "Tables: [['▁sing', 'er', ''], ['▁song', '']]\n", "\n", "column_indexes: [11, 13, 17, 19, 22, 26, 29, 32, 34, 38, 40]\n", "column_indexes_2: [11, 15, 17, 20, 24, 27, 30, 32, 36, 38, 41]\n", "table_indexes: [43, 46]\n", "table_indexes_2: [44, 46]\n", "\n", "SELECT Count(*) FROM singer\n" ] } ], "source": [ "code = infer(\"Combien de chanteurs avons-nous?\")\n", "print(code)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Código do T5, mas na execução BART\n", "Question: ['', '▁Quanto', 's', '▁can', 'tores', '▁estão', '▁dispon', 'i', 'veis', '▁?', '']\n", "Columns: [['▁*', ''], ['▁sing', 'er', '▁id', ''], ['▁name', ''], ['▁birth', '▁year', ''], ['▁net', '▁worth', '▁millions', ''], ['▁citizen', 'ship', ''], ['▁song', '▁id', ''], ['▁title', ''], ['▁sing', 'er', '▁id', ''], ['▁sales', ''], ['▁highest', '▁position', '']]\n", "Tables: [['▁sing', 'er', ''], ['▁song', '']]\n", "\n", "column_indexes: [11, 13, 17, 19, 22, 26, 29, 32, 34, 38, 40]\n", "column_indexes_2: [11, 15, 17, 20, 24, 27, 30, 32, 36, 38, 41]\n", "table_indexes: [43, 46]\n", "table_indexes_2: [44, 46]\n", "\n", "SELECT Count(*) FROM singer\n" ] } ], "source": [ "code = infer(\"Quantos cantores estão disponiveis?\")\n", "print(code)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }