import pandas as pd
import torch
from torch.utils.data import Dataset 
import numpy as np
from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
from transformers import BertTokenizer, BertForSequenceClassification, Trainer,TrainingArguments

# no augment dataset
# df = df = pd.read_csv(r".\train_set.csv")

# with augment training dataset
df = pd.read_csv(r".\cleaned_combined_aug_set.csv")
# df.info()
value_counts = df['label'].value_counts()
print(value_counts)


test_df = pd.read_csv(r".\test_set.csv")
# test_df.info()
test_df['label'].value_counts()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)

model = model.to('cuda')

# independent var
X = list(df['article'])
X_test = list(test_df['article'])

#dependent
y= list(df['label'])
y_test = list(test_df['label'])

max_length = 512
train_encodings = tokenizer(X, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
test_encodings = tokenizer(X_test, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')

class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

torch_train_dataset = CustomDataset(train_encodings,y)
torch_test_dataset = CustomDataset(test_encodings,y_test)

training_args = TrainingArguments(
    output_dir='./results/fake-news-bert-aug',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3
)

def compute_metrics(p):
    print(type(p))
    pred, labels = p
    pred = np.argmax(pred,axis=1)
    
    accuracy = accuracy_score(y_true=labels,y_pred=pred)
    recall = recall_score(y_true=labels,y_pred=pred)
    precision = precision_score(y_true=labels,y_pred=pred)
    f1 = f1_score(y_true=labels,y_pred=pred)
    
    return {"accuracy":accuracy,"precision":precision,"recall":recall,"f1":f1}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=torch_train_dataset,
    eval_dataset=torch_test_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

def predict(text):
    return trainer.predict(text)