Spaces:
Sleeping
Sleeping
serializing label encoder
Browse files- models/label_encoder.pkl +0 -0
- src/model_training.py +3 -6
models/label_encoder.pkl
ADDED
Binary file (456 Bytes). View file
|
|
src/model_training.py
CHANGED
@@ -1,14 +1,10 @@
|
|
1 |
import torch
|
2 |
-
import torch.nn as nn
|
3 |
import torch.optim as optim
|
4 |
-
from datasets import DatasetDict, Dataset
|
5 |
-
from sklearn.model_selection import train_test_split
|
6 |
-
from sklearn.preprocessing import LabelEncoder
|
7 |
from torch.utils.data import DataLoader
|
8 |
from tqdm.auto import tqdm
|
9 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
10 |
|
11 |
-
from utils import get_dataset, plot_training_history
|
12 |
|
13 |
|
14 |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
@@ -102,6 +98,7 @@ def main():
|
|
102 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
103 |
|
104 |
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
|
|
|
105 |
|
106 |
for data in dataset:
|
107 |
dataset[data] = dataset[data].remove_columns(["tweet"])
|
@@ -123,7 +120,7 @@ def main():
|
|
123 |
|
124 |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
125 |
plot_training_history(history)
|
126 |
-
|
127 |
|
128 |
if __name__ == "__main__":
|
129 |
main()
|
|
|
1 |
import torch
|
|
|
2 |
import torch.optim as optim
|
|
|
|
|
|
|
3 |
from torch.utils.data import DataLoader
|
4 |
from tqdm.auto import tqdm
|
5 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
|
7 |
+
from utils import get_dataset, serialize_data, plot_training_history
|
8 |
|
9 |
|
10 |
def train_model(model, optimizer, train_loader, val_loader, num_epochs=100, patience=10):
|
|
|
98 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
99 |
|
100 |
dataset, label_encoder = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
|
101 |
+
serialize_data(label_encoder, "../models/label_encoder.pkl")
|
102 |
|
103 |
for data in dataset:
|
104 |
dataset[data] = dataset[data].remove_columns(["tweet"])
|
|
|
120 |
|
121 |
model, history = train_model(model, optimizer, train_loader, val_loader, num_epochs=num_epochs)
|
122 |
plot_training_history(history)
|
123 |
+
|
124 |
|
125 |
if __name__ == "__main__":
|
126 |
main()
|