datagouv-french-data-analyst / tools /retrieval_tools.py
axel-darmouni's picture
all gemini
244cc53
import os
import pandas as pd
import pickle
import numpy as np
from smolagents import tool
from rank_bm25 import BM25Okapi
from dotenv import load_dotenv
from smolagents import CodeAgent, LiteLLMModel
from unidecode import unidecode
import numpy as np
load_dotenv()
# Global variables for BM25 model
_bm25_model = None
_precomputed_titles = None
_dataset_df = None
_llm_translator = None
def _initialize_retrieval_system():
"""Initialize the retrieval system with BM25 model and dataset"""
global _bm25_model, _precomputed_titles, _dataset_df, _llm_translator
# Load dataset if not already loaded
if _dataset_df is None:
try:
_dataset_df = pd.read_csv('filtered_dataset.csv')
print(f"✅ Loaded dataset with {len(_dataset_df)} entries")
except FileNotFoundError:
raise Exception("filtered_dataset.csv not found. Please ensure the dataset file exists.")
# Initialize LLM translator if not already initialized
if _llm_translator is None:
try:
model = LiteLLMModel(
model_id="gemini/gemini-2.5-flash-preview-05-20",
api_key=os.getenv("GEMINI_API_KEY")
)
_llm_translator = CodeAgent(tools=[], model=model, max_steps=1)
print("✅ LLM translator initialized")
except Exception as e:
print(f"⚠️ Error initializing LLM translator: {e}")
# Load pre-computed BM25 model if available
if _bm25_model is None:
try:
with open('bm25_data.pkl', 'rb') as f:
bm25_data = pickle.load(f)
_bm25_model = bm25_data['bm25_model']
_precomputed_titles = bm25_data['titles']
print(f"✅ Loaded pre-computed BM25 model for {len(_precomputed_titles)} datasets")
except FileNotFoundError:
print("⚠️ Pre-computed BM25 model not found. Will compute at runtime.")
except Exception as e:
print(f"⚠️ Error loading pre-computed BM25 model: {e}")
def _translate_query_llm(query, target_lang='fr'):
"""Translate query using LLM"""
global _llm_translator
if _llm_translator is None:
return query, 'unknown'
try:
if target_lang == 'fr':
target_language = "French"
elif target_lang == 'en':
target_language = "English"
else:
target_language = target_lang
translation_prompt = f"""
Translate the following text to {target_language}.
If the text is already in {target_language}, return it as is.
Only return the translated text, nothing else.
Text to translate: "{query}"
"""
response = _llm_translator.run(translation_prompt)
translated_text = str(response).strip().strip('"').strip("'")
# Simple language detection
if query.lower() == translated_text.lower():
source_lang = target_lang
else:
source_lang = 'en' if target_lang == 'fr' else 'fr'
return translated_text, source_lang
except Exception as e:
print(f"LLM translation error: {e}")
return query, 'unknown'
def _simple_keyword_preprocessing(text):
"""Simple preprocessing for keyword matching - handles case, accents and basic plurals"""
text = unidecode(str(text).lower())
words = text.split()
processed_words = []
for word in words:
if word.endswith('s') and len(word) > 3 and not word.endswith('ss'):
word = word[:-1]
elif word.endswith('x') and len(word) > 3:
word = word[:-1]
processed_words.append(word)
return processed_words
@tool
def search_datasets(query: str, top_k: int = 5) -> str:
"""
Search for relevant datasets in the French public data catalog using BM25-based keyword matching.
Args:
query: The search query describing what kind of dataset you're looking for
top_k: Number of top results to return (default: 5)
Returns:
A formatted string containing the top matching datasets with their titles, URLs, and relevance scores
"""
try:
# Initialize the retrieval system
_initialize_retrieval_system()
global _bm25_model, _precomputed_titles, _dataset_df
# Translate query to French for better matching
translated_query, original_lang = _translate_query_llm(query, target_lang='fr')
# Combine original and translated queries for search
search_queries = [query, translated_query] if query != translated_query else [query]
# Get dataset titles
dataset_titles = _dataset_df['title'].fillna('').tolist()
# Use pre-computed BM25 model if available and matches current dataset
if (_bm25_model is not None and _precomputed_titles is not None and
len(dataset_titles) == len(_precomputed_titles) and dataset_titles == _precomputed_titles):
bm25 = _bm25_model
else:
# Build BM25 model at runtime
processed_titles = [_simple_keyword_preprocessing(title) for title in dataset_titles]
bm25 = BM25Okapi(processed_titles)
# Get scores for all search queries and find best matches
all_scores = []
for search_query in search_queries:
try:
processed_query = _simple_keyword_preprocessing(search_query)
scores = bm25.get_scores(processed_query)
all_scores.append(scores)
except Exception as e:
print(f"Error processing query '{search_query}': {e}")
continue
if not all_scores:
return "Error: Could not process any search queries"
# Combine scores (take maximum across all queries)
combined_scores = all_scores[0]
for scores in all_scores[1:]:
combined_scores = np.maximum(combined_scores, scores)
# Get top-k results
top_indices = combined_scores.argsort()[-top_k:][::-1]
# Format results
results = []
results.append(f"Top {top_k} datasets for query: '{query}'")
if query != translated_query:
results.append(f"(Translated to French: '{translated_query}')")
results.append("")
for i, idx in enumerate(top_indices, 1):
score = combined_scores[idx]
title = _dataset_df.iloc[idx]['title']
url = _dataset_df.iloc[idx]['url']
organization = _dataset_df.iloc[idx].get('organization', 'N/A')
results.append(f"{i}. Score: {score:.2f}")
results.append(f" Title: {title}")
results.append(f" URL: {url}")
results.append(f" Organization: {organization}")
results.append("")
return "\n".join(results)
except Exception as e:
return f"Error during dataset search: {str(e)}"
@tool
def get_dataset_info(dataset_url: str) -> str:
"""
Get detailed information about a specific dataset from its data.gouv.fr URL.
Args:
dataset_url: The URL of the dataset page on data.gouv.fr
Returns:
Detailed information about the dataset including title, description, organization, and metadata
"""
try:
_initialize_retrieval_system()
global _dataset_df
# Find the dataset in our catalog
matching_rows = _dataset_df[_dataset_df['url'] == dataset_url]
if matching_rows.empty:
return f"Dataset not found in catalog for URL: {dataset_url}"
dataset = matching_rows.iloc[0]
# Format the dataset information
info_lines = []
info_lines.append("=== DATASET INFORMATION ===")
info_lines.append(f"Title: {dataset.get('title', 'N/A')}")
info_lines.append(f"URL: {dataset.get('url', 'N/A')}")
info_lines.append(f"Organization: {dataset.get('organization', 'N/A')}")
if 'description' in dataset and pd.notna(dataset['description']):
description = str(dataset['description'])
if len(description) > 500:
description = description[:500] + "..."
info_lines.append(f"Description: {description}")
if 'tags' in dataset and pd.notna(dataset['tags']):
info_lines.append(f"Tags: {dataset['tags']}")
if 'license' in dataset and pd.notna(dataset['license']):
info_lines.append(f"License: {dataset['license']}")
if 'temporal_coverage' in dataset and pd.notna(dataset['temporal_coverage']):
info_lines.append(f"Temporal Coverage: {dataset['temporal_coverage']}")
if 'spatial_coverage' in dataset and pd.notna(dataset['spatial_coverage']):
info_lines.append(f"Spatial Coverage: {dataset['spatial_coverage']}")
if 'quality_score' in dataset and pd.notna(dataset['quality_score']):
info_lines.append(f"Quality Score: {dataset['quality_score']}")
return "\n".join(info_lines)
except Exception as e:
return f"Error getting dataset info: {str(e)}"
@tool
def get_random_quality_dataset() -> str:
"""
Get a random high-quality dataset from the catalog, weighted by quality score.
Returns:
Information about a randomly selected high-quality dataset
"""
try:
_initialize_retrieval_system()
global _dataset_df
# Use quality_score as weights for random selection
if 'quality_score' in _dataset_df.columns:
weights = _dataset_df['quality_score'].fillna(0)
weights = weights - weights.min() + 0.1 # Shift to make all positive
else:
weights = None
# Randomly sample one dataset weighted by quality
selected_row = _dataset_df.sample(n=1, weights=weights).iloc[0]
# Return dataset info
return get_dataset_info(selected_row['url'])
except Exception as e:
return f"Error getting random dataset: {str(e)}"