import os |
import streamlit as st |
import pdfplumber |
import requests |
import google.generativeai as genai |
from bs4 import BeautifulSoup |
from langchain.schema import Document |
from langchain.text_splitter import RecursiveCharacterTextSplitter |
from langchain_pinecone import PineconeVectorStore |
from langchain_groq import ChatGroq |
from langchain.chains import create_retrieval_chain |
from langchain.chains.combine_documents import create_stuff_documents_chain |
from langchain_core.prompts import ChatPromptTemplate |
from langchain_core.embeddings import Embeddings |
from langchain_community.tools import DuckDuckGoSearchRun |
from pinecone import Pinecone |
from dotenv import load_dotenv |
import numpy as np |
import time |
import random |
from typing import List |
import arxiv |
import wikipedia |
from selenium import webdriver |
from selenium.webdriver.common.by import By |
from selenium.webdriver.chrome.options import Options |
from selenium.webdriver.common.action_chains import ActionChains |
from lxml import html |
import base64 |
import os |
import streamlit as st |
import pdfplumber |
import requests |
import google.generativeai as genai |
load_dotenv() |
groq_key = os.getenv("GROQ_API_KEY") |
pinecone_key = os.getenv("PINECONE_API_KEY") |
gemini_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") |
genai.configure(api_key=gemini_key) |
if not gemini_key: |
st.error("Gemini API key is missing. Please set either GEMINI_API_KEY or GOOGLE_API_KEY environment variable.") |
st.set_page_config( |
page_title="AI Research Assistant", |
page_icon="π", |
layout="wide", |
initial_sidebar_state="expanded" |
) |
class GeminiEmbeddings(Embeddings): |
def __init__(self, api_key): |
genai.configure(api_key=api_key) |
self.model_name = "models/embedding-001" |
def embed_documents(self, texts): |
return [self._convert_to_float32(genai.embed_content( |
model=self.model_name, content=text, task_type="retrieval_document" |
)["embedding"]) for text in texts] |
def embed_query(self, text): |
response = genai.embed_content( |
model=self.model_name, content=text, task_type="retrieval_query" |
) |
return self._convert_to_float32(response["embedding"]) |
@staticmethod |
def _convert_to_float32(embedding): |
return np.array(embedding, dtype=np.float32).tolist() |
def extract_text_from_pdf(pdf_path): |
text = "" |
try: |
with pdfplumber.open(pdf_path) as pdf: |
for page in pdf.pages: |
extracted_text = page.extract_text() |
if extracted_text: |
text += extracted_text + "\n" |
return text.strip() |
except Exception as e: |
st.error(f"Error extracting text from PDF: {e}") |
return "" |
def read_data_from_doc(uploaded_file): |
docs = [] |
with pdfplumber.open(uploaded_file) as pdf: |
for i, page in enumerate(pdf.pages): |
text = page.extract_text() or "" |
tables = page.extract_tables() |
table_text = "\n".join([ |
"\n".join(["\t".join(cell if cell is not None else "" for cell in row) for row in table]) |
for table in tables if table |
]) if tables else "" |
images = page.images |
image_text = f"[{len(images)} image(s) detected]" if images else "" |
content = f"{text}\n\n{table_text}\n\n{image_text}".strip() |
if content: |
docs.append(Document(page_content=content, metadata={"page": i + 1})) |
return docs |
def make_chunks(docs, chunk_len=1000, chunk_overlap=200): |
text_splitter = RecursiveCharacterTextSplitter( |
chunk_size=chunk_len, chunk_overlap=chunk_overlap |
) |
chunks = text_splitter.split_documents(docs) |
return [Document(page_content=chunk.page_content, metadata=chunk.metadata) for chunk in chunks] |
def get_gemini_model(model_name="gemini-1.5-pro", temperature=0.4): |
return genai.GenerativeModel(model_name) |
def get_generation_config(temperature=0.4): |
return { |
"temperature": temperature, |
"top_p": 1, |
"top_k": 1, |
"max_output_tokens": 2048, |
} |
def get_safety_settings(): |
return [ |
{"category": category, "threshold": "BLOCK_NONE"} |
for category in [ |
] |
] |
def generate_gemini_response(model, prompt): |
response = model.generate_content( |
prompt, |
generation_config=get_generation_config(), |
safety_settings=get_safety_settings() |
) |
if response.candidates and len(response.candidates) > 0: |
return response.candidates[0].content.parts[0].text |
return '' |
def summarize_text(text): |
model = get_gemini_model() |
prompt_text = f"Summarize the following research paper very concisely:\n{text[:5000]}" |
summary = generate_gemini_response(model, prompt_text) |
return summary |
def download_pdf(pdf_url, save_path="temp_paper.pdf"): |
try: |
response = requests.get(pdf_url) |
if response.status_code == 200: |
with open(save_path, "wb") as file: |
file.write(response.content) |
return save_path |
except Exception as e: |
st.error(f"Error downloading PDF: {e}") |
return None |
def search_arxiv(query, max_results=2): |
client = arxiv.Client() |
search = arxiv.Search(query=query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance) |
arxiv_docs = [] |
for result in client.results(search): |
pdf_link = next((link.href for link in result.links if 'pdf' in link.href), None) |
if pdf_link: |
with st.spinner(f"Processing arXiv paper: {result.title}"): |
pdf_path = download_pdf(pdf_link) |
if pdf_path: |
text = extract_text_from_pdf(pdf_path) |
summary = summarize_text(text) |
if os.path.exists(pdf_path): |
os.remove(pdf_path) |
else: |
summary = "PDF could not be downloaded." |
else: |
summary = "No PDF available." |
content = f""" |
**Title:** {result.title} |
**Authors:** {', '.join(author.name for author in result.authors)} |
**Published:** {result.published.strftime('%Y-%m-%d')} |
**Abstract:** {result.summary} |
**PDF Summary:** {summary} |
**PDF Link:** {pdf_link if pdf_link else 'Not available'} |
""" |
arxiv_docs.append(Document(page_content=content, metadata={"source": "arXiv", "title": result.title})) |
return arxiv_docs |
def search_wikipedia(query, max_results=2): |
try: |
page_titles = wikipedia.search(query, results=max_results) |
wiki_docs = [] |
for title in page_titles: |
try: |
with st.spinner(f"Processing Wikipedia article: {title}"): |
page = wikipedia.page(title) |
wiki_docs.append(Document( |
page_content=page.content[:2000], |
metadata={"source": "Wikipedia", "title": title} |
)) |
except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError) as e: |
st.warning(f"Error retrieving Wikipedia page {title}: {e}") |
return wiki_docs |
except Exception as e: |
st.error(f"Error searching Wikipedia: {e}") |
return [] |
class ResearchAssistant: |
def __init__(self): |
self.llm = ChatGroq( |
api_key=groq_key, |
model="llama3-70b-8192", |
temperature=0.2 |
) |
self.prompt = ChatPromptTemplate.from_template(""" |
You are an expert research assistant. Use the following context to answer the question. |
If you don't know the answer, say so, but try your best to find relevant information |
from the provided context and additional context. |
Context from user documents: |
{context} |
Additional context from research sources: |
{additional_context} |
Question: {input} |
Answer: |
""") |
self.question_answer_chain = create_stuff_documents_chain( |
self.llm, self.prompt |
) |
def retrieve_documents(self, query): |
user_context = [] |
arxiv_docs = search_arxiv(query) |
wiki_docs = search_wikipedia(query) |
summarized_context = [] |
for doc in arxiv_docs: |
summarized_context.append(f"**ArXiv - {doc.metadata.get('title', 'Unknown Title')}**:\n{doc.page_content}...") |
for doc in wiki_docs: |
summarized_context.append(f"**Wikipedia - {doc.metadata.get('title', 'Unknown Title')}**:\n{doc.page_content}...") |
return user_context, summarized_context |
def chat(self, question): |
user_context, summarized_context = self.retrieve_documents(question) |
input_data = { |
"input": question, |
"context": "\n\n".join(user_context), |
"additional_context": "\n\n".join(summarized_context) |
} |
with st.spinner("Generating answer..."): |
prompt_text = f""" |
Question: {question} |
Additional context: |
{input_data['additional_context']} |
Please provide a comprehensive answer based on the above information. |
""" |
response = self.llm.invoke(prompt_text) |
return response.content, summarized_context |
@st.cache_resource(show_spinner=False) |
def get_retrieval_chain(uploaded_file, model): |
with st.spinner("Processing document... This may take a minute."): |
genai.configure(api_key=gemini_key) |
embeddings = GeminiEmbeddings(api_key=gemini_key) |
docs = read_data_from_doc(uploaded_file) |
splits = make_chunks(docs) |
pc = Pinecone(api_key=pinecone_key) |
indexes = pc.list_indexes() |
index_name = "research-rag" |
if index_name not in [idx.name for idx in indexes]: |
pc.create_index( |
name=index_name, |
dimension=768, |
metric="cosine" |
) |
vectorstore = PineconeVectorStore.from_documents( |
splits, |
embeddings, |
index_name=index_name, |
) |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4}) |
llm = ChatGroq(model_name=model, temperature=0.75, api_key=groq_key) |
system_prompt = """ |
You are an AI assistant answering questions based on retrieved documents and additional context. |
Use the provided context from both database retrieval and additional sources to answer the question. |
- **Discard irrelevant context:** If one of the contexts (retrieved or additional) does not match the question, ignore it. |
- **Highlight conflicting information:** If multiple sources provide conflicting information, explicitly mention it by saying: |
- "According to the retrieved context, ... but as per internet sources, ..." |
- "According to the retrieved context, ... but as per internet sources, ..." |
- **Prioritize accuracy:** If neither context provides a relevant answer, say "I don't know" instead of guessing. |
Provide concise yet informative answers, ensuring clarity and completeness. |
Retrieved Context: {context} |
Additional Context: {additional_context} |
""" |
prompt = ChatPromptTemplate.from_messages([ |
("system", system_prompt), |
("human", "{input}\n\nRetrieved Context: {context}\n\nAdditional Context: {additional_context}"), |
]) |
question_answer_chain = create_stuff_documents_chain(llm, prompt) |
chain = create_retrieval_chain(retriever, question_answer_chain) |
return chain |
def create_search_prompt(query, context=""): |
system_prompt = """You are a smart assistant designed to determine whether a query needs data from a web search or can be answered using a document database. |
Consider the provided context if available. |
If the query requires external information, No context is provided, Irrelevent context is present or latest information is required, then output the special token <SEARCH> |
followed by relevant keywords extracted from the query to optimize for search engine results. |
Ensure the keywords are concise and relevant. If document data is sufficient, simply return blank.""" |
if context: |
return f"{system_prompt}\n\nContext: {context}\n\nQuery: {query}" |
return f"{system_prompt}\n\nQuery: {query}" |
def create_summary_prompt(content): |
return f"""Please provide a comprehensive yet concise summary of the following content, highlighting the most important points and maintaining factual accuracy. Organize the information in a clear and coherent manner: |
Content to summarize: |
{content} |
Summary:""" |
def init_selenium_driver(): |
chrome_options = Options() |
chrome_options.add_argument("--headless") |
chrome_options.add_argument("--disable-gpu") |
chrome_options.add_argument("--no-sandbox") |
chrome_options.add_argument("--disable-dev-shm-usage") |
driver = webdriver.Chrome(options=chrome_options) |
return driver |
def extract_static_page(url): |
try: |
response = requests.get(url, timeout=5) |
response.raise_for_status() |
soup = BeautifulSoup(response.text, 'lxml') |
text = soup.get_text(separator=" ", strip=True) |
return text[:5000] |
except requests.exceptions.RequestException as e: |
st.error(f"Error fetching page: {e}") |
return None |
def extract_dynamic_page(url, driver): |
try: |
driver.get(url) |
time.sleep(random.uniform(2, 5)) |
body = driver.find_element(By.TAG_NAME, "body") |
ActionChains(driver).move_to_element(body).perform() |
time.sleep(random.uniform(2, 5)) |
page_source = driver.page_source |
tree = html.fromstring(page_source) |
text = tree.xpath('//body//text()') |
text_content = ' '.join(text).strip() |
return text_content[:1000] |
except Exception as e: |
st.error(f"Error fetching dynamic page: {e}") |
return None |
def scrape_page(url): |
if "javascript" in url or "dynamic" in url: |
driver = init_selenium_driver() |
text = extract_dynamic_page(url, driver) |
driver.quit() |
else: |
text = extract_static_page(url) |
return text |
def scrape_web(urls, max_urls=5): |
texts = [] |
for url in urls[:max_urls]: |
text = scrape_page(url) |
if text: |
texts.append(text) |
else: |
st.warning(f"Failed to retrieve content from {url}") |
return texts |
def check_search_needed(model, query, context): |
prompt = create_search_prompt(query, context) |
response = generate_gemini_response(model, prompt) |
if "<SEARCH>" in response: |
search_terms = response.split("<SEARCH>")[1].strip() |
return True, search_terms |
return False, None |
def summarize_content(model, content): |
prompt = create_summary_prompt(content) |
return generate_gemini_response(model, prompt) |
def process_query(query, context=''): |
with st.spinner("Processing query..."): |
model = get_gemini_model() |
search_tool = DuckDuckGoSearchRun() |
needs_search, search_terms = check_search_needed(model, query, context) |
result = { |
"original_query": query, |
"needs_search": needs_search, |
"search_terms": search_terms, |
"web_content": None, |
"summary": None |
} |
if needs_search: |
with st.spinner(f"Searching the web for: {search_terms}"): |
search_results = search_tool.run(search_terms) |
result["web_content"] = search_results |
with st.spinner("Summarizing search results..."): |
summary = summarize_content(model, search_results) |
result["summary"] = summary |
return result |
def display_header(): |
st.title("π AI Research Assistant") |
st.markdown("Your all-in-one tool for research, document analysis, and web search") |
def main(): |
display_header() |
with st.sidebar: |
st.title("Navigation") |
app_mode = st.radio("Choose a mode:", |
["Research Assistant", "Document Q&A", "Web Search"]) |
st.markdown("---") |
st.subheader("About") |
st.markdown(""" |
This AI Research Assistant helps you find and analyze information from various sources: |
- arXiv papers |
- Wikipedia articles |
- Your own uploaded documents |
- Web search results |
""") |
st.markdown("---") |
st.subheader("API Status") |
if groq_key: |
Groq API connected") |
else: |
st.error("β Groq API key missing") |
if gemini_key: |
Gemini API connected") |
else: |
st.error("β Gemini API key missing") |
if pinecone_key: |
Pinecone API connected") |
else: |
st.error("β Pinecone API key missing") |
if app_mode == "Research Assistant": |
st.header("Research Assistant") |
st.markdown("Ask research questions and get answers from arXiv papers and Wikipedia.") |
if "research_history" not in st.session_state: |
st.session_state.research_history = [] |
if "research_assistant" not in st.session_state: |
with st.spinner("Initializing Research Assistant..."): |
st.session_state.research_assistant = ResearchAssistant() |
with st.form(key="research_form"): |
question = st.text_input("Your research question:", key="research_question") |
submit_button = st.form_submit_button("Search") |
if st.button("Clear Chat"): |
st.session_state.research_history = [] |
st.rerun() |
if submit_button and question: |
st.session_state.research_history.append({"role": "user", "content": question}) |
answer, sources = st.session_state.research_assistant.chat(question) |
st.session_state.research_history.append({ |
"role": "assistant", |
"content": answer, |
"sources": sources |
}) |
for message in st.session_state.research_history: |
if message["role"] == "user": |
st.write(f"π€ **You:** {message['content']}") |
else: |
st.write(f"π€ **AI Assistant:**") |
st.markdown(message["content"]) |
if message.get("sources"): |
with st.expander("View Sources"): |
for i, source in enumerate(message["sources"], 1): |
st.markdown(f"**Source {i}:**") |
st.markdown(source) |
st.markdown("---") |
elif app_mode == "Document Q&A": |
st.header("Document Q&A") |
st.markdown("Upload a PDF document and ask questions about it.") |
model_name = st.selectbox( |
"Select Groq Model", |
[ |
"llama3-70b-8192", |
"gemma2-9b-it", |
"llama-3.3-70b-versatile", |
"llama-3.1-8b-instant", |
"llama-guard-3-8b", |
"mixtral-8x7b-32768", |
"deepseek-r1-distill-llama-70b", |
"llama-3.2-1b-preview" |
], |
index=0 |
) |
if 'document_conversation' not in st.session_state: |
st.session_state.document_conversation = [] |
uploaded_file = st.file_uploader("Upload a PDF document", type="pdf") |
if uploaded_file: |
try: |
chain = get_retrieval_chain( |
uploaded_file, |
model_name |
) |
st.success("Document processed successfully! You can now ask questions.") |
for q, a in st.session_state.document_conversation: |
with st.chat_message("user"): |
st.write(q) |
with st.chat_message("assistant"): |
st.write(a) |
question = st.chat_input("Ask a question about your document...") |
if question: |
with st.chat_message("user"): |
st.write(question) |
with st.chat_message("assistant"): |
with st.spinner("Thinking..."): |
additional_context = "" |
result = chain.invoke({ |
"input": question, |
"additional_context": additional_context |
}) |
answer = result['answer'] |
st.write(answer) |
st.session_state.document_conversation.append((question, answer)) |
except Exception as e: |
st.error(f"An error occurred: {str(e)}") |
elif not (groq_key and gemini_key and pinecone_key): |
st.warning("Please make sure all API keys are properly configured.") |
else: |
st.header("Web Search") |
st.markdown("Search the web for answers to your questions.") |
with st.form("web_query_form"): |
query = st.text_area("Enter your research question", height=100, |
placeholder="E.g., What are the latest developments in quantum computing?") |
context = st.text_area("Optional: Add any context", height=100, |
placeholder="Add any additional context that might help with the research") |
submit_button = st.form_submit_button("π Research") |
if submit_button and query: |
result = process_query(query, context) |
if result["needs_search"]: |
st.success("Research completed!") |
with st.expander("Search Details", expanded=False): |
st.subheader("Search Terms Used") |
st.info(result["search_terms"]) |
st.subheader("Raw Web Content") |
st.text_area("Web Content", result["web_content"], height=200) |
st.subheader("Summary of Findings") |
st.markdown(result["summary"]) |
else: |
st.info("Based on the analysis, no web search was needed for this query.") |
if __name__ == "__main__": |
main() |