Update train.py
Browse files
train.py
CHANGED
@@ -34,10 +34,9 @@ tokenizer = T5Tokenizer.from_pretrained(model_name)
|
|
34 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
35 |
|
36 |
# Tokenization function
|
|
|
37 |
def tokenize_function(examples):
|
38 |
-
|
39 |
-
|
40 |
-
inputs = examples["input"] # Make sure "input" matches dataset keys
|
41 |
targets = examples["output"]
|
42 |
|
43 |
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
|
@@ -46,7 +45,23 @@ def tokenize_function(examples):
|
|
46 |
model_inputs["labels"] = labels["input_ids"]
|
47 |
return model_inputs
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Apply tokenization
|
52 |
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
|
|
34 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
35 |
|
36 |
# Tokenization function
|
37 |
+
# Define tokenization function before mapping
|
38 |
def tokenize_function(examples):
|
39 |
+
inputs = examples["input"] # Ensure this matches dataset key
|
|
|
|
|
40 |
targets = examples["output"]
|
41 |
|
42 |
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
|
|
|
45 |
model_inputs["labels"] = labels["input_ids"]
|
46 |
return model_inputs
|
47 |
|
48 |
+
# Check dataset structure
|
49 |
+
print("Dataset splits available:", dataset)
|
50 |
+
|
51 |
+
# If "test" split is missing, create one
|
52 |
+
if "test" not in dataset:
|
53 |
+
dataset = dataset["train"].train_test_split(test_size=0.1)
|
54 |
+
|
55 |
+
# Tokenize dataset
|
56 |
+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
57 |
+
|
58 |
+
# Assign train and eval datasets
|
59 |
+
train_dataset = tokenized_datasets["train"]
|
60 |
+
eval_dataset = tokenized_datasets["test"]
|
61 |
+
|
62 |
+
# Debug output
|
63 |
+
print("Dataset successfully split and tokenized")
|
64 |
+
|
65 |
|
66 |
# Apply tokenization
|
67 |
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|