R-A-G / app.py
muhammadshaheryar's picture
Update app.py
a181300 verified
import faiss
from annoy import AnnoyIndex
# Build Annoy index
def create_annoy_index(embeddings, num_trees=10):
index = AnnoyIndex(embeddings.shape[1], 'angular')
for i, emb in enumerate(embeddings):
index.add_item(i, emb)
return index
# Query Annoy index
def retrieve_relevant_text(query, annoy_index, texts, top_k=3):
query_embedding = embedder.encode([query])[0]
indices = annoy_index.get_nns_by_vector(query_embedding, top_k)
return [texts[i] for i in indices]
# Function to create an Annoy index from the embeddings
def create_annoy_index(embeddings, num_trees=10):
index = AnnoyIndex(embeddings.shape[1], 'angular') # Using angular distance metric
for i, emb in enumerate(embeddings):
index.add_item(i, emb)
return index
# Function to retrieve the most relevant text using Annoy
def retrieve_relevant_text(query, annoy_index, texts, top_k=3):
query_embedding = embedder.encode([query], convert_to_tensor=True)
indices = annoy_index.get_nns_by_vector(query_embedding[0], top_k)
return [texts[i] for i in indices]
import os
import fitz # PyMuPDF for PDF extraction
import faiss # for efficient vector search
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
from sentence_transformers import SentenceTransformer
import streamlit as st
# Load the pre-trained RAG model and tokenizer
model_name = "facebook/rag-token-nq" # You can change this to a different open-source RAG model if needed
tokenizer = RagTokenizer.from_pretrained(model_name)
model = RagSequenceForGeneration.from_pretrained(model_name)
# Initialize sentence transformer model for embeddings
embedder = SentenceTransformer('all-MiniLM-L6-v2')
# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_file):
pdf_document = fitz.open(pdf_file)
text = ""
for page_num in range(pdf_document.page_count):
page = pdf_document.load_page(page_num)
text += page.get_text("text")
return text
# Function to create embeddings from text data
def create_embeddings(text_data):
embeddings = embedder.encode(text_data, convert_to_tensor=True)
return embeddings
# Function to create a FAISS index from the embeddings
def create_faiss_index(embeddings):
index = faiss.IndexFlatL2(embeddings.shape[1]) # Using L2 distance metric
return index
# Function to retrieve the most relevant text using FAISS
def retrieve_relevant_text(query, faiss_index, texts, top_k=3):
query_embedding = embedder.encode([query], convert_to_tensor=True)
D, I = faiss_index.search(query_embedding, top_k) # D: distances, I: indices
return [texts[i] for i in I[0]]
# Main function to answer questions based on uploaded PDF
def get_answer_from_pdf(pdf_file, query):
# Step 1: Extract text from the uploaded PDF file
document_text = extract_text_from_pdf(pdf_file)
# Step 2: Split the document text into chunks (optional but recommended for large docs)
text_chunks = document_text.split('\n')
# Step 3: Create embeddings for each chunk of text
embeddings = create_embeddings(text_chunks)
# Step 4: Create a FAISS index for efficient retrieval
faiss_index = create_faiss_index(embeddings)
# Step 5: Retrieve relevant text from the document based on the query
relevant_texts = retrieve_relevant_text(query, faiss_index, text_chunks)
# Step 6: Combine the relevant text and pass it to the RAG model
context = " ".join(relevant_texts)
inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True)
context_inputs = tokenizer(context, return_tensors="pt", padding=True, truncation=True)
# Generate the answer
outputs = model.generate(input_ids=inputs["input_ids"], context_input_ids=context_inputs["input_ids"])
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Streamlit UI
def main():
st.title("RAG Application - PDF Q&A")
# Upload PDF file
uploaded_file = st.file_uploader("Upload a PDF Document", type="pdf")
if uploaded_file is not None:
# Ask a question from the uploaded PDF
question = st.text_input("Ask a question based on the document:")
if question:
# Get the answer from the PDF document
answer = get_answer_from_pdf(uploaded_file, question)
# Display the answer
st.write("Answer: ", answer)
if __name__ == "__main__":