demo-a-rag / app.py
cloudranger's picture
remove force install
cf6d66d
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.retrievers import ContextualCompressionRetriever
from langchain_huggingface.cross_encoders import HuggingFaceCrossEncoder
# import os
# import subprocess
# def install(package):
# subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
# install("langchain-community")
# install("pypdf")
# --- Configuration ---
# Make sure your GROQ_API_KEY is set in your Hugging Face Space secrets
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
# --- 1. Load Documents ---
def load_documents(file_path):
"""Loads documents from a file path."""
loader = PyPDFLoader(file_path)
return loader.load()
# --- 2. Text Splitting & Vector Store Creation ---
def create_vector_store(documents, chunk_size=1000, chunk_overlap=200):
"""Splits documents and creates a FAISS vector store."""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = text_splitter.split_documents(documents)
# Use a popular sentence-transformer model for embeddings
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Create the vector store
vectorstore = FAISS.from_documents(docs, embeddings)
return vectorstore
# --- 3. RAG Chain Creation ---
def create_rag_chain(vectorstore, llm, retriever_k=4, use_reranker=False):
"""Creates a simple or advanced RAG chain."""
retriever = vectorstore.as_retriever(search_kwargs={"k": retriever_k})
# Advanced RAG: Add a reranker for more relevant results
if use_reranker:
# Using a powerful cross-encoder for reranking
cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-large")
# The compression retriever will fetch more docs (e.g., 10) and the reranker will pick the top 3
compressor = ContextualCompressionRetriever(
base_compressor=cross_encoder, base_retriever=retriever
)
retriever = compressor
prompt = ChatPromptTemplate.from_template("""
You are a helpful assistant for question-answering tasks.
Use the following retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise.
Question: {input}
Context: {context}
Answer:
""")
document_chain = create_stuff_documents_chain(llm, prompt)
return create_retrieval_chain(retriever, document_chain)
# --- Gradio UI ---
def create_chat_interface(rag_chain_creator, use_reranker_flag):
"""A factory function to create the chat logic for a Gradio tab."""
def chat_logic(message, history, file_upload, chunk_size, k_retriever):
if rag_chain_creator.chain is None:
if file_upload is None:
return "Please upload a document first.", history
# Process the document and create the RAG chain
docs = load_documents(file_upload.name)
vector_store = create_vector_store(docs, chunk_size)
llm = ChatGroq(temperature=0, model_name="llama3-8b-8192", api_key=GROQ_API_KEY)
rag_chain_creator.chain = create_rag_chain(vector_store, llm, k_retriever, use_reranker=use_reranker_flag)
response = rag_chain_creator.chain.invoke({"input": message})
return response["answer"]
return chat_logic
class RAGChainState:
"""Simple class to hold the state of a RAG chain."""
def __init__(self):
self.chain = None
# Create separate state holders for each tab to keep them independent
simple_rag_state = RAGChainState()
advanced_rag_state = RAGChainState()
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
gr.Markdown("# πŸš€ RAG Pipeline Demo with Groq & LangChain")
gr.Markdown("Upload a document and ask questions. Compare the 'Simple' and 'Advanced' RAG pipelines.")
with gr.Tabs():
# --- Simple RAG Tab ---
with gr.TabItem("Simple RAG"):
with gr.Row():
with gr.Column(scale=1):
file_upload_simple = gr.File(label="Upload your TXT, MD, or PDF")
with gr.Accordion("Settings", open=False):
chunk_size_simple = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size")
k_simple = gr.Slider(1, 10, value=4, step=1, label="Docs to Retrieve (k)")
with gr.Column(scale=2):
gr.ChatInterface(
fn=create_chat_interface(simple_rag_state, use_reranker_flag=False),
additional_inputs=[file_upload_simple, chunk_size_simple, k_simple],
title="Simple RAG",
description="Basic retriever. Good for general queries."
)
# --- Advanced RAG Tab ---
with gr.TabItem("Advanced RAG (with Reranker)"):
with gr.Row():
with gr.Column(scale=1):
file_upload_advanced = gr.File(label="Upload your TXT, MD, or PDF")
with gr.Accordion("Settings", open=False):
chunk_size_advanced = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size")
k_advanced = gr.Slider(1, 10, value=4, step=1, label="Docs to Retrieve (k)")
with gr.Column(scale=2):
gr.ChatInterface(
fn=create_chat_interface(advanced_rag_state, use_reranker_flag=True),
additional_inputs=[file_upload_advanced, chunk_size_advanced, k_advanced],
title="Advanced RAG",
description="Uses a Cross-Encoder to rerank results for higher accuracy."
)
demo.launch()