# MODIFY AS REQUIRED
import torch
import pandas as pd
import seaborn as sns
import numpy as np

import matplotlib.pyplot as plt

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.model_selection import train_test_split
from datasets import load_dataset
from datasets import Dataset, DatasetDict
from transformers import DataCollatorWithPadding

from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from torch.optim import AdamW
from torch.nn import BCEWithLogitsLoss

from transformers import get_scheduler

from tqdm.auto import tqdm

import evaluate

from tqdm import tqdm
import logging
logging.basicConfig(level=logging.INFO)

from text_preprocessing import clean_tweet, clear_reply_mentions, normalizeTweet
from custom_model import CustomModel

'''
DATA_PATH = "../../data"

PROCESSED_PATH = f"{DATA_PATH}/processed"

PROCESSED_PATH_VIRAL = f'{DATA_PATH}/new/processed/viral'
PROCESSED_PATH_COVID = f'{DATA_PATH}/new/processed/covid'
'''

# Different models
BERT_BASE_UNCASED = "bert-base-uncased"
BERT_BASE_CASED = "bert-base-cased"
ROBERTA_BASE = "roberta-base"
BERT_TWEET = "vinai/bertweet-base"

# TODO: Don't forget to cite papers if you use some model
BERT_TINY = "prajjwal1/bert-tiny"

TWEET_MAX_LENGTH = 280

# TEST SPLIT RATIO + MODELS (ADD MORE MODELS FROM ABOVE)
MODELS = [BERT_TWEET, BERT_TINY, BERT_BASE_CASED, ROBERTA_BASE]
TEST_RATIO = 0.2

TOP_FEATURES = ["verified", "tweet_length", "possibly_sensitive", "sentiment", "nb_of_hashtags", "has_media", "nb_of_mentions"]

def preprocess_data(dataset):
    dataset.loc[:, 'has_media'] = dataset.has_media.astype("int")
    dataset.loc[:, 'possibly_sensitive'] = dataset.possibly_sensitive.astype("int")

    #dataset = dataset[dataset.sentiment_score > 0.7]
    dataset.loc[:, 'sentiment'] = dataset.sentiment.replace({'POSITIVE': 1, 'NEGATIVE': 0})
    dataset.loc[:, 'verified'] = dataset['verified'].astype(int)

    # remove tweets with 0 retweets (to eliminate their effects)
    #dataset = dataset[dataset.retweet_count > 0]

    ## UPDATE: Get tweets tweeted by the same user, on the same day he tweeted a viral tweet

    # Get the date from datetime
    # normalize() sets all datetimes clock to midnight, which is equivalent as keeping only the date part
    dataset['date'] = dataset.created_at.dt.normalize()

    viral_tweets = dataset[dataset.viral]
    non_viral_tweets = dataset[~dataset.viral]

    temp = non_viral_tweets.merge(viral_tweets[['author_id', 'date', 'id', 'viral']], on=['author_id', 'date'], suffixes=(None, '_y'))
    same_day_viral_ids = temp.id_y.unique()

    same_day_viral_tweets = viral_tweets[viral_tweets.id.isin(same_day_viral_ids)].drop_duplicates(subset=['author_id', 'date'])
    same_day_non_viral_tweets = temp.drop_duplicates(subset=['author_id', 'date'])

    logging.info(f"Number of viral tweets tweeted on the same day {len(same_day_viral_tweets)}")
    logging.info(f"Number of non viral tweets tweeted on the same day {len(same_day_non_viral_tweets)}")

    dataset = pd.concat([same_day_viral_tweets, same_day_non_viral_tweets], axis=0)
    dataset = dataset[['id', 'text'] + TOP_FEATURES + ['viral']]

    # Balance classes to have as many viral as non viral ones
    #dataset = pd.concat([positives, negatives.sample(n=len(positives))])
    #dataset = pd.concat([positives.iloc[:100], negatives.sample(n=len(positives)).iloc[:200]])

    # Clean text to prepare for tokenization
    #dataset = dataset.dropna()
    dataset.loc[:, "viral"] = dataset.viral.astype(int)

    # TODO: COMMENT IF YOU WANT TO KEEP TEXT AS IS
    dataset["cleaned_text"] = dataset.text.apply(lambda x: clean_tweet(x, demojize_emojis=False))

    dataset = dataset.dropna()
    dataset.loc[:, "extra_features"] = dataset[TOP_FEATURES].values.tolist()
    dataset = dataset[['id', 'cleaned_text', 'extra_features', 'viral']]

    return dataset

