Spaces:
Running
Running
""" app.py | |
An agent with access to a hybrid search tool and a large language model. | |
The search tool has access to a collection of documents from the OECD related | |
to international tax crimes. | |
Agentic framework: | |
- smolagents | |
Retrieval model: | |
- LanceDB: support for hybrid search search with reranking of results. | |
- Full text search (lexical): BM25 | |
- Vector search (semantic dense vectors): BAAI/bge-m3 | |
Rerankers: | |
- ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI | |
Generation: | |
- Mistral | |
:author: Didier Guillevic | |
:date: 2025-01-05 | |
""" | |
import gradio as gr | |
import lancedb | |
import smolagents | |
import os | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
# | |
# LanceDB with the indexed documents | |
# | |
# Connect to the database | |
lance_db = lancedb.connect("lance.db") | |
lance_tbl = lance_db.open_table("documents") | |
# Document schema | |
class Document(lancedb.pydantic.LanceModel): | |
text: str | |
vector: lancedb.pydantic.Vector(1024) | |
file_name: str | |
num_pages: int | |
creation_date: str | |
modification_date: str | |
# | |
# Retrieval: query types and reranker types | |
# | |
query_types = { | |
'lexical': 'fts', | |
'semantic': 'vector', | |
'hybrid': 'hybrid', | |
} | |
# Define a few rerankers | |
colbert_reranker = lancedb.rerankers.ColbertReranker(column='text') | |
answerai_reranker = lancedb.rerankers.AnswerdotaiRerankers(column='text') | |
crossencoder_reranker = lancedb.rerankers.CrossEncoderReranker(column='text') | |
reciprocal_rank_fusion_reranker = lancedb.rerankers.RRFReranker() # hybrid search only | |
reranker_types = { | |
'ColBERT': colbert_reranker, | |
'cross encoder': crossencoder_reranker, | |
'AnswerAI': answerai_reranker, | |
'Reciprocal Rank Fusion': reciprocal_rank_fusion_reranker | |
} | |
def search_table( | |
table: lancedb.table, | |
query: str, | |
query_type: str='hybrid', | |
reranker_name: str='cross encoder', | |
filter_year: int=2000, | |
top_k: int=5, | |
overfetch_factor: int=2 | |
): | |
# Get the instance of reranker | |
reranker = reranker_types.get(reranker_name) | |
if reranker is None: | |
logger.error(f"Invalid reranker name: {reranker_name}") | |
raise ValueError(f"Invalid reranker selected: {reranker_name}") | |
if query_type in ["vector", "fts"]: | |
if reranker == reciprocal_rank_fusion_reranker: | |
# reciprocal is for 'hybrid' search type only | |
reranker = crossencoder_reranker | |
results = ( | |
table.search(query, query_type=query_type) | |
.where(f"creation_date >= '{filter_year}'", prefilter=True) | |
.rerank(reranker=reranker) | |
.limit(top_k * overfetch_factor) | |
.to_pydantic(Document) | |
) | |
elif query_type == "hybrid": | |
results = ( | |
table.search(query, query_type=query_type) | |
.where(f"creation_date >= '{filter_year}'", prefilter=True) | |
.rerank(reranker=reranker) | |
.limit(top_k) | |
.to_pydantic(Document) | |
) | |
return results[:top_k] | |
# | |
# Define a retriever tool | |
# | |
class RetrieverTool(smolagents.Tool): | |
name = "retriever" | |
description = "Uses hybrid search to retrieve snippets from OECD documents that could be most relevant to answer your query." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
} | |
} | |
output_type = "string" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
results = search_table(table=lance_tbl, query=query) | |
return "\nRetrieved documents:\n" + "".join( | |
[ | |
f"\n\n===== Document {str(i)} =====\n" + result.text | |
for i, result in enumerate(results) | |
] | |
) | |
retriever_tool = RetrieverTool() | |
# | |
# Define a language model | |
# | |
mistral_api_key = os.environ["MISTRAL_API_KEY"] | |
mistral_model_id = "mistral/mistral-large-latest" # 128k context window | |
#mistral_model_id = "mistral/codestral-latest" | |
mistral_model = smolagents.LiteLLMModel( | |
model_id=mistral_model_id, api_key=mistral_api_key) | |
# | |
# Define an agent with access to tool(s) and language model. | |
# | |
agent = smolagents.CodeAgent( | |
tools=[retriever_tool], | |
model=mistral_model, | |
max_iterations=4, | |
verbose=True | |
) | |
# | |
# app | |
# | |
def generate_response(query: str) -> str: | |
"""Generate a response given query, search type and reranker. | |
Args: | |
Returns: | |
- the response from the agent having access to a retriever tool over | |
a collection of documents and a large language model. | |
""" | |
agent_output = agent.run(query) | |
return agent_output | |
# | |
# User interface | |
# | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Agentic Hybrid search | |
Document collection: OECD documents on international tax crimes. | |
""") | |
# Inputs: question | |
question = gr.Textbox( | |
label="Question to answer", | |
placeholder="" | |
) | |
# Response / references / snippets | |
response = gr.Textbox( | |
label="Response", | |
placeholder="" | |
) | |
# Button | |
with gr.Row(): | |
response_button = gr.Button("Submit", variant='primary') | |
clear_button = gr.Button("Clear", variant='secondary') | |
# Example questions given default provided PDF file | |
with gr.Accordion("Sample questions", open=False): | |
gr.Examples( | |
[ | |
["What is the OECD's role in combating offshore tax evasion?",], | |
["What are the key tools used in fighting offshore tax evasion?",], | |
['What are "High Net Worth Individuals" (HNWIs) and how do they relate to tax compliance efforts?',], | |
["What is the significance of international financial centers (IFCs) in the context of tax evasion?",], | |
["What is being done to address the role of professional enablers in facilitating tax evasion?",], | |
["How does the OECD measure the effectiveness of international efforts to fight offshore tax evasion?",], | |
['What are the "Ten Global Principles" for fighting tax crime?',], | |
["What are some recent developments in the fight against offshore tax evasion?",], | |
], | |
inputs=[question,], | |
outputs=[response,], | |
fn=generate_response, | |
cache_examples=False, | |
label="Sample questions" | |
) | |
# Documentation | |
with gr.Accordion("Documentation", open=False): | |
gr.Markdown(""" | |
- Agentic framework | |
- Hugging Face's smolagents | |
- Retrieval model | |
- LanceDB: support for hybrid search search with reranking of results. | |
- Full text search (lexical): BM25 | |
- Vector search (semantic dense vectors): BAAI/bge-m3 | |
- Rerankers | |
- ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI | |
- Generation | |
- Mistral | |
- Examples | |
- Generated using Google NotebookLM | |
""") | |
# Click actions | |
response_button.click( | |
fn=generate_response, | |
inputs=[question,], | |
outputs=[response,] | |
) | |
clear_button.click( | |
fn=lambda: ('', ''), | |
inputs=[], | |
outputs=[question, response] | |
) | |
demo.launch(show_api=False) | |