rivapereira123 commited on
Commit
0bbb8d4
·
verified ·
1 Parent(s): 4c7874b

Update finetune_flan_t5.py

Browse files
Files changed (1) hide show
  1. finetune_flan_t5.py +14 -13
finetune_flan_t5.py CHANGED
@@ -8,20 +8,22 @@ from transformers import (
8
  from trl import SFTTrainer
9
  import torch
10
 
11
- # 1. Load dataset
12
  dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
13
 
 
 
 
 
 
 
 
14
  # 2. Load model and tokenizer
15
  model_name = "google/flan-t5-base"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
- # 3. CORRECTED Formatting function - returns single string per example
20
- def format_instruction(example):
21
- # Return a single formatted string
22
- return f"### Instruction:\n{example['input']}\n\n### Response:\n{example['output']}"
23
-
24
- # 4. Training arguments
25
  training_args = TrainingArguments(
26
  output_dir="./flan-t5-medical-finetuned",
27
  per_device_train_batch_size=4,
@@ -35,23 +37,22 @@ training_args = TrainingArguments(
35
  report_to="none"
36
  )
37
 
38
- # 5. Initialize SFTTrainer with correct parameters
39
  trainer = SFTTrainer(
40
  model=model,
41
  tokenizer=tokenizer,
42
  train_dataset=dataset,
43
  args=training_args,
44
  max_seq_length=512,
45
- formatting_func=format_instruction, # Now returns single string
46
  data_collator=DataCollatorForSeq2Seq(
47
  tokenizer,
48
- model=model, # Added model reference
49
  pad_to_multiple_of=8,
50
  return_tensors="pt",
51
  padding=True
52
- ),
53
- dataset_text_field="text" # Explicit field name
54
  )
55
 
56
- # 6. Start training
57
  trainer.train()
 
8
  from trl import SFTTrainer
9
  import torch
10
 
11
+ # 1. Load and prepare dataset
12
  dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
13
 
14
+ # Add 'text' field containing the formatted examples
15
+ def add_text_field(example):
16
+ example['text'] = f"### Instruction:\n{example['input']}\n\n### Response:\n{example['output']}"
17
+ return example
18
+
19
+ dataset = dataset.map(add_text_field)
20
+
21
  # 2. Load model and tokenizer
22
  model_name = "google/flan-t5-base"
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
25
 
26
+ # 3. Training arguments
 
 
 
 
 
27
  training_args = TrainingArguments(
28
  output_dir="./flan-t5-medical-finetuned",
29
  per_device_train_batch_size=4,
 
37
  report_to="none"
38
  )
39
 
40
+ # 4. Initialize SFTTrainer with correct configuration
41
  trainer = SFTTrainer(
42
  model=model,
43
  tokenizer=tokenizer,
44
  train_dataset=dataset,
45
  args=training_args,
46
  max_seq_length=512,
47
+ dataset_text_field="text", # Field we created
48
  data_collator=DataCollatorForSeq2Seq(
49
  tokenizer,
50
+ model=model,
51
  pad_to_multiple_of=8,
52
  return_tensors="pt",
53
  padding=True
54
+ )
 
55
  )
56
 
57
+ # 5. Start training
58
  trainer.train()