Spaces:
Sleeping
Sleeping
""" app.py | |
Question / answer over a collection of PDF documents from OECD.org. | |
PDF text extraction: | |
- pypdf | |
Retrieval model: | |
- LanceDB: support for hybrid search search with reranking of results. | |
- Full text search (lexical): BM25 | |
- Vector search (semantic dense vectors): BAAI/bge-m3 | |
Rerankers: | |
- ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI | |
Generation: | |
- Mistral | |
:author: Didier Guillevic | |
:date: 2024-12-28 | |
""" | |
import gradio as gr | |
import lancedb | |
import llm_utils | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
# | |
# LanceDB with the indexed documents | |
# | |
# Connect to the database | |
lance_db = lancedb.connect("lance.db") | |
lance_tbl = lance_db.open_table("documents") | |
# Document schema | |
class Document(lancedb.pydantic.LanceModel): | |
text: str | |
vector: lancedb.pydantic.Vector(1024) | |
file_name: str | |
num_pages: int | |
creation_date: str | |
modification_date: str | |
# | |
# Retrieval: query types and reranker types | |
# | |
query_types = { | |
'lexical': 'fts', | |
'semantic': 'vector', | |
'hybrid': 'hybrid', | |
} | |
# Define a few rerankers | |
colbert_reranker = lancedb.rerankers.ColbertReranker(column='text') | |
answerai_reranker = lancedb.rerankers.AnswerdotaiRerankers(column='text') | |
crossencoder_reranker = lancedb.rerankers.CrossEncoderReranker(column='text') | |
reciprocal_rank_fusion_reranker = lancedb.rerankers.RRFReranker() # hybrid search only | |
reranker_types = { | |
'ColBERT': colbert_reranker, | |
'cross encoder': crossencoder_reranker, | |
'AnswerAI': answerai_reranker, | |
'Reciprocal Rank Fusion': reciprocal_rank_fusion_reranker | |
} | |
def search_table( | |
table: lancedb.table, | |
query: str, | |
query_type: str, | |
reranker_name: str, | |
filter_year: int, | |
top_k: int=5, | |
overfetch_factor: int=2 | |
): | |
# Get the instance of reranker | |
reranker = reranker_types.get(reranker_name) | |
if reranker is None: | |
logger.error(f"Invalid reranker name: {reranker_name}") | |
raise ValueError(f"Invalid reranker selected: {reranker_name}") | |
if query_type in ["vector", "fts"]: | |
if reranker == reciprocal_rank_fusion_reranker: | |
# reciprocal is for 'hybrid' search type only | |
reranker = crossencoder_reranker | |
results = ( | |
table.search(query, query_type=query_type) | |
.where(f"creation_date >= '{filter_year}'", prefilter=True) | |
.rerank(reranker=reranker) | |
.limit(top_k * overfetch_factor) | |
.to_pydantic(Document) | |
) | |
elif query_type == "hybrid": | |
results = ( | |
table.search(query, query_type=query_type) | |
.where(f"creation_date >= '{filter_year}'", prefilter=True) | |
.rerank(reranker=reranker) | |
.limit(top_k) | |
.to_pydantic(Document) | |
) | |
return results[:top_k] | |
# | |
# Generatton: query + context --> response | |
# | |
def create_bulleted_list(texts: list[str]) -> str: | |
""" | |
This function takes a list of strings and returns HTML with a bulleted list. | |
""" | |
html_items = [] | |
for item in texts: | |
html_items.append(f"<li>{item}</li>") | |
return "<ul>" + "".join(html_items) + "</ul>" | |
def generate_response( | |
query: str, | |
query_type: str, | |
reranker_name: str, | |
filter_year: int | |
) -> list[str, str, str]: | |
"""Generate a response given query, search type and reranker. | |
Args: | |
Returns: | |
- the response given the snippets extracted from the database | |
- (html string): the references (origin of the snippets of text used to generate the answer) | |
- (html string): the snippets of text used to generate the answer | |
""" | |
# Get results from LanceDB | |
results = search_table( | |
lance_tbl, | |
query=query, | |
query_type=query_type, | |
reranker_name=reranker_name, | |
filter_year=filter_year | |
) | |
references = [result.file_name for result in results] | |
references_html = "<h4>References</h4>\n" + create_bulleted_list(references) | |
snippets = [result.text for result in results] | |
snippets_html = "<h4>Snippets</h4>\n" + create_bulleted_list(snippets) | |
# Generate the reponse from the LLM | |
stream_reponse = llm_utils.generate_chat_response_streaming( | |
query, '\n\n'.join(snippets) | |
) | |
model_response = "" | |
for chunk in stream_reponse: | |
model_response += chunk.data.choices[0].delta.content | |
yield model_response, references_html, snippets_html | |
# | |
# User interface | |
# | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Hybrid search / reranking / Mistral | |
Document collection: OECD documents on international tax crimes. | |
""") | |
# Inputs: question | |
question = gr.Textbox( | |
label="Question to answer", | |
placeholder="" | |
) | |
# Response / references / snippets | |
response = gr.Textbox( | |
label="Response", | |
placeholder="" | |
) | |
with gr.Accordion("References & snippets", open=False): | |
references = gr.HTML(label="References") | |
snippets = gr.HTML(label="Snippets") | |
# Button | |
with gr.Row(): | |
response_button = gr.Button("Submit", variant='primary') | |
clear_button = gr.Button("Clear", variant='secondary') | |
# Additional inputs | |
query_type = gr.Dropdown( | |
choices=query_types.items(), | |
value='hybrid', | |
label='Query type', | |
render=False | |
) | |
reranker_name = gr.Dropdown( | |
choices=list(reranker_types.keys()), | |
value='cross encoder', | |
label='Reranker', | |
render=False | |
) | |
filter_year = gr.Slider( | |
minimum=2005, maximum=2020, value=2005, step=1, | |
label='Creation date >=', render=False | |
) | |
with gr.Row(): | |
# Example questions given default provided PDF file | |
with gr.Accordion("Sample questions", open=False): | |
gr.Examples( | |
[ | |
["What is the OECD's role in combating offshore tax evasion?",], | |
["What are the key tools used in fighting offshore tax evasion?",], | |
['What are "High Net Worth Individuals" (HNWIs) and how do they relate to tax compliance efforts?',], | |
["What is the significance of international financial centers (IFCs) in the context of tax evasion?",], | |
["What is being done to address the role of professional enablers in facilitating tax evasion?",], | |
["How does the OECD measure the effectiveness of international efforts to fight offshore tax evasion?",], | |
['What are the "Ten Global Principles" for fighting tax crime?',], | |
["What are some recent developments in the fight against offshore tax evasion?",], | |
], | |
inputs=[question, query_type, reranker_name, filter_year], | |
outputs=[response, references, snippets], | |
fn=generate_response, | |
cache_examples=False, | |
label="Sample questions" | |
) | |
# Additional inputs: search parameters | |
with gr.Accordion("Search parameters", open=False): | |
with gr.Row(): | |
query_type.render() | |
reranker_name.render() | |
filter_year.render() | |
# Documentation | |
with gr.Accordion("Documentation", open=False): | |
gr.Markdown(""" | |
- Retrieval model | |
- LanceDB: support for hybrid search search with reranking of results. | |
- Full text search (lexical): BM25 | |
- Vector search (semantic dense vectors): BAAI/bge-m3 | |
- Rerankers | |
- ColBERT, cross encoder, reciprocal rank fusion, AnswerDotAI | |
- Generation | |
- Mistral | |
- Examples | |
- Generated using Google NotebookLM | |
""") | |
# Click actions | |
response_button.click( | |
fn=generate_response, | |
inputs=[question, query_type, reranker_name, filter_year], | |
outputs=[response, references, snippets] | |
) | |
clear_button.click( | |
fn=lambda: ('', '', '', ''), | |
inputs=[], | |
outputs=[question, response, references, snippets] | |
) | |
demo.launch(show_api=False) | |