rivapereira123 commited on
Commit
4c7874b
·
verified ·
1 Parent(s): 1515e9b

Update finetune_flan_t5.py

Browse files
Files changed (1) hide show
  1. finetune_flan_t5.py +9 -9
finetune_flan_t5.py CHANGED
@@ -16,12 +16,10 @@ model_name = "google/flan-t5-base"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
- # 3. CORRECTED Formatting function - must return a list
20
- def format_instruction(examples):
21
- texts = []
22
- for input_text, output_text in zip(examples["input"], examples["output"]):
23
- texts.append(f"### Instruction:\n{input_text}\n\n### Response:\n{output_text}")
24
- return {"text": texts} # Return dict with "text" key containing list
25
 
26
  # 4. Training arguments
27
  training_args = TrainingArguments(
@@ -37,20 +35,22 @@ training_args = TrainingArguments(
37
  report_to="none"
38
  )
39
 
40
- # 5. Initialize SFTTrainer
41
  trainer = SFTTrainer(
42
  model=model,
43
  tokenizer=tokenizer,
44
  train_dataset=dataset,
45
  args=training_args,
46
  max_seq_length=512,
47
- formatting_func=format_instruction,
48
  data_collator=DataCollatorForSeq2Seq(
49
  tokenizer,
 
50
  pad_to_multiple_of=8,
51
  return_tensors="pt",
52
  padding=True
53
- )
 
54
  )
55
 
56
  # 6. Start training
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
+ # 3. CORRECTED Formatting function - returns single string per example
20
+ def format_instruction(example):
21
+ # Return a single formatted string
22
+ return f"### Instruction:\n{example['input']}\n\n### Response:\n{example['output']}"
 
 
23
 
24
  # 4. Training arguments
25
  training_args = TrainingArguments(
 
35
  report_to="none"
36
  )
37
 
38
+ # 5. Initialize SFTTrainer with correct parameters
39
  trainer = SFTTrainer(
40
  model=model,
41
  tokenizer=tokenizer,
42
  train_dataset=dataset,
43
  args=training_args,
44
  max_seq_length=512,
45
+ formatting_func=format_instruction, # Now returns single string
46
  data_collator=DataCollatorForSeq2Seq(
47
  tokenizer,
48
+ model=model, # Added model reference
49
  pad_to_multiple_of=8,
50
  return_tensors="pt",
51
  padding=True
52
+ ),
53
+ dataset_text_field="text" # Explicit field name
54
  )
55
 
56
  # 6. Start training