zaidmehdi commited on
Commit
d57f40c
·
1 Parent(s): 5db7813
Files changed (1) hide show
  1. src/model_training.py +13 -0
src/model_training.py CHANGED
@@ -1,20 +1,33 @@
1
  import pandas as pd
2
  from sklearn.model_selection import train_test_split
 
3
 
4
  from utils import get_datasetdict_object
5
 
6
 
 
 
 
 
7
  def get_dataset(train_path:str, test_path:str):
8
  df_train = pd.read_csv(train_path, sep="\t")
9
  df_train, df_val = train_test_split(df_train, test_size=0.23805, random_state=42,
10
  stratify=df_train["#3_country_label"])
 
 
11
  df_test = pd.read_csv(test_path, sep="\t")
12
 
13
  return get_datasetdict_object(df_train, df_val, df_test)
14
 
15
 
 
 
 
 
16
  def main():
17
  dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv")
 
 
18
  print(dataset)
19
 
20
 
 
1
  import pandas as pd
2
  from sklearn.model_selection import train_test_split
3
+ from transformers import AutoTokenizer
4
 
5
  from utils import get_datasetdict_object
6
 
7
 
8
+ model_name = "moussaKam/AraBART"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+
11
+
12
  def get_dataset(train_path:str, test_path:str):
13
  df_train = pd.read_csv(train_path, sep="\t")
14
  df_train, df_val = train_test_split(df_train, test_size=0.23805, random_state=42,
15
  stratify=df_train["#3_country_label"])
16
+ df_train = df_train.reset_index(drop=True)
17
+ df_val = df_val.reset_index(drop=True)
18
  df_test = pd.read_csv(test_path, sep="\t")
19
 
20
  return get_datasetdict_object(df_train, df_val, df_test)
21
 
22
 
23
+ def tokenize(batch):
24
+ return tokenizer(batch["tweet"], padding=True)
25
+
26
+
27
  def main():
28
  dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv")
29
+ dataset = dataset.map(tokenize, batched=True)
30
+
31
  print(dataset)
32
 
33