# -*- coding: utf-8 -*-
""" Class average finetuning functions. Before using any of these finetuning
    functions, ensure that the model is set up with nb_classes=2.
"""
from __future__ import print_function

import uuid
from time import sleep
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torchmoji.global_variables import (
    FINETUNING_METHODS,
    WEIGHTS_DIR)
from torchmoji.finetuning import (
    freeze_layers,
    get_data_loader,
    fit_model,
    train_by_chain_thaw,
    find_f1_threshold)

def relabel(y, current_label_nr, nb_classes):
    """ Makes a binary classification for a specific class in a
        multi-class dataset.

    # Arguments:
        y: Outputs to be relabelled.
        current_label_nr: Current label number.
        nb_classes: Total number of classes.

    # Returns:
        Relabelled outputs of a given multi-class dataset into a binary
        classification dataset.
    """

    # Handling binary classification
    if nb_classes == 2 and len(y.shape) == 1:
        return y

    y_new = np.zeros(len(y))
    y_cut = y[:, current_label_nr]
    label_pos = np.where(y_cut == 1)[0]
    y_new[label_pos] = 1
    return y_new


def class_avg_finetune(model, texts, labels, nb_classes, batch_size,
                       method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6,
                       verbose=True):
    """ Compiles and finetunes the given model.

    # Arguments:
        model: Model to be finetuned
        texts: List of three lists, containing tokenized inputs for training,
            validation and testing (in that order).
        labels: List of three lists, containing labels for training,
            validation and testing (in that order).
        nb_classes: Number of classes in the dataset.
        batch_size: Batch size.
        method: Finetuning method to be used. For available methods, see
            FINETUNING_METHODS in global_variables.py. Note that the model
            should be defined accordingly (see docstring for torchmoji_transfer())
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs. Doesn't matter much as early stopping is used.
        embed_l2: L2 regularization for the embedding layer.
        verbose: Verbosity flag.

    # Returns:
        Model after finetuning,
        score after finetuning using the class average F1 metric.
    """

    if method not in FINETUNING_METHODS:
        raise ValueError('ERROR (class_avg_tune_trainable): '
                         'Invalid method parameter. '
                         'Available options: {}'.format(FINETUNING_METHODS))

    (X_train, y_train) = (texts[0], labels[0])
    (X_val, y_val) = (texts[1], labels[1])
    (X_test, y_test) = (texts[2], labels[2])

    checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \
                      .format(WEIGHTS_DIR, str(uuid.uuid4()))

    f1_init_path = '{}/torchmoji-f1-init-{}.bin' \
                   .format(WEIGHTS_DIR, str(uuid.uuid4()))

    if method in ['last', 'new']:
        lr = 0.001
    elif method in ['full', 'chain-thaw']:
        lr = 0.0001

    loss_op = nn.BCEWithLogitsLoss()

    # Freeze layers if using last
    if method == 'last':
        model = freeze_layers(model, unfrozen_keyword='output_layer')

    # Define optimizer, for chain-thaw we define it later (after freezing)
    if method == 'last':
        adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr)
    elif method in ['full', 'new']:
        # Add L2 regulation on embeddings only
        special_params = [id(p) for p in model.embed.parameters()]
        base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad]
        embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad]
        adam = optim.Adam([
            {'params': base_params},
            {'params': embed_parameters, 'weight_decay': embed_l2},
            ], lr=lr)

    # Training
    if verbose:
        print('Method:  {}'.format(method))
        print('Classes: {}'.format(nb_classes))

    if method == 'chain-thaw':
        result = class_avg_chainthaw(model, nb_classes=nb_classes,
                                     loss_op=loss_op,
                                     train=(X_train, y_train),
                                     val=(X_val, y_val),
                                     test=(X_test, y_test),
                                     batch_size=batch_size,
                                     epoch_size=epoch_size,
                                     nb_epochs=nb_epochs,
                                     checkpoint_weight_path=checkpoint_path,
                                     f1_init_weight_path=f1_init_path,
                                     verbose=verbose)
    else:
        result = class_avg_tune_trainable(model, nb_classes=nb_classes,
                                          loss_op=loss_op,
                                          optim_op=adam,
                                          train=(X_train, y_train),
                                          val=(X_val, y_val),
                                          test=(X_test, y_test),
                                          epoch_size=epoch_size,
                                          nb_epochs=nb_epochs,
                                          batch_size=batch_size,
                                          init_weight_path=f1_init_path,
                                          checkpoint_weight_path=checkpoint_path,
                                          verbose=verbose)
    return model, result


def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes):
    # Relabel into binary classification
    y_train_new = relabel(y_train, iter_i, nb_classes)
    y_val_new = relabel(y_val, iter_i, nb_classes)
    y_test_new = relabel(y_test, iter_i, nb_classes)
    return y_train_new, y_val_new, y_test_new

def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size):
    # Create sample generators
    # Make a fixed validation set to avoid fluctuations in validation
    train_gen = get_data_loader(X_train, y_train_new, batch_size,
                                extended_batch_sampler=True)
    val_gen = get_data_loader(X_val, y_val_new, epoch_size,
                              extended_batch_sampler=True)
    X_val_resamp, y_val_resamp = next(iter(val_gen))
    return train_gen, X_val_resamp, y_val_resamp


