zaidmehdi's picture
plot training history
c112e25
raw
history blame
3.35 kB
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from datasets import DatasetDict, Dataset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
def get_datasetdict_object(df_train, df_val, df_test):
mapper = {"#2_tweet": "tweet", "#3_country_label": "label"}
columns_to_keep = ["tweet", "label"]
df_train = df_train.rename(columns=mapper)[columns_to_keep]
df_val = df_val.rename(columns=mapper)[columns_to_keep]
df_test = df_test.rename(columns=mapper)[columns_to_keep]
train_dataset = Dataset.from_pandas(df_train)
val_dataset = Dataset.from_pandas(df_val)
test_dataset = Dataset.from_pandas(df_test)
return DatasetDict({'train': train_dataset, 'val': val_dataset,
'test': test_dataset})
def tokenize(batch, tokenizer):
return tokenizer(batch["tweet"], padding='max_length', max_length=768, truncation=True)
def get_dataset(train_path:str, test_path:str, tokenizer):
df_train = pd.read_csv(train_path, sep="\t")
df_train, df_val = train_test_split(df_train, test_size=0.23805, random_state=42,
stratify=df_train["#3_country_label"])
df_train = df_train.reset_index(drop=True)
df_val = df_val.reset_index(drop=True)
df_test = pd.read_csv(test_path, sep="\t")
encoder = LabelEncoder()
df_train["#3_country_label"] = encoder.fit_transform(df_train["#3_country_label"])
df_val["#3_country_label"] = encoder.transform(df_val["#3_country_label"])
df_test["#3_country_label"] = encoder.transform(df_test["#3_country_label"])
dataset = get_datasetdict_object(df_train, df_val, df_test)
dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
return dataset, encoder
def serialize_data(data, output_path:str):
with open(output_path, "wb") as f:
pickle.dump(data, f)
def load_data(input_path:str):
with open(input_path, "rb") as f:
return pickle.load(f)
def plot_confusion_matrix(y_true, y_preds):
labels = sorted(set(y_true.tolist() + y_preds.tolist()))
cm = confusion_matrix(y_true, y_preds)
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, cmap="Blues",
xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix')
plt.show()
def plot_training_history(history):
epochs = np.arange(1, len(history["train_loss"]) + 1)
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, history["train_loss"], label='Train Loss')
plt.plot(epochs, history["valid_loss"], label='Valid Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(epochs, history["train_accuracies"], label='Train Accuracy')
plt.plot(epochs, history["valid_accuracies"], label='Valid Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.tight_layout()
plt.savefig('../docs/images/training_history.png')