import json
import math
import os

os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.26.0")


import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.naive_bayes import ComplementNB
from sklearn.pipeline import Pipeline


CATEGORIES = [
    "alt.atheism",
    "comp.graphics",
    "comp.os.ms-windows.misc",
    "comp.sys.ibm.pc.hardware",
    "comp.sys.mac.hardware",
    "comp.windows.x",
    "misc.forsale",
    "rec.autos",
    "rec.motorcycles",
    "rec.sport.baseball",
    "rec.sport.hockey",
    "sci.crypt",
    "sci.electronics",
    "sci.med",
    "sci.space",
    "soc.religion.christian",
    "talk.politics.guns",
    "talk.politics.mideast",
    "talk.politics.misc",
    "talk.religion.misc",
]


def shorten_param(param_name):
    """Remove components' prefixes in param_name."""
    if "__" in param_name:
        return param_name.rsplit("__", 1)[1]
    return param_name


def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm):
    pipeline = Pipeline(
        [
            ("vect", TfidfVectorizer()),
            ("clf", ComplementNB()),
        ]
    )

    parameters_grid = {
        "vect__max_df": [eval(value) for value in vect__max_df.split(",")],
        "vect__min_df": [eval(value) for value in vect__min_df.split(",")],
        "vect__ngram_range": eval(vect__ngram_range),  # unigrams or bigrams
        "vect__norm": [value.strip() for value in vect__norm.split(",")],
        "clf__alpha": np.logspace(-6, 6, 13),
    }

    print(parameters_grid)

    data_train = fetch_20newsgroups(
        subset="train",
        categories=categories,
        shuffle=True,
        random_state=42,
        remove=("headers", "footers", "quotes"),
    )

    data_test = fetch_20newsgroups(
        subset="test",
        categories=categories,
        shuffle=True,
        random_state=42,
        remove=("headers", "footers", "quotes"),
    )

    pipeline = Pipeline(
        [
            ("vect", TfidfVectorizer()),
            ("clf", ComplementNB()),
        ]
    )

    random_search = RandomizedSearchCV(
        estimator=pipeline,
        param_distributions=parameters_grid,
        n_iter=40,
        random_state=0,
        n_jobs=2,
        verbose=1,
    )

    random_search.fit(data_train.data, data_train.target)
    best_parameters = json.dumps(
        random_search.best_estimator_.get_params(),
        indent=4,
        sort_keys=True,
        default=str,
    )

    test_accuracy = random_search.score(data_test.data, data_test.target)

    cv_results = pd.DataFrame(random_search.cv_results_)
    cv_results = cv_results.rename(shorten_param, axis=1)

    param_names = [shorten_param(name) for name in parameters_grid.keys()]
    labels = {
        "mean_score_time": "CV Score time (s)",
        "mean_test_score": "CV score (accuracy)",
    }
    fig = px.scatter(
        cv_results,
        x="mean_score_time",
        y="mean_test_score",
        error_x="std_score_time",
        error_y="std_test_score",
        hover_data=param_names,
        labels=labels,
    )
    fig.update_layout(
        title={
            "text": "trade-off between scoring time and mean test score",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        }
    )

    column_results = param_names + ["mean_test_score", "mean_score_time"]

    transform_funcs = dict.fromkeys(column_results, lambda x: x)
    # Using a logarithmic scale for alpha
    transform_funcs["alpha"] = math.log10
    # L1 norms are mapped to index 1, and L2 norms to index 2
    transform_funcs["norm"] = lambda x: 2 if x == "l2" else 1
    # Unigrams are mapped to index 1 and bigrams to index 2
    transform_funcs["ngram_range"] = lambda x: x[1]

    fig2 = px.parallel_coordinates(
        cv_results[column_results].apply(transform_funcs),
        color="mean_test_score",
        color_continuous_scale=px.colors.sequential.Viridis_r,
        labels=labels,
    )
    fig2.update_layout(
        title={
            "text": "Parallel coordinates plot of text classifier pipeline",
            "y": 0.99,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        }
    )

    return fig, fig2, best_parameters, test_accuracy


