{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import _jsonnet\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from seq2struct.commands.infer import Inferer\n", "from seq2struct.datasets.spider import SpiderItem\n", "from seq2struct.utils import registry" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "exp_config = json.loads(\n", " _jsonnet.evaluate_file(\n", " \"experiments/spider-configs/gap-run.jsonnet\"))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_config_path = exp_config[\"model_config\"]\n", "model_config_args = exp_config.get(\"model_config_args\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "infer_config = json.loads(\n", " _jsonnet.evaluate_file(\n", " model_config_path, \n", " tla_codes={'args': json.dumps(model_config_args)}))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "infer_config[\"model\"][\"encoder_preproc\"][\"db_path\"] = \"data/sqlite_files/\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inferer = Inferer(infer_config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_dir = exp_config[\"logdir\"] + \"/bs=12,lr=1.0e-04,bert_lr=1.0e-05,end_lr=0e0,att=1\"\n", "checkpoint_step = exp_config[\"eval_steps\"][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = inferer.load_model(model_dir, checkpoint_step)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from seq2struct.datasets.spider_lib.preprocess.get_tables import dump_db_json_schema\n", "from seq2struct.datasets.spider import load_tables_from_schema_dict" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "db_id = \"singer\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "my_schema = dump_db_json_schema(\"data/sqlite_files/{db_id}/{db_id}.sqlite\".format(db_id=db_id), db_id)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from seq2struct.utils.api_utils import refine_schema_names" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "my_schema" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# If you want to change your schema name, then run this; Otherwise you can skip this.\n", "refine_schema_names(my_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "schema, eval_foreign_key_maps = load_tables_from_schema_dict(my_schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "schema.keys()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = registry.construct('dataset_infer', {\n", " \"name\": \"spider\", \"schemas\": schema, \"eval_foreign_key_maps\": eval_foreign_key_maps, \n", " \"db_path\": \"data/sqlite_files/\"\n", "})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for _, schema in dataset.schemas.items():\n", " model.preproc.enc_preproc._preprocess_schema(schema)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "spider_schema = dataset.schemas[db_id]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def infer(question):\n", " data_item = SpiderItem(\n", " text=None, # intentionally None -- should be ignored when the tokenizer is set correctly\n", " code=None,\n", " schema=spider_schema,\n", " orig_schema=spider_schema.orig,\n", " orig={\"question\": question}\n", " )\n", " model.preproc.clear_items()\n", " enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)\n", " preproc_data = enc_input, None\n", " with torch.no_grad():\n", " output = inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)\n", " return output[0][\"inferred_code\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "code = infer(\"How many singers are there?\")\n", "print(code)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }