rivapereira123 commited on
Commit
c6a9cc3
·
verified ·
1 Parent(s): 13b35d8

Update finetune_flan_t5.py

Browse files
Files changed (1) hide show
  1. finetune_flan_t5.py +25 -22
finetune_flan_t5.py CHANGED
@@ -1,35 +1,34 @@
1
  from datasets import load_dataset
2
  from transformers import (
3
- AutoTokenizer,
4
- AutoModelForSeq2SeqLM,
5
  TrainingArguments,
6
- DataCollatorForSeq2Seq,
7
- FlaxAutoModelForSeq2SeqLM # Added for explicit model loading
8
  )
9
  from trl import SFTTrainer
10
  import torch
11
 
12
- # 1. Load and prepare dataset
 
 
13
  dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
14
 
15
- # Add formatted text field
16
- dataset = dataset.map(lambda x: {
17
- "text": f"### Instruction:\n{x['input']}\n\n### Response:\n{x['output']}"
18
- })
 
 
 
 
 
19
 
20
- # 2. Load model and tokenizer - METHOD 1: Explicit FLAN-T5 loading
21
  model_name = "google/flan-t5-base"
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
-
24
- # METHOD 1: Load model directly without AutoModel
25
- from transformers import T5ForConditionalGeneration
26
  model = T5ForConditionalGeneration.from_pretrained(model_name)
27
 
28
- # METHOD 2: Or install Japanese support (if needed)
29
- # pip install transformers[ja]
30
- # Then use AutoModel as before
31
-
32
- # 3. Training arguments
33
  training_args = TrainingArguments(
34
  output_dir="./flan-t5-medical-finetuned",
35
  per_device_train_batch_size=4,
@@ -41,22 +40,26 @@ training_args = TrainingArguments(
41
  evaluation_strategy="no",
42
  fp16=torch.cuda.is_available(),
43
  report_to="none",
44
- remove_unused_columns=False
 
 
 
45
  )
46
 
47
- # 4. Initialize trainer
48
  trainer = SFTTrainer(
49
  model=model,
50
  tokenizer=tokenizer,
51
  train_dataset=dataset,
52
  args=training_args,
53
  dataset_text_field="text",
 
54
  data_collator=DataCollatorForSeq2Seq(
55
  tokenizer,
56
  model=model,
57
- padding=True
58
  )
59
  )
60
 
61
- # 5. Start training
62
  trainer.train()
 
1
  from datasets import load_dataset
2
  from transformers import (
3
+ T5ForConditionalGeneration, # Using specific model class
4
+ AutoTokenizer,
5
  TrainingArguments,
6
+ DataCollatorForSeq2Seq
 
7
  )
8
  from trl import SFTTrainer
9
  import torch
10
 
11
+
12
+
13
+ # 2. Load and prepare dataset
14
  dataset = load_dataset("json", data_files="data/med_q_n_a_converted.jsonl", split="train")
15
 
16
+ # Create properly formatted text field
17
+ def format_example(example):
18
+ return {
19
+ "text": f"Instruction: {example['input']}\nResponse: {example['output']}",
20
+ "input": example["input"],
21
+ "output": example["output"]
22
+ }
23
+
24
+ dataset = dataset.map(format_example)
25
 
26
+ # 3. Load model and tokenizer
27
  model_name = "google/flan-t5-base"
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
29
  model = T5ForConditionalGeneration.from_pretrained(model_name)
30
 
31
+ # 4. Configure training
 
 
 
 
32
  training_args = TrainingArguments(
33
  output_dir="./flan-t5-medical-finetuned",
34
  per_device_train_batch_size=4,
 
40
  evaluation_strategy="no",
41
  fp16=torch.cuda.is_available(),
42
  report_to="none",
43
+ remove_unused_columns=False,
44
+ # Add these to prevent version conflicts
45
+ dataloader_pin_memory=False,
46
+ dataloader_num_workers=0
47
  )
48
 
49
+ # 5. Initialize trainer with proper config
50
  trainer = SFTTrainer(
51
  model=model,
52
  tokenizer=tokenizer,
53
  train_dataset=dataset,
54
  args=training_args,
55
  dataset_text_field="text",
56
+ max_seq_length=512, # Explicitly set to avoid warning
57
  data_collator=DataCollatorForSeq2Seq(
58
  tokenizer,
59
  model=model,
60
+ padding="longest"
61
  )
62
  )
63
 
64
+ # 6. Start training
65
  trainer.train()