zaidmehdi commited on
Commit
2d22962
·
1 Parent(s): 6b5fdb9

defining model with classification head

Browse files
Files changed (1) hide show
  1. src/model_training.py +20 -5
src/model_training.py CHANGED
@@ -1,18 +1,33 @@
1
- from transformers import AutoTokenizer, AutoModel
 
 
2
 
3
  from utils import get_dataset
4
 
5
 
6
- class Model():
7
- def __init__(self) -> None:
8
- pass
 
 
9
 
 
 
 
 
 
 
 
 
10
 
11
  def main():
12
  model_name = "moussaKam/AraBART"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = AutoModel.from_pretrained(model_name)
 
15
  dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
 
 
16
 
17
  print(dataset["train"])
18
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
4
 
5
  from utils import get_dataset
6
 
7
 
8
+ class Model(nn.Module):
9
+ def __init__(self, model_name, config, num_labels):
10
+ super().__init__()
11
+ self.model = AutoModel.from_pretrained(model_name, config=config)
12
+ self.classification_head = nn.Linear(config.hidden_size, num_labels)
13
 
14
+ def forward(self, input_ids):
15
+ outputs = self.model(input_ids)
16
+ pooled_output = outputs.last_hidden_state[:, 0]
17
+ logits = self.classification_head(pooled_output)
18
+ probabilities = F.softmax(logits, dim=-1)
19
+
20
+ return probabilities
21
+
22
 
23
  def main():
24
  model_name = "moussaKam/AraBART"
25
  tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ config = AutoConfig.from_pretrained(model_name)
27
+
28
  dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv", tokenizer)
29
+ num_labels = len(set(dataset["train"]["label"]))
30
+ model = Model(model_name, config, num_labels)
31
 
32
  print(dataset["train"])
33