Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import HumanMessage | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_voyageai import VoyageAIEmbeddings | |
from langchain_pinecone import PineconeVectorStore | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from typing import List, Tuple | |
from langchain.schema import BaseRetriever | |
from langchain_core.documents import Document | |
from langchain_core.runnables import chain | |
from pinecone import Pinecone, ServerlessSpec | |
import openai | |
import numpy as np | |
from pinecone.grpc import PineconeGRPC as Pinecone | |
import gradio as gr | |
import asyncio | |
load_dotenv() | |
# Initialize OpenAI and Pinecone credentials | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
pinecone_api_key = os.environ.get("PINECONE_API_KEY") | |
pinecone_environment = os.environ.get("PINECONE_ENV") | |
voyage_api_key = os.environ.get("VOYAGE_API_KEY") | |
pinecone_index_name = "briefmeta" | |
# Initialize Pinecone | |
try: | |
pc = Pinecone(api_key=pinecone_api_key) | |
except Exception as e: | |
print(f"Error connecting to Pinecone: {str(e)}") | |
embeddings = VoyageAIEmbeddings( | |
voyage_api_key=voyage_api_key, model="voyage-law-2" | |
) | |
def search_documents(query): | |
try: | |
vector_store = PineconeVectorStore(index_name=pinecone_index_name, embedding=embeddings) | |
results = vector_store.max_marginal_relevance_search(query, k=10, fetch_k=30) # Adjust fetch_k for more diverse results | |
# Filter results to ensure uniqueness based on metadata.id | |
seen_ids = set() | |
unique_results = [] | |
for result in results: | |
unique_id = result.metadata.get("id") | |
if unique_id not in seen_ids: | |
seen_ids.add(unique_id) | |
unique_results.append(result) | |
# Collect relevant context from unique results | |
context = [] | |
for result in unique_results: | |
context.append({ | |
"doc_id": result.metadata.get("doc_id", "N/A"), | |
"chunk_id": result.metadata.get("id", "N/A"), | |
"title": result.metadata.get("source", "N/A"), | |
"text": result.page_content, | |
"page_number": str(result.metadata.get("page", "N/A")), | |
"score": str(result.metadata.get("score", "N/A")), | |
}) | |
return context | |
except Exception as e: | |
return [], f"Error searching documents: {str(e)}" | |
# Reranker | |
def rerank(query, context): | |
result = pc.inference.rerank( | |
model="bge-reranker-v2-m3", | |
query=query, | |
documents=context, | |
top_n=5, | |
return_documents=True, | |
# parameters={ | |
# "truncate": "END" | |
# } | |
) | |
return result | |
def generate_output(context, query): | |
try: | |
llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.7) | |
prompt_template = PromptTemplate( | |
template=""" | |
Use the following context to answer the question as accurately as possible: | |
Context: {context} | |
Question: {question} | |
Answer:""", | |
input_variables=["context", "question"] | |
) | |
prompt = prompt_template.format(context=context, question=query) | |
response = llm([HumanMessage(content=prompt)]) | |
return response.content | |
except Exception as e: | |
return f"Error generating output: {str(e)}" | |
def complete_workflow(query): | |
try: | |
context_data = search_documents(query) | |
reranked = rerank(query, context_data) | |
context_data= [] | |
for i, entry in enumerate(reranked.data): # Access the 'data' attribute | |
context_data.append({ | |
'chunk_id': entry['document']['chunk_id'], | |
'doc_id': entry['document']['doc_id'], | |
'title': entry['document']['title'], | |
'text': entry['document']['text'], | |
'page_number': str(entry['document']['page_number']), | |
'score': str(entry['score']) | |
}) | |
document_titles = list({os.path.basename(doc["title"]) for doc in context_data}) # Get only file names | |
formatted_titles = " " + "\n".join(document_titles) | |
total_results = len(context_data) # Count the total number of results | |
results = { | |
"results": [ | |
{ | |
"natural_language_output": generate_output(doc["text"], query), | |
"chunk_id": doc["chunk_id"], | |
"document_id": doc["doc_id"], # Assuming doc_id is the UUID | |
"title": doc["title"], | |
"text": doc["text"], | |
"page_number": doc["page_number"], | |
"score": doc["score"], | |
} | |
for doc in context_data | |
], | |
"total_results": total_results # Added total_results field | |
} | |
return results, formatted_titles # Return results and formatted document titles | |
except Exception as e: | |
return {"results": [], "total_results": 0}, f"Error in workflow: {str(e)}" | |
def gradio_app(): | |
with gr.Blocks(css=".result-output {width: 150%; font-size: 16px; padding: 10px;}") as app: | |
gr.Markdown("### Intelligent Document Search Prototype-v0.1.2 ") | |
with gr.Row(): | |
user_query = gr.Textbox(label="Enter Your Search Query") | |
search_btn = gr.Button("Search") | |
with gr.Row(): | |
result_output = gr.JSON(label="Search Results", elem_id="result-output") | |
with gr.Row(): | |
titles_output = gr.Textbox(label="Document Titles", interactive=False) # New Textbox for Titles | |
search_btn.click( | |
complete_workflow, | |
inputs=user_query, | |
outputs=[result_output, titles_output], | |
) | |
return app | |
# Launch the app | |
gradio_app().launch() |