MahaNeta / utils /query_generator.py
ankush-003's picture
init
10757ec
from .sql_runtime import SQLRuntime
from pydantic import BaseModel, Field
from .load_llm import load_llm
from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt
from langchain_core.runnables import chain
from typing import Optional
from dotenv import load_dotenv
class Generated_query(BaseModel):
"""
The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries
"""
queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries")
class QuerySummary(BaseModel):
"""
The summary of the SQL query results
"""
summary: str = Field(description="The analysis of the SQL query results")
errors: list[str] = Field(description="The errors in the execution of the queries")
queries: list[str] = Field(description="The SQL queries executed and their results")
@chain
def sql_generator(input: dict) -> Generated_query:
query, db_path = input["query"], input["db_path"]
sql_runtime = SQLRuntime(dbname=db_path)
query_generator_llm = load_llm().with_structured_output(Generated_query)
# getting the schemas
schemas = sql_runtime.get_schemas()
# chain to generate the queries
chain = sql_query_prompt | query_generator_llm
# executing the chain
gen_queries = chain.invoke({
"db_schema": schemas,
"input": query
})
# executing the queries
res = sql_runtime.execute_batch(gen_queries.queries)
# print(res)
return {
"input": query,
"results": res
}
@chain
def sql_formatter(input):
"""
Formats the output of the SQL queries
"""
output = []
for item in input["results"]:
if item["code"] == 0:
output.append(f"Query: {item['msg']['input']}, Result: {item['data']}")
else:
output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}")
# print(output)
return {
"query": input["input"],
"results": output
}
@chain
def analyze_results(input) -> QuerySummary:
"""
Analyzes the results of the SQL queries executed on the election database
"""
chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary)
# chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary)
return chain.invoke({
"query": input["query"],
"results": input["results"]
})
if __name__ == '__main__':
load_dotenv()
# executing the queries
# results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019")
# for result in results:
# print(f"Query: {result['msg']['input']}")
# if result["code"] != 0:
# print(f"Error executing query: {result['msg']['reason']}")
# print(f"Traceback: {result['msg']['traceback']}")
# else:
# print(result["data"])
# print("\n")
# formatting the output
res = sql_generator | sql_formatter | analyze_results
formatted_output, formatted_output2 = res.invoke(
{
"query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols",
"db_path": "./data/elections.db"
}
)
print(formatted_output.summary)
print(formatted_output.errors)
print(formatted_output.queries)
print("\n")
print(formatted_output2.summary)
print(formatted_output2.errors)
print(formatted_output2.queries)