def prepare_dataset(sample_data, balance=False):
    # Split the train and test data st each has a fixed proportion of viral tweets
    train_dataset, eval_dataset = train_test_split(sample_data, test_size=TEST_RATIO, random_state=42, stratify=sample_data.viral)

    # Balance test set
    if balance:
        eval_virals = eval_dataset[eval_dataset.viral == 1]
        eval_non_virals = eval_dataset[eval_dataset.viral == 0]
        eval_dataset = pd.concat([eval_virals, eval_non_virals.sample(n=len(eval_virals))])

    logging.info('{:>5,} training samples with {:>5,} positives and {:>5,} negatives'.format(
        len(train_dataset), len(train_dataset[train_dataset.viral == 1]), len(train_dataset[train_dataset.viral == 0])))
    logging.info('{:>5,} validation samples with {:>5,} positives and {:>5,} negatives'.format(
        len(eval_dataset), len(eval_dataset[eval_dataset.viral == 1]), len(eval_dataset[eval_dataset.viral == 0])))

    train_dataset.to_parquet("train.parquet.gzip", compression='gzip')
    eval_dataset.to_parquet("test.parquet.gzip", compression='gzip')

    ds = load_dataset("parquet", data_files={'train': 'train.parquet.gzip', 'test': 'test.parquet.gzip'})
    return ds

def tokenize_function(example, tokenizer):
  # Truncate to max length. Note that a tweet's maximum length is 280
  # TODO: check dynamic padding: https://huggingface.co/course/chapter3/2?fw=pt#dynamic-padding
  return tokenizer(example["cleaned_text"], truncation=True)


def test_all_models(ds, nb_extra_dims, models=MODELS):
    models_losses = {}
    device = torch.device("mps") if torch.mps.is_available() else torch.device("cpu")

    output = ""

    for checkpoint in models:
        torch.mps.empty_cache()
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        custom_model = CustomModel(checkpoint, num_extra_dims=nb_extra_dims, num_labels=2)
        custom_model.to(device)

        tokenized_datasets = ds.map(lambda x: tokenize_function(x, tokenizer=tokenizer), batched=True)
        data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

        tokenized_datasets = tokenized_datasets.remove_columns(["__index_level_0__", "cleaned_text", "id"])
        tokenized_datasets = tokenized_datasets.rename_column("viral", "labels")
        tokenized_datasets.set_format("torch")

        batch_size = 32

        train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator)
        eval_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size, collate_fn=data_collator)

        criterion = BCEWithLogitsLoss()
        optimizer = AdamW(custom_model.parameters(), lr=5e-5)

        num_epochs = 15
        num_training_steps = num_epochs * len(train_dataloader)
        lr_scheduler = get_scheduler(
            name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
        )

        progress_bar = tqdm(range(num_training_steps))

        losses = []
        custom_model.train()
        for epoch in range(num_epochs):
            for batch in train_dataloader:
                batch = {k: v.to(device) for k, v in batch.items()}
                logits = custom_model(**batch).squeeze()

                loss = criterion(logits, batch['labels'].float())
                #losses.append(loss.cpu().item())
                losses.append(loss.item())
                loss.backward()

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)

        models_losses[checkpoint] = losses

        metric = evaluate.combine(["accuracy", "recall", "precision", "f1"])
        custom_model.eval()
        for batch in eval_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            with torch.no_grad():
                logits = custom_model(**batch)

            #predictions = torch.argmax(outputs, dim=-1)
            predictions = torch.round(torch.sigmoid(logits))
            metric.add_batch(predictions=predictions, references=batch["labels"])

        output += f"checkpoint: {checkpoint}: {metric.compute()}\n"
    logging.info(output)
    with open("same_day_as_viral_with_features_train_test_balanced_accuracy.txt", "w") as text_file:
        text_file.write(output)
    return models_losses

def main():
    # DATA FILE SHOULD BE AT THE ROOT WITH THIS SCRIPT
    all_tweets_labeled = pd.read_parquet(f'final_dataset_since_october_2022.parquet.gzip')

    dataset = preprocess_data(all_tweets_labeled)
    ds = prepare_dataset(dataset, balance=True)

    nb_extra_dims = len(TOP_FEATURES)
    test_all_models(ds, nb_extra_dims=nb_extra_dims)

if __name__ == "__main__":
    main()