# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "ell-ai==0.0.14",
#     "marimo",
#     "openai==1.53.0",
#     "polars==1.12.0",
#     "altair==5.4.1",
# ]
# ///

import marimo

__generated_with = "0.9.20"
app = marimo.App(width="medium")


@app.cell
def __(mo):
    mo.md(r"""# Generative UI Chatbot""")
    return


@app.cell
def __(mo):
    _default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv"
    dataset_input = mo.ui.text(value=_default_dataset, full_width=True)
    return (dataset_input,)


@app.cell
def __(dataset_input, mo):
    mo.md(f"""
    This chatbot can answer questions about the following dataset: {dataset_input}
    """)
    return


@app.cell
def __(dataset_input, mo, pl):
    # Grab a dataset
    try:
        df = pl.read_csv(dataset_input.value)
        mo.output.replace(
            mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.")
        )
    except Exception as e:
        df = pl.DataFrame()
        mo.output.replace(
            mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger")
        )
    return (df,)


@app.cell
def __():
    import os

    import marimo as mo
    import polars as pl

    return mo, os, pl


@app.cell
def __(mo, os):
    api_key_input = mo.ui.text(
        label="OpenAI API Key",
        kind="password",
        value=os.environ.get("OPENAI_API_KEY") or "",
    )
    return (api_key_input,)


@app.cell
def __(api_key_input):
    api_key_input
    return


@app.cell
def __(api_key_input, mo):
    from openai import Client

    mo.stop(not api_key_input.value, mo.md("_Missing API key_"))

    client = Client(api_key=api_key_input.value)
    return Client, client


@app.cell
def __(df, mo):
    import ell

    @ell.tool()
    def chart_data(x_encoding: str, y_encoding: str, color: str):
        """Generate an altair chart"""
        import altair as alt

        return (
            alt.Chart(df)
            .mark_circle()
            .encode(x=x_encoding, y=y_encoding, color=color)
            .properties(width=500)
        )

    @ell.tool()
    def filter_dataset(sql_query: str):
        """
        Filter a polars dataframe using SQL. Please only use fields from the schema.
        When referring to the table in SQL, call it 'data'.
        """
        filtered = df.sql(sql_query, table_name="data")
        return mo.ui.table(
            filtered,
            label=f"```sql\n{sql_query}\n```",
            selection=None,
            show_column_summaries=False,
        )

    return chart_data, ell, filter_dataset


@app.cell
def __(chart_data, client, df, ell, filter_dataset, mo):
    @ell.complex(model="gpt-4o", tools=[chart_data, filter_dataset], client=client)
    def analyze_dataset(prompt: str) -> str:
        """You are a data scientist that can analyze a dataset"""
        return f"I have a dataset with schema: {df.schema}. \n{prompt}"

    def my_model(messages):
        response = analyze_dataset(messages)
        if response.tool_calls:
            return response.tool_calls[0]()
        return response.text

    mo.ui.chat(
        my_model,
        prompts=[
            "Can you chart two columns of your choosing?",
            "Can you find the min, max of all numeric fields?",
            "What is the sum of {{column}}?",
        ],
    )
    return analyze_dataset, my_model


if __name__ == "__main__":
    app.run()