{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7e3a8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/srush/MiniChain\n",
    "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49443595",
   "metadata": {
    "lines_to_next_cell": 2,
    "tags": [
     "hide_inp"
    ]
   },
   "outputs": [],
   "source": [
    "desc = \"\"\"\n",
    "### Question Answering with Retrieval\n",
    "\n",
    "Chain that answers questions with embeedding based retrieval. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/qa.ipynb)\n",
    "\n",
    "(Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5183ea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import numpy as np\n",
    "from minichain import prompt, show, OpenAIEmbed, OpenAI\n",
    "from manifest import Manifest"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bf59f0d",
   "metadata": {},
   "source": [
    "We use Hugging Face Datasets as the database by assigning\n",
    "a FAISS index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f371a85e",
   "metadata": {},
   "outputs": [],
   "source": [
    "olympics = datasets.load_from_disk(\"olympics.data\")\n",
    "olympics.add_faiss_index(\"embeddings\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1099002",
   "metadata": {},
   "source": [
    "Fast KNN retieval prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6881ae0e",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "@prompt(OpenAIEmbed())\n",
    "def get_neighbors(model, inp, k):\n",
    "    embedding = model(inp)\n",
    "    res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n",
    "    return res.examples[\"content\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59cc1355",
   "metadata": {
    "lines_to_next_cell": 1
   },
   "outputs": [],
   "source": [
    "@prompt(OpenAI(),\n",
    "        template_file=\"qa.pmpt.tpl\")\n",
    "def get_result(model, query, neighbors):\n",
    "    return model(dict(question=query, docs=neighbors))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb2f1101",
   "metadata": {},
   "outputs": [],
   "source": [
    "def qa(query):\n",
    "    n = get_neighbors(query, 3)\n",
    "    return get_result(query, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f70bac7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abdfcd87",
   "metadata": {},
   "outputs": [],
   "source": [
    "questions = [\"Who won the 2020 Summer Olympics men's high jump?\",\n",
    "             \"Why was the 2020 Summer Olympics originally postponed?\",\n",
    "             \"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\",\n",
    "             \"What is the total number of medals won by France?\",\n",
    "             \"What is the tallest mountain in the world?\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddce3ec3",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "gradio = show(qa,\n",
    "              examples=questions,\n",
    "              subprompts=[get_neighbors, get_result],\n",
    "              description=desc,\n",
    "              )\n",
    "if __name__ == \"__main__\":\n",
    "    gradio.launch()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "tags,-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}