tmm1 commited on
Commit
1f613e5
·
2 Parent(s): f319b0b 3513071

Merge branch 'main' into patch-4

Browse files
README.md CHANGED
@@ -521,7 +521,7 @@ lr_quadratic_warmup:
521
  logging_steps:
522
  save_strategy: # set to `no` to skip checkpoint saves
523
  save_steps: # leave empty to save at each epoch
524
- eval_steps:
525
  save_total_limit: # checkpoints saved at a time
526
  max_steps:
527
 
 
521
  logging_steps:
522
  save_strategy: # set to `no` to skip checkpoint saves
523
  save_steps: # leave empty to save at each epoch
524
+ eval_steps: # leave empty to eval at each epoch
525
  save_total_limit: # checkpoints saved at a time
526
  max_steps:
527
 
deepspeed/zero2.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 2,
4
+ "offload_optimizer": {
5
+ "device": "cpu"
6
+ },
7
+ "contiguous_gradients": true,
8
+ "overlap_comm": true
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "fp16": {
14
+ "enabled": "auto",
15
+ "auto_cast": false,
16
+ "loss_scale": 0,
17
+ "initial_scale_power": 32,
18
+ "loss_scale_window": 1000,
19
+ "hysteresis": 2,
20
+ "min_loss_scale": 1
21
+ },
22
+ "optimizer": {
23
+ "type": "AdamW",
24
+ "params": {
25
+ "lr": "auto",
26
+ "betas": [
27
+ 0.9,
28
+ 0.999
29
+ ],
30
+ "eps": 1e-8,
31
+ "weight_decay": "auto"
32
+ }
33
+ },
34
+ "scheduler": {
35
+ "type": "WarmupDecayLR",
36
+ "params": {
37
+ "warmup_min_lr": "auto",
38
+ "warmup_max_lr": "auto",
39
+ "warmup_num_steps": "auto",
40
+ "total_num_steps": "auto"
41
+ }
42
+ },
43
+ "train_batch_size": "auto",
44
+ "train_micro_batch_size_per_gpu": "auto",
45
+ "wall_clock_breakdown": false
46
+ }
examples/code-llama/13b/lora.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: codellama/CodeLlama-13b-hf
2
+ base_model_config: codellama/CodeLlama-13b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: CodeLlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: true
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./lora-out
17
+
18
+ sequence_len: 100000
19
+ sample_packing: true
20
+
21
+ adapter: lora
22
+ lora_model_dir:
23
+ lora_r: 32
24
+ lora_alpha: 16
25
+ lora_dropout: 0.05
26
+ lora_target_linear: true
27
+ lora_fan_in_fan_out:
28
+
29
+ wandb_project:
30
+ wandb_entity:
31
+ wandb_watch:
32
+ wandb_run_id:
33
+ wandb_log_model:
34
+
35
+ gradient_accumulation_steps: 4
36
+ micro_batch_size: 2
37
+ num_epochs: 3
38
+ optimizer: adamw_bnb_8bit
39
+ lr_scheduler: cosine
40
+ learning_rate: 0.0002
41
+
42
+ train_on_inputs: false
43
+ group_by_length: false
44
+ bf16: true
45
+ fp16: false
46
+ tf32: false
47
+
48
+ gradient_checkpointing: true
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ eval_steps: 20
58
+ save_steps:
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
65
+ bos_token: "<s>"
66
+ eos_token: "</s>"
67
+ unk_token: "<unk>"
examples/code-llama/13b/qlora.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: codellama/CodeLlama-13b-hf
2
+ base_model_config: codellama/CodeLlama-13b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: CodeLlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: true
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./qlora-out
17
+
18
+ adapter: qlora
19
+ lora_model_dir:
20
+
21
+ sequence_len: 100000
22
+ sample_packing: true
23
+
24
+ lora_r: 32
25
+ lora_alpha: 16
26
+ lora_dropout: 0.05
27
+ lora_target_modules:
28
+ lora_target_linear: true
29
+ lora_fan_in_fan_out:
30
+
31
+ wandb_project:
32
+ wandb_entity:
33
+ wandb_watch:
34
+ wandb_run_id:
35
+ wandb_log_model:
36
+
37
+ gradient_accumulation_steps: 4
38
+ micro_batch_size: 2
39
+ num_epochs: 3
40
+ optimizer: paged_adamw_32bit
41
+ lr_scheduler: cosine
42
+ learning_rate: 0.0002
43
+
44
+ train_on_inputs: false
45
+ group_by_length: false
46
+ bf16: true
47
+ fp16: false
48
+ tf32: false
49
+
50
+ gradient_checkpointing: true
51
+ early_stopping_patience:
52
+ resume_from_checkpoint:
53
+ local_rank:
54
+ logging_steps: 1
55
+ xformers_attention:
56
+ flash_attention: true
57
+
58
+ warmup_steps: 10
59
+ eval_steps: 20
60
+ save_steps:
61
+ debug:
62
+ deepspeed:
63
+ weight_decay: 0.0
64
+ fsdp:
65
+ fsdp_config:
66
+ special_tokens:
67
+ bos_token: "<s>"
68
+ eos_token: "</s>"
69
+ unk_token: "<unk>"
examples/code-llama/34b/lora.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: codellama/CodeLlama-34b-hf
2
+ base_model_config: codellama/CodeLlama-34b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: CodeLlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: true
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./lora-out
17
+
18
+ sequence_len: 100000
19
+ sample_packing: true
20
+
21
+ adapter: lora
22
+ lora_model_dir:
23
+ lora_r: 32
24
+ lora_alpha: 16
25
+ lora_dropout: 0.05
26
+ lora_target_linear: true
27
+ lora_fan_in_fan_out:
28
+
29
+ wandb_project:
30
+ wandb_entity:
31
+ wandb_watch:
32
+ wandb_run_id:
33
+ wandb_log_model:
34
+
35
+ gradient_accumulation_steps: 4
36
+ micro_batch_size: 2
37
+ num_epochs: 3
38
+ optimizer: adamw_bnb_8bit
39
+ lr_scheduler: cosine
40
+ learning_rate: 0.0002
41
+
42
+ train_on_inputs: false
43
+ group_by_length: false
44
+ bf16: true
45
+ fp16: false
46
+ tf32: false
47
+
48
+ gradient_checkpointing: true
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ eval_steps: 20
58
+ save_steps:
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
65
+ bos_token: "<s>"
66
+ eos_token: "</s>"
67
+ unk_token: "<unk>"
examples/code-llama/34b/qlora.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: codellama/CodeLlama-34b-hf
2
+ base_model_config: codellama/CodeLlama-34b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: CodeLlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: true
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./qlora-out
17
+
18
+ adapter: qlora
19
+ lora_model_dir:
20
+
21
+ sequence_len: 100000
22
+ sample_packing: true
23
+
24
+ lora_r: 32
25
+ lora_alpha: 16
26
+ lora_dropout: 0.05
27
+ lora_target_modules:
28
+ lora_target_linear: true
29
+ lora_fan_in_fan_out:
30
+
31
+ wandb_project:
32
+ wandb_entity:
33
+ wandb_watch:
34
+ wandb_run_id:
35
+ wandb_log_model:
36
+
37
+ gradient_accumulation_steps: 4
38
+ micro_batch_size: 2
39
+ num_epochs: 3
40
+ optimizer: paged_adamw_32bit
41
+ lr_scheduler: cosine
42
+ learning_rate: 0.0002
43
+
44
+ train_on_inputs: false
45
+ group_by_length: false
46
+ bf16: true
47
+ fp16: false
48
+ tf32: false
49
+
50
+ gradient_checkpointing: true
51
+ early_stopping_patience:
52
+ resume_from_checkpoint:
53
+ local_rank:
54
+ logging_steps: 1
55
+ xformers_attention:
56
+ flash_attention: true
57
+
58
+ warmup_steps: 10
59
+ eval_steps: 20
60
+ save_steps:
61
+ debug:
62
+ deepspeed:
63
+ weight_decay: 0.0
64
+ fsdp:
65
+ fsdp_config:
66
+ special_tokens:
67
+ bos_token: "<s>"
68
+ eos_token: "</s>"
69
+ unk_token: "<unk>"
examples/code-llama/7b/lora.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: codellama/CodeLlama-7b-hf
2
+ base_model_config: codellama/CodeLlama-7b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: CodeLlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: true
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./lora-out
17
+
18
+ sequence_len: 100000
19
+ sample_packing: true
20
+
21
+ adapter: lora
22
+ lora_model_dir:
23
+ lora_r: 32
24
+ lora_alpha: 16
25
+ lora_dropout: 0.05
26
+ lora_target_linear: true
27
+ lora_fan_in_fan_out:
28
+
29
+ wandb_project:
30
+ wandb_entity:
31
+ wandb_watch:
32
+ wandb_run_id:
33
+ wandb_log_model:
34
+
35
+ gradient_accumulation_steps: 4
36
+ micro_batch_size: 2
37
+ num_epochs: 3
38
+ optimizer: adamw_bnb_8bit
39
+ lr_scheduler: cosine
40
+ learning_rate: 0.0002
41
+
42
+ train_on_inputs: false
43
+ group_by_length: false
44
+ bf16: true
45
+ fp16: false
46
+ tf32: false
47
+
48
+ gradient_checkpointing: true
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ eval_steps: 20
58
+ save_steps:
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
65
+ bos_token: "<s>"
66
+ eos_token: "</s>"
67
+ unk_token: "<unk>"
examples/code-llama/7b/qlora.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: codellama/CodeLlama-7b-hf
2
+ base_model_config: codellama/CodeLlama-7b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: CodeLlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: true
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./qlora-out
17
+
18
+ adapter: qlora
19
+ lora_model_dir:
20
+
21
+ sequence_len: 100000
22
+ sample_packing: true
23
+
24
+ lora_r: 32
25
+ lora_alpha: 16
26
+ lora_dropout: 0.05
27
+ lora_target_modules:
28
+ lora_target_linear: true
29
+ lora_fan_in_fan_out:
30
+
31
+ wandb_project:
32
+ wandb_entity:
33
+ wandb_watch:
34
+ wandb_run_id:
35
+ wandb_log_model:
36
+
37
+ gradient_accumulation_steps: 4
38
+ micro_batch_size: 2
39
+ num_epochs: 3
40
+ optimizer: paged_adamw_32bit
41
+ lr_scheduler: cosine
42
+ learning_rate: 0.0002
43
+
44
+ train_on_inputs: false
45
+ group_by_length: false
46
+ bf16: true
47
+ fp16: false
48
+ tf32: false
49
+
50
+ gradient_checkpointing: true
51
+ early_stopping_patience:
52
+ resume_from_checkpoint:
53
+ local_rank:
54
+ logging_steps: 1
55
+ xformers_attention:
56
+ flash_attention: true
57
+
58
+ warmup_steps: 10
59
+ eval_steps: 20
60
+ save_steps:
61
+ debug:
62
+ deepspeed:
63
+ weight_decay: 0.0
64
+ fsdp:
65
+ fsdp_config:
66
+ special_tokens:
67
+ bos_token: "<s>"
68
+ eos_token: "</s>"
69
+ unk_token: "<unk>"
examples/code-llama/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Overview
2
+
3
+ This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
4
+
5
+ The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
6
+
7
+ The 13b variant will fit if you change these settings to these values:
8
+ gradient_accumulation_steps: 2
9
+ micro_batch_size: 1
10
+
11
+ The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
12
+
13
+ ```shell
14
+ accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
15
+
16
+ ```
17
+ or
18
+
19
+ ```shell
20
+ accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
21
+
22
+ ```
examples/llama-2/relora.yml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: meta-llama/Llama-2-7b-hf
2
+ base_model_config: meta-llama/Llama-2-7b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: true
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: teknium/GPT4-LLM-Cleaned
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./relora-out
17
+
18
+ adapter: qlora
19
+ lora_model_dir:
20
+
21
+ sequence_len: 4096
22
+ sample_packing: true
23
+
24
+ lora_r: 8
25
+ lora_alpha: 16
26
+ lora_dropout: 0.05
27
+ lora_target_modules:
28
+ lora_target_linear: true
29
+ lora_fan_in_fan_out:
30
+
31
+ relora_steps: 150
32
+ relora_warmup_steps: 10
33
+ relora_cpu_offload: false
34
+
35
+ wandb_project:
36
+ wandb_entity:
37
+ wandb_watch:
38
+ wandb_run_id:
39
+ wandb_log_model:
40
+
41
+ gradient_accumulation_steps: 4
42
+ micro_batch_size: 4
43
+ num_epochs: 3
44
+ optimizer: adamw_bnb_8bit
45
+ lr_scheduler: cosine
46
+ learning_rate: 0.0002
47
+
48
+ train_on_inputs: false
49
+ group_by_length: false
50
+ bf16: true
51
+ fp16: false
52
+ tf32: false
53
+
54
+ gradient_checkpointing: true
55
+ early_stopping_patience:
56
+ resume_from_checkpoint:
57
+ local_rank:
58
+ logging_steps: 1
59
+ xformers_attention:
60
+ flash_attention: true
61
+
62
+ warmup_steps: 10
63
+ eval_steps: 20
64
+ save_steps: 50
65
+ debug:
66
+ deepspeed:
67
+ weight_decay: 0.0
68
+ fsdp:
69
+ fsdp_config:
70
+ special_tokens:
71
+ bos_token: "<s>"
72
+ eos_token: "</s>"
73
+ unk_token: "<unk>"
scripts/finetune.py CHANGED
@@ -82,6 +82,8 @@ def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
82
  max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
83
  )
84
 
 
 
85
  while True:
86
  print("=" * 80)
87
  # support for multiline inputs
 
82
  max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
83
  )
84
 
85
+ model = model.to(cfg.device)
86
+
87
  while True:
88
  print("=" * 80)
89
  # support for multiline inputs
src/axolotl/utils/trainer.py CHANGED
@@ -10,19 +10,13 @@ from functools import partial
10
  from pathlib import Path
11
  from typing import Optional, Union
12
 
13
- import bitsandbytes as bnb
14
  import numpy as np
15
  import torch.cuda
16
- import transformers
17
  from datasets import Dataset, set_caching_enabled
18
- from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
- from transformers.trainer_pt_utils import (
23
- SequentialDistributedSampler,
24
- get_parameter_names,
25
- )
26
 
27
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
28
  from axolotl.utils.callbacks import (
@@ -32,10 +26,7 @@ from axolotl.utils.callbacks import (
32
  )
33
  from axolotl.utils.collators import DataCollatorForSeq2Seq
34
  from axolotl.utils.dataloader import MultipackDistributedDataloader
35
- from axolotl.utils.schedulers import (
36
- InterpolatingLogScheduler,
37
- get_cosine_schedule_with_quadratic_warmup,
38
- )
39
 
40
  LOG = logging.getLogger("axolotl")
41
 
@@ -570,66 +561,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
570
  if Path(cfg.torchdistx_path).exists():
571
  sys.path.append(cfg.torchdistx_path)
572
  importlib.import_module("torchdistx")
573
- if (
574
- cfg.optimizer == "adamw_bnb_8bit"
575
- and not cfg.gptq
576
- and "deepspeed" not in training_arguments_kwargs
577
- and not cfg.fsdp
578
- ):
579
- decay_parameters = get_parameter_names(model, [nn.LayerNorm])
580
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
581
- optimizer_grouped_parameters = [
582
- {
583
- "params": [
584
- p
585
- for n, p in model.named_parameters()
586
- if (n in decay_parameters and p.requires_grad)
587
- ],
588
- "weight_decay": training_args.weight_decay,
589
- },
590
- {
591
- "params": [
592
- p
593
- for n, p in model.named_parameters()
594
- if (n not in decay_parameters and p.requires_grad)
595
- ],
596
- "weight_decay": 0.0,
597
- },
598
- ]
599
-
600
- optimizer = bnb.optim.Adam8bit(
601
- optimizer_grouped_parameters,
602
- betas=(training_args.adam_beta1, training_args.adam_beta2),
603
- eps=training_args.adam_epsilon,
604
- lr=training_args.learning_rate,
605
- )
606
-
607
- if cfg.lr_scheduler == "one_cycle":
608
- lr_scheduler_kwargs = (
609
- cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
610
- )
611
- lr_scheduler = OneCycleLR(
612
- optimizer,
613
- cfg.learning_rate,
614
- total_steps=total_num_steps,
615
- epochs=cfg.num_epochs,
616
- div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
617
- **lr_scheduler_kwargs,
618
- )
619
- elif cfg.lr_scheduler == "log_sweep":
620
- lr_scheduler = InterpolatingLogScheduler(
621
- optimizer,
622
- cfg.warmup_steps,
623
- cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
624
- cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
625
- )
626
- else:
627
- lr_scheduler = transformers.get_cosine_schedule_with_warmup(
628
- optimizer,
629
- training_args.warmup_steps,
630
- total_num_steps,
631
- )
632
- trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
633
 
634
  callbacks = []
635
  callbacks.append(GPUStatsCallback(cfg))
 
10
  from pathlib import Path
11
  from typing import Optional, Union
12
 
 
13
  import numpy as np
14
  import torch.cuda
 
15
  from datasets import Dataset, set_caching_enabled
 
16
  from torch.optim.lr_scheduler import OneCycleLR
17
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
18
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
19
+ from transformers.trainer_pt_utils import SequentialDistributedSampler
 
 
 
20
 
21
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
22
  from axolotl.utils.callbacks import (
 
26
  )
27
  from axolotl.utils.collators import DataCollatorForSeq2Seq
28
  from axolotl.utils.dataloader import MultipackDistributedDataloader
29
+ from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
 
 
 
30
 
31
  LOG = logging.getLogger("axolotl")
32
 
 
561
  if Path(cfg.torchdistx_path).exists():
562
  sys.path.append(cfg.torchdistx_path)
563
  importlib.import_module("torchdistx")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
  callbacks = []
566
  callbacks.append(GPUStatsCallback(cfg))