Spaces:
Sleeping
Sleeping
return label encoder
Browse files- src/utils.py +3 -2
src/utils.py
CHANGED
@@ -26,7 +26,7 @@ def get_datasetdict_object(df_train, df_val, df_test):
|
|
26 |
|
27 |
|
28 |
def tokenize(batch, tokenizer):
|
29 |
-
return tokenizer(batch["tweet"], padding='max_length', max_length=
|
30 |
|
31 |
|
32 |
def get_dataset(train_path:str, test_path:str, tokenizer):
|
@@ -46,7 +46,8 @@ def get_dataset(train_path:str, test_path:str, tokenizer):
|
|
46 |
dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
|
47 |
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
|
48 |
|
49 |
-
return dataset
|
|
|
50 |
|
51 |
def serialize_data(data, output_path:str):
|
52 |
with open(output_path, "wb") as f:
|
|
|
26 |
|
27 |
|
28 |
def tokenize(batch, tokenizer):
|
29 |
+
return tokenizer(batch["tweet"], padding='max_length', max_length=768, truncation=True)
|
30 |
|
31 |
|
32 |
def get_dataset(train_path:str, test_path:str, tokenizer):
|
|
|
46 |
dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
|
47 |
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
|
48 |
|
49 |
+
return dataset, encoder
|
50 |
+
|
51 |
|
52 |
def serialize_data(data, output_path:str):
|
53 |
with open(output_path, "wb") as f:
|