{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# oobabooga/text-generation-webui\n",
        "\n",
        "After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate API links.\n",
        "\n",
        "* Project page: https://github.com/oobabooga/text-generation-webui\n",
        "* Gradio server status: https://status.gradio.app/"
      ],
      "metadata": {
        "id": "MFQl6-FjSYtY"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 1. Keep this tab alive to prevent Colab from disconnecting you { display-mode: \"form\" }\n",
        "\n",
        "#@markdown Press play on the music player that will appear below:\n",
        "%%html\n",
        "<audio src=\"https://oobabooga.github.io/silence.m4a\" controls>"
      ],
      "metadata": {
        "id": "f7TVVj_z4flw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title 2. Launch the web UI\n",
        "\n",
        "#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
        "\n",
        "import torch\n",
        "from pathlib import Path\n",
        "\n",
        "if Path.cwd().name != 'text-generation-webui':\n",
        "  print(\"Installing the webui...\")\n",
        "\n",
        "  !git clone https://github.com/oobabooga/text-generation-webui\n",
        "  %cd text-generation-webui\n",
        "\n",
        "  torver = torch.__version__\n",
        "  print(f\"TORCH: {torver}\")\n",
        "  is_cuda118 = '+cu118' in torver  # 2.1.0+cu118\n",
        "  is_cuda117 = '+cu117' in torver  # 2.0.1+cu117\n",
        "\n",
        "  textgen_requirements = open('requirements.txt').read().splitlines()\n",
        "  if is_cuda117:\n",
        "      textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n",
        "  elif is_cuda118:\n",
        "      textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n",
        "  with open('temp_requirements.txt', 'w') as file:\n",
        "      file.write('\\n'.join(textgen_requirements))\n",
        "\n",
        "  !pip install -r extensions/api/requirements.txt --upgrade\n",
        "  !pip install -r temp_requirements.txt --upgrade\n",
        "\n",
        "  print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
        "  print(\"\\033[1;32;1m\\n --> There is no need to restart the runtime.\\n\\033[0;37;0m\")\n",
        "\n",
        "  try:\n",
        "    import flash_attn\n",
        "  except:\n",
        "    !pip uninstall -y flash_attn\n",
        "\n",
        "# Parameters\n",
        "model_url = \"https://huggingface.co/turboderp/Mistral-7B-instruct-exl2\" #@param {type:\"string\"}\n",
        "branch = \"4.0bpw\" #@param {type:\"string\"}\n",
        "command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n",
        "api = False #@param {type:\"boolean\"}\n",
        "\n",
        "if api:\n",
        "  for param in ['--api', '--public-api']:\n",
        "    if param not in command_line_flags:\n",
        "      command_line_flags += f\" {param}\"\n",
        "\n",
        "model_url = model_url.strip()\n",
        "if model_url != \"\":\n",
        "    if not model_url.startswith('http'):\n",
        "        model_url = 'https://huggingface.co/' + model_url\n",
        "\n",
        "    # Download the model\n",
        "    url_parts = model_url.strip('/').strip().split('/')\n",
        "    output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
        "    branch = branch.strip('\"\\' ')\n",
        "    if branch.strip() != '':\n",
        "        output_folder += f\"_{branch}\"\n",
        "        !python download-model.py {model_url} --branch {branch}\n",
        "    else:\n",
        "        !python download-model.py {model_url}\n",
        "else:\n",
        "    output_folder = \"\"\n",
        "\n",
        "# Start the web UI\n",
        "cmd = f\"python server.py --share\"\n",
        "if output_folder != \"\":\n",
        "    cmd += f\" --model {output_folder}\"\n",
        "cmd += f\" {command_line_flags}\"\n",
        "print(cmd)\n",
        "!$cmd"
      ],
      "metadata": {
        "id": "LGQ8BiMuXMDG",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}