rivapereira123 commited on
Commit
8a20cad
·
verified ·
1 Parent(s): 07b1338

Update finetune_flan_t5.py

Browse files
Files changed (1) hide show
  1. finetune_flan_t5.py +29 -50
finetune_flan_t5.py CHANGED
@@ -1,75 +1,54 @@
1
  from datasets import load_dataset
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments
3
- from trl import SFTTrainer # ✅ from trl
4
- from transformers import DataCollatorForSeq2Seq # ✅ from transformers
 
 
 
 
5
  import torch
6
 
7
- # Load your dataset (from the converted JSONL file)
8
  dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
9
 
10
- # Load tokenizer and model
11
  model_name = "google/flan-t5-base"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
 
15
- # Preprocess dataset
16
- def preprocess(example):
17
- input_text = example["input"]
18
- target_text = example["output"]
19
-
20
- model_inputs = tokenizer(
21
- input_text,
22
- max_length=512,
23
- truncation=True,
24
- padding="max_length"
25
- )
26
-
27
- labels = tokenizer(
28
- target_text,
29
- max_length=128,
30
- truncation=True,
31
- padding="max_length"
32
- )["input_ids"]
33
-
34
- model_inputs["labels"] = labels
35
- return model_inputs
36
-
37
-
38
 
39
- # Apply preprocessing
40
- tokenized_dataset = dataset.map(preprocess)
41
-
42
-
43
- # Define training arguments
44
  training_args = TrainingArguments(
45
- output_dir="./flan-t5-medical",
46
  per_device_train_batch_size=4,
47
  gradient_accumulation_steps=2,
48
  num_train_epochs=3,
 
49
  logging_dir="./logs",
50
  save_strategy="epoch",
51
  evaluation_strategy="no",
52
- fp16=torch.cuda.is_available()
 
53
  )
54
 
55
- # Define data collator
56
- data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
57
-
58
- def formatting_func(example):
59
- return [f"Input: {example['input']}\nOutput: {example['output']}"]
60
-
61
- from trl import SFTTrainer
62
- from transformers import DataCollatorForSeq2Seq
63
-
64
  trainer = SFTTrainer(
65
  model=model,
66
  tokenizer=tokenizer,
67
- train_dataset=tokenized_dataset, # already tokenized
68
  args=training_args,
69
- data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True),
70
- packing=False,
71
- tokenized_dataset=True # ✅ Now supported after upgrade
 
 
 
 
 
72
  )
73
 
74
- # Start training
75
- trainer.train()
 
1
  from datasets import load_dataset
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSeq2SeqLM,
5
+ TrainingArguments,
6
+ DataCollatorForSeq2Seq
7
+ )
8
+ from trl import SFTTrainer
9
  import torch
10
 
11
+ # 1. Load dataset
12
  dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
13
 
14
+ # 2. Load model and tokenizer
15
  model_name = "google/flan-t5-base"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
 
19
+ # 3. Formatting function for SFTTrainer
20
+ def format_instruction(example):
21
+ return f"### Instruction:\n{example['input']}\n\n### Response:\n{example['output']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # 4. Training arguments
 
 
 
 
24
  training_args = TrainingArguments(
25
+ output_dir="./flan-t5-medical-finetuned",
26
  per_device_train_batch_size=4,
27
  gradient_accumulation_steps=2,
28
  num_train_epochs=3,
29
+ learning_rate=5e-5,
30
  logging_dir="./logs",
31
  save_strategy="epoch",
32
  evaluation_strategy="no",
33
+ fp16=torch.cuda.is_available(),
34
+ report_to="none"
35
  )
36
 
37
+ # 5. Initialize SFTTrainer correctly
 
 
 
 
 
 
 
 
38
  trainer = SFTTrainer(
39
  model=model,
40
  tokenizer=tokenizer,
41
+ train_dataset=dataset,
42
  args=training_args,
43
+ max_seq_length=512,
44
+ formatting_func=format_instruction,
45
+ data_collator=DataCollatorForSeq2Seq(
46
+ tokenizer,
47
+ pad_to_multiple_of=8,
48
+ return_tensors="pt",
49
+ padding=True
50
+ )
51
  )
52
 
53
+ # 6. Start training
54
+ trainer.train()