from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments from trl import SFTTrainer, DataCollatorForSeq2Seq import torch # Load your dataset (from the converted JSONL file) dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train") # Load tokenizer and model model_name = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # Preprocess dataset def preprocess(example): input_text = example["instruction"] target_text = example["output"] tokenized = tokenizer( input_text, max_length=512, truncation=True, padding="max_length" ) with tokenizer.as_target_tokenizer(): tokenized["labels"] = tokenizer( target_text, max_length=128, truncation=True, padding="max_length" )["input_ids"] return tokenized tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names) # Define training arguments training_args = TrainingArguments( output_dir="./flan-t5-medical", per_device_train_batch_size=4, gradient_accumulation_steps=2, num_train_epochs=3, logging_dir="./logs", save_strategy="epoch", evaluation_strategy="no", fp16=torch.cuda.is_available() ) # Define data collator data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) # Initialize trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator, ) # Start training trainer.train()