Spaces:
Sleeping
Sleeping
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.schema import StrOutputParser | |
from langchain.schema.runnable import Runnable | |
from langchain.schema.runnable.config import RunnableConfig | |
from langchain.memory import ChatMessageHistory, ConversationBufferMemory | |
import os | |
import chainlit as cl | |
from chainlit.types import AskFileResponse | |
from lxml import html | |
from pydantic import BaseModel | |
from typing import Any, Optional | |
from unstructured.partition.pdf import partition_pdf | |
from prompts import * | |
import uuid | |
from langchain.vectorstores import Chroma | |
from langchain.storage import InMemoryStore | |
from langchain.schema.document import Document | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.retrievers.multi_vector import MultiVectorRetriever | |
from operator import itemgetter | |
from langchain.schema.runnable import RunnablePassthrough | |
welcome_message = """Welcome to the Semi-Structured PDF QA! To get started: | |
1. Upload a PDF or text file | |
2. Ask a question about the file | |
3. (Optional) Ask a question from any Table in the PDF | |
Note: The PDF loading takes time because it uses `unstructured` to detect tables | |
and create summaries. Please be patient. The chatbot uses `gpt-4` | |
""" | |
class Element(BaseModel): | |
type: str | |
text: Any | |
def get_elements(raw_pdf_elements): | |
# Categorize by type | |
categorized_elements = [] | |
for element in raw_pdf_elements: | |
if "unstructured.documents.elements.Table" in str(type(element)): | |
categorized_elements.append(Element(type="table", text=str(element))) | |
elif "unstructured.documents.elements.CompositeElement" in str(type(element)): | |
categorized_elements.append(Element(type="text", text=str(element))) | |
# Tables | |
table_elements = [e for e in categorized_elements if e.type == "table"] | |
print(len(table_elements)) | |
# Text | |
text_elements = [e for e in categorized_elements if e.type == "text"] | |
print(len(text_elements)) | |
return table_elements, text_elements | |
def process_docs(file: AskFileResponse): | |
import tempfile | |
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tempfile: | |
if file.type == "text/plain": | |
tempfile.write(file.content) | |
elif file.type == "application/pdf": | |
with open(tempfile.name, "wb") as f: | |
f.write(file.content) | |
raw_pdf_elements = partition_pdf(filename=tempfile.name, | |
# Unstructured first finds embedded image blocks | |
extract_images_in_pdf=False, | |
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles | |
# Titles are any sub-section of the document | |
infer_table_structure=True, | |
# Post processing to aggregate text once we have the title | |
chunking_strategy="by_title", | |
# Chunking params to aggregate text blocks | |
# Attempt to create a new chunk 3800 chars | |
# Attempt to keep chunks > 2000 chars | |
max_characters=4000, | |
new_after_n_chars=3800, | |
combine_text_under_n_chars=2000) | |
table_elements, text_elements = get_elements(raw_pdf_elements) | |
return table_elements, text_elements | |
async def on_chat_start(): | |
await cl.Avatar( | |
name="QA Chatbot", | |
url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", | |
).send() | |
await cl.Avatar( | |
name="User", | |
path="icon/avatar.png", | |
).send() | |
files = None | |
while files is None: | |
files = await cl.AskFileMessage( | |
content=welcome_message, | |
accept=["text/plain", "application/pdf"], | |
max_size_mb=20, | |
timeout=180, | |
disable_human_feedback=True, | |
).send() | |
file = files[0] | |
msg = cl.Message( | |
content=f"Processing `{file.name}`...Please wait", disable_human_feedback=True | |
) | |
await msg.send() | |
table_elements, text_elements = await cl.make_async(process_docs)(file) | |
message_history = ChatMessageHistory() | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key="answer", | |
chat_memory=message_history, | |
return_messages=True, | |
) | |
model = ChatOpenAI(streaming=True, | |
temperature=0, | |
model="gpt-4", | |
openai_api_key=os.getenv("OPENAI_API_KEY")) | |
prompt = ChatPromptTemplate.from_template(TABLE_TEXT_SUMMARY_PROMPT) | |
summarize_chain = {"element": lambda x:x} | prompt | model | StrOutputParser() | |
# Apply to tables | |
tables = [i.text for i in table_elements] | |
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5}) | |
# Apply to texts | |
texts = [i.text for i in text_elements] | |
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5}) | |
vectorstore = Chroma(persist_directory="./chroma_db", | |
collection_name="summaries", | |
embedding_function=OpenAIEmbeddings() | |
) | |
# The storage layer for the parent documents | |
store = InMemoryStore() | |
id_key = "doc_id" | |
# The retriever (empty to start) | |
retriever = MultiVectorRetriever( | |
vectorstore=vectorstore, | |
docstore=store, | |
id_key=id_key, | |
) | |
# Add texts | |
doc_ids = [str(uuid.uuid4()) for _ in texts] | |
summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)] | |
retriever.vectorstore.add_documents(summary_texts) | |
retriever.docstore.mset(list(zip(doc_ids, texts))) | |
# Add tables | |
table_ids = [str(uuid.uuid4()) for _ in tables] | |
summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)] | |
retriever.vectorstore.add_documents(summary_tables) | |
retriever.docstore.mset(list(zip(table_ids, tables))) | |
msg.content = f"`{file.name}` processed. You can now ask questions!" | |
await msg.update() | |
# Prompt template | |
prompt = ChatPromptTemplate.from_template(QA_PROMPT) | |
runnable = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| model | |
| StrOutputParser() | |
) | |
cl.user_session.set("runnable", runnable) | |
async def on_message(message: cl.Message): | |
runnable = cl.user_session.get("runnable") # type: Runnable | |
msg = cl.Message(content="") | |
async for chunk in runnable.astream( | |
message.content, | |
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
): | |
await msg.stream_token(chunk) | |
await msg.send() |