from datasets import load_dataset from transformers import ( T5ForConditionalGeneration, # Using specific model class AutoTokenizer, TrainingArguments, DataCollatorForSeq2Seq ) from trl import SFTTrainer import torch # 2. Load and prepare dataset dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train") # Create properly formatted text field def format_example(example): return { "text": f"Instruction: {example['input']}\nResponse: {example['output']}", "input": example["input"], "output": example["output"] } dataset = dataset.map(format_example) # 3. Load model and tokenizer model_name = "google/flan-t5-base" tokenizer = AutoTokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) # 4. Configure training training_args = TrainingArguments( output_dir="./flan-t5-medical-finetuned", per_device_train_batch_size=4, gradient_accumulation_steps=2, num_train_epochs=3, learning_rate=5e-5, logging_dir="./logs", save_strategy="epoch", evaluation_strategy="no", fp16=torch.cuda.is_available(), report_to="none", remove_unused_columns=False, # Add these to prevent version conflicts dataloader_pin_memory=False, dataloader_num_workers=0 ) # 5. Initialize trainer with proper config trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset, args=training_args, dataset_text_field="text", max_seq_length=512, # Explicitly set to avoid warning data_collator=DataCollatorForSeq2Seq( tokenizer, model=model, padding="longest" ) ) # 6. Start training trainer.train()