winglian commited on
Commit
cb9a887
·
unverified ·
2 Parent(s): a15d823 a10a826

Merge pull request #13 from winglian/dev

Browse files
TODO.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # todo list
2
+
3
+ - [] Validation of parameters for combinations that won't work
4
+
5
+
6
+
7
+ ## things that are known not to work
8
+
9
+ - FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
10
+ - adamw_bnb_8bit doesn't play well with FSDP offload
ds_config.json CHANGED
@@ -10,21 +10,42 @@
10
  "hysteresis": 2,
11
  "min_loss_scale": 1
12
  },
 
 
 
 
 
 
 
 
 
13
  "scheduler": {
14
- "type": "OneCycle",
15
  "params": {
16
- "cycle_min_lr": 1e-7,
17
- "cycle_max_lr": 1e-4
 
 
18
  }
19
  },
20
  "zero_optimization": {
21
  "stage": 2,
 
 
 
 
 
 
 
 
22
  "overlap_comm": true,
23
  "allgather_partitions": true,
24
  "allgather_bucket_size": 5e8,
25
  "contiguous_gradients": true,
26
  "reduce_bucket_size": "auto",
27
  "reduce_scatter": true,
 
 
28
  "stage3_gather_16bit_weights_on_model_save": true
29
  },
30
  "gradient_accumulation_steps": "auto",
 
10
  "hysteresis": 2,
11
  "min_loss_scale": 1
12
  },
13
+ "optimizer": {
14
+ "type": "Adam",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
  "scheduler": {
23
+ "type": "WarmupDecayLR",
24
  "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto",
28
+ "total_num_steps": "auto"
29
  }
30
  },
31
  "zero_optimization": {
32
  "stage": 2,
33
+ "offload_optimizer": {
34
+ "device": "cpu",
35
+ "pin_memory": true
36
+ },
37
+ "offload_param": {
38
+ "device": "cpu",
39
+ "pin_memory": true
40
+ },
41
  "overlap_comm": true,
42
  "allgather_partitions": true,
43
  "allgather_bucket_size": 5e8,
44
  "contiguous_gradients": true,
45
  "reduce_bucket_size": "auto",
46
  "reduce_scatter": true,
47
+ "stage3_max_live_parameters": 0,
48
+ "stage3_max_reuse_distance": 0,
49
  "stage3_gather_16bit_weights_on_model_save": true
50
  },
51
  "gradient_accumulation_steps": "auto",
scripts/finetune.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import logging
2
  import os
 
3
  import random
4
  import signal
5
  import sys
@@ -11,6 +13,8 @@ import yaml
11
  from attrdict import AttrDefault
12
 
13
  # add src to the pythonpath so we don't need to pip install this
 
 
14
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
15
  src_dir = os.path.join(project_root, "src")
16
  sys.path.insert(0, src_dir)
@@ -42,48 +46,20 @@ def choose_device(cfg):
42
  cfg.device_map = {"": cfg.device}
43
 
44
 
45
- def check_dataset_labels(dataset, tokenizer):
46
- from termcolor import colored
47
-
48
- # the dataset is already shuffled, so let's just check the first 5 elements
49
- for idx in range(5):
50
- # Get the input_ids, labels, and attention_mask from the dataset
51
- input_ids = dataset[idx]["input_ids"]
52
- labels = dataset[idx]["labels"]
53
- attention_mask = dataset[idx]["attention_mask"]
54
-
55
- # You can compare the input_ids and labels element-wise
56
- # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
57
- colored_tokens = []
58
- for i, (input_id, label_id, mask) in enumerate(
59
- zip(input_ids, labels, attention_mask)
60
- ):
61
- decoded_input_token = tokenizer.decode(input_id)
62
- # Choose the color based on whether the label has the ignore value or not
63
- color = (
64
- "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
65
- )
66
- colored_token = colored(decoded_input_token, color) + colored(
67
- f"({label_id}, {mask})", "white"
68
- )
69
- colored_tokens.append(colored_token)
70
-
71
- logging.info(" ".join(colored_tokens))
72
- logging.info("\n\n\n")
73
-
74
-
75
- def do_inference(cfg, model, tokenizer):
76
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
77
  tokenizer.add_special_tokens({"bos_token": "<s>"})
