Didier Guillevic commited on
Commit
fecab9d
·
1 Parent(s): a3b9984

Adding app.y llm_utils.py and the build requirements.

Browse files
Files changed (3) hide show
  1. app.py +269 -0
  2. llm_utils.py +55 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ app.py
2
+
3
+ Question / answer over a collection of PDF documents from OECD.org.
4
+
5
+ PDF text extraction:
6
+ - pypdf
7
+
8
+ Retrieval model:
9
+ - LanceDB: support for hybrid search search with reranking of results.
10
+ - Full text search (lexical): BM25
11
+ - Vector search (semantic dense vectors): BAAI/bge-m3
12
+
13
+ Rerankers:
14
+ - ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI
15
+
16
+ Generation:
17
+ - Mistral
18
+
19
+ :author: Didier Guillevic
20
+ :date: 2024-12-28
21
+ """
22
+
23
+ import gradio as gr
24
+ import lancedb
25
+
26
+ import llm_utils
27
+
28
+ import logging
29
+ logger = logging.getLogger(__name__)
30
+ logging.basicConfig(level=logging.INFO)
31
+
32
+ #
33
+ # LanceDB with the indexed documents
34
+ #
35
+
36
+ # Connect to the database
37
+ lance_db = lancedb.connect("lance.db")
38
+ lance_tbl = lance_db.open_table("documents")
39
+
40
+ # Document schema
41
+ class Document(lancedb.pydantic.LanceModel):
42
+ text: str
43
+ vector: lancedb.pydantic.Vector(1024)
44
+ file_name: str
45
+ num_pages: int
46
+ creation_date: str
47
+ modification_date: str
48
+
49
+ #
50
+ # Retrieval: query types and reranker types
51
+ #
52
+
53
+ query_types = {
54
+ 'lexical': 'fts',
55
+ 'semantic': 'vector',
56
+ 'hybrid': 'hybrid',
57
+ }
58
+
59
+ # Define a few rerankers
60
+ colbert_reranker = lancedb.rerankers.ColbertReranker(column='text')
61
+ answerai_reranker = lancedb.rerankers.AnswerdotaiRerankers(column='text')
62
+ crossencoder_reranker = lancedb.rerankers.CrossEncoderReranker(column='text')
63
+ reciprocal_rank_fusion_reranker = lancedb.rerankers.RRFReranker() # hybrid search only
64
+
65
+ reranker_types = {
66
+ 'ColBERT': colbert_reranker,
67
+ 'cross encoder': crossencoder_reranker,
68
+ 'AnswerAI': answerai_reranker,
69
+ 'Reciprocal Rank Fusion': reciprocal_rank_fusion_reranker
70
+ }
71
+
72
+ def search_table(
73
+ table: lancedb.table,
74
+ query: str,
75
+ query_type: str,
76
+ reranker_name: str,
77
+ filter_year: int,
78
+ top_k: int=5,
79
+ overfetch_factor: int=2
80
+ ):
81
+ # Get the instance of reranker
82
+ reranker = reranker_types.get(reranker_name)
83
+ if reranker is None:
84
+ logger.error(f"Invalid reranker name: {reranker_name}")
85
+ raise ValueError(f"Invalid reranker selected: {reranker_name}")
86
+
87
+ if query_type in ["vector", "fts"]:
88
+ if reranker == reciprocal_rank_fusion_reranker:
89
+ # reciprocal is for 'hybrid' search type only
90
+ reranker = crossencoder_reranker
91
+ results = (
92
+ table.search(query, query_type=query_type)
93
+ .where(f"creation_date >= '{filter_year}'", prefilter=True)
94
+ .rerank(reranker=reranker)
95
+ .limit(top_k * overfetch_factor)
96
+ .to_pydantic(Document)
97
+ )
98
+ elif query_type == "hybrid":
99
+ results = (
100
+ table.search(query, query_type=query_type)
101
+ .where(f"creation_date >= '{filter_year}'", prefilter=True)
102
+ .rerank(reranker=reranker)
103
+ .limit(top_k)
104
+ .to_pydantic(Document)
105
+ )
106
+
107
+ return results[:top_k]
108
+
109
+
110
+ #
111
+ # Generatton: query + context --> response
112
+ #
113
+
114
+ def create_bulleted_list(texts: list[str]) -> str:
115
+ """
116
+ This function takes a list of strings and returns HTML with a bulleted list.
117
+ """
118
+ html_items = []
119
+ for item in texts:
120
+ html_items.append(f"<li>{item}</li>")
121
+ return "<ul>" + "".join(html_items) + "</ul>"
122
+
123
+
124
+ def generate_response(
125
+ query: str,
126
+ query_type: str,
127
+ reranker_name: str,
128
+ filter_year: int
129
+ ) -> list[str, str, str]:
130
+ """Generate a response given query, search type and reranker.
131
+
132
+ Args:
133
+
134
+ Returns:
135
+ - the response given the snippets extracted from the database
136
+ - (html string): the references (origin of the snippets of text used to generate the answer)
137
+ - (html string): the snippets of text used to generate the answer
138
+ """
139
+ # Get results from LanceDB
140
+ results = search_table(
141
+ lance_tbl,
142
+ query=query,
143
+ query_type=query_type,
144
+ reranker_name=reranker_name,
145
+ filter_year=filter_year
146
+ )
147
+
148
+ references = [result.file_name for result in results]
149
+ references_html = "<h4>References</h4>\n" + create_bulleted_list(references)
150
+
151
+ snippets = [result.text for result in results]
152
+ snippets_html = "<h4>Snippets</h4>\n" + create_bulleted_list(snippets)
153
+
154
+ # Generate the reponse from the LLM
155
+ stream_reponse = llm_utils.generate_chat_response_streaming(
156
+ query, '\n\n'.join(snippets)
157
+ )
158
+
159
+ model_response = ""
160
+ for chunk in stream_reponse:
161
+ model_response += chunk.data.choices[0].delta.content
162
+ yield model_response, references_html, snippets_html
163
+
164
+
165
+ #
166
+ # User interface
167
+ #
168
+
169
+ with gr.Blocks() as demo:
170
+ gr.Markdown("""
171
+ # Hybrid search / reranking / Mistral
172
+ Document collection: OECD documents on international tax crimes.
173
+ """)
174
+
175
+ # Inputs: question
176
+ question = gr.Textbox(
177
+ label="Question to answer",
178
+ placeholder=""
179
+ )
180
+
181
+ # Response / references / snippets
182
+ response = gr.Textbox(
183
+ label="Response",
184
+ placeholder=""
185
+ )
186
+ with gr.Accordion("References & snippets", open=False):
187
+ references = gr.HTML(label="References")
188
+ snippets = gr.HTML(label="Snippets")
189
+
190
+ # Button
191
+ with gr.Row():
192
+ response_button = gr.Button("Submit", variant='primary')
193
+ clear_button = gr.Button("Clear", variant='secondary')
194
+
195
+ # Additional inputs
196
+ query_type = gr.Dropdown(
197
+ choices=query_types.items(),
198
+ value='hybrid',
199
+ label='Query type',
200
+ render=False
201
+ )
202
+ reranker_name = gr.Dropdown(
203
+ choices=list(reranker_types.keys()),
204
+ value='cross encoder',
205
+ label='Reranker',
206
+ render=False
207
+ )
208
+ filter_year = gr.Slider(
209
+ minimum=2005, maximum=2020, value=2005, step=1,
210
+ label='Creation date >=', render=False
211
+ )
212
+
213
+ with gr.Row():
214
+ # Example questions given default provided PDF file
215
+ with gr.Accordion("Sample questions", open=False):
216
+ gr.Examples(
217
+ [
218
+ ["What is the OECD's role in combating offshore tax evasion?",],
219
+ ["What are the key tools used in fighting offshore tax evasion?",],
220
+ ['What are "High Net Worth Individuals" (HNWIs) and how do they relate to tax compliance efforts?',],
221
+ ["What is the significance of international financial centers (IFCs) in the context of tax evasion?",],
222
+ ["What is being done to address the role of professional enablers in facilitating tax evasion?",],
223
+ ["How does the OECD measure the effectiveness of international efforts to fight offshore tax evasion?",],
224
+ ['What are the "Ten Global Principles" for fighting tax crime?',],
225
+ ["What are some recent developments in the fight against offshore tax evasion?",],
226
+ ],
227
+ inputs=[question, query_type, reranker_name, filter_year],
228
+ outputs=[response, references, snippets],
229
+ fn=generate_response,
230
+ cache_examples=False,
231
+ label="Sample questions"
232
+ )
233
+
234
+ # Additional inputs: search parameters
235
+ with gr.Accordion("Search parameters", open=False):
236
+ with gr.Row():
237
+ query_type.render()
238
+ reranker_name.render()
239
+ filter_year.render()
240
+
241
+ # Documentation
242
+ with gr.Accordion("Documentation", open=False):
243
+ gr.Markdown("""
244
+ - Retrieval model
245
+ - LanceDB: support for hybrid search search with reranking of results.
246
+ - Full text search (lexical): BM25
247
+ - Vector search (semantic dense vectors): BAAI/bge-m3
248
+ - Rerankers
249
+ - ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI
250
+ - Generation
251
+ - Mistral
252
+ - Examples
253
+ - Generated using Google NotebookLM
254
+ """)
255
+
256
+ # Click actions
257
+ response_button.click(
258
+ fn=generate_response,
259
+ inputs=[question, query_type, reranker_name, filter_year],
260
+ outputs=[response, references, snippets]
261
+ )
262
+ clear_button.click(
263
+ fn=lambda: ('', '', '', ''),
264
+ inputs=[],
265
+ outputs=[question, response, references, snippets]
266
+ )
267
+
268
+
269
+ demo.launch(show_api=False)
llm_utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ llm_utils.py
2
+
3
+ Utilities for working with Large Language Models
4
+
5
+ :author: Didier Guillevic
6
+ :email: [email protected]
7
+ :creation: 2024-12-28
8
+ """
9
+
10
+ import logging
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ import os
15
+ from mistralai import Mistral
16
+
17
+ #
18
+ # Mistral AI client
19
+ #
20
+ api_key = os.environ["MISTRAL_API_KEY"]
21
+ client = Mistral(api_key=api_key)
22
+ model_id = "mistral-large-latest" # 128k context window
23
+
24
+
25
+ #
26
+ # Some functions
27
+ #
28
+ def generate_chat_response_streaming(
29
+ query: str,
30
+ context: str,
31
+ max_new_tokens=1_024,
32
+ temperature=0.0
33
+ ):
34
+ """
35
+
36
+ """
37
+ # Instruction
38
+ instruction = (
39
+ f"You will be given a question and list of context that might "
40
+ f"be relevant to the question. "
41
+ f"Do not include facts not contained in the provided context. "
42
+ f"If no such relecant context provided to answer the question, "
43
+ f"then soimply say so. Do not invent anything.\n\n"
44
+ f"Question: {query}\n\n\n"
45
+ f"Context:\n\n{context}"
46
+ )
47
+
48
+ # messages
49
+ messages = []
50
+ messages.append({'role': 'user', 'content': instruction})
51
+ #logger.info(messages)
52
+
53
+ # Yield the model response as the tokens are being generated
54
+ stream_reponse = client.chat.stream(model=model_id, messages=messages)
55
+ return stream_reponse
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ lancedb
3
+ sentence-transformers
4
+ pytorch
5
+ mistralai