cconsti commited on
Commit
9810d0f
·
verified ·
1 Parent(s): 816facc

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +19 -4
train.py CHANGED
@@ -34,10 +34,9 @@ tokenizer = T5Tokenizer.from_pretrained(model_name)
34
  model = T5ForConditionalGeneration.from_pretrained(model_name)
35
 
36
  # Tokenization function
 
37
  def tokenize_function(examples):
38
- print("Sample data structure:", examples) # Move print inside function
39
-
40
- inputs = examples["input"] # Make sure "input" matches dataset keys
41
  targets = examples["output"]
42
 
43
  model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
@@ -46,7 +45,23 @@ def tokenize_function(examples):
46
  model_inputs["labels"] = labels["input_ids"]
47
  return model_inputs
48
 
49
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Apply tokenization
52
  tokenized_datasets = dataset.map(tokenize_function, batched=True)
 
34
  model = T5ForConditionalGeneration.from_pretrained(model_name)
35
 
36
  # Tokenization function
37
+ # Define tokenization function before mapping
38
  def tokenize_function(examples):
39
+ inputs = examples["input"] # Ensure this matches dataset key
 
 
40
  targets = examples["output"]
41
 
42
  model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
 
45
  model_inputs["labels"] = labels["input_ids"]
46
  return model_inputs
47
 
48
+ # Check dataset structure
49
+ print("Dataset splits available:", dataset)
50
+
51
+ # If "test" split is missing, create one
52
+ if "test" not in dataset:
53
+ dataset = dataset["train"].train_test_split(test_size=0.1)
54
+
55
+ # Tokenize dataset
56
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
57
+
58
+ # Assign train and eval datasets
59
+ train_dataset = tokenized_datasets["train"]
60
+ eval_dataset = tokenized_datasets["test"]
61
+
62
+ # Debug output
63
+ print("Dataset successfully split and tokenized")
64
+
65
 
66
  # Apply tokenization
67
  tokenized_datasets = dataset.map(tokenize_function, batched=True)