zaidmehdi commited on
Commit
8d84597
·
1 Parent(s): 9d3bdb3

add max length to padding and convert dataset to tensor

Browse files
Files changed (1) hide show
  1. src/utils.py +5 -4
src/utils.py CHANGED
@@ -26,7 +26,7 @@ def get_datasetdict_object(df_train, df_val, df_test):
26
 
27
 
28
  def tokenize(batch, tokenizer):
29
- return tokenizer(batch["tweet"], padding='max_length')
30
 
31
 
32
  def get_dataset(train_path:str, test_path:str, tokenizer):
@@ -43,9 +43,10 @@ def get_dataset(train_path:str, test_path:str, tokenizer):
43
  df_test["#3_country_label"] = encoder.transform(df_test["#3_country_label"])
44
 
45
  dataset = get_datasetdict_object(df_train, df_val, df_test)
46
-
47
- return dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
48
-
 
49
 
50
  def serialize_data(data, output_path:str):
51
  with open(output_path, "wb") as f:
 
26
 
27
 
28
  def tokenize(batch, tokenizer):
29
+ return tokenizer(batch["tweet"], padding='max_length', max_length=256)
30
 
31
 
32
  def get_dataset(train_path:str, test_path:str, tokenizer):
 
43
  df_test["#3_country_label"] = encoder.transform(df_test["#3_country_label"])
44
 
45
  dataset = get_datasetdict_object(df_train, df_val, df_test)
46
+ dataset = dataset.map(lambda x: tokenize(x, tokenizer), batched=True)
47
+ dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
48
+
49
+ return dataset
50
 
51
  def serialize_data(data, output_path:str):
52
  with open(output_path, "wb") as f: