import gradio as gr
import plotly.express as px
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
from PIL import Image
from io import BytesIO

def generate_plot(
    x_sequence: str, 
    y_sequence: str, 
    plot_type: str, 
    x_label: str, 
    y_label: str, 
    width: int, 
    height: int
) -> Image:
    """
    Generate a plot based on the provided x and y sequences and plot type.

    Parameters:
    - x_sequence (str): A comma-separated string of x values.
    - y_sequence (str): A comma-separated string of y values.
    - plot_type (str): The type of plot to generate ('Bar', 'Scatter', 'Confusion Matrix').
    - x_label (str): Label for the x-axis.
    - y_label (str): Label for the y-axis.
    - width (int): Width of the plot.
    - height (int): Height of the plot.

    Returns:
    - Image: A PIL Image object of the generated plot.
    """
    # Convert the input sequences to lists of numbers
    try:
        x_data = list(map(float, x_sequence.split(",")))
        y_data = list(map(float, y_sequence.split(",")))
    except ValueError:
        return "Invalid input. Please enter sequences of numbers separated by commas."

    # Ensure the x and y sequences have the same length
    if len(x_data) != len(y_data):
        return "The x and y sequences must have the same length."

    # Create a DataFrame for plotting
    df = pd.DataFrame({"x": x_data, "y": y_data})

    # Set default width and height if not provided
    width = width if width else 800
    height = height if height else 600

    # Generate the plot based on the selected type
    if plot_type == "Bar":
        fig = px.bar(
            df,
            x="x",
            y="y",
            title="Bar Plot",
            labels={"x": x_label, "y": y_label},
            width=width,
            height=height,
        )
    elif plot_type == "Scatter":
        fig = px.scatter(
            df,
            x="x",
            y="y",
            title="Scatter Plot",
            labels={"x": x_label, "y": y_label},
            width=width,
            height=height,
        )
    elif plot_type == "Confusion Matrix":
        # For demonstration, create a confusion matrix from the sequence
        y_true = np.random.randint(0, 2, len(y_data))
        y_pred = np.array(y_data) > 0.5
        cm = confusion_matrix(y_true, y_pred)
        fig = px.imshow(
            cm, text_auto=True, title="Confusion Matrix", width=width, height=height
        )
    else:
        return "Invalid plot type selected."

    # Convert the plot to a PNG image
    img_bytes = fig.to_image(
        format="png", width=width, height=height, scale=2, engine="kaleido"
    )
    return Image.open(BytesIO(img_bytes))


# Define the Gradio interface using the new syntax
app = gr.Interface(
    fn=generate_plot,
    inputs=[
        gr.Textbox(
            lines=2,
            placeholder="Enter x sequence of numbers separated by commas",
            label="X",
        ),
        gr.Textbox(
            lines=2,
            placeholder="Enter y sequence of numbers separated by commas",
            label="Y",
        ),
        gr.Radio(["Bar", "Scatter", "Confusion Matrix"], label="Type", value="Bar"),
        gr.Textbox(
            placeholder="Enter x-axis label (optional)", label="X_Label", value=""
        ),
        gr.Textbox(
            placeholder="Enter y-axis label (optional)", label="Y_Label", value=""
        ),
        gr.Number(
            value=800,
            label="Width",
        ),
        gr.Number(value=600, label="Height"),
    ],
    outputs=gr.Image(type="pil", label="Generated Plot"),
    title="Plotly Plot Generator",
    description="Generate plots using Plotly based on inputted sequences. Choose from Bar, Scatter, or Confusion Matrix plots.",
)

# Launch the app
app.launch()