File size: 3,938 Bytes
2cb1421
 
1964ece
2f5f23f
c112e25
6b5fdb9
2f5f23f
1964ece
00cdd45
 
6b5fdb9
 
00cdd45
 
5099736
00cdd45
 
 
 
5099736
00cdd45
 
 
5099736
00cdd45
6b5fdb9
5099736
 
9a23b5c
 
6b5fdb9
ff82938
6b5fdb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d84597
 
 
ff82938
 
f8b3be6
2cb1421
 
 
f8b3be6
 
2cb1421
 
 
2f5f23f
 
 
 
 
 
 
 
 
 
 
c112e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1964ece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import pickle

import evaluate
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from datasets import DatasetDict, Dataset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


def get_datasetdict_object(df_train, df_val, df_test):
    mapper = {"#2_tweet": "tweet", "#3_country_label": "label"}
    columns_to_keep = ["tweet", "label"]

    df_train = df_train.rename(columns=mapper)[columns_to_keep]
    df_val = df_val.rename(columns=mapper)[columns_to_keep]
    df_test = df_test.rename(columns=mapper)[columns_to_keep]

    train_dataset = Dataset.from_pandas(df_train)
    val_dataset = Dataset.from_pandas(df_val)
    test_dataset = Dataset.from_pandas(df_test)

    return DatasetDict({'train': train_dataset, 'val': val_dataset,
                        'test': test_dataset})


def tokenize(batch, tokenizer):
    return tokenizer(batch["tweet"], padding='max_length', max_length=768, truncation=True)


def get_dataset(train_path:str, test_path:str, tokenizer):
    df_train = pd.read_csv(train_path, sep="\t")
    df_train, df_val = train_test_split(df_train, test_size=0.23805, random_state=42, 
                                        stratify=df_train["#3_country_label"])
    df_train = df_train.reset_index(drop=True)
    df_val = df_val.reset_index(drop=True)
    df_test = pd.read_csv(test_path, sep="\t")

    encoder = LabelEncoder()
    df_train["#3_country_label"] = encoder.fit_transform(df_train["#3_country_label"])
    df_val["#3_country_label"] = encoder.transform(df_val["#3_country_label"])
    df_test["#3_country_label"] = encoder.transform(df_test["#3_country_label"])

    dataset = get_datasetdict_object(df_train, df_val, df_test)
    dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
    dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    
    return dataset, encoder


def serialize_data(data, output_path:str):
    with open(output_path, "wb") as f:
        pickle.dump(data, f) 


def load_data(input_path:str):
    with open(input_path, "rb") as f:
        return pickle.load(f)


def plot_confusion_matrix(y_true, y_preds):
    labels = sorted(set(y_true.tolist() + y_preds.tolist()))
    cm = confusion_matrix(y_true, y_preds)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, cmap="Blues",
                xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()


def plot_training_history(history):
    epochs = np.arange(1, len(history["train_loss"]) + 1)
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, history["train_loss"], label='Train Loss')
    plt.plot(epochs, history["valid_loss"], label='Valid Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, history["train_accuracies"], label='Train Accuracy')
    plt.plot(epochs, history["valid_accuracies"], label='Valid Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig('../docs/images/training_history.png')


def get_model_accuracy(model, test_loader):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    metric = evaluate.load("accuracy")
    model.eval()
    for batch in test_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])

    return metric.compute()