Spaces:
Runtime error
Runtime error
from .sql_runtime import SQLRuntime | |
from pydantic import BaseModel, Field | |
from .load_llm import load_llm | |
from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt | |
from langchain_core.runnables import chain | |
from typing import Optional | |
from dotenv import load_dotenv | |
class Generated_query(BaseModel): | |
""" | |
The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries | |
""" | |
queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries") | |
class QuerySummary(BaseModel): | |
""" | |
The summary of the SQL query results | |
""" | |
summary: str = Field(description="The analysis of the SQL query results") | |
errors: list[str] = Field(description="The errors in the execution of the queries") | |
queries: list[str] = Field(description="The SQL queries executed and their results") | |
def sql_generator(input: dict) -> Generated_query: | |
query, db_path = input["query"], input["db_path"] | |
sql_runtime = SQLRuntime(dbname=db_path) | |
query_generator_llm = load_llm().with_structured_output(Generated_query) | |
# getting the schemas | |
schemas = sql_runtime.get_schemas() | |
# chain to generate the queries | |
chain = sql_query_prompt | query_generator_llm | |
# executing the chain | |
gen_queries = chain.invoke({ | |
"db_schema": schemas, | |
"input": query | |
}) | |
# executing the queries | |
res = sql_runtime.execute_batch(gen_queries.queries) | |
# print(res) | |
return { | |
"input": query, | |
"results": res | |
} | |
def sql_formatter(input): | |
""" | |
Formats the output of the SQL queries | |
""" | |
output = [] | |
for item in input["results"]: | |
if item["code"] == 0: | |
output.append(f"Query: {item['msg']['input']}, Result: {item['data']}") | |
else: | |
output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}") | |
# print(output) | |
return { | |
"query": input["input"], | |
"results": output | |
} | |
def analyze_results(input) -> QuerySummary: | |
""" | |
Analyzes the results of the SQL queries executed on the election database | |
""" | |
chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary) | |
# chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary) | |
return chain.invoke({ | |
"query": input["query"], | |
"results": input["results"] | |
}) | |
if __name__ == '__main__': | |
load_dotenv() | |
# executing the queries | |
# results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019") | |
# for result in results: | |
# print(f"Query: {result['msg']['input']}") | |
# if result["code"] != 0: | |
# print(f"Error executing query: {result['msg']['reason']}") | |
# print(f"Traceback: {result['msg']['traceback']}") | |
# else: | |
# print(result["data"]) | |
# print("\n") | |
# formatting the output | |
res = sql_generator | sql_formatter | analyze_results | |
formatted_output, formatted_output2 = res.invoke( | |
{ | |
"query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols", | |
"db_path": "./data/elections.db" | |
} | |
) | |
print(formatted_output.summary) | |
print(formatted_output.errors) | |
print(formatted_output.queries) | |
print("\n") | |
print(formatted_output2.summary) | |
print(formatted_output2.errors) | |
print(formatted_output2.queries) |