firstaid / finetune_flan_t5.py
rivapereira123's picture
Create finetune_flan_t5.py
adcccfb verified
raw
history blame
1.69 kB
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments
from trl import SFTTrainer, DataCollatorForSeq2Seq
import torch
# Load your dataset (from the converted JSONL file)
dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
# Load tokenizer and model
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Preprocess dataset
def preprocess(example):
input_text = example["instruction"]
target_text = example["output"]
tokenized = tokenizer(
input_text,
max_length=512,
truncation=True,
padding="max_length"
)
with tokenizer.as_target_tokenizer():
tokenized["labels"] = tokenizer(
target_text,
max_length=128,
truncation=True,
padding="max_length"
)["input_ids"]
return tokenized
tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
# Define training arguments
training_args = TrainingArguments(
output_dir="./flan-t5-medical",
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_train_epochs=3,
logging_dir="./logs",
save_strategy="epoch",
evaluation_strategy="no",
fp16=torch.cuda.is_available()
)
# Define data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# Initialize trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
# Start training
trainer.train()