Spaces:
Sleeping
Sleeping
get dataset
Browse files- src/model_training.py +22 -0
src/model_training.py
CHANGED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from sklearn.model_selection import train_test_split
|
3 |
+
|
4 |
+
from utils import get_datasetdict_object
|
5 |
+
|
6 |
+
|
7 |
+
def get_dataset(train_path:str, test_path:str):
|
8 |
+
df_train = pd.read_csv(train_path, sep="\t")
|
9 |
+
df_train, df_val = train_test_split(df_train, test_size=0.23805, random_state=42,
|
10 |
+
stratify=df_train["#3_country_label"])
|
11 |
+
df_test = pd.read_csv(test_path, sep="\t")
|
12 |
+
|
13 |
+
return get_datasetdict_object(df_train, df_val, df_test)
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
dataset = get_dataset("data/DA_train_labeled.tsv", "data/DA_dev_labeled.tsv")
|
18 |
+
print(dataset)
|
19 |
+
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
main()
|