zaidmehdi commited on
Commit
ff82938
·
1 Parent(s): a10e0a0

return label encoder

Browse files
Files changed (1) hide show
  1. src/utils.py +3 -2
src/utils.py CHANGED
@@ -26,7 +26,7 @@ def get_datasetdict_object(df_train, df_val, df_test):
26
 
27
 
28
  def tokenize(batch, tokenizer):
29
- return tokenizer(batch["tweet"], padding='max_length', max_length=256, truncation=True)
30
 
31
 
32
  def get_dataset(train_path:str, test_path:str, tokenizer):
@@ -46,7 +46,8 @@ def get_dataset(train_path:str, test_path:str, tokenizer):
46
  dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
47
  dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
48
 
49
- return dataset
 
50
 
51
  def serialize_data(data, output_path:str):
52
  with open(output_path, "wb") as f:
 
26
 
27
 
28
  def tokenize(batch, tokenizer):
29
+ return tokenizer(batch["tweet"], padding='max_length', max_length=768, truncation=True)
30
 
31
 
32
  def get_dataset(train_path:str, test_path:str, tokenizer):
 
46
  dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
47
  dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
48
 
49
+ return dataset, encoder
50
+
51
 
52
  def serialize_data(data, output_path:str):
53
  with open(output_path, "wb") as f: