Didier Guillevic
Another fix
c3d3886
""" app.py
An agent with access to a hybrid search tool and a large language model.
The search tool has access to a collection of documents from the OECD related
to international tax crimes.
Agentic framework:
- smolagents
Retrieval model:
- LanceDB: support for hybrid search search with reranking of results.
- Full text search (lexical): BM25
- Vector search (semantic dense vectors): BAAI/bge-m3
Rerankers:
- ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI
Generation:
- Mistral
:author: Didier Guillevic
:date: 2025-01-05
"""
import gradio as gr
import lancedb
import smolagents
import os
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
#
# LanceDB with the indexed documents
#
# Connect to the database
lance_db = lancedb.connect("lance.db")
lance_tbl = lance_db.open_table("documents")
# Document schema
class Document(lancedb.pydantic.LanceModel):
text: str
vector: lancedb.pydantic.Vector(1024)
file_name: str
num_pages: int
creation_date: str
modification_date: str
#
# Retrieval: query types and reranker types
#
query_types = {
'lexical': 'fts',
'semantic': 'vector',
'hybrid': 'hybrid',
}
# Define a few rerankers
colbert_reranker = lancedb.rerankers.ColbertReranker(column='text')
answerai_reranker = lancedb.rerankers.AnswerdotaiRerankers(column='text')
crossencoder_reranker = lancedb.rerankers.CrossEncoderReranker(column='text')
reciprocal_rank_fusion_reranker = lancedb.rerankers.RRFReranker() # hybrid search only
reranker_types = {
'ColBERT': colbert_reranker,
'cross encoder': crossencoder_reranker,
'AnswerAI': answerai_reranker,
'Reciprocal Rank Fusion': reciprocal_rank_fusion_reranker
}
def search_table(
table: lancedb.table,
query: str,
query_type: str='hybrid',
reranker_name: str='cross encoder',
filter_year: int=2000,
top_k: int=5,
overfetch_factor: int=2
):
# Get the instance of reranker
reranker = reranker_types.get(reranker_name)
if reranker is None:
logger.error(f"Invalid reranker name: {reranker_name}")
raise ValueError(f"Invalid reranker selected: {reranker_name}")
if query_type in ["vector", "fts"]:
if reranker == reciprocal_rank_fusion_reranker:
# reciprocal is for 'hybrid' search type only
reranker = crossencoder_reranker
results = (
table.search(query, query_type=query_type)
.where(f"creation_date >= '{filter_year}'", prefilter=True)
.rerank(reranker=reranker)
.limit(top_k * overfetch_factor)
.to_pydantic(Document)
)
elif query_type == "hybrid":
results = (
table.search(query, query_type=query_type)
.where(f"creation_date >= '{filter_year}'", prefilter=True)
.rerank(reranker=reranker)
.limit(top_k)
.to_pydantic(Document)
)
return results[:top_k]
#
# Define a retriever tool
#
class RetrieverTool(smolagents.Tool):
name = "retriever"
description = "Uses hybrid search to retrieve snippets from OECD documents that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
}
}
output_type = "string"
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
results = search_table(table=lance_tbl, query=query)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + result.text
for i, result in enumerate(results)
]
)
retriever_tool = RetrieverTool()
#
# Define a language model
#
mistral_api_key = os.environ["MISTRAL_API_KEY"]
mistral_model_id = "mistral/mistral-large-latest" # 128k context window
#mistral_model_id = "mistral/codestral-latest"
mistral_model = smolagents.LiteLLMModel(
model_id=mistral_model_id, api_key=mistral_api_key)
#
# Define an agent with access to tool(s) and language model.
#
agent = smolagents.CodeAgent(
tools=[retriever_tool],
model=mistral_model,
max_iterations=4,
verbose=True
)
#
# app
#
def generate_response(query: str) -> str:
"""Generate a response given query, search type and reranker.
Args:
Returns:
- the response from the agent having access to a retriever tool over
a collection of documents and a large language model.
"""
agent_output = agent.run(query)
return agent_output
#
# User interface
#
with gr.Blocks() as demo:
gr.Markdown("""
# Agentic Hybrid search
Document collection: OECD documents on international tax crimes.
""")
# Inputs: question
question = gr.Textbox(
label="Question to answer",
placeholder=""
)
# Response / references / snippets
response = gr.Textbox(
label="Response",
placeholder=""
)
# Button
with gr.Row():
response_button = gr.Button("Submit", variant='primary')
clear_button = gr.Button("Clear", variant='secondary')
# Example questions given default provided PDF file
with gr.Accordion("Sample questions", open=False):
gr.Examples(
[
["What is the OECD's role in combating offshore tax evasion?",],
["What are the key tools used in fighting offshore tax evasion?",],
['What are "High Net Worth Individuals" (HNWIs) and how do they relate to tax compliance efforts?',],
["What is the significance of international financial centers (IFCs) in the context of tax evasion?",],
["What is being done to address the role of professional enablers in facilitating tax evasion?",],
["How does the OECD measure the effectiveness of international efforts to fight offshore tax evasion?",],
['What are the "Ten Global Principles" for fighting tax crime?',],
["What are some recent developments in the fight against offshore tax evasion?",],
],
inputs=[question,],
outputs=[response,],
fn=generate_response,
cache_examples=False,
label="Sample questions"
)
# Documentation
with gr.Accordion("Documentation", open=False):
gr.Markdown("""
- Agentic framework
- Hugging Face's smolagents
- Retrieval model
- LanceDB: support for hybrid search search with reranking of results.
- Full text search (lexical): BM25
- Vector search (semantic dense vectors): BAAI/bge-m3
- Rerankers
- ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI
- Generation
- Mistral
- Examples
- Generated using Google NotebookLM
""")
# Click actions
response_button.click(
fn=generate_response,
inputs=[question,],
outputs=[response,]
)
clear_button.click(
fn=lambda: ('', ''),
inputs=[],
outputs=[question, response]
)
demo.launch(show_api=False)