Spaces:
Sleeping
Sleeping
import nltk | |
nltk.download('punkt') | |
nltk.download('punkt_tab') # Explicitly download the missing resource | |
nltk.download('averaged_perceptron_tagger_eng') | |
import os | |
import streamlit as st | |
import pickle | |
import faiss | |
import time | |
import numpy as np | |
from langchain_openai import OpenAI | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import UnstructuredURLLoader | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from langchain.schema import Document | |
import requests | |
from dotenv import load_dotenv | |
load_dotenv() # take environment variables from .env (especially openai api key) | |
# api_key = os.getenv("OPENAI_API_KEY") | |
# if not api_key: | |
# st.error("OpenAI API Key not found. Ensure it's set in the environment variables.") | |
# else: | |
# st.write("OpenAI API Key loaded successfully.") | |
st.title("Article/News Research Tool") | |
st.sidebar.title("Article URLs...") | |
# Initialize session state for Q&A history | |
if "qa_history" not in st.session_state: | |
st.session_state.qa_history = [] | |
# Ask the user how many URLs they want to input | |
num_urls = st.sidebar.number_input("How many URLs do you want to process?", min_value=1, max_value=10, value=2) | |
urls = [] | |
for i in range(num_urls): | |
url = st.sidebar.text_input(f"URL {i+1}") | |
urls.append(url) | |
# urls = [] | |
# for i in range(3): | |
# url = st.sidebar.text_input(f"URL {i+1}") | |
# urls.append(url) | |
process_url_clicked = st.sidebar.button("Process Article URLs") | |
# file_path = "faiss_store_openai.pkl" | |
# | |
main_placeholder = st.empty() | |
llm = OpenAI(temperature=0.5, max_tokens=500) | |
index_path = "faiss_index.bin" | |
docs_path = "docs.pkl" | |
index_to_docstore_id_path = "index_to_docstore_id.pkl" | |
if process_url_clicked: | |
for url in urls: | |
try: | |
response = requests.get(url) | |
print(f"Status for {url}: {response.status_code}") | |
except Exception as e: | |
print(f"Error accessing {url}: {e}") | |
# load data | |
loader = UnstructuredURLLoader(urls=urls) | |
main_placeholder.text("Data Loading...Initiated...") | |
try: | |
data = loader.load() | |
if not data: | |
st.error("No data loaded from the provided URLs. Please check the URLs.") | |
else: | |
for i, doc in enumerate(data): | |
print(f"Document {i}: {doc.page_content[:100]}") # Preview first 100 characters | |
except Exception as e: | |
print(f"Error loading URLs: {e}") | |
st.error(f"Error loading URLs: {e}") | |
# split data | |
text_splitter = RecursiveCharacterTextSplitter( | |
# separators=['\n\n', '\n', '.', ','], | |
chunk_size=1000, | |
chunk_overlap=200 | |
) | |
docs = text_splitter.split_documents(data) | |
print(f"Number of chunks created: {len(docs)}") | |
for i, doc in enumerate(docs[:5]): # Limit output | |
print(f"Chunk {i}: {doc.page_content[:100]}") # Preview first 100 characters | |
main_placeholder.text("Text Splitter...Initiated...") | |
docs = text_splitter.split_documents(data) | |
# create embeddings and save it to FAISS index | |
embeddings = OpenAIEmbeddings() | |
embedding_dimension = 1536 | |
docstore_dict = {str(i): doc for i, doc in enumerate(docs)} | |
docstore = InMemoryDocstore(docstore_dict) | |
# Create FAISS vector index | |
index = faiss.IndexFlatL2(embedding_dimension) | |
# Initialize the FAISS vector store with a correct mapping | |
index_to_docstore_id = {i: str(i) for i in range(len(docs))} | |
vector_store = FAISS(embedding_function=embeddings, index=index, docstore=docstore, index_to_docstore_id=index_to_docstore_id) | |
# Add documents to the FAISS index | |
for i, vector in enumerate([embeddings.embed_query(doc.page_content) for doc in docs]): | |
print(f"Adding vector {i} to FAISS index.") | |
index.add(np.array([vector])) | |
# vector_store.add_documents(docs) | |
main_placeholder.text("Embedding Vector Building Initiated...") | |
# Save the FAISS index and documents separately | |
# index_path = "faiss_index.bin" | |
faiss.write_index(vector_store.index, index_path) | |
# docs_path = "docs.pkl" | |
with open(docs_path, "wb") as f: | |
pickle.dump(docs, f) | |
# Save the index_to_docstore_id mapping | |
# index_to_docstore_id_path = "index_to_docstore_id.pkl" | |
with open(index_to_docstore_id_path, "wb") as f: | |
pickle.dump(vector_store.index_to_docstore_id, f) | |
query = main_placeholder.text_input("Question: ") | |
if query: | |
if not process_url_clicked: | |
# Load the FAISS index and documents | |
if os.path.exists(index_path) and os.path.exists(docs_path) and os.path.exists(index_to_docstore_id_path): | |
# st.write("Files loaded successfully.") | |
# else: | |
# st.write("One or more files are missing:") | |
# if not os.path.exists(index_path): | |
# st.write(f"Missing: {index_path}") | |
# if not os.path.exists(docs_path): | |
# st.write(f"Missing: {docs_path}") | |
# if not os.path.exists(index_to_docstore_id_path): | |
# st.write(f"Missing: {index_to_docstore_id_path}") | |
# st.write("Loading precomputed FAISS index and documents...") | |
index = faiss.read_index(index_path) | |
with open(docs_path, "rb") as f: | |
docs = pickle.load(f) | |
with open(index_to_docstore_id_path, "rb") as f: | |
index_to_docstore_id = pickle.load(f) | |
docstore = InMemoryDocstore({str(i): doc for i, doc in enumerate(docs)}) | |
# print(f"Loaded document store keys: {list(docstore._dict.keys())[:10]}") # Debug output | |
embeddings = OpenAIEmbeddings() # Recreate embeddings object | |
vector_store = FAISS(embedding_function=embeddings, index=index, docstore=docstore, | |
index_to_docstore_id=index_to_docstore_id) | |
else: | |
st.error("Precomputed files are missing. Please upload them.") | |
chain = RetrievalQAWithSourcesChain.from_llm(llm=llm, retriever=vector_store.as_retriever()) | |
result = chain.invoke({"question": query}, return_only_outputs=True) | |
# st.write(f"Raw result: {result}") | |
# Extract and display the result | |
answer = result.get("answer", "No answer found.") | |
sources = result.get("sources", "No sources available.") | |
# Add to session state history | |
st.session_state.qa_history.append({"question": query, "answer": answer, "sources": sources}) | |
# result will be a dictionary of this format --> {"answer": "", "sources": [] } | |
st.subheader("Response:") | |
st.write(result["answer"]) | |
# Display sources, if available | |
sources = result.get("sources", "") | |
if sources: | |
st.subheader("Sources:") | |
sources_list = sources.split("\n") # Split the sources by newline | |
for source in sources_list: | |
st.write(source) | |
# Display all questions and answers from the session | |
if st.session_state.qa_history: | |
st.write("---------------------------------------------------------------------") | |
st.subheader("History:") | |
for entry in st.session_state.qa_history: | |
st.write(f"**Q:** {entry['question']}") | |
st.write(f"**A:** {entry['answer']}") | |
st.write(f"**Sources:** {entry['sources']}") | |