Spaces:
Paused
Paused
import os | |
import PyPDF2 | |
import faiss | |
import torch | |
import streamlit as st | |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
# Load embedding model | |
embedding_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L3-v2") | |
# Load a powerful LLM (e.g., Mistral-7B, GPT-4 API, T5-based model) | |
llm_model_name = "google/flan-t5-small" | |
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
#llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name) | |
llm_model = AutoModelForSeq2SeqLM.from_pretrained(llm_model_name, torch_dtype=torch.float16) | |
# Function to extract text from PDF | |
def extract_text_from_pdf(pdf_path): | |
"""Extract text from a research paper (PDF).""" | |
text = "" | |
with open(pdf_path, "rb") as f: | |
reader = PyPDF2.PdfReader(f) | |
for page in reader.pages: | |
text += page.extract_text() + "\n" | |
return text | |
# Function to chunk text into sections | |
def chunk_text(text, chunk_size=300): | |
"""Splits text into sections based on paper structure.""" | |
sections = {} | |
split_text = text.split("\n") | |
#chunk_size = min(chunk_size, len(split_text)) | |
current_section = "Other" | |
sections[current_section] = [] | |
for line in split_text: | |
line = line.strip() | |
if line.lower().startswith("abstract"): | |
current_section = "Abstract" | |
sections[current_section] = [] | |
elif line.lower().startswith("introduction"): | |
current_section = "Introduction" | |
sections[current_section] = [] | |
elif line.lower().startswith("conclusion"): | |
current_section = "Conclusion" | |
sections[current_section] = [] | |
sections[current_section].append(line) | |
# Convert sections to chunks | |
for section in sections: | |
sections[section] = " ".join(sections[section]) | |
return sections | |
# Function to create FAISS vector database | |
def build_vector_database(sections): | |
"""Builds FAISS vector index for research paper sections.""" | |
chunk_texts = list(sections.values()) | |
embeddings = embedding_model.encode(chunk_texts) | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dim) | |
index.add(embeddings) | |
return index, embeddings, list(sections.keys()), chunk_texts | |
# Function to retrieve relevant context | |
def retrieve_context(query, index, embeddings, section_titles, section_texts, top_k=1): | |
"""Retrieves most relevant sections for a query.""" | |
query_embedding = embedding_model.encode([query]) | |
embeddings = torch.tensor(embeddings) | |
distances, indices = index.search(query_embedding, top_k) | |
retrieved_contexts = [f"**{section_titles[idx]}**: {section_texts[idx]}" for idx in indices[0]] | |
return "\n".join(retrieved_contexts) | |
# Function to generate a concise answer | |
def generate_answer_rag(question, context, max_length=512): | |
"""Truncate input text to prevent exceeding model token limit.""" | |
input_text = f"Question: {question}\nContext: {context[:max_length]}" | |
input_ids = llm_tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).input_ids | |
output_ids = llm_model.generate(input_ids, max_length=150) | |
return llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Streamlit UI | |
def main(): | |
st.title("AI Research Paper RAG Chatbot") | |
uploaded_pdf = st.file_uploader("Upload a Research Paper (PDF)", type=["pdf"]) | |
if uploaded_pdf is not None: | |
pdf_path = "uploaded_paper.pdf" | |
with open(pdf_path, "wb") as f: | |
f.write(uploaded_pdf.read()) | |
st.success("PDF uploaded successfully!") | |
# Extract and preprocess text | |
text = extract_text_from_pdf(pdf_path) | |
text_sections = chunk_text(text) | |
# Build FAISS vector database | |
index, embeddings, section_titles, section_texts = build_vector_database(text_sections) | |
st.write(f"Paper processed into {len(text_sections)} sections for efficient retrieval.") | |
# User query input | |
user_question = st.text_input("Ask a question about the paper:") | |
if user_question: | |
context = retrieve_context(user_question, index, embeddings, section_titles, section_texts) | |
answer = generate_answer_rag(user_question, context) | |
st.write(f"**Retrieved Context:**\n{context}") | |
st.write(f"**Generated Answer:**\n{answer}") | |
if __name__ == "__main__": | |
main() | |