Spaces:
Runtime error
Runtime error
| import time | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from functools import reduce | |
| import json | |
| import os | |
| from pathlib import Path | |
| import re | |
| import requests | |
| from requests.models import MissingSchema | |
| import sys | |
| from typing import List, Optional, Tuple, Dict, Callable, Any | |
| from bs4 import BeautifulSoup | |
| import docx | |
| from html2text import html2text | |
| import langchain | |
| from langchain.callbacks import get_openai_callback | |
| from langchain.cache import SQLiteCache | |
| from langchain.chains import LLMChain | |
| from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chat_models.base import BaseChatModel | |
| from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader | |
| from langchain.embeddings.base import Embeddings | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.llms import OpenAI | |
| from langchain.llms.base import LLM, BaseLLM | |
| from langchain.prompts.chat import AIMessagePromptTemplate | |
| from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import Pinecone as OriginalPinecone | |
| import numpy as np | |
| import openai | |
| import pinecone | |
| from pptx import Presentation | |
| from pypdf import PdfReader | |
| import trafilatura | |
| from streamlit_langchain_chat.constants import * | |
| from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS | |
| from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone | |
| from streamlit_langchain_chat.utils import maybe_is_text, maybe_is_truncated | |
| from streamlit_langchain_chat.prompts import * | |
| if REUSE_ANSWERS: | |
| CACHE_PATH = TEMP_DIR / "llm_cache.db" | |
| os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True) | |
| langchain.llm_cache = SQLiteCache(str(CACHE_PATH)) | |
| # option 1 | |
| TextSplitter = TokenTextSplitter | |
| # option 2 | |
| # TextSplitter = RecursiveCharacterTextSplitter # usado por gpt4_pdf_chatbot_langchain (aka GPCL) | |
| class Answer: | |
| """A class to hold the answer to a question.""" | |
| question: str = "" | |
| answer: str = "" | |
| context: str = "" | |
| chunks: str = "" | |
| packages: List[Any] = None | |
| references: str = "" | |
| cost_str: str = "" | |
| passages: Dict[str, str] = None | |
| tokens: List[Dict] = None | |
| def __post_init__(self): | |
| """Initialize the answer.""" | |
| if self.packages is None: | |
| self.packages = [] | |
| if self.passages is None: | |
| self.passages = {} | |
| def __str__(self) -> str: | |
| """Return the answer as a string.""" | |
| return self.answer | |
| def parse_docx(path, citation, key, chunk_chars=2000, overlap=50): | |
| try: | |
| document = docx.Document(path) | |
| fullText = [] | |
| for paragraph in document.paragraphs: | |
| fullText.append(paragraph.text) | |
| doc = '\n'.join(fullText) + '\n' | |
| except Exception as e: | |
| print(f"code_error: {e}") | |
| sys.exit(1) | |
| if doc: | |
| text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
| texts = text_splitter.split_text(doc) | |
| return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
| else: | |
| return [], [] | |
| # TODO: si pones un conector con el formato loader = ... ; data = loader.load(); | |
| # podrás poner todos los conectores de langchain | |
| # https://langchain.readthedocs.io/en/stable/modules/document_loaders/examples/pdf.html | |
| def parse_pdf(path, citation, key, chunk_chars=2000, overlap=50): | |
| pdfFileObj = open(path, "rb") | |
| pdfReader = PdfReader(pdfFileObj) | |
| splits = [] | |
| split = "" | |
| pages = [] | |
| metadatas = [] | |
| for i, page in enumerate(pdfReader.pages): | |
| split += page.extract_text() | |
| pages.append(str(i + 1)) | |
| # split could be so long it needs to be split | |
| # into multiple chunks. Or it could be so short | |
| # that it needs to be combined with the next chunk. | |
| while len(split) > chunk_chars: | |
| splits.append(split[:chunk_chars]) | |
| # pretty formatting of pages (e.g. 1-3, 4, 5-7) | |
| pg = "-".join([pages[0], pages[-1]]) | |
| metadatas.append( | |
| dict( | |
| citation=citation, | |
| dockey=key, | |
| key=f"{key} pages {pg}", | |
| ) | |
| ) | |
| split = split[chunk_chars - overlap:] | |
| pages = [str(i + 1)] | |
| if len(split) > overlap: | |
| splits.append(split[:chunk_chars]) | |
| pg = "-".join([pages[0], pages[-1]]) | |
| metadatas.append( | |
| dict( | |
| citation=citation, | |
| dockey=key, | |
| key=f"{key} pages {pg}", | |
| ) | |
| ) | |
| pdfFileObj.close() | |
| # # ### option 2. PyPDFLoader | |
| # loader = PyPDFLoader(path) | |
| # data = loader.load_and_split() | |
| # # ### option 2.1. PyPDFLoader usado por GPCL, aunque luego usa el | |
| # loader = PyPDFLoader(path) | |
| # rawDocs = loader.load() | |
| # text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
| # texts = text_splitter.split_documents(rawDocs) | |
| # # ### option 3. PDFMiner. Este parece la mejor opcion | |
| # loader = PyMuPDFLoader(path) | |
| # data = loader.load() | |
| return splits, metadatas | |
| def parse_pptx(path, citation, key, chunk_chars=2000, overlap=50): | |
| try: | |
| presentation = Presentation(path) | |
| fullText = [] | |
| for slide in presentation.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| fullText.append(shape.text) | |
| doc = ''.join(fullText) | |
| if doc: | |
| text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
| texts = text_splitter.split_text(doc) | |
| return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
| else: | |
| return [], [] | |
| except Exception as e: | |
| print(f"code_error: {e}") | |
| sys.exit(1) | |
| def parse_txt(path, citation, key, chunk_chars=2000, overlap=50, html=False): | |
| try: | |
| with open(path) as f: | |
| doc = f.read() | |
| except UnicodeDecodeError as e: | |
| with open(path, encoding="utf-8", errors="ignore") as f: | |
| doc = f.read() | |
| if html: | |
| doc = html2text(doc) | |
| # yo, no idea why but the texts are not split correctly | |
| text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
| texts = text_splitter.split_text(doc) | |
| return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
| def parse_url(url: str, citation, key, chunk_chars=2000, overlap=50): | |
| def beautifulsoup_extract_text_fallback(response_content): | |
| """ | |
| This is a fallback function, so that we can always return a value for text content. | |
| Even for when both Trafilatura and BeautifulSoup are unable to extract the text from a | |
| single URL. | |
| """ | |
| # Create the beautifulsoup object: | |
| soup = BeautifulSoup(response_content, 'html.parser') | |
| # Finding the text: | |
| text = soup.find_all(text=True) | |
| # Remove unwanted tag elements: | |
| cleaned_text = '' | |
| blacklist = [ | |
| '[document]', | |
| 'noscript', | |
| 'header', | |
| 'html', | |
| 'meta', | |
| 'head', | |
| 'input', | |
| 'script', | |
| 'style', ] | |
| # Then we will loop over every item in the extract text and make sure that the beautifulsoup4 tag | |
| # is NOT in the blacklist | |
| for item in text: | |
| if item.parent.name not in blacklist: | |
| cleaned_text += f'{item} ' # cleaned_text += '{} '.format(item) | |
| # Remove any tab separation and strip the text: | |
| cleaned_text = cleaned_text.replace('\t', '') | |
| return cleaned_text.strip() | |
| def extract_text_from_single_web_page(url): | |
| print(f"\n===========\n{url=}\n===========\n") | |
| downloaded_url = trafilatura.fetch_url(url) | |
| a = None | |
| try: | |
| a = trafilatura.extract(downloaded_url, | |
| output_format='json', | |
| with_metadata=True, | |
| include_comments=False, | |
| date_extraction_params={'extensive_search': True, | |
| 'original_date': True}) | |
| except AttributeError: | |
| a = trafilatura.extract(downloaded_url, | |
| output_format='json', | |
| with_metadata=True, | |
| date_extraction_params={'extensive_search': True, | |
| 'original_date': True}) | |
| except Exception as e: | |
| print(f"code_error: {e}") | |
| if a: | |
| json_output = json.loads(a) | |
| return json_output['text'] | |
| else: | |
| try: | |
| headers = {'User-Agent': 'Chrome/83.0.4103.106'} | |
| resp = requests.get(url, headers=headers) | |
| print(f"{resp=}\n") | |
| # We will only extract the text from successful requests: | |
| if resp.status_code == 200: | |
| return beautifulsoup_extract_text_fallback(resp.content) | |
| else: | |
| # This line will handle for any failures in both the Trafilature and BeautifulSoup4 functions: | |
| return np.nan | |
| # Handling for any URLs that don't have the correct protocol | |
| except MissingSchema: | |
| return np.nan | |
| text_to_split = extract_text_from_single_web_page(url) | |
| text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
| texts = text_splitter.split_text(text_to_split) | |
| return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
| def read_source(path: str = None, | |
| citation: str = None, | |
| key: str = None, | |
| chunk_chars: int = 3000, | |
| overlap: int = 100, | |
| disable_check: bool = False): | |
| if path.endswith(".pdf"): | |
| return parse_pdf(path, citation, key, chunk_chars, overlap) | |
| elif path.endswith(".txt"): | |
| return parse_txt(path, citation, key, chunk_chars, overlap) | |
| elif path.endswith(".html"): | |
| return parse_txt(path, citation, key, chunk_chars, overlap, html=True) | |
| elif path.endswith(".docx"): | |
| return parse_docx(path, citation, key, chunk_chars, overlap) | |
| elif path.endswith(".pptx"): | |
| return parse_pptx(path, citation, key, chunk_chars, overlap) | |
| elif path.startswith("http://") or path.startswith("https://"): | |
| return parse_url(path, citation, key, chunk_chars, overlap) | |
| # TODO: poner mas conectores | |
| # else: | |
| # return parse_code_txt(path, citation, key, chunk_chars, overlap) | |
| else: | |
| raise "unknown extension" | |
| class Dataset: | |
| """A collection of documents to be used for answering questions.""" | |
| def __init__( | |
| self, | |
| chunk_size_limit: int = 3000, | |
| llm: Optional[BaseLLM] | Optional[BaseChatModel] = None, | |
| summary_llm: Optional[BaseLLM] = None, | |
| name: str = "default", | |
| index_path: Optional[Path] = None, | |
| ) -> None: | |
| """Initialize the collection of documents. | |
| Args: | |
| chunk_size_limit: The maximum number of characters to use for a single chunk of text. | |
| llm: The language model to use for answering questions. Default - OpenAI chat-gpt-turbo | |
| summary_llm: The language model to use for summarizing documents. If None, llm is used. | |
| name: The name of the collection. | |
| index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name | |
| """ | |
| self.docs = dict() | |
| self.keys = set() | |
| self.chunk_size_limit = chunk_size_limit | |
| self.index_docstore = None | |
| if llm is None: | |
| llm = ChatOpenAI(temperature=0.1, max_tokens=512) | |
| if summary_llm is None: | |
| summary_llm = llm | |
| self.update_llm(llm, summary_llm) | |
| if index_path is None: | |
| index_path = TEMP_DIR / name | |
| self.index_path = index_path | |
| self.name = name | |
| def update_llm(self, llm: BaseLLM | ChatOpenAI, summary_llm: Optional[BaseLLM] = None) -> None: | |
| """Update the LLM for answering questions.""" | |
| self.llm = llm | |
| if summary_llm is None: | |
| summary_llm = llm | |
| self.summary_llm = summary_llm | |
| self.summary_chain = LLMChain(prompt=chat_summary_prompt, llm=summary_llm) | |
| self.search_chain = LLMChain(prompt=search_prompt, llm=llm) | |
| self.cite_chain = LLMChain(prompt=citation_prompt, llm=llm) | |
| def add( | |
| self, | |
| path: str, | |
| citation: Optional[str] = None, | |
| key: Optional[str] = None, | |
| disable_check: bool = False, | |
| chunk_chars: Optional[int] = 3000, | |
| ) -> None: | |
| """Add a document to the collection.""" | |
| if path in self.docs: | |
| print(f"Document {path} already in collection.") | |
| return None | |
| if citation is None: | |
| # peak first chunk | |
| texts, _ = read_source(path, "", "", chunk_chars=chunk_chars) | |
| with get_openai_callback() as cb: | |
| citation = self.cite_chain.run(texts[0]) | |
| if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: | |
| citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" | |
| if key is None: | |
| # get first name and year from citation | |
| try: | |
| author = re.search(r"([A-Z][a-z]+)", citation).group(1) | |
| except AttributeError: | |
| # panicking - no word?? | |
| raise ValueError( | |
| f"Could not parse key from citation {citation}. Consider just passing key explicitly - e.g. docs.py (path, citation, key='mykey')" | |
| ) | |
| try: | |
| year = re.search(r"(\d{4})", citation).group(1) | |
| except AttributeError: | |
| year = "" | |
| key = f"{author}{year}" | |
| suffix = "" | |
| while key + suffix in self.keys: | |
| # move suffix to next letter | |
| if suffix == "": | |
| suffix = "a" | |
| else: | |
| suffix = chr(ord(suffix) + 1) | |
| key += suffix | |
| self.keys.add(key) | |
| texts, metadata = read_source(path, citation, key, chunk_chars=chunk_chars) | |
| # loose check to see if document was loaded | |
| # | |
| if len("".join(texts)) < 10 or ( | |
| not disable_check and not maybe_is_text("".join(texts)) | |
| ): | |
| raise ValueError( | |
| f"This does not look like a text document: {path}. Path disable_check to ignore this error." | |
| ) | |
| self.docs[path] = dict(texts=texts, metadata=metadata, key=key) | |
| if self.index_docstore is not None: | |
| self.index_docstore.add_texts(texts, metadatas=metadata) | |
| def clear(self) -> None: | |
| """Clear the collection of documents.""" | |
| self.docs = dict() | |
| self.keys = set() | |
| self.index_docstore = None | |
| # delete index file | |
| pkl = self.index_path / "index.pkl" | |
| if pkl.exists(): | |
| pkl.unlink() | |
| fs = self.index_path / "index.faiss" | |
| if fs.exists(): | |
| fs.unlink() | |
| def doc_previews(self) -> List[Tuple[int, str, str]]: | |
| """Return a list of tuples of (key, citation) for each document.""" | |
| return [ | |
| ( | |
| len(doc["texts"]), | |
| doc["metadata"][0]["dockey"], | |
| doc["metadata"][0]["citation"], | |
| ) | |
| for doc in self.docs.values() | |
| ] | |
| # to pickle, we have to save the index as a file | |
| def __getstate__(self, embedding: Embeddings): | |
| if embedding is None: | |
| embedding = OpenAIEmbeddings() | |
| if self.index_docstore is None and len(self.docs) > 0: | |
| self._build_faiss_index(embedding) | |
| state = self.__dict__.copy() | |
| if self.index_docstore is not None: | |
| state["_index"].save_local(self.index_path) | |
| del state["_index"] | |
| # remove LLMs (they can have callbacks, which can't be pickled) | |
| del state["summary_chain"] | |
| del state["qa_chain"] | |
| del state["cite_chain"] | |
| del state["search_chain"] | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__.update(state) | |
| try: | |
| self.index_docstore = FAISS.load_local(self.index_path, OpenAIEmbeddings()) | |
| except: | |
| # they use some special exception type, but I don't want to import it | |
| self.index_docstore = None | |
| self.update_llm( | |
| ChatOpenAI(temperature=0.1, max_tokens=512) | |
| ) | |
| def _build_faiss_index(self, embedding: Embeddings = None): | |
| if embedding is None: | |
| embedding = OpenAIEmbeddings() | |
| if self.index_docstore is None: | |
| texts = reduce( | |
| lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], [] | |
| ) | |
| metadatas = reduce( | |
| lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], [] | |
| ) | |
| # if the index exists, load it | |
| if LOAD_INDEX_LOCALLY and (self.index_path / "index.faiss").exists(): | |
| self.index_docstore = FAISS.load_local(self.index_path, embedding) | |
| # search if the text and metadata already existed in the index | |
| for i in reversed(range(len(texts))): | |
| text = texts[i] | |
| metadata = metadatas[i] | |
| for key, value in self.index_docstore.docstore.dict_.items(): | |
| if value.page_content == text: | |
| if value.metadata.get('citation').split(os.sep)[-1] != metadata.get('citation').split(os.sep)[-1]: | |
| self.index_docstore.docstore.dict_[key].metadata['citation'] = metadata.get('citation').split(os.sep)[-1] | |
| self.index_docstore.docstore.dict_[key].metadata['dockey'] = metadata.get('citation').split(os.sep)[-1] | |
| self.index_docstore.docstore.dict_[key].metadata['key'] = metadata.get('citation').split(os.sep)[-1] | |
| texts.pop(i) | |
| metadatas.pop(i) | |
| # add remaining texts | |
| if texts: | |
| self.index_docstore.add_texts(texts=texts, metadatas=metadatas) | |
| else: | |
| # crete new index | |
| self.index_docstore = FAISS.from_texts(texts, embedding, metadatas=metadatas) | |
| # | |
| if SAVE_INDEX_LOCALLY: | |
| # save index. | |
| self.index_docstore.save_local(self.index_path) | |
| def _build_pinecone_index(self, embedding: Embeddings = None): | |
| if embedding is None: | |
| embedding = OpenAIEmbeddings() | |
| if self.index_docstore is None: | |
| pinecone.init( | |
| api_key=os.environ['PINECONE_API_KEY'], # find at app.pinecone.io | |
| environment=os.environ['PINECONE_ENVIRONMENT'] # next to api key in console | |
| ) | |
| texts = reduce( | |
| lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], [] | |
| ) | |
| metadatas = reduce( | |
| lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], [] | |
| ) | |
| # TODO: que cuando exista que no lo borre, sino que lo actualice | |
| # index_name = "langchain-demo1" | |
| # if index_name in pinecone.list_indexes(): | |
| # self.index_docstore = pinecone.Index(index_name) | |
| # vectors = [] | |
| # for text, metadata in zip(texts, metadatas): | |
| # # embed = <faltaria saber con que embedding se hizo el index que ya existia> | |
| # self.index_docstore.upsert(vectors=vectors) | |
| # else: | |
| # if openai.api_type == 'azure': | |
| # self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
| # else: | |
| # self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
| index_name = "langchain-demo1" | |
| # if the index exists, delete it | |
| if index_name in pinecone.list_indexes(): | |
| pinecone.delete_index(index_name) | |
| # create new index | |
| if openai.api_type == 'azure': | |
| self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
| else: | |
| self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
| def get_evidence( | |
| self, | |
| answer: Answer, | |
| embedding: Embeddings, | |
| k: int = 3, | |
| max_sources: int = 5, | |
| marginal_relevance: bool = True, | |
| ) -> str: | |
| if self.index_docstore is None: | |
| self._build_faiss_index(embedding) | |
| init_search_time = time.time() | |
| # want to work through indices but less k | |
| if marginal_relevance: | |
| docs = self.index_docstore.max_marginal_relevance_search( | |
| answer.question, k=k, fetch_k=5 * k | |
| ) | |
| else: | |
| docs = self.index_docstore.similarity_search( | |
| answer.question, k=k, fetch_k=5 * k | |
| ) | |
| if OPERATING_MODE == "debug": | |
| print(f"time to search docs to build context: {time.time() - init_search_time:.2f} [s]") | |
| init_summary_time = time.time() | |
| partial_summary_time = "" | |
| for i, doc in enumerate(docs): | |
| with get_openai_callback() as cb: | |
| init__partial_summary_time = time.time() | |
| summary_of_chunked_text = self.summary_chain.run( | |
| question=answer.question, context_str=doc.page_content | |
| ) | |
| if OPERATING_MODE == "debug": | |
| partial_summary_time += f"- time to make relevant summary of doc '{i}': {time.time() - init__partial_summary_time:.2f} [s]\n" | |
| engine = self.summary_chain.llm.model_kwargs.get('deployment_id') or self.summary_chain.llm.model_name | |
| if not answer.tokens: | |
| answer.tokens = [{ | |
| 'engine': engine, | |
| 'total_tokens': cb.total_tokens}] | |
| else: | |
| answer.tokens.append({ | |
| 'engine': engine, | |
| 'total_tokens': cb.total_tokens | |
| }) | |
| summarized_package = ( | |
| doc.metadata["key"], | |
| doc.metadata["citation"], | |
| summary_of_chunked_text, | |
| doc.page_content, | |
| ) | |
| if "Not applicable" not in summary_of_chunked_text and summarized_package not in answer.packages: | |
| answer.packages.append(summarized_package) | |
| yield answer | |
| if len(answer.packages) == max_sources: | |
| break | |
| if OPERATING_MODE == "debug": | |
| print(f"time to make all relevant summaries: {time.time() - init_summary_time:.2f} [s]") | |
| # no se printea el ultimo caracter porque es un \n | |
| print(partial_summary_time[:-1]) | |
| context_str = "\n\n".join( | |
| [f"{citation}: {summary_of_chunked_text}" | |
| for key, citation, summary_of_chunked_text, chunked_text in answer.packages | |
| if "Not applicable" not in summary_of_chunked_text] | |
| ) | |
| chunks_str = "\n\n".join( | |
| [f"{citation}: {chunked_text}" | |
| for key, citation, summary_of_chunked_text, chunked_text in answer.packages | |
| if "Not applicable" not in summary_of_chunked_text] | |
| ) | |
| valid_keys = [key | |
| for key, citation, summary_of_chunked_text, chunked_textin in answer.packages | |
| if "Not applicable" not in summary_of_chunked_text] | |
| if len(valid_keys) > 0: | |
| context_str += "\n\nValid keys: " + ", ".join(valid_keys) | |
| chunks_str += "\n\nValid keys: " + ", ".join(valid_keys) | |
| answer.context = context_str | |
| answer.chunks = chunks_str | |
| yield answer | |
| def query( | |
| self, | |
| query: str, | |
| embedding: Embeddings, | |
| chat_history: list[tuple[str, str]], | |
| k: int = 10, | |
| max_sources: int = 5, | |
| length_prompt: str = "about 100 words", | |
| marginal_relevance: bool = True, | |
| ): | |
| for answer in self._query( | |
| query, | |
| embedding, | |
| chat_history, | |
| k=k, | |
| max_sources=max_sources, | |
| length_prompt=length_prompt, | |
| marginal_relevance=marginal_relevance, | |
| ): | |
| pass | |
| return answer | |
| def _query( | |
| self, | |
| query: str, | |
| embedding: Embeddings, | |
| chat_history: list[tuple[str, str]], | |
| k: int, | |
| max_sources: int, | |
| length_prompt: str, | |
| marginal_relevance: bool, | |
| ): | |
| if k < max_sources: | |
| k = max_sources + 1 | |
| answer = Answer(question=query) | |
| messages_qa = [system_message_prompt] | |
| if len(chat_history) != 0: | |
| for conversation in chat_history: | |
| messages_qa.append(HumanMessagePromptTemplate.from_template(conversation[0])) | |
| messages_qa.append(AIMessagePromptTemplate.from_template(conversation[1])) | |
| messages_qa.append(human_qa_message_prompt) | |
| chat_qa_prompt = ChatPromptTemplate.from_messages(messages_qa) | |
| self.qa_chain = LLMChain(prompt=chat_qa_prompt, llm=self.llm) | |
| for answer in self.get_evidence( | |
| answer, | |
| embedding, | |
| k=k, | |
| max_sources=max_sources, | |
| marginal_relevance=marginal_relevance, | |
| ): | |
| yield answer | |
| references_dict = dict() | |
| passages = dict() | |
| if len(answer.context) < 10: | |
| answer_text = "I cannot answer this question due to insufficient information." | |
| else: | |
| with get_openai_callback() as cb: | |
| init_qa_time = time.time() | |
| answer_text = self.qa_chain.run( | |
| question=answer.question, context_str=answer.context, length=length_prompt | |
| ) | |
| if OPERATING_MODE == "debug": | |
| print(f"time to make the Q&A answer: {time.time() - init_qa_time:.2f} [s]") | |
| engine = self.qa_chain.llm.model_kwargs.get('deployment_id') or self.qa_chain.llm.model_name | |
| if not answer.tokens: | |
| answer.tokens = [{ | |
| 'engine': engine, | |
| 'total_tokens': cb.total_tokens}] | |
| else: | |
| answer.tokens.append({ | |
| 'engine': engine, | |
| 'total_tokens': cb.total_tokens | |
| }) | |
| # it still happens lol | |
| if "(Foo2012)" in answer_text: | |
| answer_text = answer_text.replace("(Foo2012)", "") | |
| for key, citation, summary, text in answer.packages: | |
| # do check for whole key (so we don't catch Callahan2019a with Callahan2019) | |
| skey = key.split(" ")[0] | |
| if skey + " " in answer_text or skey + ")" in answer_text: | |
| references_dict[skey] = citation | |
| passages[key] = text | |
| references_str = "\n\n".join( | |
| [f"{i+1}. ({k}): {c}" for i, (k, c) in enumerate(references_dict.items())] | |
| ) | |
| # cost_str = f"{answer_text}\n\n" | |
| cost_str = "" | |
| itemized_cost = "" | |
| total_amount = 0 | |
| for d in answer.tokens: | |
| total_tokens = d.get('total_tokens') | |
| if total_tokens: | |
| engine = d.get('engine') | |
| key_price = None | |
| for key in PRICES.keys(): | |
| if re.match(f"{key}", engine): | |
| key_price = key | |
| break | |
| if PRICES.get(key_price): | |
| partial_amount = total_tokens / 1000 * PRICES.get(key_price) | |
| total_amount += partial_amount | |
| itemized_cost += f"- {engine}: {total_tokens} tokens\t ---> ${partial_amount:.4f},\n" | |
| else: | |
| itemized_cost += f"- {engine}: {total_tokens} tokens,\n" | |
| # delete ,\n | |
| itemized_cost = itemized_cost[:-2] | |
| # add tokens to formatted answer | |
| cost_str += f"Total cost: ${total_amount:.4f}\nItemized cost:\n{itemized_cost}" | |
| answer.answer = answer_text | |
| answer.cost_str = cost_str | |
| answer.references = references_str | |
| answer.passages = passages | |
| yield answer | |