def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test,
                             epoch_size, nb_epochs, batch_size,
                             init_weight_path, checkpoint_weight_path, patience=5,
                             verbose=True):
    """ Finetunes the given model using the F1 measure.

    # Arguments:
        model: Model to be finetuned.
        nb_classes: Number of classes in the given dataset.
        train: Training data, given as a tuple of (inputs, outputs)
        val: Validation data, given as a tuple of (inputs, outputs)
        test: Testing data, given as a tuple of (inputs, outputs)
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs.
        batch_size: Batch size.
        init_weight_path: Filepath where weights will be initially saved before
            training each class. This file will be rewritten by the function.
        checkpoint_weight_path: Filepath where weights will be checkpointed to
            during training. This file will be rewritten by the function.
        verbose: Verbosity flag.

    # Returns:
        F1 score of the trained model
    """
    total_f1 = 0
    nb_iter = nb_classes if nb_classes > 2 else 1

    # Unpack args
    X_train, y_train = train
    X_val, y_val = val
    X_test, y_test = test

    # Save and reload initial weights after running for
    # each class to avoid learning across classes
    torch.save(model.state_dict(), init_weight_path)
    for i in range(nb_iter):
        if verbose:
            print('Iteration number {}/{}'.format(i+1, nb_iter))

        model.load_state_dict(torch.load(init_weight_path))
        y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
                                                            y_test, i, nb_classes)
        train_gen, X_val_resamp, y_val_resamp = \
            prepare_generators(X_train, y_train_new, X_val, y_val_new,
                               batch_size, epoch_size)

        if verbose:
            print("Training..")
        fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)],
                  nb_epochs, checkpoint_weight_path, patience, verbose=0)

        # Reload the best weights found to avoid overfitting
        # Wait a bit to allow proper closing of weights file
        sleep(1)
        model.load_state_dict(torch.load(checkpoint_weight_path))

        # Evaluate
        y_pred_val = model(X_val).cpu().numpy()
        y_pred_test = model(X_test).cpu().numpy()

        f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
                                            y_test_new, y_pred_test)
        if verbose:
            print('f1_test: {}'.format(f1_test))
            print('best_t:  {}'.format(best_t))
        total_f1 += f1_test

    return total_f1 / nb_iter


def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size,
                        epoch_size, nb_epochs, checkpoint_weight_path,
                        f1_init_weight_path, patience=5,
                        initial_lr=0.001, next_lr=0.0001, verbose=True):
    """ Finetunes given model using chain-thaw and evaluates using F1.
        For a dataset with multiple classes, the model is trained once for
        each class, relabeling those classes into a binary classification task.
        The result is an average of all F1 scores for each class.

    # Arguments:
        model: Model to be finetuned.
        nb_classes: Number of classes in the given dataset.
        train: Training data, given as a tuple of (inputs, outputs)
        val: Validation data, given as a tuple of (inputs, outputs)
        test: Testing data, given as a tuple of (inputs, outputs)
        batch_size: Batch size.
        loss: Loss function to be used during training.
        epoch_size: Number of samples in an epoch.
        nb_epochs: Number of epochs.
        checkpoint_weight_path: Filepath where weights will be checkpointed to
            during training. This file will be rewritten by the function.
        f1_init_weight_path: Filepath where weights will be saved to and
            reloaded from before training each class. This ensures that
            each class is trained independently. This file will be rewritten.
        initial_lr: Initial learning rate. Will only be used for the first
            training step (i.e. the softmax layer)
        next_lr: Learning rate for every subsequent step.
        seed: Random number generator seed.
        verbose: Verbosity flag.

    # Returns:
        Averaged F1 score.
    """

    # Unpack args
    X_train, y_train = train
    X_val, y_val = val
    X_test, y_test = test

    total_f1 = 0
    nb_iter = nb_classes if nb_classes > 2 else 1

    torch.save(model.state_dict(), f1_init_weight_path)

    for i in range(nb_iter):
        if verbose:
            print('Iteration number {}/{}'.format(i+1, nb_iter))

        model.load_state_dict(torch.load(f1_init_weight_path))
        y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val,
                                                            y_test, i, nb_classes)
        train_gen, X_val_resamp, y_val_resamp = \
                prepare_generators(X_train, y_train_new, X_val, y_val_new,
                                   batch_size, epoch_size)

        if verbose:
            print("Training..")

        # Train using chain-thaw
        train_by_chain_thaw(model=model, train_gen=train_gen,
                            val_gen=[(X_val_resamp, y_val_resamp)],
                            loss_op=loss_op, patience=patience,
                            nb_epochs=nb_epochs,
                            checkpoint_path=checkpoint_weight_path,
                            initial_lr=initial_lr, next_lr=next_lr,
                            verbose=verbose)

        # Evaluate
        y_pred_val = model(X_val).cpu().numpy()
        y_pred_test = model(X_test).cpu().numpy()

        f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val,
                                            y_test_new, y_pred_test)

        if verbose:
            print('f1_test: {}'.format(f1_test))
            print('best_t:  {}'.format(best_t))
        total_f1 += f1_test

    return total_f1 / nb_iter