Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from datetime import datetime | |
| import json | |
| # Add the duckdb-nsql directory to the Python path | |
| current_dir = Path(__file__).resolve().parent | |
| duckdb_nsql_dir = current_dir / 'duckdb-nsql' | |
| eval_dir = duckdb_nsql_dir / 'eval' | |
| sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) | |
| # Import necessary functions and classes from predict.py and evaluate.py | |
| from eval.predict import predict, console, get_manifest, DefaultLoader | |
| from eval.constants import PROMPT_FORMATTERS | |
| from eval.evaluate import evaluate, compute_metrics, get_to_print | |
| from eval.evaluate import test_suite_evaluation, read_tables_json | |
| def run_evaluation(model_name): | |
| results = [] | |
| if "OPENROUTER_API_KEY" not in os.environ: | |
| return "Error: OPENROUTER_API_KEY not found in environment variables." | |
| try: | |
| # Set up the arguments similar to the CLI in predict.py | |
| dataset_path = "duckdb-nsql/eval/data/dev.json" | |
| table_meta_path = "duckdb-nsql/eval/data/tables.json" | |
| output_dir = "duckdb-nsql/output/" | |
| prompt_format = "duckdbinstgraniteshort" | |
| stop_tokens = [';'] | |
| max_tokens = 30000 | |
| temperature = 0.1 | |
| num_beams = -1 | |
| manifest_client = "openrouter" | |
| manifest_engine = model_name | |
| manifest_connection = "http://localhost:5000" | |
| overwrite_manifest = True | |
| parallel = False | |
| # Initialize necessary components | |
| data_formatter = DefaultLoader() | |
| prompt_formatter = PROMPT_FORMATTERS[prompt_format]() | |
| # Load manifest | |
| manifest = get_manifest( | |
| manifest_client=manifest_client, | |
| manifest_connection=manifest_connection, | |
| manifest_engine=manifest_engine, | |
| ) | |
| results.append(f"Using model: {manifest_engine}") | |
| # Load data and metadata | |
| results.append("Loading metadata and data...") | |
| db_to_tables = data_formatter.load_table_metadata(table_meta_path) | |
| data = data_formatter.load_data(dataset_path) | |
| # Generate output filename | |
| date_today = datetime.now().strftime("%y-%m-%d") | |
| pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json" | |
| pred_path = Path(output_dir) / pred_filename | |
| results.append(f"Prediction will be saved to: {pred_path}") | |
| # Debug: Print predict function signature | |
| yield f"Predict function signature: {inspect.signature(predict)}" | |
| # Run prediction | |
| yield "Starting prediction..." | |
| try: | |
| predict( | |
| dataset_path=dataset_path, | |
| table_meta_path=table_meta_path, | |
| output_dir=output_dir, | |
| prompt_format=prompt_format, | |
| stop_tokens=stop_tokens, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| manifest_client=manifest_client, | |
| manifest_engine=manifest_engine, | |
| manifest_connection=manifest_connection, | |
| overwrite_manifest=overwrite_manifest, | |
| parallel=parallel | |
| ) | |
| except TypeError as e: | |
| yield f"TypeError in predict function: {str(e)}" | |
| yield "Attempting to call predict with only expected arguments..." | |
| # Try calling predict with only the arguments it expects | |
| predict_args = inspect.getfullargspec(predict).args | |
| filtered_args = {k: v for k, v in locals().items() if k in predict_args} | |
| predict(**filtered_args) | |
| results.append("Prediction completed.") | |
| # Run evaluation | |
| results.append("Starting evaluation...") | |
| # Set up evaluation arguments | |
| gold_path = Path(dataset_path) | |
| db_dir = "duckdb-nsql/eval/data/databases/" | |
| tables_path = Path(table_meta_path) | |
| kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) | |
| db_schemas = read_tables_json(str(tables_path)) | |
| gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) | |
| pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()] | |
| gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] | |
| setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] | |
| validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] | |
| gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] | |
| pred_sqls = [p["pred"] for p in pred_sqls_dict] | |
| categories = [p.get("category", "") for p in gold_sqls_dict] | |
| metrics = compute_metrics( | |
| gold_sqls=gold_sqls, | |
| pred_sqls=pred_sqls, | |
| gold_dbs=gold_dbs, | |
| setup_sqls=setup_sqls, | |
| validate_sqls=validate_sqls, | |
| kmaps=kmaps, | |
| db_schemas=db_schemas, | |
| database_dir=db_dir, | |
| lowercase_schema_match=False, | |
| model_name=model_name, | |
| categories=categories, | |
| ) | |
| results.append("Evaluation completed.") | |
| # Format and add the evaluation metrics to the results | |
| if metrics: | |
| to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls)) | |
| formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]]) | |
| results.append(f"Evaluation metrics:\n{formatted_metrics}") | |
| else: | |
| results.append("No evaluation metrics returned.") | |
| except Exception as e: | |
| results.append(f"An unexpected error occurred: {str(e)}") | |
| return "\n\n".join(results) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# DuckDB SQL Evaluation App") | |
| model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)") | |
| start_btn = gr.Button("Start Evaluation") | |
| output = gr.Textbox(label="Output", lines=20) | |
| start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output) | |
| demo.launch() |