Update train.py
Browse files
train.py
CHANGED
@@ -6,7 +6,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
-
|
|
|
10 |
# Load dataset
|
11 |
dataset = load_dataset("tatsu-lab/alpaca") # Change if using your dataset
|
12 |
|
@@ -45,7 +46,7 @@ print("Dataset successfully split and tokenized.")
|
|
45 |
|
46 |
# Define training arguments
|
47 |
training_args = TrainingArguments(
|
48 |
-
output_dir=
|
49 |
per_device_train_batch_size=2, # Lowered to avoid memory issues
|
50 |
per_device_eval_batch_size=2,
|
51 |
num_train_epochs=1, # Test run (increase for full fine-tuning)
|
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
+
output_dir = "/tmp/t5-finetuned"
|
10 |
+
os.makedirs(output_dir, exist_ok=True)
|
11 |
# Load dataset
|
12 |
dataset = load_dataset("tatsu-lab/alpaca") # Change if using your dataset
|
13 |
|
|
|
46 |
|
47 |
# Define training arguments
|
48 |
training_args = TrainingArguments(
|
49 |
+
output_dir=output_dir,
|
50 |
per_device_train_batch_size=2, # Lowered to avoid memory issues
|
51 |
per_device_eval_batch_size=2,
|
52 |
num_train_epochs=1, # Test run (increase for full fine-tuning)
|