llm-app / app.py
harisyammnv
feat: added files
933594d
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import RunnableConfig
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
import os
import chainlit as cl
from chainlit.types import AskFileResponse
from lxml import html
from pydantic import BaseModel
from typing import Any, Optional
from unstructured.partition.pdf import partition_pdf
from prompts import *
import uuid
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from operator import itemgetter
from langchain.schema.runnable import RunnablePassthrough
welcome_message = """Welcome to the Semi-Structured PDF QA! To get started:
1. Upload a PDF or text file
2. Ask a question about the file
3. (Optional) Ask a question from any Table in the PDF
Note: The PDF loading takes time because it uses `unstructured` to detect tables
and create summaries. Please be patient. The chatbot uses `gpt-4`
"""
class Element(BaseModel):
type: str
text: Any
def get_elements(raw_pdf_elements):
# Categorize by type
categorized_elements = []
for element in raw_pdf_elements:
if "unstructured.documents.elements.Table" in str(type(element)):
categorized_elements.append(Element(type="table", text=str(element)))
elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
categorized_elements.append(Element(type="text", text=str(element)))
# Tables
table_elements = [e for e in categorized_elements if e.type == "table"]
print(len(table_elements))
# Text
text_elements = [e for e in categorized_elements if e.type == "text"]
print(len(text_elements))
return table_elements, text_elements
def process_docs(file: AskFileResponse):
import tempfile
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tempfile:
if file.type == "text/plain":
tempfile.write(file.content)
elif file.type == "application/pdf":
with open(tempfile.name, "wb") as f:
f.write(file.content)
raw_pdf_elements = partition_pdf(filename=tempfile.name,
# Unstructured first finds embedded image blocks
extract_images_in_pdf=False,
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
# Titles are any sub-section of the document
infer_table_structure=True,
# Post processing to aggregate text once we have the title
chunking_strategy="by_title",
# Chunking params to aggregate text blocks
# Attempt to create a new chunk 3800 chars
# Attempt to keep chunks > 2000 chars
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000)
table_elements, text_elements = get_elements(raw_pdf_elements)
return table_elements, text_elements
@cl.on_chat_start
async def on_chat_start():
await cl.Avatar(
name="QA Chatbot",
url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4",
).send()
await cl.Avatar(
name="User",
path="icon/avatar.png",
).send()
files = None
while files is None:
files = await cl.AskFileMessage(
content=welcome_message,
accept=["text/plain", "application/pdf"],
max_size_mb=20,
timeout=180,
disable_human_feedback=True,
).send()
file = files[0]
msg = cl.Message(
content=f"Processing `{file.name}`...Please wait", disable_human_feedback=True
)
await msg.send()
table_elements, text_elements = await cl.make_async(process_docs)(file)
message_history = ChatMessageHistory()
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
chat_memory=message_history,
return_messages=True,
)
model = ChatOpenAI(streaming=True,
temperature=0,
model="gpt-4",
openai_api_key=os.getenv("OPENAI_API_KEY"))
prompt = ChatPromptTemplate.from_template(TABLE_TEXT_SUMMARY_PROMPT)
summarize_chain = {"element": lambda x:x} | prompt | model | StrOutputParser()
# Apply to tables
tables = [i.text for i in table_elements]
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
# Apply to texts
texts = [i.text for i in text_elements]
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
vectorstore = Chroma(persist_directory="./chroma_db",
collection_name="summaries",
embedding_function=OpenAIEmbeddings()
)
# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"
# The retriever (empty to start)
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))
# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))
msg.content = f"`{file.name}` processed. You can now ask questions!"
await msg.update()
# Prompt template
prompt = ChatPromptTemplate.from_template(QA_PROMPT)
runnable = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
await msg.stream_token(chunk)
await msg.send()