def load_description(name):
    with open(f"./descriptions/{name}.md", "r") as f:
        return f.read()


AUTHOR = """
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
"""


with gr.Blocks(theme=gr.themes.Soft()) as app:
    with gr.Row():
        with gr.Column():
            gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
            gr.Markdown(load_description("description_part1"))
            gr.Markdown(load_description("description_part2"))
            gr.Markdown(AUTHOR)

    with gr.Row():
        with gr.Column():
            gr.Markdown("""## CATEGORY SELECTION""")
            gr.Markdown(load_description("description_category_selection"))
            drop_categories = gr.Dropdown(
                CATEGORIES,
                value=["alt.atheism", "talk.religion.misc"],
                multiselect=True,
                label="Categories",
                info="Please select up to two categories that you want to receive training on.",
                max_choices=2,
                interactive=True,
            )
    with gr.Row():
        with gr.Tab("PARAMETERS GRID"):
            gr.Markdown(load_description("description_parameter_grid"))
            with gr.Row():
                with gr.Column():
                    clf__alpha = gr.Textbox(
                        label="Classifier Alpha (clf__alpha)",
                        value="1.e-06, 1.e-05, 1.e-04",
                        info="Due to practical considerations, this parameter was kept constant.",
                        interactive=False,
                    )
                    vect__max_df = gr.Textbox(
                        label="Vectorizer max_df (vect__max_df)",
                        value="0.2, 0.4, 0.6, 0.8, 1.0",
                        info="Values ranging from 0 to 1.0, separated by a comma.",
                        interactive=True,
                    )
                    vect__min_df = gr.Textbox(
                        label="Vectorizer min_df (vect__min_df)",
                        value="1, 3, 5, 10",
                        info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.",
                        interactive=True,
                    )
                with gr.Column():
                    vect__ngram_range = gr.Textbox(
                        label="Vectorizer ngram_range (vect__ngram_range)",
                        value="(1, 1), (1, 2)",
                        info="""Tuples of integer values separated by a comma. For example an `ngram_range` of `(1, 1)` means only unigrams, `(1, 2)` means unigrams and bigrams, and `(2, 2)` means only bigrams.""",
                        interactive=True,
                    )
                    vect__norm = gr.Textbox(
                        label="Vectorizer norm (vect__norm)",
                        value="l1, l2",
                        info="'l1' or 'l2', separated by a comma",
                        interactive=True,
                    )

        with gr.Tab("DESCRIPTION OF PARAMETERS"):
            gr.Markdown("""### Classifier Alpha""")
            gr.Markdown(load_description("parameter_grid/alpha"))
            gr.Markdown("""### Vectorizer max_df""")
            gr.Markdown(load_description("parameter_grid/max_df"))
            gr.Markdown("""### Vectorizer min_df""")
            gr.Markdown(load_description("parameter_grid/min_df"))
            gr.Markdown("""### Vectorizer ngram_range""")
            gr.Markdown(load_description("parameter_grid/ngram_range"))
            gr.Markdown("""### Vectorizer norm""")
            gr.Markdown(load_description("parameter_grid/norm"))

    with gr.Row():
        gr.Markdown(
            """
            ## MODEL PIPELINE
            ```python
            pipeline = Pipeline(
                [
                    ("vect", TfidfVectorizer()),
                    ("clf", ComplementNB()),
                ]
            )
            ```
            """
        )
    with gr.Row():
        with gr.Column():
            gr.Markdown("""## TRAINING""")
            with gr.Row():
                brn_train = gr.Button("Train").style(container=False)

    gr.Markdown("## RESULTS")
    with gr.Row():
        best_parameters = gr.Textbox(label="Best parameters")
        test_accuracy = gr.Textbox(label="Test accuracy")

    plot_trade = gr.Plot(label="")
    plot_coordinates = gr.Plot(label="")

    brn_train.click(
        train_model,
        inputs=[
            drop_categories,
            vect__max_df,
            vect__min_df,
            vect__ngram_range,
            vect__norm,
        ],
        outputs=[plot_trade, plot_coordinates, best_parameters, test_accuracy],
    )

app.launch()