firstaid / finetune_flan_t5.py
rivapereira123's picture
Update finetune_flan_t5.py
3f67405 verified
raw
history blame
1.69 kB
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
TrainingArguments,
DataCollatorForSeq2Seq,
FlaxAutoModelForSeq2SeqLM # Added for explicit model loading
)
from trl import SFTTrainer
import torch
# 1. Load and prepare dataset
dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
# Add formatted text field
dataset = dataset.map(lambda x: {
"text": f"### Instruction:\n{x['input']}\n\n### Response:\n{x['output']}"
})
# 2. Load model and tokenizer - METHOD 1: Explicit FLAN-T5 loading
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# METHOD 1: Load model directly without AutoModel
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained(model_name)
# METHOD 2: Or install Japanese support (if needed)
# pip install transformers[ja]
# Then use AutoModel as before
# 3. Training arguments
training_args = TrainingArguments(
output_dir="./flan-t5-medical-finetuned",
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_train_epochs=3,
learning_rate=5e-5,
logging_dir="./logs",
save_strategy="epoch",
evaluation_strategy="no",
fp16=torch.cuda.is_available(),
report_to="none",
remove_unused_columns=False
)
# 4. Initialize trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=training_args,
dataset_text_field="text",
data_collator=DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True
)
)
# 5. Start training
trainer.train()