"""
@author:jishnuprakash
"""
# This file consists of constants, attributes and classes used for training
import re
import nltk
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import auroc
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from transformers import AutoTokenizer, AutoModel

random_seed = 42
num_epochs = 10
batch = 1
threshold = 0.5
max_tokens = 512
clean_text = False

# bert_model = "bert-base-uncased"
bert_model = "nlpaueb/legal-bert-base-uncased"
checkpoint_dir = "checkpoints"
check_filename = "legal-full-data"

earlystop_monitor = "val_loss"
earlystop_patience = 2

lex_classes = ["Article 2", "Article 3", "Article 5", "Article 6", 
               "Article 8", "Article 9", "Article 10", "Article 11",
               "Article 14", "Article 1 of Protocol 1", "No Violation"]

num_classes = len(lex_classes)

#Stop words
stop_words = stopwords.words("english")
lemmatizer = WordNetLemmatizer()

def preprocess_text(text, remove_stopwords, stop_words):
    """
    Clean text
    """
    text = text.lower()
    # remove special chars and numbers
    text = re.sub("[^A-Za-z]+", " ", text)
    # remove stopwords
    if remove_stopwords:
        # 1. tokenize
        tokens = nltk.word_tokenize(text)
        # 2. check if stopword
        tokens = [w for w in tokens if not w.lower() in stop_words]
        # 3. Lemmatize
        tokens = [lemmatizer.lemmatize(i) for i in tokens]
        # 4. join back together
        text = " ".join(tokens)
    # return text in lower case and stripped of whitespaces
    text = text.lower().strip()
    return text

def preprocess_data(df, clean=False):
    """
    Perform basic data preprocessing
    """
    df = df[df['text'].map(len)>0]
    df['labels'] = df.labels.apply(lambda x: x if len(x)>0 else [10])
    df.dropna(inplace=True)
    if clean:
        df['text'] = df.apply(lambda x: [preprocess_text(i, True, stop_words) for i in x['text']], axis=1)
    return df

class LexGlueDataset(Dataset):
    """
    Lex GLUE Dataset as pytorch dataset
    """

    def __init__(self, data, tokenizer, max_tokens=512):
        super().__init__()
        self.tokenizer = tokenizer
        self.data = data
        self.max_tokens = max_tokens

    def __len__(self):
        # return len(self.data)
        return self.data.__len__()
    
    def generateLabels(self, labels):
        out = [0] * num_classes
        for i in labels:
            out[i] = 1
        return out

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        lex_text = data_row.text
        multi_labels = self.generateLabels(data_row.labels)

        encoding = self.tokenizer.encode_plus(lex_text,
                                              add_special_tokens=True,
                                              max_length=self.max_tokens,
                                              return_token_type_ids=False,
                                              padding="max_length",
                                              truncation=True,
                                              return_attention_mask=True,
                                              is_split_into_words=True,
                                              return_tensors='pt',)
        
        return dict(text = lex_text,
                    input_ids = encoding["input_ids"].flatten(),
                    attention_mask = encoding["attention_mask"].flatten(),
                    labels = torch.FloatTensor(multi_labels))


class LexGlueDataModule(pl.LightningDataModule):
    """
    Data module to load LexGlueDataset for training, validating and testing
    """

    def __init__(self, train, test, tokenizer, batch_size=8, max_tokens=512):
        super().__init__()
        self.batch_size = batch_size
        self.train = train
        self.test = test
        self.tokenizer = tokenizer
        self.max_tokens = max_tokens

    def setup(self, stage=None):
        self.train_dataset = LexGlueDataset(self.train, 
                                            self.tokenizer,
                                            self.max_tokens)

        self.test_dataset = LexGlueDataset(self.test, 
                                           self.tokenizer,
                                           self.max_tokens)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size,
                          shuffle=True,)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size,)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,
                          batch_size=self.batch_size,)


class LexGlueTagger(pl.LightningModule):
    """
    Model and Training instance as LexGlueTagger class for Pytorch Lightning module 
    """

    def __init__(self, num_classes, training_steps=None, warmup_steps=None):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model, return_dict=True)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.training_steps = training_steps
        self.warmup_steps = warmup_steps
        self.criterion = nn.BCELoss()
    
    def forward(self, input_ids, attention_mask, labels=None):
        """
        Forward pass
        """
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.classifier(output.pooler_output)
        output = torch.sigmoid(output)    
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs, "labels": labels}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def training_epoch_end(self, outputs):
        labels = []
        predictions = []
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)
        
        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)

        for i, name in enumerate(lex_classes):
            class_roc_auc = auroc(predictions[:, i], labels[:, i])
            self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)
        

    def configure_optimizers(self):
        """
        Optimizer and Learning rate scheduler
        """
        optimizer = AdamW(self.parameters(), lr=2e-5)
        scheduler = get_linear_schedule_with_warmup(optimizer,
                                                    num_warmup_steps=self.warmup_steps,
                                                    num_training_steps=self.training_steps)
        return dict(optimizer=optimizer,
                    lr_scheduler=dict(scheduler=scheduler,
                                      interval='step'))