Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import io | |
import sqlite3 | |
# Initialize the InferenceClient with the specified model | |
client = InferenceClient(model="HuggingFaceH4/zephyr-7b-beta") | |
# Specify the path to your CSV file here | |
csv_file_path = 'Movies.csv' | |
# Load dataset into a dataframe | |
df = pd.read_csv(csv_file_path) | |
# Function to generate SQL queries | |
def generate_sql_query(prompt, table_metadata): | |
input_text = f"Generate an SQL query for the table with the following structure: {table_metadata}. Prompt: {prompt}" | |
response = "" | |
for message in client.chat_completion( | |
messages=[{"role": "system", "content": input_text}], | |
max_tokens=512, | |
stream=True, | |
temperature=0.7, | |
top_p=0.95, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
return response | |
# Function to execute SQL query on the dataframe | |
def execute_query(df, query): | |
try: | |
with sqlite3.connect(':memory:') as conn: | |
df.to_sql('data', conn, index=False, if_exists='replace') | |
result_df = pd.read_sql_query(query, conn) | |
return result_df | |
except Exception as e: | |
return str(e) | |
# Function to create a plot from the result dataframe | |
def create_plot(df): | |
fig, ax = plt.subplots() | |
df.plot(ax=ax) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
return buf | |
# Gradio function to handle user input and interaction | |
def respond(user_prompt, system_message, max_tokens, temperature, top_p): | |
table_metadata = str(df.dtypes.to_dict()) | |
sql_query = generate_sql_query(user_prompt, table_metadata) | |
result_df = execute_query(df, sql_query) | |
if isinstance(result_df, str): | |
return sql_query, result_df, None # Return the error message | |
plot = create_plot(result_df) | |
return sql_query, result_df.head().to_html(), plot | |
# Gradio UI components | |
def create_demo(): | |
with gr.Blocks() as demo: | |
user_prompt = gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="User Prompt") | |
system_message = gr.Textbox(value="You are an AI assistant that generates SQL queries based on user prompts.", label="System message") | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
output_sql_query = gr.Textbox(label="Generated SQL Query") | |
output_result_df = gr.HTML(label="Query Result") | |
output_plot = gr.Image(label="Result Plot") | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(respond, inputs=[user_prompt, system_message, max_tokens, temperature, top_p], outputs=[output_sql_query, output_result_df, output_plot]) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() | |