Spaces:
Sleeping
Sleeping
defining functions to train classification head
Browse files- src/model_training.py +17 -5
src/model_training.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
from datasets import DatasetDict, Dataset
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
import torch
|
|
|
|
|
5 |
from sklearn.linear_model import LogisticRegression
|
6 |
from sklearn.metrics import accuracy_score, f1_score
|
7 |
from transformers import AutoModel, AutoTokenizer
|
@@ -58,6 +59,11 @@ class PreProcessor:
|
|
58 |
serialize_data(data_hidden, output_path=self.output_path)
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
|
|
61 |
class Model():
|
62 |
def __init__(self, data_input_path:str, model_name:str):
|
63 |
self.model_name = model_name
|
@@ -75,14 +81,20 @@ class Model():
|
|
75 |
random_state=2024)
|
76 |
lr_model.fit(X_train, y_train)
|
77 |
return lr_model
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def train_model(self, output_path):
|
80 |
-
if self.model_name
|
|
|
|
|
|
|
|
|
81 |
raise ValueError(f"Model name {self.model_name} does not exist. Please try 'lr'!")
|
82 |
|
83 |
-
|
84 |
-
self.model = lr_model
|
85 |
-
serialize_data(lr_model, output_path)
|
86 |
|
87 |
def _get_metrics(self, y_true, y_preds):
|
88 |
accuracy = accuracy_score(y_true, y_preds)
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from datasets import DatasetDict, Dataset
|
6 |
from sklearn.linear_model import LogisticRegression
|
7 |
from sklearn.metrics import accuracy_score, f1_score
|
8 |
from transformers import AutoModel, AutoTokenizer
|
|
|
59 |
serialize_data(data_hidden, output_path=self.output_path)
|
60 |
|
61 |
|
62 |
+
class ClassificationHead(nn.Module):
|
63 |
+
def __init__(self, ) -> None:
|
64 |
+
super(ClassificationHead, self).__init__()
|
65 |
+
|
66 |
+
|
67 |
class Model():
|
68 |
def __init__(self, data_input_path:str, model_name:str):
|
69 |
self.model_name = model_name
|
|
|
81 |
random_state=2024)
|
82 |
lr_model.fit(X_train, y_train)
|
83 |
return lr_model
|
84 |
+
|
85 |
+
def _train_classification_head(X_train, y_train, base, train_base=False):
|
86 |
+
|
87 |
+
return
|
88 |
|
89 |
def train_model(self, output_path):
|
90 |
+
if self.model_name == "lr":
|
91 |
+
self.model = self._train_logistic_regression(self.X_train, self.y_train)
|
92 |
+
elif self.model_name == "classification_head":
|
93 |
+
self.model = self._train_classification_head(self.X_train, self.y_train)
|
94 |
+
else:
|
95 |
raise ValueError(f"Model name {self.model_name} does not exist. Please try 'lr'!")
|
96 |
|
97 |
+
serialize_data(self.model, output_path)
|
|
|
|
|
98 |
|
99 |
def _get_metrics(self, y_true, y_preds):
|
100 |
accuracy = accuracy_score(y_true, y_preds)
|