zaidmehdi commited on
Commit
8580c38
·
1 Parent(s): e16cc69

defining functions to train classification head

Browse files
Files changed (1) hide show
  1. src/model_training.py +17 -5
src/model_training.py CHANGED
@@ -1,7 +1,8 @@
1
- from datasets import DatasetDict, Dataset
2
  import numpy as np
3
  import pandas as pd
4
  import torch
 
 
5
  from sklearn.linear_model import LogisticRegression
6
  from sklearn.metrics import accuracy_score, f1_score
7
  from transformers import AutoModel, AutoTokenizer
@@ -58,6 +59,11 @@ class PreProcessor:
58
  serialize_data(data_hidden, output_path=self.output_path)
59
 
60
 
 
 
 
 
 
61
  class Model():
62
  def __init__(self, data_input_path:str, model_name:str):
63
  self.model_name = model_name
@@ -75,14 +81,20 @@ class Model():
75
  random_state=2024)
76
  lr_model.fit(X_train, y_train)
77
  return lr_model
 
 
 
 
78
 
79
  def train_model(self, output_path):
80
- if self.model_name != "lr":
 
 
 
 
81
  raise ValueError(f"Model name {self.model_name} does not exist. Please try 'lr'!")
82
 
83
- lr_model = self._train_logistic_regression(self.X_train, self.y_train)
84
- self.model = lr_model
85
- serialize_data(lr_model, output_path)
86
 
87
  def _get_metrics(self, y_true, y_preds):
88
  accuracy = accuracy_score(y_true, y_preds)
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import torch
4
+ import torch.nn as nn
5
+ from datasets import DatasetDict, Dataset
6
  from sklearn.linear_model import LogisticRegression
7
  from sklearn.metrics import accuracy_score, f1_score
8
  from transformers import AutoModel, AutoTokenizer
 
59
  serialize_data(data_hidden, output_path=self.output_path)
60
 
61
 
62
+ class ClassificationHead(nn.Module):
63
+ def __init__(self, ) -> None:
64
+ super(ClassificationHead, self).__init__()
65
+
66
+
67
  class Model():
68
  def __init__(self, data_input_path:str, model_name:str):
69
  self.model_name = model_name
 
81
  random_state=2024)
82
  lr_model.fit(X_train, y_train)
83
  return lr_model
84
+
85
+ def _train_classification_head(X_train, y_train, base, train_base=False):
86
+
87
+ return
88
 
89
  def train_model(self, output_path):
90
+ if self.model_name == "lr":
91
+ self.model = self._train_logistic_regression(self.X_train, self.y_train)
92
+ elif self.model_name == "classification_head":
93
+ self.model = self._train_classification_head(self.X_train, self.y_train)
94
+ else:
95
  raise ValueError(f"Model name {self.model_name} does not exist. Please try 'lr'!")
96
 
97
+ serialize_data(self.model, output_path)
 
 
98
 
99
  def _get_metrics(self, y_true, y_preds):
100
  accuracy = accuracy_score(y_true, y_preds)