Update train.py
Browse files
train.py
CHANGED
@@ -6,9 +6,9 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
-
save_dir = "
|
10 |
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
11 |
-
trainer.save_model(save_dir)
|
12 |
|
13 |
# Load dataset
|
14 |
dataset = load_dataset("tatsu-lab/alpaca")
|
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
+
save_dir = "/home/user/t5-finetuned" # Change to a writable directory
|
10 |
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
11 |
+
trainer.save_model(save_dir)
|
12 |
|
13 |
# Load dataset
|
14 |
dataset = load_dataset("tatsu-lab/alpaca")
|