zaidmehdi commited on
Commit
c112e25
·
1 Parent(s): e12fb70

plot training history

Browse files
Files changed (2) hide show
  1. src/model_training.py +3 -5
  2. src/utils.py +26 -1
src/model_training.py CHANGED
@@ -1,6 +1,3 @@
1
- import matplotlib.pyplot as plt
2
- import numpy as np
3
- import pandas as pd
4
  import torch
5
  import torch.nn as nn
6
  import torch.optim as optim
@@ -11,7 +8,7 @@ from torch.utils.data import DataLoader
11
  from tqdm.auto import tqdm
12
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
13
 
14
- from utils import get_dataset
15
 
16
 
17
  def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
@@ -125,7 +122,8 @@ def main():
125
  num_epochs = 100
126
 
127
  model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
128
-
 
129
 
130
  if __name__ == "__main__":
131
  main()
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.optim as optim
 
8
  from tqdm.auto import tqdm
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
 
11
+ from utils import get_dataset, plot_training_history
12
 
13
 
14
  def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
 
122
  num_epochs = 100
123
 
124
  model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
125
+ plot_training_history(history)
126
+
127
 
128
  if __name__ == "__main__":
129
  main()
src/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import pickle
2
 
3
  import matplotlib.pyplot as plt
 
4
  import pandas as pd
5
  import seaborn as sns
6
  from datasets import DatasetDict, Dataset
@@ -68,4 +69,28 @@ def plot_confusion_matrix(y_true, y_preds):
68
  plt.xlabel('Predicted Label')
69
  plt.ylabel('True Label')
70
  plt.title('Confusion Matrix')
71
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pickle
2
 
3
  import matplotlib.pyplot as plt
4
+ import numpy as np
5
  import pandas as pd
6
  import seaborn as sns
7
  from datasets import DatasetDict, Dataset
 
69
  plt.xlabel('Predicted Label')
70
  plt.ylabel('True Label')
71
  plt.title('Confusion Matrix')
72
+ plt.show()
73
+
74
+
75
+ def plot_training_history(history):
76
+ epochs = np.arange(1, len(history["train_loss"]) + 1)
77
+ plt.figure(figsize=(10, 5))
78
+
79
+ plt.subplot(1, 2, 1)
80
+ plt.plot(epochs, history["train_loss"], label='Train Loss')
81
+ plt.plot(epochs, history["valid_loss"], label='Valid Loss')
82
+ plt.xlabel('Epoch')
83
+ plt.ylabel('Loss')
84
+ plt.title('Training and Validation Loss')
85
+ plt.legend()
86
+
87
+ plt.subplot(1, 2, 2)
88
+ plt.plot(epochs, history["train_accuracies"], label='Train Accuracy')
89
+ plt.plot(epochs, history["valid_accuracies"], label='Valid Accuracy')
90
+ plt.xlabel('Epoch')
91
+ plt.ylabel('Accuracy')
92
+ plt.title('Training and Validation Accuracy')
93
+ plt.legend()
94
+
95
+ plt.tight_layout()
96
+ plt.savefig('../docs/images/training_history.png')