trial1 / train.py
cconsti's picture
Update train.py
49eb74f verified
raw
history blame
2.44 kB
import os
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
# Ensure Hugging Face cache directory is writable
os.environ["HF_HOME"] = "/app/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/app/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/hf_cache"
output_dir = "/tmp/t5-finetuned"
os.makedirs(output_dir, exist_ok=True)
# Load dataset
dataset = load_dataset("tatsu-lab/alpaca")
dataset["train"] = dataset["train"].select(range(5000))
# Check dataset structure
print("Dataset splits available:", dataset)
print("Sample row:", dataset["train"][0])
# If no 'test' split exists, create one
if "test" not in dataset:
dataset = dataset["train"].train_test_split(test_size=0.1)
# Load tokenizer & model
model_name = "t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.gradient_checkpointing_enable(False)
# Define tokenization function
def tokenize_function(examples):
inputs = examples["input"] # Ensure this matches dataset key
targets = examples["output"]
model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")
model_inputs["labels"] = labels["input_ids"]
return model_inputs
# Tokenize dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Assign train & eval datasets
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]
print("Dataset successfully split and tokenized.")
# Define training arguments
training_args = TrainingArguments(
output_dir="/tmp/t5-finetuned",
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=1, # βœ… Train for 1 epoch only
gradient_accumulation_steps=2, # βœ… Reduce steps to speed up
logging_steps=100, # βœ… Log less frequently
save_steps=500, # βœ… Save less frequently
evaluation_strategy="epoch",
save_strategy="epoch",
push_to_hub=False,
fp16=True
)
# Set up Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Start fine-tuning
trainer.train()
print("Fine-tuning complete!")
# Save model locally
trainer.save_model("./t5-finetuned")
print("Model saved successfully!")