{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49192d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/srush/MiniChain\n",
    "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf1da24c",
   "metadata": {},
   "source": [
    "Summarize a long document by chunking and summarizing parts.  Uses\n",
    "aynchronous calls to the API.  Adapted from LangChain [Map-Reduce\n",
    "summary](https://langchain.readthedocs.io/en/stable/_modules/langchain/chains/mapreduce.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce74ed6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import trio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25908e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from minichain import TemplatePrompt, show_log, start_chain"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "174e7a29",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "Prompt that asks LLM to produce a bash command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12b26a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SummaryPrompt(TemplatePrompt):\n",
    "    template_file = \"summary.pmpt.tpl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98747659",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chunk(f, width=4000, overlap=800):\n",
    "    \"Split a documents into 4800 character overlapping chunks\"\n",
    "    text = open(f).read().replace(\"\\n\\n\", \"\\n\")\n",
    "    chunks = []\n",
    "    for i in range(4):\n",
    "        if i * width > len(text):\n",
    "            break\n",
    "        chunks.append({\"text\": text[i * width : (i + 1) * width + overlap]})\n",
    "    return chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0ccfddc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with start_chain(\"summary\") as backend:\n",
    "    prompt = SummaryPrompt(backend.OpenAI())\n",
    "    list_prompt = prompt.map()\n",
    "\n",
    "    # Map - Summarize each chunk in parallel\n",
    "    out = trio.run(list_prompt.arun, chunk(\"../state_of_the_union.txt\"))\n",
    "\n",
    "    # Reduce - Summarize the summarized chunks\n",
    "    print(prompt({\"text\": \"\\n\".join(out)}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3ffd907",
   "metadata": {
    "tags": [
     "hide_inp"
    ]
   },
   "outputs": [],
   "source": [
    "SummaryPrompt().show(\n",
    "    {\"text\": \"One way to fight is to drive down wages and make Americans poorer.\"},\n",
    "    \"Make Americans poorer\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52be8068",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_log(\"summary.log\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "tags,-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}