{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "from typing import List, Union, Any\n",
    "from tqdm import tqdm\n",
    "from sentence_transformers import CrossEncoder\n",
    "from langchain.chains import RetrievalQA\n",
    "from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings\n",
    "from langchain.document_loaders import TextLoader\n",
    "from langchain.indexes import VectorstoreIndexCreator\n",
    "from langchain.text_splitter import CharacterTextSplitter\n",
    "from langchain.vectorstores import FAISS\n",
    "from sentence_transformers import CrossEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageInstructEmbeddings(HuggingFaceInstructEmbeddings):\n",
    "    max_length: int = None\n",
    "    def __init__(self, max_length: int = 512, **kwargs: Any):\n",
    "        super().__init__(**kwargs)\n",
    "        self.max_length = max_length\n",
    "        if self.max_length < 0:\n",
    "            print('max_length is not specified, using model default max_seq_length')\n",
    "\n",
    "    def embed_documents(self, texts: List[str]) -> List[List[float]]:\n",
    "        all_embeddings = []\n",
    "        for text in tqdm(texts, desc=\"Embedding documents\"):\n",
    "            if len(text) > self.max_length and self.max_length > -1:\n",
    "                n_chunks = math.ceil(len(text)/self.max_length)\n",
    "                chunks = [\n",
    "                    text[i*self.max_length:(i+1)*self.max_length]\n",
    "                    for i in range(n_chunks)\n",
    "                ]\n",
    "                instruction_pairs = [[self.embed_instruction, chunk] for chunk in chunks]\n",
    "                chunk_embeddings = self.client.encode(instruction_pairs)\n",
    "                avg_embedding = np.mean(chunk_embeddings, axis=0)\n",
    "                all_embeddings.append(avg_embedding.tolist())\n",
    "            else:\n",
    "                instruction_pairs = [[self.embed_instruction, text]]\n",
    "                embeddings = self.client.encode(instruction_pairs)\n",
    "                all_embeddings.append(embeddings[0].tolist())\n",
    "\n",
    "        return all_embeddings\n",
    "\n",
    "\n",
    "class BenchDataST:\n",
    "    def __init__(self, path: str, percentage: float = 0.005, chunk_size: int = 512, chunk_overlap: int = 100):\n",
    "        self.path = path\n",
    "        self.percentage = percentage\n",
    "        self.docs = []\n",
    "        self.metadata = []\n",
    "        self.load()\n",
    "        self.text_splitter = CharacterTextSplitter(separator=\"\", chunk_size=chunk_size, chunk_overlap=chunk_overlap)\n",
    "        self.docs_processed = self.text_splitter.create_documents(self.docs, self.metadata)\n",
    "\n",
    "    def load(self):\n",
    "        for p in Path(self.path).iterdir():\n",
    "            if not p.is_dir():\n",
    "                with open(p) as f:\n",
    "                    source = f.readline().strip().replace('source: ', '')\n",
    "                    self.docs.append(f.read())\n",
    "                    self.metadata.append({\"source\": source})\n",
    "        self.docs = self.docs[:int(len(self.docs) * self.percentage)]\n",
    "        self.metadata = self.metadata[:int(len(self.metadata) * self.percentage)]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.docs)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.docs[idx], self.metadata[idx]\n",
    "\n",
    "    def __iter__(self):\n",
    "        for doc, metadata in zip(self.docs, self.metadata):\n",
    "            yield doc, metadata\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f'BenchDataST({len(self)} docs) at {self.path} with {self.percentage} percentage \\nSources: {self.metadata} \\nChunks: {self.text_splitter}'\n",
    "    \n",
    "\n",
    "class BenchmarkST:\n",
    "    def __init__(self, data: BenchDataST, baseline_model: Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, AverageInstructEmbeddings], embedding_models: List[Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, AverageInstructEmbeddings]]):\n",
    "        self.data = data\n",
    "        self.baseline_model = baseline_model\n",
    "        self.embedding_models = embedding_models\n",
    "        self.baseline_index, self.indexes  = self.build_indexes()\n",
    "\n",
    "    def build_indexes(self):\n",
    "        indexes = []\n",
    "        for model in [self.baseline_model] + self.embedding_models:\n",
    "            print(f\"Building index for {model}\")\n",
    "            index = FAISS.from_documents(self.data.docs_processed, model)\n",
    "            indexes.append(index)\n",
    "        return indexes[0], indexes[1:]\n",
    "    \n",
    "    def add_index(self, index: FAISS):\n",
    "        self.indexes.append(index)\n",
    "    \n",
    "    def evaluate(self, query: str, k: int = 3):\n",
    "        baseline_results = self.baseline_index.similarity_search_with_score(query, k=k)\n",
    "        results = []\n",
    "        for index in self.indexes:\n",
    "            results.append(index.similarity_search_with_score(query, k=k))\n",
    "        return baseline_results, results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load INSTRUCTOR_Transformer\n",
      "max_seq_length  512\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No sentence-transformers model found with name /Users/michalwilinski/.cache/torch/sentence_transformers/cross-encoder_ms-marco-MiniLM-L-12-v2. Creating a new one with MEAN pooling.\n",
      "Some weights of the model checkpoint at /Users/michalwilinski/.cache/torch/sentence_transformers/cross-encoder_ms-marco-MiniLM-L-12-v2 were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building index for client=INSTRUCTOR(\n",
      "  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: T5EncoderModel \n",
      "  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False})\n",
      "  (2): Dense({'in_features': 768, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})\n",
      "  (3): Normalize()\n",
      ") model_name='hkunlp/instructor-base' cache_folder=None model_kwargs={} encode_kwargs={} embed_instruction='Represent this piece of text for searching relevant information:' query_instruction='Query the most relevant piece of information from the Hugging Face documentation' max_length=512\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Embedding documents: 100%|██████████| 278/278 [00:19<00:00, 14.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building index for client=SentenceTransformer(\n",
      "  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n",
      "  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
      ") model_name='cross-encoder/ms-marco-MiniLM-L-12-v2' cache_folder=None model_kwargs={} encode_kwargs={} multi_process=False\n"
     ]
    }
   ],
   "source": [
    "data = BenchDataST(\n",
    "    path=\"./datasets/huggingface_docs/\",\n",
    "    percentage=0.005,\n",
    "    chunk_size=512,\n",
    "    chunk_overlap=100\n",
    ")\n",
    "\n",
    "baseline_embedding_model = AverageInstructEmbeddings(\n",
    "    model_name=\"hkunlp/instructor-base\",\n",
    "    embed_instruction=\"Represent this piece of text for searching relevant information:\",\n",
    "    query_instruction=\"Query the most relevant piece of information from the Hugging Face documentation\",\n",
    "    max_length=512,\n",
    ")\n",
    "\n",
    "embedding_model = HuggingFaceEmbeddings(\n",
    "    model_name=\"intfloat/e5-large-v2\",\n",
    ")\n",
    "\n",
    "cross_encoder = HuggingFaceEmbeddings(model_name=\"cross-encoder/ms-marco-MiniLM-L-12-v2\")\n",
    "\n",
    "benchmark = BenchmarkST(\n",
    "    data=data,\n",
    "    baseline_model=baseline_embedding_model,\n",
    "    embedding_models=[cross_encoder]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baseline results:\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.23610792\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24087097\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24181677\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24541612\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24639006\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.24780047\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.2535807\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.25887597\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.27293646\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.27374876\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.27710187\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.28146794\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.29536068\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.29784447\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.30452335\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.3061711\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.31600478\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.3166225\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.33345556\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.3469957\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.35222226\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.36451602\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.36925688\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 0.37025565\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/README.md'} 0.37112093\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.37146708\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.3766507\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.37794292\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.37923962\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.38359642\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.3878625\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.39796114\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.40057343\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.40114868\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.40156174\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.40341228\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/README.md'} 0.40720195\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41241395\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.4134417\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4134435\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41754264\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41917825\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41928726\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.41988587\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.42029166\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.42128915\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4226097\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.42302307\n",
      "{'source': 'https://github.com/gradio-app/gradio/blob/main/demo/stt_or_tts/run.ipynb'} 0.4252566\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/README.md'} 0.42704937\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4297651\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43067485\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.43116528\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.43272027\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.43434155\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.43486434\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43524152\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.43530554\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.4371896\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43753576\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43824\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4384127\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43900505\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.43903238\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.44034868\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.44217598\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/schedulers/euler_ancestral.md'} 0.4426194\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44303834\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.4452571\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44619536\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.44652176\n",
      "{'source': 'https://github.com/gradio-app/gradio/blob/main/demo/stt_or_tts/run.ipynb'} 0.44683564\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.44743723\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44768596\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4477852\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.44906363\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45155957\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45215163\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45415214\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4541726\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.4542602\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.4544394\n",
      "{'source': 'https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/open-llama.md'} 0.45448524\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.454512\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.45478693\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/schedulers/euler_ancestral.md'} 0.45494407\n",
      "{'source': 'https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/open-llama.md'} 0.45494407\n",
      "{'source': 'https://github.com/gradio-app/gradio/blob/main/js/accordion/CHANGELOG.md'} 0.45520714\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.4559689\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.4568352\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.4577096\n",
      "{'source': 'https://github.com/huggingface/simulate/blob/main/docs/source/api/lights.mdx'} 0.4577096\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.45773098\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.45818624\n",
      "{'source': 'https://github.com/huggingface/optimum/blob/main/docs/source/exporters/onnx/usage_guides/export_a_model.mdx'} 0.45871085\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/bloom.md'} 0.4591412\n",
      "{'source': 'https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md'} 0.46033093\n",
      "{'source': 'https://github.com/huggingface/blog/blob/main/accelerate-deepspeed.md'} 0.4605264\n",
      "{'source': 'https://github.com/huggingface/pytorch-image-models/blob/main/docs/changes.md'} 0.46091354\n",
      "{'source': 'https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/open-llama.md'} 0.46182537\n",
      "Cross encoder results:\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} 6.840022\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} -0.98426485\n",
      "{'source': 'https://github.com/huggingface/course/blob/main/chapters/en/chapter6/4.mdx'} -1.9345549\n",
      "bye\n"
     ]
    }
   ],
   "source": [
    "query = \"textual inversion\"\n",
    "k = 100\n",
    "baseline_results, results = benchmark.evaluate(query=query, k=k)\n",
    "print(\"Baseline results:\")\n",
    "[print(doc.metadata,score) for (doc,score) in baseline_results]\n",
    "cross_encoder = CrossEncoder(\"cross-encoder/ms-marco-MiniLM-L-12-v2\")\n",
    "cross_encoder_results = cross_encoder.predict([(query, doc.page_content) for doc in data.docs_processed])\n",
    "# rerank results\n",
    "cross_encoder_results = sorted(zip(data.docs_processed, cross_encoder_results), key=lambda x: x[1], reverse=True)\n",
    "print(\"Cross encoder results:\")\n",
    "final_results = cross_encoder_results[:3]\n",
    "[print(doc.metadata, score) for (doc,score) in final_results]\n",
    "print(\"bye\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "es where the space character is not used (like Chinese or Japanese).\n",
      "\n",
      "The other main feature of SentencePiece is *reversible tokenization*: since there is no special treatment of spaces, decoding the tokens is done simply by concatenating them and replacing the `_`s with spaces -- this results in the normalized text. As we saw earlier, the BERT tokenizer removes repeating spaces, so its tokenization is not reversible.\n",
      "\n",
      "## Algorithm overview[[algorithm-overview]]\n",
      "\n",
      "In the following sections, we'll dive into t\n"
     ]
    }
   ],
   "source": [
    "print(final_results[0][0].page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hf_qa_bot",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}