import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
from collections import Counter
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from collections import Counter, defaultdict
from nltk.corpus import stopwords
from sklearn.model_selection import train_test_split
import re
from sklearn.utils import shuffle



def cos_dist(x, y):
    ## cosine distance function
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    batch_size = x.size(0)
    c = torch.clamp(1 - cos(x.view(batch_size, -1), y.view(batch_size, -1)),
                    min=0)
    return c.mean()




def tag_mapping(tags):
    """
    Create a dictionary and a mapping of tags, sorted by frequency.
    """
    #tags = [s[1] for s in dataset]
    dico = Counter(tags)
    tag_to_id, id_to_tag = create_mapping(dico)
    print("Found %i unique named entity tags" % len(dico))
    return dico, tag_to_id, id_to_tag


def create_mapping(dico):
    """
    Create a mapping (item to ID / ID to item) from a dictionary.
    Items are ordered by decreasing frequency.
    """
    sorted_items = sorted(dico.items(), key=lambda x: (-x[1], x[0]))
    id_to_item = {i: v[0] for i, v in enumerate(sorted_items)}
    item_to_id = {v: k for k, v in id_to_item.items()}
    return item_to_id, id_to_item




def clean_str(string):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()


def clean_doc(x, word_freq):
    stop_words = set(stopwords.words('english'))
    clean_docs = []
    most_commons = dict(word_freq.most_common(min(len(word_freq), 50000)))
    for doc_content in x:
        doc_words = []
        cleaned = clean_str(doc_content.strip())
        for word in cleaned.split():
            if word not in stop_words and word_freq[word] >= 5:
                if word in most_commons:
                    doc_words.append(word)
                else:
                    doc_words.append("<UNK>")
        doc_str = ' '.join(doc_words).strip()
        clean_docs.append(doc_str)
    return clean_docs



