cconsti commited on
Commit
16f00d1
·
verified ·
1 Parent(s): cee0285

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +2 -3
train.py CHANGED
@@ -6,10 +6,9 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
6
  os.environ["HF_HOME"] = "/app/hf_cache"
7
  os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
8
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
9
- save_dir = "/home/user/t5-finetuned" # Change to a writable directory
10
  os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
11
- trainer.save_model(save_dir)
12
-
13
  # Load dataset
14
  dataset = load_dataset("tatsu-lab/alpaca")
15
  dataset["train"] = dataset["train"].select(range(2000))
 
6
  os.environ["HF_HOME"] = "/app/hf_cache"
7
  os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
8
  os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
9
+ save_dir = "/tmp/t5-finetuned" # Use /tmp/, which is writable
10
  os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
11
+ trainer.save_model(save_dir) # Save the model
 
12
  # Load dataset
13
  dataset = load_dataset("tatsu-lab/alpaca")
14
  dataset["train"] = dataset["train"].select(range(2000))