Update train.py
Browse files
train.py
CHANGED
@@ -6,8 +6,10 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, Train
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
-
|
10 |
-
os.makedirs(
|
|
|
|
|
11 |
# Load dataset
|
12 |
dataset = load_dataset("tatsu-lab/alpaca")
|
13 |
dataset["train"] = dataset["train"].select(range(2000))
|
@@ -48,19 +50,20 @@ print("Dataset successfully split and tokenized.")
|
|
48 |
|
49 |
# Define training arguments
|
50 |
training_args = TrainingArguments(
|
51 |
-
output_dir="
|
52 |
-
per_device_train_batch_size=1,
|
53 |
-
per_device_eval_batch_size=1,
|
54 |
-
num_train_epochs=1, # β
Train for 1 epoch only
|
55 |
-
gradient_accumulation_steps=2, # β
Reduce steps to speed up
|
56 |
-
logging_steps=100, # β
Log less frequently
|
57 |
-
save_steps=500, # β
Save less frequently
|
58 |
evaluation_strategy="epoch",
|
59 |
save_strategy="epoch",
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
fp16=True
|
62 |
)
|
63 |
|
|
|
64 |
# Set up Trainer
|
65 |
trainer = Trainer(
|
66 |
model=model,
|
|
|
6 |
os.environ["HF_HOME"] = "/app/hf_cache"
|
7 |
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
|
8 |
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
|
9 |
+
osave_dir = "./models/t5-finetuned"
|
10 |
+
os.makedirs(save_dir, exist_ok=True) # Ensure the directory exists
|
11 |
+
trainer.save_model(save_dir)
|
12 |
+
|
13 |
# Load dataset
|
14 |
dataset = load_dataset("tatsu-lab/alpaca")
|
15 |
dataset["train"] = dataset["train"].select(range(2000))
|
|
|
50 |
|
51 |
# Define training arguments
|
52 |
training_args = TrainingArguments(
|
53 |
+
output_dir="./results",
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
evaluation_strategy="epoch",
|
55 |
save_strategy="epoch",
|
56 |
+
learning_rate=5e-6, # Reduce from 5e-5 to 5e-6
|
57 |
+
per_device_train_batch_size=8, # Keep batch size reasonable
|
58 |
+
per_device_eval_batch_size=8,
|
59 |
+
num_train_epochs=3,
|
60 |
+
weight_decay=0.01,
|
61 |
+
logging_dir="./logs",
|
62 |
+
logging_steps=10,
|
63 |
fp16=True
|
64 |
)
|
65 |
|
66 |
+
|
67 |
# Set up Trainer
|
68 |
trainer = Trainer(
|
69 |
model=model,
|