# %% Importing the dependencies we need
import numpy as np
import torch
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import (accuracy_score, f1_score, confusion_matrix, 
                            ConfusionMatrixDisplay, classification_report)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from skops import card, hub_utils
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, ProgressBar
from skorch.hf import HuggingfacePretrainedTokenizer
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
# for model hosting and requirements
from pathlib import Path 
import transformers
import skorch
import sklearn 
import torch

# %%
# Choose a tokenizer and BERT model that work together
TOKENIZER = "distilbert-base-uncased"
PRETRAINED_MODEL = "distilbert-base-uncased"

# model hyper-parameters
OPTMIZER = torch.optim.AdamW
LR = 5e-5
MAX_EPOCHS = 3
CRITERION = nn.CrossEntropyLoss
BATCH_SIZE = 8

# device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# %% Load the dataset, define features & labels and split
dataset = fetch_20newsgroups()

print(dataset.DESCR.split('Usage')[0])

dataset.target_names

X = dataset.data
y = dataset.target
X_train, X_test, y_train, y_test, = train_test_split(X, y, stratify=y, random_state=0)
num_training_steps = MAX_EPOCHS * (len(X_train) // BATCH_SIZE + 1)

# %% 
# Defining learning rate scheduler & BERT in nn.Module

def lr_schedule(current_step):
    factor = float(num_training_steps - current_step) / float(max(1, num_training_steps))
    assert factor > 0
    return factor

class BertModule(nn.Module):
    def __init__(self, name, num_labels):
        super().__init__()
        self.name = name
        self.num_labels = num_labels
        
        self.reset_weights()
        
    def reset_weights(self):
        self.bert = AutoModelForSequenceClassification.from_pretrained(
            self.name, num_labels=self.num_labels
        )
        
    def forward(self, **kwargs):
        pred = self.bert(**kwargs)
        return pred.logits

# %% Chaining tokenizer and BERT in one pipeline
pipeline = Pipeline([
    ('tokenizer', HuggingfacePretrainedTokenizer(TOKENIZER)),
    ('net', NeuralNetClassifier(
        BertModule,
        module__name=PRETRAINED_MODEL,
        module__num_labels=len(set(y_train)),
        optimizer=OPTMIZER,
        lr=LR,
        max_epochs=MAX_EPOCHS,
        criterion=CRITERION,
        batch_size=BATCH_SIZE,
        iterator_train__shuffle=True,
        device=DEVICE,
        callbacks=[
            LRScheduler(LambdaLR, lr_lambda=lr_schedule, step_every='batch'),
            ProgressBar(),
        ],
    )),
])

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)

# %% Training
%time pipeline.fit(X_train, y_train)

# %% Evaluate the model
%%time
with torch.inference_mode():
    y_pred = pipeline.predict(X_test)

accuracy_score(y_test, y_pred)

# %% Save the model
import pickle
with open("model.pkl", mode="bw") as f:
    pickle.dump(pipeline, file=f)

# %% Initialize the repository for Hub
local_repo = "model_repo"
hub_utils.init(
    model="model.pkl",
    requirements=[f"scikit-learn={sklearn.__version__}", f"transformers={transformers.__version__}",
                  f"torch={torch.__version__}", f"skorch={skorch.__version__}"],
    dst=local_repo,
    task="text-classification",
    data=X_test,
)

# %% Create model card
model_card = card.Card(pipeline, metadata=card.metadata_from_config(Path("model_repo")))

# %% We will add information related to model 
model_description = (
    "This is a neural net classifier and distilbert model chained with sklearn Pipeline trained on 20 news groups dataset."
)
limitations = "This model is trained for a tutorial and is not ready to be used in production."
model_card.add(
    model_description=model_description,
    limitations=limitations
)

# %% We can add plots, evaluation results and more!
eval_descr = (
    "The model is evaluated on validation data from 20 news group's test split,"
    " using accuracy and F1-score with micro average."
)
model_card.add(eval_method=eval_descr)

accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average="micro")
model_card.add_metrics(**{"accuracy": accuracy, "f1 score": f1})


cm = confusion_matrix(y_test, y_pred, labels=pipeline.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=pipeline.classes_)
disp.plot()

disp.figure_.savefig(Path(local_repo) / "confusion_matrix.png")
model_card.add_plot(**{"Confusion matrix": "confusion_matrix.png"})

clf_report = classification_report(
    y_test, y_pred, output_dict=True, target_names=dataset.target_names
)
# %% We can add classification report as a table
# We first need to convert classification report to DataFrame to add it as a table
import pandas as pd
del clf_report["accuracy"]
clf_report = pd.DataFrame(clf_report).T.reset_index()
model_card.add_table(
    folded=True,
    **{
        "Classification Report": clf_report,
    },
)

# %% We will save our model card
model_card.save(Path(local_repo) / "README.md")

# %% We will add the training script to our repository
hub_utils.add_files(__file__, dst=local_repo) 

# %% Push to Hub! This requires us to authenticate ourselves first.
from huggingface_hub import notebook_login
notebook_login()

hub_utils.push(
    repo_id="scikit-learn/skorch-text-classification",
    source=local_repo,
    create_remote=True,
)