Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import json | |
import pandas as pd | |
import chainlit as cl | |
from dotenv import load_dotenv | |
from langchain_core.documents import Document | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain_experimental.text_splitter import SemanticChunker | |
from langchain_community.vectorstores import Qdrant | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.tools import tool | |
from langchain.schema import HumanMessage | |
from typing_extensions import List, TypedDict | |
from operator import itemgetter | |
# Load environment variables | |
load_dotenv() | |
# Define paths | |
UPLOAD_PATH = "upload/" | |
OUTPUT_PATH = "output/" | |
os.makedirs(UPLOAD_PATH, exist_ok=True) | |
os.makedirs(OUTPUT_PATH, exist_ok=True) | |
model_id = "Snowflake/snowflake-arctic-embed-m" | |
embedding_model = HuggingFaceEmbeddings(model_name=model_id) | |
semantic_splitter = SemanticChunker(embedding_model, add_start_index=True, buffer_size=30) | |
llm = ChatOpenAI(model="gpt-4o-mini") | |
# Export comparison prompt | |
export_prompt = export_prompt = """ | |
CONTEXT: | |
{context} | |
QUERY: | |
{question} | |
You are a helpful assistant. Use the available context to answer the question. | |
Between these two files containing protocols, identify and match **entire assessment sections** based on conceptual similarity. Do NOT match individual questions. | |
### **Output Format:** | |
Return the response in **valid JSON format** structured as a list of dictionaries, where each dictionary contains: | |
[ | |
{{ | |
"Derived Description": "A short name for the matched concept", | |
"Protocol_1": "Protocol 1 - Matching Element", | |
"Protocol_2": "Protocol 2 - Matching Element" | |
}}, | |
... | |
] | |
### **Example Output:** | |
[ | |
{{ | |
"Derived Description": "Pain Coping Strategies", | |
"Protocol_1": "Pain Coping Strategy Scale (PCSS-9)", | |
"Protocol_2": "Chronic Pain Adjustment Index (CPAI-10)" | |
}}, | |
{{ | |
"Derived Description": "Work Stress and Fatigue", | |
"Protocol_1": "Work-Related Stress Scale (WRSS-8)", | |
"Protocol_2": "Occupational Fatigue Index (OFI-7)" | |
}}, | |
... | |
] | |
### Rules: | |
1. Only output **valid JSON** with no explanations, summaries, or markdown formatting. | |
2. Ensure each entry in the JSON list represents a single matched data element from the two protocols. | |
3. If no matching element is found in a protocol, leave it empty (""). | |
4. **Do NOT include headers, explanations, or additional formatting**—only return the raw JSON list. | |
5. It should include all the elements in the two protocols. | |
6. If it cannot match the element, create the row and include the protocol it did find and put "could not match" in the other protocol column. | |
7. protocol should be the between | |
""" | |
compare_export_prompt = ChatPromptTemplate.from_template(export_prompt) | |
QUERY_PROMPT = """ | |
You are a helpful assistant. Use the available context to answer the question concisely and informatively. | |
CONTEXT: | |
{context} | |
QUERY: | |
{question} | |
Provide a natural-language response using the given information. If you do not know the answer, say so. | |
""" | |
query_prompt = ChatPromptTemplate.from_template(QUERY_PROMPT) | |
## tool configurations | |
def document_query_tool(question: str) -> str: | |
"""Retrieves relevant document sections and answers questions based on the uploaded documents.""" | |
retriever = cl.user_session.get("qdrant_retriever") | |
if not retriever: | |
return "Error: No documents available for retrieval. Please upload two PDF files first." | |
retriever = retriever.with_config({"k": 10}) | |
# Use a RAG chain similar to the comparison tool | |
rag_chain = ( | |
{"context": itemgetter("question") | retriever, "question": itemgetter("question")} | |
| query_prompt | llm | StrOutputParser() | |
) | |
response_text = rag_chain.invoke({"question": question}) | |
# Get the retrieved docs for context | |
retrieved_docs = retriever.invoke(question) | |
return { | |
"messages": [HumanMessage(content=response_text)], | |
"context": retrieved_docs | |
} | |
def document_comparison_tool(question: str) -> str: | |
"""Compares documents, identifies matched elements, exports them as JSON, formats into CSV, and provides a download link.""" | |
# Retrieve the vector database retriever | |
retriever = cl.user_session.get("qdrant_retriever") | |
if not retriever: | |
return "Error: No documents available for retrieval. Please upload two PDF files first." | |
retriever = retriever.with_config({"k": 10}) | |
# Process query using RAG | |
rag_chain = ( | |
{"context": itemgetter("question") | retriever, "question": itemgetter("question")} | |
| compare_export_prompt | llm | StrOutputParser() | |
) | |
response_text = rag_chain.invoke({"question": question}) | |
# Parse response and save as CSV | |
try: | |
structured_data = json.loads(response_text) | |
if not structured_data: | |
return "Error: No matched elements found." | |
# Define output file path | |
file_path = os.path.join(OUTPUT_PATH, "comparison_results.csv") | |
# Save to CSV | |
df = pd.DataFrame(structured_data, columns=["Derived Description", "Protocol_1", "Protocol_2"]) | |
df.to_csv(file_path, index=False) | |
return file_path # Return path to the CSV file | |
except json.JSONDecodeError: | |
return "Error: Response is not valid JSON." | |
async def process_files(files: list[cl.File]): | |
documents_with_metadata = [] | |
for file in files: | |
file_path = os.path.join(UPLOAD_PATH, file.name) | |
shutil.copyfile(file.path, file_path) | |
loader = PyMuPDFLoader(file_path) | |
documents = loader.load() | |
for doc in documents: | |
source_name = file.name | |
chunks = semantic_splitter.split_text(doc.page_content) | |
for chunk in chunks: | |
doc_chunk = Document(page_content=chunk, metadata={"source": source_name}) | |
documents_with_metadata.append(doc_chunk) | |
if documents_with_metadata: | |
qdrant_vectorstore = Qdrant.from_documents( | |
documents_with_metadata, | |
embedding_model, | |
location=":memory:", | |
collection_name="document_comparison", | |
) | |
return qdrant_vectorstore.as_retriever() | |
return None | |
async def start(): | |
cl.user_session.set("qdrant_retriever", None) | |
files = await cl.AskFileMessage( | |
content="Please upload **two PDF files** for comparison:", | |
accept=["application/pdf"], | |
max_files=2 | |
).send() | |
if len(files) != 2: | |
await cl.Message("Error: You must upload exactly two PDF files.").send() | |
return | |
retriever = await process_files(files) | |
if retriever: | |
cl.user_session.set("qdrant_retriever", retriever) | |
await cl.Message("Files uploaded and processed successfully! You can now enter your query.").send() | |
else: | |
await cl.Message("Error: Unable to process files. Please try again.").send() | |
async def handle_message(message: cl.Message): | |
user_input = message.content.lower() | |
# If the user asks for a comparison, run the document_comparison_tool | |
if "compare" in user_input or "export" in user_input: | |
file_path = document_comparison_tool.invoke(user_input) | |
if file_path and file_path.endswith(".csv"): | |
await cl.Message( | |
content="Comparison complete! Download the CSV below:", | |
elements=[cl.File(name="comparison_results.csv", path=file_path, display="inline")], | |
).send() | |
else: | |
await cl.Message(file_path).send() | |
else: | |
response_text = document_query_tool.invoke(user_input) | |
await cl.Message(response_text["messages"][0].content).send() |