Spaces:
Sleeping
Sleeping
plot training history
Browse files- src/model_training.py +3 -5
- src/utils.py +26 -1
src/model_training.py
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
-
import matplotlib.pyplot as plt
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.optim as optim
|
@@ -11,7 +8,7 @@ from torch.utils.data import DataLoader
|
|
11 |
from tqdm.auto import tqdm
|
12 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
13 |
|
14 |
-
from utils import get_dataset
|
15 |
|
16 |
|
17 |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
@@ -125,7 +122,8 @@ def main():
|
|
125 |
num_epochs = 100
|
126 |
|
127 |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
128 |
-
|
|
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
main()
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
import torch.optim as optim
|
|
|
8 |
from tqdm.auto import tqdm
|
9 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
10 |
|
11 |
+
from utils import get_dataset, plot_training_history
|
12 |
|
13 |
|
14 |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
|
|
122 |
num_epochs = 100
|
123 |
|
124 |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
125 |
+
plot_training_history(history)
|
126 |
+
|
127 |
|
128 |
if __name__ == "__main__":
|
129 |
main()
|
src/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import pickle
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
|
|
4 |
import pandas as pd
|
5 |
import seaborn as sns
|
6 |
from datasets import DatasetDict, Dataset
|
@@ -68,4 +69,28 @@ def plot_confusion_matrix(y_true, y_preds):
|
|
68 |
plt.xlabel('Predicted Label')
|
69 |
plt.ylabel('True Label')
|
70 |
plt.title('Confusion Matrix')
|
71 |
-
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pickle
|
2 |
|
3 |
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
import pandas as pd
|
6 |
import seaborn as sns
|
7 |
from datasets import DatasetDict, Dataset
|
|
|
69 |
plt.xlabel('Predicted Label')
|
70 |
plt.ylabel('True Label')
|
71 |
plt.title('Confusion Matrix')
|
72 |
+
plt.show()
|
73 |
+
|
74 |
+
|
75 |
+
def plot_training_history(history):
|
76 |
+
epochs = np.arange(1, len(history["train_loss"]) + 1)
|
77 |
+
plt.figure(figsize=(10, 5))
|
78 |
+
|
79 |
+
plt.subplot(1, 2, 1)
|
80 |
+
plt.plot(epochs, history["train_loss"], label='Train Loss')
|
81 |
+
plt.plot(epochs, history["valid_loss"], label='Valid Loss')
|
82 |
+
plt.xlabel('Epoch')
|
83 |
+
plt.ylabel('Loss')
|
84 |
+
plt.title('Training and Validation Loss')
|
85 |
+
plt.legend()
|
86 |
+
|
87 |
+
plt.subplot(1, 2, 2)
|
88 |
+
plt.plot(epochs, history["train_accuracies"], label='Train Accuracy')
|
89 |
+
plt.plot(epochs, history["valid_accuracies"], label='Valid Accuracy')
|
90 |
+
plt.xlabel('Epoch')
|
91 |
+
plt.ylabel('Accuracy')
|
92 |
+
plt.title('Training and Validation Accuracy')
|
93 |
+
plt.legend()
|
94 |
+
|
95 |
+
plt.tight_layout()
|
96 |
+
plt.savefig('../docs/images/training_history.png')
|