Update train.py
Browse files
train.py
CHANGED
@@ -6,10 +6,9 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
-
save_dir = "/
|
10 |
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
11 |
-
trainer.save_model(save_dir)
|
12 |
-
|
13 |
# Load dataset
|
14 |
dataset = load_dataset("tatsu-lab/alpaca")
|
15 |
dataset["train"] = dataset["train"].select(range(2000))
|
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
+
save_dir = "/tmp/t5-finetuned" # Use /tmp/, which is writable
|
10 |
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
11 |
+
trainer.save_model(save_dir) # Save the model
|
|
|
12 |
# Load dataset
|
13 |
dataset = load_dataset("tatsu-lab/alpaca")
|
14 |
dataset["train"] = dataset["train"].select(range(2000))
|