|
import os |
|
import pandas as pd |
|
import pickle |
|
import numpy as np |
|
from smolagents import tool |
|
from rank_bm25 import BM25Okapi |
|
from dotenv import load_dotenv |
|
from smolagents import CodeAgent, LiteLLMModel |
|
from unidecode import unidecode |
|
import numpy as np |
|
|
|
load_dotenv() |
|
|
|
|
|
_bm25_model = None |
|
_precomputed_titles = None |
|
_dataset_df = None |
|
_llm_translator = None |
|
|
|
def _initialize_retrieval_system(): |
|
"""Initialize the retrieval system with BM25 model and dataset""" |
|
global _bm25_model, _precomputed_titles, _dataset_df, _llm_translator |
|
|
|
|
|
if _dataset_df is None: |
|
try: |
|
_dataset_df = pd.read_csv('filtered_dataset.csv') |
|
print(f"✅ Loaded dataset with {len(_dataset_df)} entries") |
|
except FileNotFoundError: |
|
raise Exception("filtered_dataset.csv not found. Please ensure the dataset file exists.") |
|
|
|
|
|
if _llm_translator is None: |
|
try: |
|
model = LiteLLMModel( |
|
model_id="gemini/gemini-2.5-flash-preview-05-20", |
|
api_key=os.getenv("GEMINI_API_KEY") |
|
) |
|
_llm_translator = CodeAgent(tools=[], model=model, max_steps=1) |
|
print("✅ LLM translator initialized") |
|
except Exception as e: |
|
print(f"⚠️ Error initializing LLM translator: {e}") |
|
|
|
|
|
if _bm25_model is None: |
|
try: |
|
with open('bm25_data.pkl', 'rb') as f: |
|
bm25_data = pickle.load(f) |
|
_bm25_model = bm25_data['bm25_model'] |
|
_precomputed_titles = bm25_data['titles'] |
|
print(f"✅ Loaded pre-computed BM25 model for {len(_precomputed_titles)} datasets") |
|
except FileNotFoundError: |
|
print("⚠️ Pre-computed BM25 model not found. Will compute at runtime.") |
|
except Exception as e: |
|
print(f"⚠️ Error loading pre-computed BM25 model: {e}") |
|
|
|
def _translate_query_llm(query, target_lang='fr'): |
|
"""Translate query using LLM""" |
|
global _llm_translator |
|
|
|
if _llm_translator is None: |
|
return query, 'unknown' |
|
|
|
try: |
|
if target_lang == 'fr': |
|
target_language = "French" |
|
elif target_lang == 'en': |
|
target_language = "English" |
|
else: |
|
target_language = target_lang |
|
|
|
translation_prompt = f""" |
|
Translate the following text to {target_language}. |
|
If the text is already in {target_language}, return it as is. |
|
Only return the translated text, nothing else. |
|
|
|
Text to translate: "{query}" |
|
""" |
|
|
|
response = _llm_translator.run(translation_prompt) |
|
translated_text = str(response).strip().strip('"').strip("'") |
|
|
|
|
|
if query.lower() == translated_text.lower(): |
|
source_lang = target_lang |
|
else: |
|
source_lang = 'en' if target_lang == 'fr' else 'fr' |
|
|
|
return translated_text, source_lang |
|
|
|
except Exception as e: |
|
print(f"LLM translation error: {e}") |
|
return query, 'unknown' |
|
|
|
def _simple_keyword_preprocessing(text): |
|
"""Simple preprocessing for keyword matching - handles case, accents and basic plurals""" |
|
text = unidecode(str(text).lower()) |
|
|
|
words = text.split() |
|
processed_words = [] |
|
|
|
for word in words: |
|
if word.endswith('s') and len(word) > 3 and not word.endswith('ss'): |
|
word = word[:-1] |
|
elif word.endswith('x') and len(word) > 3: |
|
word = word[:-1] |
|
processed_words.append(word) |
|
|
|
return processed_words |
|
|
|
@tool |
|
def search_datasets(query: str, top_k: int = 5) -> str: |
|
""" |
|
Search for relevant datasets in the French public data catalog using BM25-based keyword matching. |
|
|
|
Args: |
|
query: The search query describing what kind of dataset you're looking for |
|
top_k: Number of top results to return (default: 5) |
|
|
|
Returns: |
|
A formatted string containing the top matching datasets with their titles, URLs, and relevance scores |
|
""" |
|
try: |
|
|
|
_initialize_retrieval_system() |
|
|
|
global _bm25_model, _precomputed_titles, _dataset_df |
|
|
|
|
|
translated_query, original_lang = _translate_query_llm(query, target_lang='fr') |
|
|
|
|
|
search_queries = [query, translated_query] if query != translated_query else [query] |
|
|
|
|
|
dataset_titles = _dataset_df['title'].fillna('').tolist() |
|
|
|
|
|
if (_bm25_model is not None and _precomputed_titles is not None and |
|
len(dataset_titles) == len(_precomputed_titles) and dataset_titles == _precomputed_titles): |
|
bm25 = _bm25_model |
|
else: |
|
|
|
processed_titles = [_simple_keyword_preprocessing(title) for title in dataset_titles] |
|
bm25 = BM25Okapi(processed_titles) |
|
|
|
|
|
all_scores = [] |
|
for search_query in search_queries: |
|
try: |
|
processed_query = _simple_keyword_preprocessing(search_query) |
|
scores = bm25.get_scores(processed_query) |
|
all_scores.append(scores) |
|
except Exception as e: |
|
print(f"Error processing query '{search_query}': {e}") |
|
continue |
|
|
|
if not all_scores: |
|
return "Error: Could not process any search queries" |
|
|
|
|
|
combined_scores = all_scores[0] |
|
for scores in all_scores[1:]: |
|
combined_scores = np.maximum(combined_scores, scores) |
|
|
|
|
|
top_indices = combined_scores.argsort()[-top_k:][::-1] |
|
|
|
|
|
results = [] |
|
results.append(f"Top {top_k} datasets for query: '{query}'") |
|
if query != translated_query: |
|
results.append(f"(Translated to French: '{translated_query}')") |
|
results.append("") |
|
|
|
for i, idx in enumerate(top_indices, 1): |
|
score = combined_scores[idx] |
|
title = _dataset_df.iloc[idx]['title'] |
|
url = _dataset_df.iloc[idx]['url'] |
|
organization = _dataset_df.iloc[idx].get('organization', 'N/A') |
|
|
|
results.append(f"{i}. Score: {score:.2f}") |
|
results.append(f" Title: {title}") |
|
results.append(f" URL: {url}") |
|
results.append(f" Organization: {organization}") |
|
results.append("") |
|
|
|
return "\n".join(results) |
|
|
|
except Exception as e: |
|
return f"Error during dataset search: {str(e)}" |
|
|
|
@tool |
|
def get_dataset_info(dataset_url: str) -> str: |
|
""" |
|
Get detailed information about a specific dataset from its data.gouv.fr URL. |
|
|
|
Args: |
|
dataset_url: The URL of the dataset page on data.gouv.fr |
|
|
|
Returns: |
|
Detailed information about the dataset including title, description, organization, and metadata |
|
""" |
|
try: |
|
_initialize_retrieval_system() |
|
|
|
global _dataset_df |
|
|
|
|
|
matching_rows = _dataset_df[_dataset_df['url'] == dataset_url] |
|
|
|
if matching_rows.empty: |
|
return f"Dataset not found in catalog for URL: {dataset_url}" |
|
|
|
dataset = matching_rows.iloc[0] |
|
|
|
|
|
info_lines = [] |
|
info_lines.append("=== DATASET INFORMATION ===") |
|
info_lines.append(f"Title: {dataset.get('title', 'N/A')}") |
|
info_lines.append(f"URL: {dataset.get('url', 'N/A')}") |
|
info_lines.append(f"Organization: {dataset.get('organization', 'N/A')}") |
|
|
|
if 'description' in dataset and pd.notna(dataset['description']): |
|
description = str(dataset['description']) |
|
if len(description) > 500: |
|
description = description[:500] + "..." |
|
info_lines.append(f"Description: {description}") |
|
|
|
if 'tags' in dataset and pd.notna(dataset['tags']): |
|
info_lines.append(f"Tags: {dataset['tags']}") |
|
|
|
if 'license' in dataset and pd.notna(dataset['license']): |
|
info_lines.append(f"License: {dataset['license']}") |
|
|
|
if 'temporal_coverage' in dataset and pd.notna(dataset['temporal_coverage']): |
|
info_lines.append(f"Temporal Coverage: {dataset['temporal_coverage']}") |
|
|
|
if 'spatial_coverage' in dataset and pd.notna(dataset['spatial_coverage']): |
|
info_lines.append(f"Spatial Coverage: {dataset['spatial_coverage']}") |
|
|
|
if 'quality_score' in dataset and pd.notna(dataset['quality_score']): |
|
info_lines.append(f"Quality Score: {dataset['quality_score']}") |
|
|
|
return "\n".join(info_lines) |
|
|
|
except Exception as e: |
|
return f"Error getting dataset info: {str(e)}" |
|
|
|
@tool |
|
def get_random_quality_dataset() -> str: |
|
""" |
|
Get a random high-quality dataset from the catalog, weighted by quality score. |
|
|
|
Returns: |
|
Information about a randomly selected high-quality dataset |
|
""" |
|
try: |
|
_initialize_retrieval_system() |
|
|
|
global _dataset_df |
|
|
|
|
|
if 'quality_score' in _dataset_df.columns: |
|
weights = _dataset_df['quality_score'].fillna(0) |
|
weights = weights - weights.min() + 0.1 |
|
else: |
|
weights = None |
|
|
|
|
|
selected_row = _dataset_df.sample(n=1, weights=weights).iloc[0] |
|
|
|
|
|
return get_dataset_info(selected_row['url']) |
|
|
|
except Exception as e: |
|
return f"Error getting random dataset: {str(e)}" |