zaidmehdi commited on
Commit
1c28db7
·
1 Parent(s): c112e25

serializing label encoder

Browse files
models/label_encoder.pkl ADDED
Binary file (456 Bytes). View file
 
src/model_training.py CHANGED
@@ -1,14 +1,10 @@
1
  import torch
2
- import torch.nn as nn
3
  import torch.optim as optim
4
- from datasets import DatasetDict, Dataset
5
- from sklearn.model_selection import train_test_split
6
- from sklearn.preprocessing import LabelEncoder
7
  from torch.utils.data import DataLoader
8
  from tqdm.auto import tqdm
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
 
11
- from utils import get_dataset, plot_training_history
12
 
13
 
14
  def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
@@ -102,6 +98,7 @@ def main():
102
  tokenizer = AutoTokenizer.from_pretrained(model_name)
103
 
104
  dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
 
105
 
106
  for data in dataset:
107
  dataset[data] = dataset[data].remove_columns(["tweet"])
@@ -123,7 +120,7 @@ def main():
123
 
124
  model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
125
  plot_training_history(history)
126
-
127
 
128
  if __name__ == "__main__":
129
  main()
 
1
  import torch
 
2
  import torch.optim as optim
 
 
 
3
  from torch.utils.data import DataLoader
4
  from tqdm.auto import tqdm
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
 
7
+ from utils import get_dataset, serialize_data, plot_training_history
8
 
9
 
10
  def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
 
98
  tokenizer = AutoTokenizer.from_pretrained(model_name)
99
 
100
  dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
101
+ serialize_data(label_encoder, "../models/label_encoder.pkl")
102
 
103
  for data in dataset:
104
  dataset[data] = dataset[data].remove_columns(["tweet"])
 
120
 
121
  model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
122
  plot_training_history(history)
123
+
124
 
125
  if __name__ == "__main__":
126
  main()