Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain.chains import create_retrieval_chain | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain_huggingface.cross_encoders import HuggingFaceCrossEncoder | |
# import os | |
# import subprocess | |
# def install(package): | |
# subprocess.check_call([os.sys.executable, "-m", "pip", "install", package]) | |
# install("langchain-community") | |
# install("pypdf") | |
# --- Configuration --- | |
# Make sure your GROQ_API_KEY is set in your Hugging Face Space secrets | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
# --- 1. Load Documents --- | |
def load_documents(file_path): | |
"""Loads documents from a file path.""" | |
loader = PyPDFLoader(file_path) | |
return loader.load() | |
# --- 2. Text Splitting & Vector Store Creation --- | |
def create_vector_store(documents, chunk_size=1000, chunk_overlap=200): | |
"""Splits documents and creates a FAISS vector store.""" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
docs = text_splitter.split_documents(documents) | |
# Use a popular sentence-transformer model for embeddings | |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
# Create the vector store | |
vectorstore = FAISS.from_documents(docs, embeddings) | |
return vectorstore | |
# --- 3. RAG Chain Creation --- | |
def create_rag_chain(vectorstore, llm, retriever_k=4, use_reranker=False): | |
"""Creates a simple or advanced RAG chain.""" | |
retriever = vectorstore.as_retriever(search_kwargs={"k": retriever_k}) | |
# Advanced RAG: Add a reranker for more relevant results | |
if use_reranker: | |
# Using a powerful cross-encoder for reranking | |
cross_encoder = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-large") | |
# The compression retriever will fetch more docs (e.g., 10) and the reranker will pick the top 3 | |
compressor = ContextualCompressionRetriever( | |
base_compressor=cross_encoder, base_retriever=retriever | |
) | |
retriever = compressor | |
prompt = ChatPromptTemplate.from_template(""" | |
You are a helpful assistant for question-answering tasks. | |
Use the following retrieved context to answer the question. | |
If you don't know the answer, just say that you don't know. | |
Use three sentences maximum and keep the answer concise. | |
Question: {input} | |
Context: {context} | |
Answer: | |
""") | |
document_chain = create_stuff_documents_chain(llm, prompt) | |
return create_retrieval_chain(retriever, document_chain) | |
# --- Gradio UI --- | |
def create_chat_interface(rag_chain_creator, use_reranker_flag): | |
"""A factory function to create the chat logic for a Gradio tab.""" | |
def chat_logic(message, history, file_upload, chunk_size, k_retriever): | |
if rag_chain_creator.chain is None: | |
if file_upload is None: | |
return "Please upload a document first.", history | |
# Process the document and create the RAG chain | |
docs = load_documents(file_upload.name) | |
vector_store = create_vector_store(docs, chunk_size) | |
llm = ChatGroq(temperature=0, model_name="llama3-8b-8192", api_key=GROQ_API_KEY) | |
rag_chain_creator.chain = create_rag_chain(vector_store, llm, k_retriever, use_reranker=use_reranker_flag) | |
response = rag_chain_creator.chain.invoke({"input": message}) | |
return response["answer"] | |
return chat_logic | |
class RAGChainState: | |
"""Simple class to hold the state of a RAG chain.""" | |
def __init__(self): | |
self.chain = None | |
# Create separate state holders for each tab to keep them independent | |
simple_rag_state = RAGChainState() | |
advanced_rag_state = RAGChainState() | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo: | |
gr.Markdown("# π RAG Pipeline Demo with Groq & LangChain") | |
gr.Markdown("Upload a document and ask questions. Compare the 'Simple' and 'Advanced' RAG pipelines.") | |
with gr.Tabs(): | |
# --- Simple RAG Tab --- | |
with gr.TabItem("Simple RAG"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_upload_simple = gr.File(label="Upload your TXT, MD, or PDF") | |
with gr.Accordion("Settings", open=False): | |
chunk_size_simple = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size") | |
k_simple = gr.Slider(1, 10, value=4, step=1, label="Docs to Retrieve (k)") | |
with gr.Column(scale=2): | |
gr.ChatInterface( | |
fn=create_chat_interface(simple_rag_state, use_reranker_flag=False), | |
additional_inputs=[file_upload_simple, chunk_size_simple, k_simple], | |
title="Simple RAG", | |
description="Basic retriever. Good for general queries." | |
) | |
# --- Advanced RAG Tab --- | |
with gr.TabItem("Advanced RAG (with Reranker)"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
file_upload_advanced = gr.File(label="Upload your TXT, MD, or PDF") | |
with gr.Accordion("Settings", open=False): | |
chunk_size_advanced = gr.Slider(500, 2000, value=1000, step=100, label="Chunk Size") | |
k_advanced = gr.Slider(1, 10, value=4, step=1, label="Docs to Retrieve (k)") | |
with gr.Column(scale=2): | |
gr.ChatInterface( | |
fn=create_chat_interface(advanced_rag_state, use_reranker_flag=True), | |
additional_inputs=[file_upload_advanced, chunk_size_advanced, k_advanced], | |
title="Advanced RAG", | |
description="Uses a Cross-Encoder to rerank results for higher accuracy." | |
) | |
demo.launch() |