File size: 5,350 Bytes
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bebe878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c18375
 
 
 
 
 
 
42aa62b
1c18375
 
 
 
 
 
 
 
 
 
 
b9e2642
 
 
c7a2aef
1c18375
 
 
 
 
 
6f7cb14
1c18375
 
 
 
 
 
 
 
29d567f
b9e2642
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f7cb14
 
 
 
 
 
 
 
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a2aef
 
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
""" app.py

Question / answer over a collection of PDF documents using late interaction
ColBERT model for retrieval and DSPy+Mistral for answer generation.

:author: Didier Guillevic
:date: 2024-12-22
"""

import gradio as gr

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import os
import pdf_utils # utilities for pdf processing
import colbert_utils # utilities for to build a ColBERT retrieval model
import dspy_utils # utilities for building a DSPy based retrieval generation model

from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')


dspy_rag_model = None

def build_rag_model(files: list[str]) -> str:
    """Build a retrieval augmented model using given files to index.
    
    """
    global dspy_rag_model

    # Get the text from the pdf files
    documents = []
    metadatas = []
    for pdf_file in files:
        logger.info(f"Processing {pdf_file}")
        metadata = pdf_utils.get_metadata_info(pdf_file)
        text = pdf_utils.get_text_from_pdf(pdf_file)
        if text:
            documents.append(text)
            metadatas.append(metadata)
    
    # Build the ColBERT retrieval model
    colbert_base_model = 'antoinelouis/colbert-xm' # multilingual model
    colbert_index_name = 'OECD_HNW' # for web app, generate unique name with uuid.uuid4()
    retrieval_model = colbert_utils.build_colbert_model(
        documents,
        metadatas,
        pretrained_model=colbert_base_model, 
        index_name=colbert_index_name
    )

    # Instanatiate the DSPy based RAG model
    dspy_rag_model = dspy_utils.DSPyRagModel(retrieval_model)

    return "Done building RAG model."


def generate_response(question: str) -> list[str, str, str]:
    """Generate a response to a given question using the RAG model.
    
    """
    global dspy_rag_model

    if dspy_rag_model is None:
        return "RAG model not built. Please build the model first.", [], []

    # Generate response
    responses, references, snippets = dspy_rag_model.generate_response(
        question=question, k=5, method='chain_of_thought')
    
    return responses, references, snippets


with gr.Blocks() as demo:
    gr.Markdown("""
        # Retrieval (ColBERT) + Generation (DSPy & Mistral)
        - Note:
            - building the retrieval model will be slow on **free CPU** (expect 5+ minutes).
            - first question/answer will be slow (2 minutes for model loading). Subsequent question (approx. 10 seconds)
        - Usage: upload a few PDF files to index. Build the model. Ask questions.
    """)

    # Input files and build status
    with gr.Row():
        upload_files = gr.File(
            label="Upload PDF files to index", file_count="multiple",
            value=["OECD_Dividend_tax_fraud_2023-en.pdf",],
            scale=5)
        build_status = gr.Textbox(label="Build status", placeholder="", scale=2)
    
    # button
    build_button = gr.Button("Build retrieval generation model", variant='primary')
    
    # Question to answer
    question = gr.Textbox(
        label="Question",
        placeholder="What is dividend stripping?"
    )
    response = gr.Textbox(
        label="Response",
        placeholder=""
    )
    with gr.Accordion("References & snippets", open=False):
        references = gr.HTML(label="References")
        snippets = gr.HTML(label="Snippets")
    
    # button
    response_button = gr.Button("Submit", variant='primary')
    
    # Example questions given default provided PDF file
    with gr.Accordion("Sample questions", open=False):
        gr.Examples(
            [
                ["What is dividend stripping?",],
                ["What are the most common types of dividend stripping schemes?",],
                ["How do authorities detect dividend stripping?",],
                ["What are some indicators of potential dividend stripping?",],
                ["What are the consequences of dividend stripping?",],
                ["How can countries combat dividend stripping?",],
                ["What is the role of professional enablers in dividend stripping?",],
                ["How can countries address the role of professional enablers in dividend stripping?",],
            ],
            inputs=[question,],
            outputs=[response, references, snippets],
            fn=generate_response,
            cache_examples=False,
            label="Sample questions"
        )
    
    # Documentation
    with gr.Accordion("Documentation", open=False):
        gr.Markdown("""
            - What
                - Retrieval augmented generation (RAG) model based on ColBERT and DSPy.
                - Retrieval base model:  'antoinelouis/colbert-xm' (multilingual model)
                - Generation framework: DSPy and Mistral.
            - How
                - Upload PDF files to index.
                - Build the retrieval generation model (might take a few minutes)
                - Ask a question about the content of those uploaded documents.
        """)

    # Click actions
    build_button.click(
        fn=build_rag_model,
        inputs=[upload_files],
        outputs=[build_status]
    )
    response_button.click(
        fn=generate_response,
        inputs=[question],
        outputs=[response, references, snippets]
    )


demo.launch(show_api=False)