zaidmehdi commited on
Commit
5db7813
·
1 Parent(s): 5099736

get dataset

Browse files
Files changed (1) hide show
  1. src/model_training.py +22 -0
src/model_training.py CHANGED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.model_selection import train_test_split
3
+
4
+ from utils import get_datasetdict_object
5
+
6
+
7
+ def get_dataset(train_path:str, test_path:str):
8
+ df_train = pd.read_csv(train_path, sep="\t")
9
+ df_train, df_val = train_test_split(df_train, test_size=0.23805, random_state=42,
10
+ stratify=df_train["#3_country_label"])
11
+ df_test = pd.read_csv(test_path, sep="\t")
12
+
13
+ return get_datasetdict_object(df_train, df_val, df_test)
14
+
15
+
16
+ def main():
17
+ dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv")
18
+ print(dataset)
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()