def load_dataset(dataset):

    if dataset == 'sst':
        df_train = pd.read_csv("./dataset/sst/SST-2/train.tsv", delimiter='\t', header=0)
        
        df_val = pd.read_csv("./dataset/sst/SST-2/dev.tsv", delimiter='\t', header=0)
        
        df_test = pd.read_csv("./dataset/sst/SST-2/sst-test.tsv", delimiter='\t', header=None, names=['sentence', 'label'])

        train_sentences = df_train.sentence.values
        val_sentences = df_val.sentence.values
        test_sentences = df_test.sentence.values
        train_labels = df_train.label.values
        val_labels = df_val.label.values
        test_labels = df_test.label.values   
    

    if dataset == '20news':
        
        VALIDATION_SPLIT = 0.8
        newsgroups_train  = fetch_20newsgroups('dataset/20news', subset='train',  shuffle=True, random_state=0)
        print(newsgroups_train.target_names)
        print(len(newsgroups_train.data))

        newsgroups_test  = fetch_20newsgroups('dataset/20news', subset='test',  shuffle=False)

        print(len(newsgroups_test.data))

        train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))

        train_sentences = newsgroups_train.data[:train_len]
        val_sentences = newsgroups_train.data[train_len:]
        test_sentences = newsgroups_test.data
        train_labels = newsgroups_train.target[:train_len]
        val_labels = newsgroups_train.target[train_len:]
        test_labels = newsgroups_test.target



    if dataset == '20news-15':
        VALIDATION_SPLIT = 0.8
        cats = ['alt.atheism',
        'comp.graphics',
        'comp.os.ms-windows.misc',
        'comp.sys.ibm.pc.hardware',
        'comp.sys.mac.hardware',
        'comp.windows.x',
        'rec.autos',
        'rec.motorcycles',
        'rec.sport.baseball',
        'rec.sport.hockey',
        'misc.forsale',
        'sci.crypt',
        'sci.electronics',
        'sci.med',
        'sci.space']
        newsgroups_train  = fetch_20newsgroups('dataset/20news', subset='train',  shuffle=True, categories=cats, random_state=0)
        print(newsgroups_train.target_names)
        print(len(newsgroups_train.data))

        newsgroups_test  = fetch_20newsgroups('dataset/20news', subset='test',  shuffle=False, categories=cats)

        print(len(newsgroups_test.data))

        train_len = int(VALIDATION_SPLIT * len(newsgroups_train.data))

        train_sentences = newsgroups_train.data[:train_len]
        val_sentences = newsgroups_train.data[train_len:]
        test_sentences = newsgroups_test.data
        train_labels = newsgroups_train.target[:train_len]
        val_labels = newsgroups_train.target[train_len:]
        test_labels = newsgroups_test.target


    if dataset == '20news-5':
        cats = [
        'soc.religion.christian',
        'talk.politics.guns',
        'talk.politics.mideast',
        'talk.politics.misc',
        'talk.religion.misc']
              
        newsgroups_test  = fetch_20newsgroups('dataset/20news', subset='test',  shuffle=False, categories=cats)
        print(newsgroups_test.target_names)
        print(len(newsgroups_test.data))

        train_sentences = None
        val_sentences = None
        test_sentences = newsgroups_test.data
        train_labels = None
        val_labels = None
        test_labels = newsgroups_test.target

    if dataset == 'wos':
        TESTING_SPLIT = 0.6
        VALIDATION_SPLIT = 0.8
        file_path = './dataset/WebOfScience/WOS46985/X.txt'
        with open(file_path, 'r') as read_file:
            x_temp = read_file.readlines()
            x_all = []
            for x in x_temp:
                x_all.append(str(x))

        print(len(x_all))

        file_path = './dataset/WebOfScience/WOS46985/Y.txt'
        with open(file_path, 'r') as read_file:
            y_temp= read_file.readlines()
            y_all = []
            for y in y_temp:
                y_all.append(int(y))
        print(len(y_all))
        print(max(y_all), min(y_all))


        x_in = []
        y_in = []
        for i in range(len(x_all)):
            x_in.append(x_all[i])
            y_in.append(y_all[i])


        train_val_len = int(TESTING_SPLIT * len(x_in))
        train_len = int(VALIDATION_SPLIT * train_val_len)

        train_sentences = x_in[:train_len]
        val_sentences = x_in[train_len:train_val_len]
        test_sentences = x_in[train_val_len:]

        train_labels = y_in[:train_len]
        val_labels = y_in[train_len:train_val_len]
        test_labels = y_in[train_val_len:]

        print(len(train_labels))
        print(len(val_labels))
        print(len(test_labels))


    if dataset == 'wos-100':
        TESTING_SPLIT = 0.6
        VALIDATION_SPLIT = 0.8
        file_path = './dataset/WebOfScience/WOS46985/X.txt'
        with open(file_path, 'r') as read_file:
            x_temp = read_file.readlines()
            x_all = []
            for x in x_temp:
                x_all.append(str(x))

        print(len(x_all))

        file_path = './dataset/WebOfScience/WOS46985/Y.txt'
        with open(file_path, 'r') as read_file:
            y_temp= read_file.readlines()
            y_all = []
            for y in y_temp:
                y_all.append(int(y))
        print(len(y_all))
        print(max(y_all), min(y_all))


        x_in = []
        y_in = []
        for i in range(len(x_all)):
            if y_all[i] in range(100):
                x_in.append(x_all[i])
                y_in.append(y_all[i])

        for i in range(133):
            num = 0
            for y in y_in:
                if y == i:
                    num = num + 1
            # print(num)

        train_val_len = int(TESTING_SPLIT * len(x_in))
        train_len = int(VALIDATION_SPLIT * train_val_len)

        train_sentences = x_in[:train_len]
        val_sentences = x_in[train_len:train_val_len]
        test_sentences = x_in[train_val_len:]

        train_labels = y_in[:train_len]
        val_labels = y_in[train_len:train_val_len]
        test_labels = y_in[train_val_len:]

        print(len(train_labels))
        print(len(val_labels))
        print(len(test_labels))

    if dataset == 'wos-34':
        TESTING_SPLIT = 0.6
        VALIDATION_SPLIT = 0.8
        file_path = './dataset/WebOfScience/WOS46985/X.txt'
        with open(file_path, 'r') as read_file:
            x_temp = read_file.readlines()
            x_all = []
            for x in x_temp:
                x_all.append(str(x))

        print(len(x_all))

        file_path = './dataset/WebOfScience/WOS46985/Y.txt'
        with open(file_path, 'r') as read_file:
            y_temp= read_file.readlines()
            y_all = []
            for y in y_temp:
                y_all.append(int(y))
        print(len(y_all))
        print(max(y_all), min(y_all))

        x_in = []
        y_in = []
        for i in range(len(x_all)):
            if (y_all[i] in range(100)) != True:
                x_in.append(x_all[i])
                y_in.append(y_all[i])

        for i in range(133):
            num = 0
            for y in y_in:
                if y == i:
                    num = num + 1
            # print(num)

        train_val_len = int(TESTING_SPLIT * len(x_in))
        train_len = int(VALIDATION_SPLIT * train_val_len)

        train_sentences = None
        val_sentences = None
        test_sentences = x_in[train_val_len:]
        
        train_labels = None
        val_labels = None
        test_labels = y_in[train_val_len:]

        print(len(test_labels))
        
    if dataset == 'agnews':

        VALIDATION_SPLIT = 0.8
        labels_in_domain = [1, 2]

        train_df = pd.read_csv('./dataset/agnews/train.csv', header=None)
        train_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
        # train_df = pd.concat([train_df, pd.get_dummies(train_df['label'],prefix='label')], axis=1)
        print(train_df.dtypes)
        train_in_df_sentence = []
        train_in_df_label = []
        
        for i in range(len(train_df.sentence.values)):
            sentence_temp = ''.join(str(train_df.sentence.values[i]))
            train_in_df_sentence.append(sentence_temp)
            train_in_df_label.append(train_df.label.values[i]-1)

        test_df = pd.read_csv('./dataset/agnews/test.csv', header=None)
        test_df.rename(columns={0: 'label',1: 'title', 2:'sentence'}, inplace=True)
        # test_df = pd.concat([test_df, pd.get_dummies(test_df['label'],prefix='label')], axis=1)
        test_in_df_sentence = []
        test_in_df_label = []
        for i in range(len(test_df.sentence.values)):
            test_in_df_sentence.append(str(test_df.sentence.values[i]))
            test_in_df_label.append(test_df.label.values[i]-1)

        train_len = int(VALIDATION_SPLIT * len(train_in_df_sentence))

        train_sentences = train_in_df_sentence[:train_len]
        val_sentences = train_in_df_sentence[train_len:]
        test_sentences = test_in_df_sentence
        train_labels = train_in_df_label[:train_len]
        val_labels = train_in_df_label[train_len:]
        test_labels = test_in_df_label
        print(len(train_sentences))
        print(len(val_sentences))
        print(len(test_sentences))


    return train_sentences, val_sentences, test_sentences, train_labels, val_labels, test_labels