import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import warnings

from functools import partial
from sklearn.datasets import make_blobs, make_spd_matrix
from sklearn.svm import LinearSVC
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.exceptions import ConvergenceWarning

def train_model(n_samples, C, penalty, loss, max_iter):

    if penalty == "l1" and loss == "hinge":
        raise gr.Error("The combination of penalty='l1' and loss='hinge' is not supported")
    
    default_base = {"n_samples": 20}

    # Algorithms to compare
    params = default_base.copy()
    params.update({"n_samples":n_samples,
                   "C": C,
                   "penalty": penalty,
                   "loss": loss,
                   "max_iter": max_iter})

    X, y = make_blobs(n_samples=params["n_samples"], centers=2, random_state=0)
    
    fig, ax = plt.subplots()

    # catch warnings related to convergence
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=ConvergenceWarning)

        # add penalty, l1 and l2. Default is l2
        # add loss, square_hinge is Default. the other loss is hinge
        # multi_class{‘ovr’, ‘crammer_singer’}, default=’ovr’

        clf = LinearSVC(penalty=penalty, C=params["C"], 
                        loss=params["loss"], 
                        max_iter=params["max_iter"], 
                        random_state=42).fit(X, y)
        # obtain the support vectors through the decision function
        decision_function = clf.decision_function(X)
        # we can also calculate the decision function manually
        # decision_function = np.dot(X, clf.coef_[0]) + clf.intercept_[0]
        # The support vectors are the samples that lie within the margin
        # boundaries, whose size is conventionally constrained to 1
        support_vector_indices = np.where(np.abs(decision_function) <= 1 + 1e-15)[0]
        support_vectors = X[support_vector_indices]

        ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired)
        DecisionBoundaryDisplay.from_estimator(
            clf,
            X,
            ax=ax,
            grid_resolution=50,
            plot_method="contour",
            colors="k",
            levels=[-1, 0, 1],
            alpha=0.5,
            linestyles=["--", "-", "--"],
        )
        ax.scatter(
            support_vectors[:, 0],
            support_vectors[:, 1],
            s=100,
            linewidth=1,
            facecolors="none",
            edgecolors="k",
        )
        ax.set_title("C=" + str(C))

        return fig

def iter_grid(n_rows, n_cols):
    # create a grid using gradio Block
    for _ in range(n_rows):
        with gr.Row():
            for _ in range(n_cols):
                with gr.Column():
                    yield

title = "📈 Linear Support Vector Classification"
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown("The LinearSVC is an implementation of a \
                Support Vector Machine (SVM) for classification. \
                It aims to find the optimal linear \
                decision boundary that separates classes in the input data.")
    gr.Markdown("The most important parameters of `LinearSVC` are:")
    param_C = "\
    1. `C`: The inverse of the regularization strength. \
        A smaller `C` value increases the amount of regularization, \
        promoting simpler models, while a larger `C` value reduces \
        regularization, allowing more complex models. \
        It controls the trade-off between fitting the \
        training data and generalization to unseen data."
    param_loss=" \
    2. `loss`: The loss function used for training. \
        The default is `squared_hinge`, which is a variant \
        of hinge loss. The combination of penalty='l1' and \
        loss='hinge' is not supported."
    param_penalty="\
    3. `penalty`: The type of regularization penalty \
        applied to the model. The default is `l2`, which uses \
        the L2 norm."
    param_dual="\
    4. `dual`: Determines whether the dual or primal optimization \
        problem is solved. By default, `dual=True` when the number \
        of samples is less than the number of features, and `dual=False` \
        otherwise. For large-scale problems, setting `dual=False`  \
        can be more efficient."
    param_tol="\
    5. `tol`: The tolerance for stopping criteria. \
        The solver stops when the optimization reaches \
        a specified tolerance level."
    param_max_iter="\
    6. `max_iter`: The maximum number of iterations for solver \
        convergence. If not specified, the default value is 1000."
    gr.Markdown(param_C)
    gr.Markdown(param_loss)
    gr.Markdown(param_penalty)
    gr.Markdown(param_dual)
    gr.Markdown(param_tol)
    gr.Markdown(param_max_iter)
    gr.Markdown("Read more in the \
    [original example](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html#sklearn.svm.LinearSVC).")

    
    n_samples = gr.Slider(minimum=20, maximum=100, step=5, 
    label = "Number of Samples")

    
    with gr.Row():
        input_model = "LinearSVC"
        fn = partial(train_model)

        with gr.Row():
            penalty = gr.Dropdown(["l1", "l2"], value="l2", interactive=True, label="Penalty to prevent overfitting")
            loss = gr.Dropdown(["hinge", "squared hinge"], value="hinge", interactive=True, label="Loss function")
        
        with gr.Row():
            max_iter = gr.Slider(minimum=100, maximum=2000, step=100, value=1000, 
            label = "Max. number of iterations")
            param_C = gr.Number(value=1,
            label = "Regularization parameter C", 
            # info="When C is smal the regularization effect is stronger. " 
            #     + "This can help to avoid overfitting but may lead to higher bias. "
            #     + "On the other hand, when C is large, the regularization effect "
            #     + "is weaker, and the model can have larger parameter values, "
            #     + "allowing for more complex decision boundaries that fit the "
            #     + "training data more closely. This may increase the risk of "
            #     + "overfitting and result in a higher variance model."
                )

        with gr.Row():
            penalty2 = gr.Dropdown(["l1", "l2"], value="l2", interactive=True, label="Penalty to prevent overfitting")
            loss2 = gr.Dropdown(["hinge", "squared hinge"], value="hinge", interactive=True, label="Loss function")
        
        with gr.Row():
            max_iter2 = gr.Slider(minimum=100, maximum=2000, step=100, value=1000, 
            label = "Max. number of iterations")
            param_C2 = gr.Number(value=100,
            label = "Regularization parameter C"
                )

    with gr.Row():
        plot = gr.Plot(label=input_model)
        n_samples.change(fn=fn, inputs=[n_samples, param_C, penalty, loss, max_iter], outputs=plot)
        param_C.change(fn=fn, inputs=[n_samples, param_C, penalty, loss, max_iter], outputs=plot)
        penalty.change(fn=fn, inputs=[n_samples, param_C, penalty, loss, max_iter], outputs=plot)
        loss.change(fn=fn, inputs=[n_samples, param_C, penalty, loss, max_iter], outputs=plot)
        max_iter.change(fn=fn, inputs=[n_samples, param_C, penalty, loss, max_iter], outputs=plot)
    
        plot2 = gr.Plot(label=input_model)
        n_samples.change(fn=fn, inputs=[n_samples, param_C2, penalty2, loss2, max_iter2], outputs=plot2)
        param_C2.change(fn=fn, inputs=[n_samples, param_C2, penalty2, loss2, max_iter2], outputs=plot2)
        penalty2.change(fn=fn, inputs=[n_samples, param_C2, penalty2, loss2, max_iter2], outputs=plot2)
        loss2.change(fn=fn, inputs=[n_samples, param_C2, penalty2, loss2, max_iter2], outputs=plot2)
        max_iter2.change(fn=fn, inputs=[n_samples, param_C2, penalty2, loss2, max_iter2], outputs=plot2)

demo.launch()