File size: 5,184 Bytes
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bebe878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a2aef
 
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a2aef
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a2aef
 
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
""" app.py

Question / answer over a collection of PDF documents using late interaction
ColBERT model for retrieval and DSPy+Mistral for answer generation.

:author: Didier Guillevic
:date: 2024-12-22
"""

import gradio as gr

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import os
import pdf_utils # utilities for pdf processing
import colbert_utils # utilities for to build a ColBERT retrieval model
import dspy_utils # utilities for building a DSPy based retrieval generation model

from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')


dspy_rag_model = None

def build_rag_model(files: list[str]) -> str:
    """Build a retrieval augmented model using given files to index.
    
    """
    global dspy_rag_model

    # Get the text from the pdf files
    documents = []
    metadatas = []
    for pdf_file in files:
        logger.info(f"Processing {pdf_file}")
        metadata = pdf_utils.get_metadata_info(pdf_file)
        text = pdf_utils.get_text_from_pdf(pdf_file)
        if text:
            documents.append(text)
            metadatas.append(metadata)
    
    # Build the ColBERT retrieval model
    colbert_base_model = 'antoinelouis/colbert-xm' # multilingual model
    colbert_index_name = 'OECD_HNW' # for web app, generate unique name with uuid.uuid4()
    retrieval_model = colbert_utils.build_colbert_model(
        documents,
        metadatas,
        pretrained_model=colbert_base_model, 
        index_name=colbert_index_name
    )

    # Instanatiate the DSPy based RAG model
    dspy_rag_model = dspy_utils.DSPyRagModel(retrieval_model)

    return "Done building RAG model."


def generate_response(question: str) -> list[str, str, str]:
    """Generate a response to a given question using the RAG model.
    
    """
    global dspy_rag_model

    if dspy_rag_model is None:
        return "RAG model not built. Please build the model first."

    # Generate response
    responses, references, snippets = dspy_rag_model.generate_response(
        question=question, k=5, method='chain_of_thought')
    
    return responses, references, snippets


with gr.Blocks() as demo:
    gr.Markdown("""
        # Retrieval (ColBERT) + Generation (DSPy & Mistral)
        - Note: building the retrieval model might take a few minutes.
        - Usage: upload a few PDF files to index. Build the model. Ask questions.
    """)

    # Input files and build status
    with gr.Row():
        upload_files = gr.File(
            label="Upload PDF files to index", file_count="multiple",
            value=["OECD_Engaging_with_HNW_individuals_tax_compliance_(2009).pdf",],
            scale=5)
        build_status = gr.Textbox(label="Build status", placeholder="", scale=2)
    
    # button
    build_button = gr.Button("Build retrieval generation model", variant='primary')
    
    # Question to answer
    question = gr.Textbox(
        label="Question about the content of the documents uploaded",
        placeholder="How do tax administrations address aggressive tax planning by HNWIs?"
    )
    response = gr.Textbox(
        label="Response",
        placeholder=""
    )
    with gr.Accordion("References & snippets", open=False):
        references = gr.HTML(label="References")
        snippets = gr.HTML(label="Snippets")
    
    # button
    response_button = gr.Button("Submit", variant='primary')
    
    # Example questions given default provided PDF file
    with gr.Accordion("Sample questions", open=False):
        gr.Examples(
            [
                ["What are the tax risks associated with high net worth individuals (HNWIs)?",],
                ["How do tax administrations address aggressive tax planning by HNWIs?",],
                ["How can tax administrations engage with HNWIs to improve tax compliance?",],
                ["What are the benefits of establishing dedicated HNWI units within tax administrations?",],
                ["How can international cooperation help address offshore tax risks associated with HNWIs?",],
            ],
            inputs=[question,],
            outputs=[response, references, snippets],
            fn=generate_response,
            cache_examples=False,
            label="Sample questions"
        )
    
    # Documentation
    with gr.Accordion("Documentation", open=False):
        gr.Markdown("""
            - What
                - Retrieval augmented generation (RAG) model based on ColBERT and DSPy.
                - Retrieval base model:  'antoinelouis/colbert-xm' (multilingual model)
                - Generation framework: DSPy and Mistral.
            - How
                - Upload PDF files to index.
                - Build the retrieval generation model (might take a few minutes)
                - Ask a question about the content of those uploaded documents.
        """)

    # Click actions
    build_button.click(
        fn=build_rag_model,
        inputs=[upload_files],
        outputs=[build_status]
    )
    response_button.click(
        fn=generate_response,
        inputs=[question],
        outputs=[response, references, snippets]
    )


demo.launch(show_api=False)