78
  tokenizer.add_special_tokens({"eos_token": "</s>"})
79
 
80
- from axolotl.prompters import ReflectAlpacaPrompter
81
 
82
  while True:
83
- instruction = str(input("Give me an instruction: "))
 
 
84
  if not instruction:
85
  return
86
- prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
87
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
88
 
89
  model.eval()
@@ -174,8 +150,8 @@ def train(
174
  cfg.bf16 = False
175
 
176
  # Load the model and tokenizer
177
- logging.info("loading model, tokenizer, and lora_config...")
178
- model, tokenizer, lora_config = load_model(
179
  cfg.base_model,
180
  cfg.base_model_config,
181
  cfg.model_type,
@@ -190,6 +166,10 @@ def train(
190
  do_inference(cfg, model, tokenizer)
191
  return
192
 
 
 
 
 
193
  train_dataset, eval_dataset = load_prepare_datasets(
194
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
195
  )
@@ -199,8 +179,9 @@ def train(
199
  return
200
 
201
  if cfg.debug:
 
202
  check_dataset_labels(
203
- train_dataset.select([random.randrange(0, len(train_dataset) - 1)]),
204
  tokenizer,
205
  )
206
 
@@ -213,9 +194,9 @@ def train(
213
  model = torch.compile(model)
214
 
215
  # go ahead and presave, so we have the adapter config available to inspect
216
- if lora_config:
217
  logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
218
- lora_config.save_pretrained(cfg.output_dir)
219
 
220
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
221
  if cfg.local_rank == 0:
@@ -234,12 +215,11 @@ def train(
234
  logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
235
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
236
 
237
- if cfg.local_rank == 0:
238
- # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
239
- logging.info(
240
- f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
241
- )
242
- model.save_pretrained(cfg.output_dir)
243
 
244
 
245
  if __name__ == "__main__":
 
1
+ import importlib
2
  import logging
3
  import os
4
+ import pathlib
5
  import random
6
  import signal
7
  import sys
 
13
  from attrdict import AttrDefault
14
 
15
  # add src to the pythonpath so we don't need to pip install this
16
+ from axolotl.utils.tokenization import check_dataset_labels
17
+
18
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
19
  src_dir = os.path.join(project_root, "src")
20
  sys.path.insert(0, src_dir)
 
46
  cfg.device_map = {"": cfg.device}
47
 
48
 
49
+ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
51
  tokenizer.add_special_tokens({"bos_token": "<s>"})
52
  tokenizer.add_special_tokens({"eos_token": "</s>"})
53
 
54
+ prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
55
 
56
  while True:
57
+ # support for multiline inputs
58
+ print("Give me an instruction (Ctrl + D to finish): ")
59
+ instruction = pathlib.Path("/proc/self/fd/0").read_text()
60
  if not instruction:
61
  return
62
+ prompt = prompter_module().build_prompt(instruction=instruction)
63
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
64
 
65
  model.eval()
 
150
  cfg.bf16 = False
151
 
152
  # Load the model and tokenizer
153
+ logging.info("loading model, tokenizer, and peft_config...")
154
+ model, tokenizer, peft_config = load_model(
155
  cfg.base_model,
156
  cfg.base_model_config,
157
  cfg.model_type,
 
166
  do_inference(cfg, model, tokenizer)
167
  return
168
 
169
+ if "shard" in kwargs:
170
+ model.save_pretrained(cfg.output_dir)
171
+ return
172
+
173
  train_dataset, eval_dataset = load_prepare_datasets(
174
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
175
  )
 
179
  return
180
 
181
  if cfg.debug:
182
+ logging.info("check_dataset_labels...")
183
  check_dataset_labels(
184
+ train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
185
  tokenizer,
186
  )
187
 
 
194
  model = torch.compile(model)
195
 
196
  # go ahead and presave, so we have the adapter config available to inspect
197
+ if peft_config:
198
  logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
199
+ peft_config.save_pretrained(cfg.output_dir)
200
 
201
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
202
  if cfg.local_rank == 0:
 
215
  logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
216
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
217
 
218
+ logging.info(
219
+ f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
220
+ )
221
+ # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
222
+ trainer.save_model(cfg.output_dir)
 
223
 
224
 
225
  if __name__ == "__main__":
scripts/setup-runpod.sh CHANGED
@@ -26,6 +26,15 @@ if [ -z "${TORCH_CUDA_ARCH_LIST}" ]; then # only set this if not set yet
26
  export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
27
  fi
28
 
 
 
 
 
 
 
 
 
 
29
  cd /workspace/
30
  git clone https://github.com/winglian/axolotl.git
31
  cd axolotl
 
26
  export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
27
  fi
28
 
29
+ # install flash-attn and deepspeed from pre-built wheels for this specific container b/c these take forever to install
30
+ mkdir -p /workspace/wheels
31
+ cd /workspace/wheels
32
+ curl -L -O https://github.com/winglian/axolotl/raw/wheels/wheels/deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
33
+ curl -L -O https://github.com/winglian/axolotl/raw/wheels/wheels/flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
34
+ pip install deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
35
+ pip install flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
36
+ pip install "peft @ git+https://github.com/huggingface/peft.git@main" --force-reinstall --no-dependencies
37
+
38
  cd /workspace/
39
  git clone https://github.com/winglian/axolotl.git
40
  cd axolotl
src/axolotl/prompters.py CHANGED
@@ -127,7 +127,7 @@ conv_vicuna_v1_1 = Conversation(
127
 
128
 
129
  class ShareGPTPrompter:
130
- def build_prompt(self, source, tokenizer):
131
  # ignore the system prompt if provided
132
  if source[0]["from"] == "system":
133
  source.pop(0)
@@ -157,13 +157,14 @@ class ShareGPTPrompter:
157
  role = roles[sentence["from"]]
158
  assert role == conv.roles[j % 2]
159
  conv.append_message(role, sentence["value"])
 
160
  conversation = conv.get_prompt()
161
 
162
  # Tokenize conversations
163
  tokenized_result = tokenizer(
164
  conversation,
165
  truncation=True,
166
- max_length=2048, # FIXME
167
  padding=False,
168
  return_tensors=None,
169
  )
@@ -173,7 +174,9 @@ class ShareGPTPrompter:
173
  sep = conv.sep + conv.roles[1] + ": "
174
 
175
  rounds = conversation.split(conv.sep2)
 
176
  cur_len = 1
 
177
  for i, rou in enumerate(rounds):
178
  if rou == "":
179
  break
@@ -182,19 +185,27 @@ class ShareGPTPrompter:
182
  if len(parts) != 2:
183
  break
184
  parts[0] += sep
185
- round_len = len(tokenizer(rou)["input_ids"])
186
- instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
 
187
  target[cur_len : cur_len + instruction_len] = [
188
  IGNORE_TOKEN_ID
189
  ] * instruction_len
190
 
191
  cur_len += round_len
192
- target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
 
 
 
 
 
 
193
  attention_mask = [
194
  1 if x != tokenizer.pad_token_id else 0
195
  for x in tokenized_result["input_ids"]
196
  ]
197
 
 
198
  return dict(
199
  input_ids=tokenized_result["input_ids"],
200
  labels=target,
 
127
 
128
 
129
  class ShareGPTPrompter:
130
+ def build_prompt(self, source, tokenizer, sequence_len=2048):
131
  # ignore the system prompt if provided
132
  if source[0]["from"] == "system":
133
  source.pop(0)
 
157
  role = roles[sentence["from"]]
158
  assert role == conv.roles[j % 2]
159
  conv.append_message(role, sentence["value"])
160
+ # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up
161
  conversation = conv.get_prompt()
162
 
163
  # Tokenize conversations
164
  tokenized_result = tokenizer(
165
  conversation,
166
  truncation=True,
167
+ max_length=sequence_len, # FIXME
168
  padding=False,
169
  return_tensors=None,
170
  )
 
174
  sep = conv.sep + conv.roles[1] + ": "
175
 
176
  rounds = conversation.split(conv.sep2)
177
+ rounds = [r + conv.sep2 for r in rounds]
178
  cur_len = 1
179
+ target[0] = IGNORE_TOKEN_ID # mask out the bos
180
  for i, rou in enumerate(rounds):
181
  if rou == "":
182
  break
 
185
  if len(parts) != 2:
186
  break
187
  parts[0] += sep
188
+ round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
189
+ # we have to strip the initial part, any dangling whitespace creates an additional ghost token
190
+ instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
191
  target[cur_len : cur_len + instruction_len] = [
192
  IGNORE_TOKEN_ID
193
  ] * instruction_len
194
 
195
  cur_len += round_len
196
+ if cur_len >= sequence_len:
197
+ break
198
+
199
+ # Fix: Truncate the target to have the same length as input_ids
200
+ target = target[:len(tokenized_result["input_ids"])]
201
+ # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
202
+
203
  attention_mask = [
204
  1 if x != tokenizer.pad_token_id else 0
205
  for x in tokenized_result["input_ids"]
206
  ]
207
 
208
+ # TODO truncate len to sequence_len
209
  return dict(
210
  input_ids=tokenized_result["input_ids"],
211
  labels=target,
src/axolotl/utils/models.py CHANGED
@@ -53,7 +53,7 @@ def load_model(
53
  logging.info("patching with xformers attention")
54
  hijack_llama_attention()
55
 
56
- torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
57
  try:
58
  if cfg.load_4bit:
59
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -101,30 +101,23 @@ def load_model(
101
  )
102
  load_in_8bit = False
103
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
104
- if not cfg.load_in_8bit:
105
- model = LlamaForCausalLM.from_pretrained(
106
- base_model,
107
- device_map=cfg.device_map,
108
- )
109
- else:
110
- model = LlamaForCausalLM.from_pretrained(
111
- base_model,
112
- load_in_8bit=cfg.load_in_8bit,
113
- torch_dtype=torch_dtype,
114
- device_map=cfg.device_map,
115
- )
116
-
117
  elif model_type:
118
  model = getattr(transformers, model_type).from_pretrained(
119
  base_model,
120
- load_in_8bit=cfg.load_in_8bit,
121
  torch_dtype=torch_dtype,
122
  device_map=cfg.device_map,
123
  )
124
  else:
125
  model = AutoModelForCausalLM.from_pretrained(
126
  base_model,
127
- load_in_8bit=cfg.load_in_8bit,
128
  torch_dtype=torch_dtype,
129
  device_map=cfg.device_map,
130
  )
@@ -135,7 +128,7 @@ def load_model(
135
  logging.exception(e)
136
  model = AutoModelForCausalLM.from_pretrained(
137
  base_model,
138
- load_in_8bit=cfg.load_in_8bit,
139
  torch_dtype=torch_dtype,
140
  device_map=cfg.device_map,
141
  )
@@ -147,7 +140,7 @@ def load_model(
147
  else:
148
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
149
  except:
150
- tokenizer = AutoTokenizer.from_pretrained(base_model)
151
 
152
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
153
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
@@ -161,12 +154,12 @@ def load_model(
161
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
162
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
163
 
164
- if cfg.special_tokens:
165
- for k, v in cfg.special_tokens.items():
166
- setattr(tokenizer, k, v)
167
 
168
- if load_in_8bit and not cfg.load_4bit:
169
- logging.info("converting model w/ prepare_model_for_int8_training")
170
  model = prepare_model_for_int8_training(model)
171
 
172
  model, lora_config = load_adapter(model, cfg, adapter)
@@ -186,6 +179,11 @@ def load_model(
186
  m.scales = m.scales.half()
187
  m.bias = m.bias.half()
188
 
 
 
 
 
 
189
  # TODO resume_from_checkpoint handling
190
  return model, tokenizer, lora_config
191
 
@@ -197,11 +195,41 @@ def load_adapter(model, cfg, adapter):
197
  return model, None
198
  if adapter == "lora":
199
  return load_lora(model, cfg)
200
- # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls
 
201
 
202
  raise NotImplementedError(f"{adapter} peft adapter not available")
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def load_lora(model, cfg):
206
  # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
207
 
@@ -213,27 +241,26 @@ def load_lora(model, cfg):
213
 
214
  lora_config = None
215
 
216
- if cfg.adapter == "lora":
217
- lora_config = LoraConfig(
218
- r=cfg.lora_r,
219
- lora_alpha=cfg.lora_alpha,
220
- target_modules=cfg.lora_target_modules,
221
- lora_dropout=cfg.lora_dropout,
222
- fan_in_fan_out=cfg.lora_fan_in_fan_out,
223
- bias="none",
224
- task_type="CAUSAL_LM",
225
- )
226
 
227
- if cfg.lora_model_dir:
228
- model = PeftModel.from_pretrained(
229
- model,
230
- cfg.lora_model_dir,
231
- device_map=cfg.device_map,
232
- torch_dtype=torch.float16,
233
- )
234
- else:
235
- model = get_peft_model(model, lora_config)
236
 
237
- model.print_trainable_parameters()
238
 
239
  return model, lora_config
 
53
  logging.info("patching with xformers attention")
54
  hijack_llama_attention()
55
 
56
+ torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
57
  try:
58
  if cfg.load_4bit:
59
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
 
101
  )
102
  load_in_8bit = False
103
  elif is_llama_derived_model and "LlamaForCausalLM" in globals():
104
+ model = LlamaForCausalLM.from_pretrained(
105
+ base_model,
106
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
107
+ torch_dtype=torch_dtype,
108
+ device_map=cfg.device_map,
109
+ )
 
 
 
 
 
 
 
110
  elif model_type:
111
  model = getattr(transformers, model_type).from_pretrained(
112
  base_model,
113
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
114
  torch_dtype=torch_dtype,
115
  device_map=cfg.device_map,
116
  )
117
  else:
118
  model = AutoModelForCausalLM.from_pretrained(
119
  base_model,
120
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
121
  torch_dtype=torch_dtype,
122
  device_map=cfg.device_map,
123
  )
 
128
  logging.exception(e)
129
  model = AutoModelForCausalLM.from_pretrained(
130
  base_model,
131
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
132
  torch_dtype=torch_dtype,
133
  device_map=cfg.device_map,
134
  )
 
140
  else:
141
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
142
  except:
143
+ tokenizer = AutoTokenizer.from_pretrained(base_model_config)
144
 
145
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
146
  logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
 
154
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
155
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
156
 
157
+ if cfg.tokens:
158
+ for k, v in cfg.tokens.items():
159
+ tokenizer.add_special_tokens({k: v})
160
 
161
+ if cfg.adapter and load_in_8bit and not cfg.load_4bit:
162
+ logging.info("converting PEFT model w/ prepare_model_for_int8_training")
163
  model = prepare_model_for_int8_training(model)
164
 
165
  model, lora_config = load_adapter(model, cfg, adapter)
 
179
  m.scales = m.scales.half()
180
  m.bias = m.bias.half()
181
 
182
+ if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
183
+ model.is_parallelizable = True
184
+ model.model_parallel = True
185
+
186
+
187
  # TODO resume_from_checkpoint handling
188
  return model, tokenizer, lora_config
189
 
 
195
  return model, None
196
  if adapter == "lora":
197
  return load_lora(model, cfg)
198
+ if adapter == "llama-adapter":
199
+ return load_llama_adapter(model, cfg)
200
 
201
  raise NotImplementedError(f"{adapter} peft adapter not available")
202
 
203
 
204
+ def load_llama_adapter(model, cfg):
205
+ # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
206
+ from peft import (
207
+ AdaptionPromptConfig,
208
+ get_peft_model,
209
+ PeftModel,
210
+ )
211
+
212
+ peft_config = AdaptionPromptConfig(
213
+ adapter_layers=cfg.peft_adapter.layers, # layers (L)
214
+ adapter_len=cfg.peft_adapter.len, # prompt length (K)
215
+ task_type="CAUSAL_LM",
216
+ )
217
+
218
+ if cfg.peft_model_dir:
219
+ model = PeftModel.from_pretrained(
220
+ model,
221
+ cfg.lora_model_dir,
222
+ device_map=cfg.device_map,
223
+ torch_dtype=torch.float16,
224
+ )
225
+ else:
226
+ model = get_peft_model(model, peft_config)
227
+
228
+ model.print_trainable_parameters()
229
+
230
+ return model, peft_config
231
+
232
+
233
  def load_lora(model, cfg):
234
  # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
235
 
 
241
 
242
  lora_config = None
243
 
244
+ lora_config = LoraConfig(
245
+ r=cfg.lora_r,
246
+ lora_alpha=cfg.lora_alpha,
247
+ target_modules=cfg.lora_target_modules,
248
+ lora_dropout=cfg.lora_dropout,
249
+ fan_in_fan_out=cfg.lora_fan_in_fan_out,
250
+ bias="none",
251
+ task_type="CAUSAL_LM",
252
+ )
 
253
 
254
+ if cfg.lora_model_dir:
255
+ model = PeftModel.from_pretrained(
256
+ model,
257
+ cfg.lora_model_dir,
258
+ device_map=cfg.device_map,
259
+ torch_dtype=torch.float16,
260
+ )
261
+ else:
262
+ model = get_peft_model(model, lora_config)
263
 
264
+ model.print_trainable_parameters()
265
 
266
  return model, lora_config
src/axolotl/utils/schedulers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import LRScheduler
2
+
3
+
4
+ class InterpolatingLogScheduler(LRScheduler):
5
+ def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
6
+ """A scheduler that interpolates learning rates in a logarithmic fashion
7
+
8
+ Args:
9
+ - optimizer: pytorch optimizer
10
+ - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
11
+ - min_lr: float, the minimum learning rate
12
+ - max_lr: float, the maximum learning rate
13
+
14
+ Usage:
15
+ fc = nn.Linear(1,1)
16
+ optimizer = optim.Adam(fc.parameters())
17
+ lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
18
+ """
19
+ self.num_steps = num_steps
20
+ self.min_lr = min_lr
21
+ self.max_lr = max_lr
22
+ self.q = (max_lr / min_lr) ** (1 / (num_steps - 1))
23
+ super().__init__(optimizer, last_epoch)
24
+
25
+ def get_lr(self):
26
+ if self.last_epoch <= 0:
27
+ lrs = [self.min_lr for base_lr in self.base_lrs]
28
+ elif self.last_epoch < self.num_steps:
29
+ lrs = [self.min_lr * (self.q ** (self.last_epoch - 1)) for base_lr in self.base_lrs]
30
+ else:
31
+ lrs = [self.max_lr for base_lr in self.base_lrs]
32
+
33
+ return lrs
src/axolotl/utils/tokenization.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from termcolor import colored
2
+ import logging
3
+
4
+ def check_dataset_labels(dataset, tokenizer):
5
+ # the dataset is already shuffled, so let's just check the first 5 elements
6
+ for idx in range(5):
7
+ check_example_labels(dataset[idx], tokenizer)
8
+
9
+
10
+ def check_example_labels(example, tokenizer):
11
+ # Get the input_ids, labels, and attention_mask from the dataset
12
+ input_ids = example["input_ids"]
13
+ labels = example["labels"]
14
+ attention_mask =example["attention_mask"]
15
+
16
+ # You can compare the input_ids and labels element-wise
17
+ # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
18
+ colored_tokens = []
19
+ for i, (input_id, label_id, mask) in enumerate(
20
+ zip(input_ids, labels, attention_mask)
21
+ ):
22
+ decoded_input_token = tokenizer.decode(input_id)
23
+ # Choose the color based on whether the label has the ignore value or not
24
+ color = (
25
+ "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
26
+ )
27
+ colored_token = colored(decoded_input_token, color) + colored(
28
+ f"({label_id}, {mask}, {input_id})", "white"
29
+ )
30
+ colored_tokens.append(colored_token)
31
+
32
+ logging.info(" ".join(colored_tokens))
33
+ logging.info("\n\n\n")
src/axolotl/utils/trainer.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import math
2
  import os
 
3
  from pathlib import Path
4
 
5
  import bitsandbytes as bnb
@@ -10,14 +12,33 @@ from torch.optim.lr_scheduler import OneCycleLR
10
  from transformers import EarlyStoppingCallback
11
  from transformers.trainer_pt_utils import get_parameter_names
12
 
 
 
13
 
14
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
15
  total_num_steps = int(
16
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
17
  )
18
- warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
19
- logging_steps = cfg.logging_steps if cfg.logging_steps else max(min(int(0.005 * total_num_steps), 10), 1)
20
- save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  training_arguments_kwargs = {}
23
  if cfg.bf16 == "full":
@@ -29,15 +50,32 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
29
  training_arguments_kwargs["logging_steps"] = logging_steps
30
  if cfg.gradient_checkpointing is not None:
31
  if cfg.load_4bit:
32
- from alpaca_lora_4bit.gradient_checkpointing import apply_gradient_checkpointing
33
- gradient_checkpointing_ratio = cfg.gradient_checkpointing_ratio if cfg.gradient_checkpointing_ratio else 1.0
34
- apply_gradient_checkpointing(model, checkpoint_ratio=gradient_checkpointing_ratio)
35
- else:
36
- training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # deepspeed
40
- if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
 
 
 
41
  if cfg.deepspeed:
42
  training_arguments_kwargs["deepspeed"] = cfg.deepspeed
43
  else:
@@ -49,6 +87,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
49
  per_device_train_batch_size=cfg.micro_batch_size,
50
  per_device_eval_batch_size=cfg.eval_batch_size,
51
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
 
52
  num_train_epochs=cfg.num_epochs,
53
  learning_rate=cfg.learning_rate,
54
  evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
@@ -57,31 +96,51 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
57
  save_steps=save_steps,
58
  output_dir=cfg.output_dir,
59
  save_total_limit=3,
60
- load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False,
 
 
61
  ddp_find_unused_parameters=False if cfg.ddp else None,
62
  group_by_length=cfg.group_by_length,
63
  report_to="wandb" if cfg.use_wandb else None,
64
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
 
 
 
65
  **training_arguments_kwargs,
66
  )
67
 
68
  trainer_kwargs = {}
69
 
70
- if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
 
 
 
 
 
 
 
 
71
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
72
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
73
  optimizer_grouped_parameters = [
74
  {
75
- "params": [p for n, p in model.named_parameters() if n in decay_parameters],
 
 
 
 
76
  "weight_decay": training_args.weight_decay,
77
  },
78
  {
79
  "params": [
80
- p for n, p in model.named_parameters() if n not in decay_parameters
 
 
81
  ],
82
  "weight_decay": 0.0,
83
  },
84
  ]
 
85
  optimizer = bnb.optim.Adam8bit(
86
  optimizer_grouped_parameters,
87
  betas=(training_args.adam_beta1, training_args.adam_beta2),
@@ -97,8 +156,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
97
  optimizer,
98
  cfg.learning_rate,
99
  total_steps=total_num_steps,
 
100
  **lr_scheduler_kwargs,
101
  )
 
 
 
 
 
 
 
102
  else:
103
  lr_scheduler = transformers.get_cosine_schedule_with_warmup(
104
  optimizer,
 
1
+ import importlib
2
  import math
3
  import os
4
+ import sys
5
  from pathlib import Path
6
 
7
  import bitsandbytes as bnb
 
12
  from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
+ from axolotl.utils.schedulers import InterpolatingLogScheduler
16
+
17
 
18
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
19
  total_num_steps = int(
20
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
21
  )
22
+ warmup_steps = (
23
+ cfg.warmup_steps
24
+ if cfg.warmup_steps is not None
25
+ else min(int(0.03 * total_num_steps), 100)
26
+ )
27
+ logging_steps = (
28
+ cfg.logging_steps
29
+ if cfg.logging_steps is not None
30
+ else max(min(int(0.005 * total_num_steps), 10), 1)
31
+ )
32
+ save_steps = (
33
+ cfg.save_steps
34
+ if cfg.save_steps is not None
35
+ else min(int(0.05 * total_num_steps), 200)
36
+ )
37
+ eval_steps = (
38
+ cfg.eval_steps
39
+ if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
40
+ else save_steps
41
+ )
42
 
43
  training_arguments_kwargs = {}
44
  if cfg.bf16 == "full":
 
50
  training_arguments_kwargs["logging_steps"] = logging_steps
51
  if cfg.gradient_checkpointing is not None:
52
  if cfg.load_4bit:
53
+ from alpaca_lora_4bit.gradient_checkpointing import (
54
+ apply_gradient_checkpointing,
55
+ )
 
 
56
 
57
+ gradient_checkpointing_ratio = (
58
+ cfg.gradient_checkpointing_ratio
59
+ if cfg.gradient_checkpointing_ratio
60
+ else 1.0
61
+ )
62
+ apply_gradient_checkpointing(
63
+ model, checkpoint_ratio=gradient_checkpointing_ratio
64
+ )
65
+ else:
66
+ training_arguments_kwargs[
67
+ "gradient_checkpointing"
68
+ ] = cfg.gradient_checkpointing
69
+ if cfg.fsdp:
70
+ training_arguments_kwargs["fsdp"] = cfg.fsdp
71
+ if cfg.fsdp_config:
72
+ training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
73
 
74
  # deepspeed
75
+ if (
76
+ os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
77
+ and torch.cuda.device_count() > 1
78
+ ):
79
  if cfg.deepspeed:
80
  training_arguments_kwargs["deepspeed"] = cfg.deepspeed
81
  else:
 
87
  per_device_train_batch_size=cfg.micro_batch_size,
88
  per_device_eval_batch_size=cfg.eval_batch_size,
89
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
90
+ eval_accumulation_steps=cfg.gradient_accumulation_steps,
91
  num_train_epochs=cfg.num_epochs,
92
  learning_rate=cfg.learning_rate,
93
  evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
 
96
  save_steps=save_steps,
97
  output_dir=cfg.output_dir,
98
  save_total_limit=3,
99
+ load_best_model_at_end=True
100
+ if cfg.val_set_size > 0 and save_steps % eval_steps == 0
101
+ else False,
102
  ddp_find_unused_parameters=False if cfg.ddp else None,
103
  group_by_length=cfg.group_by_length,
104
  report_to="wandb" if cfg.use_wandb else None,
105
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
106
+ optim=cfg.optimizer if cfg.optimizer else None,
107
+ lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
108
+ weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
109
  **training_arguments_kwargs,
110
  )
111
 
112
  trainer_kwargs = {}
113
 
114
+ if cfg.optimizer == "adamw_anyprecision":
115
+ if Path(cfg.torchdistx_path).exists():
116
+ sys.path.append(cfg.torchdistx_path)
117
+ importlib.import_module("torchdistx")
118
+ if (
119
+ cfg.optimizer == "adamw_bnb_8bit"
120
+ and not cfg.load_4bit
121
+ and not "deepspeed" in training_arguments_kwargs
122
+ ):
123
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
124
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
125
  optimizer_grouped_parameters = [
126
  {
127
+ "params": [
128
+ p
129
+ for n, p in model.named_parameters()
130
+ if (n in decay_parameters and p.requires_grad)
131
+ ],
132
  "weight_decay": training_args.weight_decay,
133
  },
134
  {
135
  "params": [
136
+ p
137
+ for n, p in model.named_parameters()
138
+ if (n not in decay_parameters and p.requires_grad)
139
  ],
140
  "weight_decay": 0.0,
141
  },
142
  ]
143
+
144
  optimizer = bnb.optim.Adam8bit(
145
  optimizer_grouped_parameters,
146
  betas=(training_args.adam_beta1, training_args.adam_beta2),
 
156
  optimizer,
157
  cfg.learning_rate,
158
  total_steps=total_num_steps,
159
+ epochs=cfg.num_epochs,
160
  **lr_scheduler_kwargs,
161
  )
162
+ elif cfg.lr_scheduler == "log_sweep":
163
+ lr_scheduler = InterpolatingLogScheduler(
164
+ optimizer,
165
+ cfg.warmup_steps,
166
+ cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
167
+ cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
168
+ )
169
  else:
170
  lr_scheduler = transformers.get_cosine_schedule_with_warmup(
171
  optimizer,