Spaces:
Sleeping
Sleeping
""" app.py | |
Question / answer over a collection of PDF documents using late interaction | |
ColBERT model for retrieval and DSPy+Mistral for answer generation. | |
:author: Didier Guillevic | |
:date: 2024-12-22 | |
""" | |
import gradio as gr | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
import os | |
import pdf_utils # utilities for pdf processing | |
import colbert_utils # utilities for to build a ColBERT retrieval model | |
import dspy_utils # utilities for building a DSPy based retrieval generation model | |
from tqdm.notebook import tqdm | |
import warnings | |
warnings.filterwarnings('ignore') | |
dspy_rag_model = None | |
def build_rag_model(files: list[str]) -> str: | |
"""Build a retrieval augmented model using given files to index. | |
""" | |
global dspy_rag_model | |
# Get the text from the pdf files | |
documents = [] | |
metadatas = [] | |
for pdf_file in files: | |
logger.info(f"Processing {pdf_file}") | |
metadata = pdf_utils.get_metadata_info(pdf_file) | |
text = pdf_utils.get_text_from_pdf(pdf_file) | |
if text: | |
documents.append(text) | |
metadatas.append(metadata) | |
# Build the ColBERT retrieval model | |
colbert_base_model = 'antoinelouis/colbert-xm' # multilingual model | |
colbert_index_name = 'OECD_HNW' # for web app, generate unique name with uuid.uuid4() | |
retrieval_model = colbert_utils.build_colbert_model( | |
documents, | |
metadatas, | |
pretrained_model=colbert_base_model, | |
index_name=colbert_index_name | |
) | |
# Instanatiate the DSPy based RAG model | |
dspy_rag_model = dspy_utils.DSPyRagModel(retrieval_model) | |
return "Done building RAG model." | |
def generate_response(question: str) -> list[str, str, str]: | |
"""Generate a response to a given question using the RAG model. | |
""" | |
global dspy_rag_model | |
if dspy_rag_model is None: | |
return "RAG model not built. Please build the model first." | |
# Generate response | |
responses, references, snippets = dspy_rag_model.generate_response( | |
question=question, k=5, method='chain_of_thought') | |
return responses, references, snippets | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Retrieval (ColBERT) + Generation (DSPy & Mistral) | |
- Note: building the retrieval model might take a few minutes. | |
- Usage: upload a few PDF files to index. Build the model. Ask questions. | |
""") | |
# Input files and build status | |
with gr.Row(): | |
upload_files = gr.File( | |
label="Upload PDF files to index", file_count="multiple", | |
value=["OECD_Engaging_with_HNW_individuals_tax_compliance_(2009).pdf",], | |
scale=5) | |
build_status = gr.Textbox(label="Build status", placeholder="", scale=2) | |
# button | |
build_button = gr.Button("Build retrieval generation model", variant='primary') | |
# Question to answer | |
question = gr.Textbox( | |
label="Question about the content of the documents uploaded", | |
placeholder="How do tax administrations address aggressive tax planning by HNWIs?" | |
) | |
response = gr.Textbox( | |
label="Response", | |
placeholder="" | |
) | |
with gr.Accordion("References & snippets", open=False): | |
references = gr.HTML(label="References") | |
snippets = gr.HTML(label="Snippets") | |
# button | |
response_button = gr.Button("Submit", variant='primary') | |
# Example questions given default provided PDF file | |
with gr.Accordion("Sample questions", open=False): | |
gr.Examples( | |
[ | |
["What are the tax risks associated with high net worth individuals (HNWIs)?",], | |
["How do tax administrations address aggressive tax planning by HNWIs?",], | |
["How can tax administrations engage with HNWIs to improve tax compliance?",], | |
["What are the benefits of establishing dedicated HNWI units within tax administrations?",], | |
["How can international cooperation help address offshore tax risks associated with HNWIs?",], | |
], | |
inputs=[question,], | |
outputs=[response, references, snippets], | |
fn=generate_response, | |
cache_examples=False, | |
label="Sample questions" | |
) | |
# Documentation | |
with gr.Accordion("Documentation", open=False): | |
gr.Markdown(""" | |
- What | |
- Retrieval augmented generation (RAG) model based on ColBERT and DSPy. | |
- Retrieval base model: 'antoinelouis/colbert-xm' (multilingual model) | |
- Generation framework: DSPy and Mistral. | |
- How | |
- Upload PDF files to index. | |
- Build the retrieval generation model (might take a few minutes) | |
- Ask a question about the content of those uploaded documents. | |
""") | |
# Click actions | |
build_button.click( | |
fn=build_rag_model, | |
inputs=[upload_files], | |
outputs=[build_status] | |
) | |
response_button.click( | |
fn=generate_response, | |
inputs=[question], | |
outputs=[response, references, snippets] | |
) | |
demo.launch(show_api=False) | |