Text Generation
Transformers
Safetensors
qwen3
Generated from Trainer
conversational
text-generation-inference
sumuks commited on
Commit
08328b6
·
verified ·
1 Parent(s): 42af185

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. README.md +143 -0
  2. added_tokens.json +28 -0
  3. chat_template.jinja +87 -0
  4. checkpoint-2000/added_tokens.json +28 -0
  5. checkpoint-2000/chat_template.jinja +87 -0
  6. checkpoint-2000/config.json +59 -0
  7. checkpoint-2000/generation_config.json +13 -0
  8. checkpoint-2000/latest +1 -0
  9. checkpoint-2000/merges.txt +0 -0
  10. checkpoint-2000/special_tokens_map.json +31 -0
  11. checkpoint-2000/tokenizer_config.json +239 -0
  12. checkpoint-2000/trainer_state.json +2883 -0
  13. checkpoint-2000/vocab.json +0 -0
  14. checkpoint-2000/zero_to_fp32.py +760 -0
  15. checkpoint-2500/added_tokens.json +28 -0
  16. checkpoint-2500/chat_template.jinja +87 -0
  17. checkpoint-2500/config.json +59 -0
  18. checkpoint-2500/generation_config.json +13 -0
  19. checkpoint-2500/latest +1 -0
  20. checkpoint-2500/merges.txt +0 -0
  21. checkpoint-2500/special_tokens_map.json +31 -0
  22. checkpoint-2500/tokenizer_config.json +239 -0
  23. checkpoint-2500/vocab.json +0 -0
  24. checkpoint-2500/zero_to_fp32.py +760 -0
  25. checkpoint-3000/added_tokens.json +28 -0
  26. checkpoint-3000/chat_template.jinja +87 -0
  27. checkpoint-3000/config.json +59 -0
  28. checkpoint-3000/latest +1 -0
  29. checkpoint-3000/merges.txt +0 -0
  30. checkpoint-3000/special_tokens_map.json +31 -0
  31. checkpoint-3000/tokenizer_config.json +239 -0
  32. checkpoint-3000/trainer_state.json +0 -0
  33. checkpoint-3000/vocab.json +0 -0
  34. checkpoint-3000/zero_to_fp32.py +760 -0
  35. checkpoint-3297/added_tokens.json +28 -0
  36. checkpoint-3297/chat_template.jinja +87 -0
  37. checkpoint-3297/config.json +59 -0
  38. checkpoint-3297/generation_config.json +13 -0
  39. checkpoint-3297/latest +1 -0
  40. checkpoint-3297/merges.txt +0 -0
  41. checkpoint-3297/special_tokens_map.json +31 -0
  42. checkpoint-3297/tokenizer_config.json +239 -0
  43. checkpoint-3297/trainer_state.json +0 -0
  44. checkpoint-3297/vocab.json +0 -0
  45. checkpoint-3297/zero_to_fp32.py +760 -0
  46. config.json +59 -0
  47. generation_config.json +13 -0
  48. merges.txt +0 -0
  49. special_tokens_map.json +31 -0
  50. tokenizer_config.json +239 -0
README.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ base_model: Qwen/Qwen3-1.7B
5
+ tags:
6
+ - generated_from_trainer
7
+ datasets:
8
+ - sumuks/essential-web-v1.0-sample-100M-with-cleaned-responses-sft
9
+ model-index:
10
+ - name: output/1.7B-Instruct-Tuned-New-Data
11
+ results: []
12
+ ---
13
+
14
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
+ should probably proofread and complete it, then remove this comment. -->
16
+
17
+ [<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
18
+ <details><summary>See axolotl config</summary>
19
+
20
+ axolotl version: `0.11.0`
21
+ ```yaml
22
+ base_model: Qwen/Qwen3-1.7B
23
+
24
+ # plugins:
25
+ # - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
26
+ strict: false
27
+
28
+ # plugins:
29
+ # - axolotl.integrations.liger.LigerPlugin
30
+
31
+ # liger_rope: true
32
+ # liger_rms_norm: true
33
+ # liger_glu_activation: true
34
+ # liger_layer_norm: true
35
+ # liger_fused_linear_cross_entropy: true
36
+
37
+ datasets:
38
+ - path: sumuks/essential-web-v1.0-sample-100M-with-cleaned-responses-sft
39
+ type: chat_template
40
+ field_messages: conversations
41
+ split: train
42
+ val_set_size: 0.05
43
+ dataset_prepared_path: dataset/prepared_dataset_1.7b
44
+
45
+ train_on_inputs: false
46
+ output_dir: ./output/1.7B-Instruct-Tuned-New-Data
47
+ chat_template: qwen3
48
+ sequence_len: 8192
49
+ sample_packing: true
50
+ eval_sample_packing: true
51
+ # pad_to_sequence_len: true
52
+
53
+ wandb_project: essential-web-sft
54
+ wandb_name: qwen3-1.7b-sft-new-data
55
+
56
+ gradient_accumulation_steps: 4
57
+ gradient_checkpointing: true
58
+ gradient_checkpointing_kwargs:
59
+ use_reentrant: false
60
+ flash_attention: true
61
+ micro_batch_size: 1
62
+ optimizer: paged_adamw_8bit
63
+ lr_scheduler: cosine
64
+ learning_rate: 2e-5
65
+ num_epochs: 1
66
+
67
+ load_best_model_at_end: true
68
+ metric_for_best_model: loss
69
+ greater_is_better: false
70
+
71
+ early_stopping_patience: 3
72
+ bf16: auto
73
+ tf32: true
74
+
75
+ logging_steps: 5
76
+
77
+ deepspeed: ./configs_prod/zero3.json
78
+
79
+ save_steps: 500
80
+ eval_steps: 500
81
+
82
+ warmup_ratio: 0.05
83
+ # save_first_step: true
84
+ ```
85
+
86
+ </details><br>
87
+
88
+ # output/1.7B-Instruct-Tuned-New-Data
89
+
90
+ This model is a fine-tuned version of [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) on the sumuks/essential-web-v1.0-sample-100M-with-cleaned-responses-sft dataset.
91
+ It achieves the following results on the evaluation set:
92
+ - Loss: 0.3669
93
+
94
+ ## Model description
95
+
96
+ More information needed
97
+
98
+ ## Intended uses & limitations
99
+
100
+ More information needed
101
+
102
+ ## Training and evaluation data
103
+
104
+ More information needed
105
+
106
+ ## Training procedure
107
+
108
+ ### Training hyperparameters
109
+
110
+ The following hyperparameters were used during training:
111
+ - learning_rate: 2e-05
112
+ - train_batch_size: 1
113
+ - eval_batch_size: 1
114
+ - seed: 42
115
+ - distributed_type: multi-GPU
116
+ - num_devices: 2
117
+ - gradient_accumulation_steps: 4
118
+ - total_train_batch_size: 8
119
+ - total_eval_batch_size: 2
120
+ - optimizer: Use OptimizerNames.PAGED_ADAMW_8BIT with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
121
+ - lr_scheduler_type: cosine
122
+ - lr_scheduler_warmup_steps: 164
123
+ - training_steps: 3297
124
+
125
+ ### Training results
126
+
127
+ | Training Loss | Epoch | Step | Validation Loss |
128
+ |:-------------:|:------:|:----:|:---------------:|
129
+ | No log | 0 | 0 | 0.8829 |
130
+ | 0.3689 | 0.1517 | 500 | 0.4088 |
131
+ | 0.3919 | 0.3033 | 1000 | 0.3952 |
132
+ | 0.386 | 0.4550 | 1500 | 0.3839 |
133
+ | 0.409 | 0.6066 | 2000 | 0.3755 |
134
+ | 0.3473 | 0.7583 | 2500 | 0.3694 |
135
+ | 0.3518 | 0.9099 | 3000 | 0.3669 |
136
+
137
+
138
+ ### Framework versions
139
+
140
+ - Transformers 4.53.1
141
+ - Pytorch 2.7.1+cu126
142
+ - Datasets 3.6.0
143
+ - Tokenizers 0.21.2
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
chat_template.jinja ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- else %}
85
+ {{- '<think>\n\n' }}
86
+ {%- endif %}
87
+ {%- endif %}
checkpoint-2000/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
checkpoint-2000/chat_template.jinja ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- else %}
85
+ {{- '<think>\n\n' }}
86
+ {%- endif %}
87
+ {%- endif %}
checkpoint-2000/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 151645,
8
+ "head_dim": 128,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 6144,
13
+ "layer_types": [
14
+ "full_attention",
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention"
42
+ ],
43
+ "max_position_embeddings": 40960,
44
+ "max_window_layers": 28,
45
+ "model_type": "qwen3",
46
+ "num_attention_heads": 16,
47
+ "num_hidden_layers": 28,
48
+ "num_key_value_heads": 8,
49
+ "rms_norm_eps": 1e-06,
50
+ "rope_scaling": null,
51
+ "rope_theta": 1000000,
52
+ "sliding_window": null,
53
+ "tie_word_embeddings": true,
54
+ "torch_dtype": "bfloat16",
55
+ "transformers_version": "4.53.1",
56
+ "use_cache": false,
57
+ "use_sliding_window": false,
58
+ "vocab_size": 151936
59
+ }
checkpoint-2000/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.53.1"
13
+ }
checkpoint-2000/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step2000
checkpoint-2000/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-2000/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
checkpoint-2000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
checkpoint-2000/trainer_state.json ADDED
@@ -0,0 +1,2883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": 2000,
3
+ "best_metric": 0.3754875957965851,
4
+ "best_model_checkpoint": "./output/1.7B-Instruct-Tuned-New-Data/checkpoint-2000",
5
+ "epoch": 0.6066120715802245,
6
+ "eval_steps": 500,
7
+ "global_step": 2000,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0,
14
+ "eval_loss": 0.8828833103179932,
15
+ "eval_runtime": 183.2487,
16
+ "eval_samples_per_second": 48.759,
17
+ "eval_steps_per_second": 24.382,
18
+ "step": 0
19
+ },
20
+ {
21
+ "epoch": 0.001516530178950561,
22
+ "grad_norm": 24.921934804682877,
23
+ "learning_rate": 4.878048780487805e-07,
24
+ "loss": 0.8639,
25
+ "step": 5
26
+ },
27
+ {
28
+ "epoch": 0.003033060357901122,
29
+ "grad_norm": 22.993835600118015,
30
+ "learning_rate": 1.0975609756097562e-06,
31
+ "loss": 0.8551,
32
+ "step": 10
33
+ },
34
+ {
35
+ "epoch": 0.004549590536851683,
36
+ "grad_norm": 21.595580129534707,
37
+ "learning_rate": 1.707317073170732e-06,
38
+ "loss": 0.8283,
39
+ "step": 15
40
+ },
41
+ {
42
+ "epoch": 0.006066120715802244,
43
+ "grad_norm": 5.010673102150899,
44
+ "learning_rate": 2.317073170731708e-06,
45
+ "loss": 0.559,
46
+ "step": 20
47
+ },
48
+ {
49
+ "epoch": 0.0075826508947528055,
50
+ "grad_norm": 2.8725236445814146,
51
+ "learning_rate": 2.926829268292683e-06,
52
+ "loss": 0.5412,
53
+ "step": 25
54
+ },
55
+ {
56
+ "epoch": 0.009099181073703366,
57
+ "grad_norm": 2.187108009423291,
58
+ "learning_rate": 3.5365853658536588e-06,
59
+ "loss": 0.521,
60
+ "step": 30
61
+ },
62
+ {
63
+ "epoch": 0.010615711252653927,
64
+ "grad_norm": 1.4262904456959642,
65
+ "learning_rate": 4.146341463414634e-06,
66
+ "loss": 0.5138,
67
+ "step": 35
68
+ },
69
+ {
70
+ "epoch": 0.012132241431604488,
71
+ "grad_norm": 1.5492098631988767,
72
+ "learning_rate": 4.75609756097561e-06,
73
+ "loss": 0.5069,
74
+ "step": 40
75
+ },
76
+ {
77
+ "epoch": 0.01364877161055505,
78
+ "grad_norm": 1.2785619453952701,
79
+ "learning_rate": 5.365853658536586e-06,
80
+ "loss": 0.4792,
81
+ "step": 45
82
+ },
83
+ {
84
+ "epoch": 0.015165301789505611,
85
+ "grad_norm": 1.182973976230221,
86
+ "learning_rate": 5.9756097560975615e-06,
87
+ "loss": 0.4029,
88
+ "step": 50
89
+ },
90
+ {
91
+ "epoch": 0.01668183196845617,
92
+ "grad_norm": 1.1563458384082506,
93
+ "learning_rate": 6.585365853658538e-06,
94
+ "loss": 0.385,
95
+ "step": 55
96
+ },
97
+ {
98
+ "epoch": 0.018198362147406732,
99
+ "grad_norm": 1.3349292384420064,
100
+ "learning_rate": 7.1951219512195125e-06,
101
+ "loss": 0.4144,
102
+ "step": 60
103
+ },
104
+ {
105
+ "epoch": 0.019714892326357293,
106
+ "grad_norm": 0.9335126123634385,
107
+ "learning_rate": 7.804878048780489e-06,
108
+ "loss": 0.4261,
109
+ "step": 65
110
+ },
111
+ {
112
+ "epoch": 0.021231422505307854,
113
+ "grad_norm": 1.30792221561379,
114
+ "learning_rate": 8.414634146341464e-06,
115
+ "loss": 0.4606,
116
+ "step": 70
117
+ },
118
+ {
119
+ "epoch": 0.022747952684258416,
120
+ "grad_norm": 1.6097633219159984,
121
+ "learning_rate": 9.02439024390244e-06,
122
+ "loss": 0.4276,
123
+ "step": 75
124
+ },
125
+ {
126
+ "epoch": 0.024264482863208977,
127
+ "grad_norm": 1.1673685397111215,
128
+ "learning_rate": 9.634146341463415e-06,
129
+ "loss": 0.445,
130
+ "step": 80
131
+ },
132
+ {
133
+ "epoch": 0.025781013042159538,
134
+ "grad_norm": 1.3515026545633433,
135
+ "learning_rate": 1.024390243902439e-05,
136
+ "loss": 0.496,
137
+ "step": 85
138
+ },
139
+ {
140
+ "epoch": 0.0272975432211101,
141
+ "grad_norm": 1.5361335286977482,
142
+ "learning_rate": 1.0853658536585368e-05,
143
+ "loss": 0.4189,
144
+ "step": 90
145
+ },
146
+ {
147
+ "epoch": 0.02881407340006066,
148
+ "grad_norm": 1.3323687104870683,
149
+ "learning_rate": 1.1463414634146342e-05,
150
+ "loss": 0.4461,
151
+ "step": 95
152
+ },
153
+ {
154
+ "epoch": 0.030330603579011222,
155
+ "grad_norm": 1.3055093124799528,
156
+ "learning_rate": 1.2073170731707317e-05,
157
+ "loss": 0.4558,
158
+ "step": 100
159
+ },
160
+ {
161
+ "epoch": 0.03184713375796178,
162
+ "grad_norm": 1.4071077040506643,
163
+ "learning_rate": 1.2682926829268294e-05,
164
+ "loss": 0.4608,
165
+ "step": 105
166
+ },
167
+ {
168
+ "epoch": 0.03336366393691234,
169
+ "grad_norm": 1.140971116003336,
170
+ "learning_rate": 1.329268292682927e-05,
171
+ "loss": 0.4046,
172
+ "step": 110
173
+ },
174
+ {
175
+ "epoch": 0.034880194115862906,
176
+ "grad_norm": 1.3427459277886848,
177
+ "learning_rate": 1.3902439024390244e-05,
178
+ "loss": 0.3874,
179
+ "step": 115
180
+ },
181
+ {
182
+ "epoch": 0.036396724294813464,
183
+ "grad_norm": 1.1775148547297358,
184
+ "learning_rate": 1.451219512195122e-05,
185
+ "loss": 0.4605,
186
+ "step": 120
187
+ },
188
+ {
189
+ "epoch": 0.03791325447376403,
190
+ "grad_norm": 1.2616670504697265,
191
+ "learning_rate": 1.5121951219512196e-05,
192
+ "loss": 0.4781,
193
+ "step": 125
194
+ },
195
+ {
196
+ "epoch": 0.039429784652714586,
197
+ "grad_norm": 1.4344564952473762,
198
+ "learning_rate": 1.5731707317073173e-05,
199
+ "loss": 0.4337,
200
+ "step": 130
201
+ },
202
+ {
203
+ "epoch": 0.04094631483166515,
204
+ "grad_norm": 1.7662014691737036,
205
+ "learning_rate": 1.6341463414634145e-05,
206
+ "loss": 0.4636,
207
+ "step": 135
208
+ },
209
+ {
210
+ "epoch": 0.04246284501061571,
211
+ "grad_norm": 1.3917668573790878,
212
+ "learning_rate": 1.6951219512195124e-05,
213
+ "loss": 0.4651,
214
+ "step": 140
215
+ },
216
+ {
217
+ "epoch": 0.043979375189566274,
218
+ "grad_norm": 1.1345687800745432,
219
+ "learning_rate": 1.75609756097561e-05,
220
+ "loss": 0.3976,
221
+ "step": 145
222
+ },
223
+ {
224
+ "epoch": 0.04549590536851683,
225
+ "grad_norm": 1.0517487527616944,
226
+ "learning_rate": 1.8170731707317075e-05,
227
+ "loss": 0.4172,
228
+ "step": 150
229
+ },
230
+ {
231
+ "epoch": 0.047012435547467396,
232
+ "grad_norm": 1.0174150315487598,
233
+ "learning_rate": 1.878048780487805e-05,
234
+ "loss": 0.3947,
235
+ "step": 155
236
+ },
237
+ {
238
+ "epoch": 0.048528965726417954,
239
+ "grad_norm": 1.3496351705883955,
240
+ "learning_rate": 1.9390243902439026e-05,
241
+ "loss": 0.4429,
242
+ "step": 160
243
+ },
244
+ {
245
+ "epoch": 0.05004549590536852,
246
+ "grad_norm": 1.1321668963681037,
247
+ "learning_rate": 2e-05,
248
+ "loss": 0.4044,
249
+ "step": 165
250
+ },
251
+ {
252
+ "epoch": 0.051562026084319076,
253
+ "grad_norm": 1.0951786995673565,
254
+ "learning_rate": 1.999987431366603e-05,
255
+ "loss": 0.3749,
256
+ "step": 170
257
+ },
258
+ {
259
+ "epoch": 0.05307855626326964,
260
+ "grad_norm": 1.1089458391034641,
261
+ "learning_rate": 1.999949725782353e-05,
262
+ "loss": 0.4227,
263
+ "step": 175
264
+ },
265
+ {
266
+ "epoch": 0.0545950864422202,
267
+ "grad_norm": 1.1845203588062014,
268
+ "learning_rate": 1.9998868841950646e-05,
269
+ "loss": 0.3867,
270
+ "step": 180
271
+ },
272
+ {
273
+ "epoch": 0.056111616621170764,
274
+ "grad_norm": 0.8571685139076549,
275
+ "learning_rate": 1.9997989081844044e-05,
276
+ "loss": 0.4351,
277
+ "step": 185
278
+ },
279
+ {
280
+ "epoch": 0.05762814680012132,
281
+ "grad_norm": 0.9900250625278807,
282
+ "learning_rate": 1.999685799961849e-05,
283
+ "loss": 0.4125,
284
+ "step": 190
285
+ },
286
+ {
287
+ "epoch": 0.059144676979071886,
288
+ "grad_norm": 1.1898649968753983,
289
+ "learning_rate": 1.999547562370629e-05,
290
+ "loss": 0.4217,
291
+ "step": 195
292
+ },
293
+ {
294
+ "epoch": 0.060661207158022444,
295
+ "grad_norm": 1.1463755374202382,
296
+ "learning_rate": 1.99938419888566e-05,
297
+ "loss": 0.4564,
298
+ "step": 200
299
+ },
300
+ {
301
+ "epoch": 0.06217773733697301,
302
+ "grad_norm": 1.3479323790987279,
303
+ "learning_rate": 1.9991957136134542e-05,
304
+ "loss": 0.5052,
305
+ "step": 205
306
+ },
307
+ {
308
+ "epoch": 0.06369426751592357,
309
+ "grad_norm": 1.475205459236139,
310
+ "learning_rate": 1.9989821112920155e-05,
311
+ "loss": 0.38,
312
+ "step": 210
313
+ },
314
+ {
315
+ "epoch": 0.06521079769487413,
316
+ "grad_norm": 1.1629718970331968,
317
+ "learning_rate": 1.9987433972907225e-05,
318
+ "loss": 0.4202,
319
+ "step": 215
320
+ },
321
+ {
322
+ "epoch": 0.06672732787382468,
323
+ "grad_norm": 1.0931736708866155,
324
+ "learning_rate": 1.998479577610193e-05,
325
+ "loss": 0.3596,
326
+ "step": 220
327
+ },
328
+ {
329
+ "epoch": 0.06824385805277525,
330
+ "grad_norm": 1.0829052216446347,
331
+ "learning_rate": 1.9981906588821322e-05,
332
+ "loss": 0.4087,
333
+ "step": 225
334
+ },
335
+ {
336
+ "epoch": 0.06976038823172581,
337
+ "grad_norm": 1.247069136713069,
338
+ "learning_rate": 1.997876648369168e-05,
339
+ "loss": 0.4044,
340
+ "step": 230
341
+ },
342
+ {
343
+ "epoch": 0.07127691841067638,
344
+ "grad_norm": 1.0789938627982327,
345
+ "learning_rate": 1.9975375539646656e-05,
346
+ "loss": 0.3966,
347
+ "step": 235
348
+ },
349
+ {
350
+ "epoch": 0.07279344858962693,
351
+ "grad_norm": 1.2874248378984725,
352
+ "learning_rate": 1.997173384192532e-05,
353
+ "loss": 0.3919,
354
+ "step": 240
355
+ },
356
+ {
357
+ "epoch": 0.07430997876857749,
358
+ "grad_norm": 1.251206148462947,
359
+ "learning_rate": 1.9967841482070002e-05,
360
+ "loss": 0.4433,
361
+ "step": 245
362
+ },
363
+ {
364
+ "epoch": 0.07582650894752806,
365
+ "grad_norm": 1.2436949055413618,
366
+ "learning_rate": 1.996369855792398e-05,
367
+ "loss": 0.431,
368
+ "step": 250
369
+ },
370
+ {
371
+ "epoch": 0.07734303912647862,
372
+ "grad_norm": 1.2591168062725386,
373
+ "learning_rate": 1.9959305173629056e-05,
374
+ "loss": 0.4249,
375
+ "step": 255
376
+ },
377
+ {
378
+ "epoch": 0.07885956930542917,
379
+ "grad_norm": 1.1569304716868531,
380
+ "learning_rate": 1.9954661439622894e-05,
381
+ "loss": 0.4274,
382
+ "step": 260
383
+ },
384
+ {
385
+ "epoch": 0.08037609948437974,
386
+ "grad_norm": 1.1825334487051853,
387
+ "learning_rate": 1.994976747263628e-05,
388
+ "loss": 0.3794,
389
+ "step": 265
390
+ },
391
+ {
392
+ "epoch": 0.0818926296633303,
393
+ "grad_norm": 1.4809085739387609,
394
+ "learning_rate": 1.9944623395690162e-05,
395
+ "loss": 0.4105,
396
+ "step": 270
397
+ },
398
+ {
399
+ "epoch": 0.08340915984228087,
400
+ "grad_norm": 1.0467676858303625,
401
+ "learning_rate": 1.9939229338092584e-05,
402
+ "loss": 0.3861,
403
+ "step": 275
404
+ },
405
+ {
406
+ "epoch": 0.08492569002123142,
407
+ "grad_norm": 1.1992281733632604,
408
+ "learning_rate": 1.99335854354354e-05,
409
+ "loss": 0.3317,
410
+ "step": 280
411
+ },
412
+ {
413
+ "epoch": 0.08644222020018198,
414
+ "grad_norm": 1.3404254075615103,
415
+ "learning_rate": 1.9927691829590903e-05,
416
+ "loss": 0.4457,
417
+ "step": 285
418
+ },
419
+ {
420
+ "epoch": 0.08795875037913255,
421
+ "grad_norm": 1.3258249809215283,
422
+ "learning_rate": 1.992154866870824e-05,
423
+ "loss": 0.4478,
424
+ "step": 290
425
+ },
426
+ {
427
+ "epoch": 0.08947528055808311,
428
+ "grad_norm": 1.215305619652228,
429
+ "learning_rate": 1.9915156107209673e-05,
430
+ "loss": 0.4032,
431
+ "step": 295
432
+ },
433
+ {
434
+ "epoch": 0.09099181073703366,
435
+ "grad_norm": 1.1528812652880362,
436
+ "learning_rate": 1.9908514305786733e-05,
437
+ "loss": 0.3796,
438
+ "step": 300
439
+ },
440
+ {
441
+ "epoch": 0.09250834091598423,
442
+ "grad_norm": 1.1857942136064035,
443
+ "learning_rate": 1.990162343139616e-05,
444
+ "loss": 0.435,
445
+ "step": 305
446
+ },
447
+ {
448
+ "epoch": 0.09402487109493479,
449
+ "grad_norm": 1.220916675694686,
450
+ "learning_rate": 1.989448365725569e-05,
451
+ "loss": 0.3798,
452
+ "step": 310
453
+ },
454
+ {
455
+ "epoch": 0.09554140127388536,
456
+ "grad_norm": 0.985856274220688,
457
+ "learning_rate": 1.988709516283974e-05,
458
+ "loss": 0.3925,
459
+ "step": 315
460
+ },
461
+ {
462
+ "epoch": 0.09705793145283591,
463
+ "grad_norm": 1.0859800217625073,
464
+ "learning_rate": 1.987945813387486e-05,
465
+ "loss": 0.4437,
466
+ "step": 320
467
+ },
468
+ {
469
+ "epoch": 0.09857446163178647,
470
+ "grad_norm": 1.0765506678345211,
471
+ "learning_rate": 1.9871572762335085e-05,
472
+ "loss": 0.3847,
473
+ "step": 325
474
+ },
475
+ {
476
+ "epoch": 0.10009099181073704,
477
+ "grad_norm": 1.1464592931749376,
478
+ "learning_rate": 1.9863439246437108e-05,
479
+ "loss": 0.4129,
480
+ "step": 330
481
+ },
482
+ {
483
+ "epoch": 0.10160752198968759,
484
+ "grad_norm": 1.077561921028752,
485
+ "learning_rate": 1.985505779063528e-05,
486
+ "loss": 0.3808,
487
+ "step": 335
488
+ },
489
+ {
490
+ "epoch": 0.10312405216863815,
491
+ "grad_norm": 0.8205386721296154,
492
+ "learning_rate": 1.98464286056165e-05,
493
+ "loss": 0.3598,
494
+ "step": 340
495
+ },
496
+ {
497
+ "epoch": 0.10464058234758872,
498
+ "grad_norm": 1.1795375866850706,
499
+ "learning_rate": 1.9837551908294887e-05,
500
+ "loss": 0.3986,
501
+ "step": 345
502
+ },
503
+ {
504
+ "epoch": 0.10615711252653928,
505
+ "grad_norm": 1.2649777397213469,
506
+ "learning_rate": 1.9828427921806358e-05,
507
+ "loss": 0.4592,
508
+ "step": 350
509
+ },
510
+ {
511
+ "epoch": 0.10767364270548983,
512
+ "grad_norm": 1.2819761889608279,
513
+ "learning_rate": 1.9819056875502986e-05,
514
+ "loss": 0.423,
515
+ "step": 355
516
+ },
517
+ {
518
+ "epoch": 0.1091901728844404,
519
+ "grad_norm": 1.0892900872427516,
520
+ "learning_rate": 1.980943900494727e-05,
521
+ "loss": 0.4253,
522
+ "step": 360
523
+ },
524
+ {
525
+ "epoch": 0.11070670306339096,
526
+ "grad_norm": 1.0090566077966177,
527
+ "learning_rate": 1.979957455190618e-05,
528
+ "loss": 0.41,
529
+ "step": 365
530
+ },
531
+ {
532
+ "epoch": 0.11222323324234153,
533
+ "grad_norm": 1.1111913464704366,
534
+ "learning_rate": 1.9789463764345113e-05,
535
+ "loss": 0.4219,
536
+ "step": 370
537
+ },
538
+ {
539
+ "epoch": 0.11373976342129208,
540
+ "grad_norm": 0.9666799079863133,
541
+ "learning_rate": 1.9779106896421627e-05,
542
+ "loss": 0.4593,
543
+ "step": 375
544
+ },
545
+ {
546
+ "epoch": 0.11525629360024264,
547
+ "grad_norm": 1.2048376185649605,
548
+ "learning_rate": 1.9768504208479077e-05,
549
+ "loss": 0.4312,
550
+ "step": 380
551
+ },
552
+ {
553
+ "epoch": 0.11677282377919321,
554
+ "grad_norm": 1.3164491474278361,
555
+ "learning_rate": 1.975765596704006e-05,
556
+ "loss": 0.3745,
557
+ "step": 385
558
+ },
559
+ {
560
+ "epoch": 0.11828935395814377,
561
+ "grad_norm": 0.8809824695823545,
562
+ "learning_rate": 1.9746562444799712e-05,
563
+ "loss": 0.397,
564
+ "step": 390
565
+ },
566
+ {
567
+ "epoch": 0.11980588413709432,
568
+ "grad_norm": 1.270641853983644,
569
+ "learning_rate": 1.9735223920618857e-05,
570
+ "loss": 0.4513,
571
+ "step": 395
572
+ },
573
+ {
574
+ "epoch": 0.12132241431604489,
575
+ "grad_norm": 1.291191652244439,
576
+ "learning_rate": 1.9723640679517015e-05,
577
+ "loss": 0.4401,
578
+ "step": 400
579
+ },
580
+ {
581
+ "epoch": 0.12283894449499545,
582
+ "grad_norm": 1.0763253903156915,
583
+ "learning_rate": 1.9711813012665198e-05,
584
+ "loss": 0.4089,
585
+ "step": 405
586
+ },
587
+ {
588
+ "epoch": 0.12435547467394602,
589
+ "grad_norm": 1.0687463364798888,
590
+ "learning_rate": 1.9699741217378625e-05,
591
+ "loss": 0.424,
592
+ "step": 410
593
+ },
594
+ {
595
+ "epoch": 0.12587200485289657,
596
+ "grad_norm": 1.038145624225704,
597
+ "learning_rate": 1.9687425597109238e-05,
598
+ "loss": 0.3912,
599
+ "step": 415
600
+ },
601
+ {
602
+ "epoch": 0.12738853503184713,
603
+ "grad_norm": 0.9636211495905858,
604
+ "learning_rate": 1.9674866461438065e-05,
605
+ "loss": 0.3856,
606
+ "step": 420
607
+ },
608
+ {
609
+ "epoch": 0.1289050652107977,
610
+ "grad_norm": 1.2745603021068876,
611
+ "learning_rate": 1.966206412606745e-05,
612
+ "loss": 0.4363,
613
+ "step": 425
614
+ },
615
+ {
616
+ "epoch": 0.13042159538974826,
617
+ "grad_norm": 1.0775019059829327,
618
+ "learning_rate": 1.964901891281312e-05,
619
+ "loss": 0.3764,
620
+ "step": 430
621
+ },
622
+ {
623
+ "epoch": 0.13193812556869883,
624
+ "grad_norm": 0.9079407408157604,
625
+ "learning_rate": 1.9635731149596075e-05,
626
+ "loss": 0.4002,
627
+ "step": 435
628
+ },
629
+ {
630
+ "epoch": 0.13345465574764936,
631
+ "grad_norm": 1.0314514042029734,
632
+ "learning_rate": 1.962220117043436e-05,
633
+ "loss": 0.3707,
634
+ "step": 440
635
+ },
636
+ {
637
+ "epoch": 0.13497118592659993,
638
+ "grad_norm": 1.2219058440800525,
639
+ "learning_rate": 1.9608429315434683e-05,
640
+ "loss": 0.4086,
641
+ "step": 445
642
+ },
643
+ {
644
+ "epoch": 0.1364877161055505,
645
+ "grad_norm": 1.1625549762597092,
646
+ "learning_rate": 1.959441593078383e-05,
647
+ "loss": 0.3905,
648
+ "step": 450
649
+ },
650
+ {
651
+ "epoch": 0.13800424628450106,
652
+ "grad_norm": 1.1998854874780447,
653
+ "learning_rate": 1.9580161368739984e-05,
654
+ "loss": 0.3846,
655
+ "step": 455
656
+ },
657
+ {
658
+ "epoch": 0.13952077646345162,
659
+ "grad_norm": 1.0469274329601646,
660
+ "learning_rate": 1.956566598762388e-05,
661
+ "loss": 0.4229,
662
+ "step": 460
663
+ },
664
+ {
665
+ "epoch": 0.1410373066424022,
666
+ "grad_norm": 1.0010991478406723,
667
+ "learning_rate": 1.955093015180979e-05,
668
+ "loss": 0.3698,
669
+ "step": 465
670
+ },
671
+ {
672
+ "epoch": 0.14255383682135275,
673
+ "grad_norm": 1.315626156939816,
674
+ "learning_rate": 1.9535954231716334e-05,
675
+ "loss": 0.3568,
676
+ "step": 470
677
+ },
678
+ {
679
+ "epoch": 0.14407036700030332,
680
+ "grad_norm": 0.9897859436282127,
681
+ "learning_rate": 1.952073860379722e-05,
682
+ "loss": 0.3947,
683
+ "step": 475
684
+ },
685
+ {
686
+ "epoch": 0.14558689717925385,
687
+ "grad_norm": 1.0830619205943097,
688
+ "learning_rate": 1.950528365053174e-05,
689
+ "loss": 0.4243,
690
+ "step": 480
691
+ },
692
+ {
693
+ "epoch": 0.14710342735820442,
694
+ "grad_norm": 1.0427173223016224,
695
+ "learning_rate": 1.9489589760415186e-05,
696
+ "loss": 0.392,
697
+ "step": 485
698
+ },
699
+ {
700
+ "epoch": 0.14861995753715498,
701
+ "grad_norm": 0.8720070105985259,
702
+ "learning_rate": 1.9473657327949055e-05,
703
+ "loss": 0.3858,
704
+ "step": 490
705
+ },
706
+ {
707
+ "epoch": 0.15013648771610555,
708
+ "grad_norm": 1.0930260154511207,
709
+ "learning_rate": 1.9457486753631152e-05,
710
+ "loss": 0.4229,
711
+ "step": 495
712
+ },
713
+ {
714
+ "epoch": 0.1516530178950561,
715
+ "grad_norm": 1.0011322130603957,
716
+ "learning_rate": 1.9441078443945525e-05,
717
+ "loss": 0.3689,
718
+ "step": 500
719
+ },
720
+ {
721
+ "epoch": 0.1516530178950561,
722
+ "eval_loss": 0.4088059365749359,
723
+ "eval_runtime": 173.5642,
724
+ "eval_samples_per_second": 51.48,
725
+ "eval_steps_per_second": 25.743,
726
+ "step": 500
727
+ },
728
+ {
729
+ "epoch": 0.15316954807400668,
730
+ "grad_norm": 0.9905645606543211,
731
+ "learning_rate": 1.9424432811352224e-05,
732
+ "loss": 0.412,
733
+ "step": 505
734
+ },
735
+ {
736
+ "epoch": 0.15468607825295724,
737
+ "grad_norm": 1.3173648440740526,
738
+ "learning_rate": 1.940755027427696e-05,
739
+ "loss": 0.4601,
740
+ "step": 510
741
+ },
742
+ {
743
+ "epoch": 0.1562026084319078,
744
+ "grad_norm": 0.992923300776095,
745
+ "learning_rate": 1.939043125710057e-05,
746
+ "loss": 0.4151,
747
+ "step": 515
748
+ },
749
+ {
750
+ "epoch": 0.15771913861085834,
751
+ "grad_norm": 1.0345346572748841,
752
+ "learning_rate": 1.937307619014836e-05,
753
+ "loss": 0.4082,
754
+ "step": 520
755
+ },
756
+ {
757
+ "epoch": 0.1592356687898089,
758
+ "grad_norm": 1.0277113400586657,
759
+ "learning_rate": 1.9355485509679274e-05,
760
+ "loss": 0.3652,
761
+ "step": 525
762
+ },
763
+ {
764
+ "epoch": 0.16075219896875947,
765
+ "grad_norm": 1.1428063392505712,
766
+ "learning_rate": 1.9337659657874943e-05,
767
+ "loss": 0.4294,
768
+ "step": 530
769
+ },
770
+ {
771
+ "epoch": 0.16226872914771004,
772
+ "grad_norm": 0.818605782882043,
773
+ "learning_rate": 1.9319599082828554e-05,
774
+ "loss": 0.4165,
775
+ "step": 535
776
+ },
777
+ {
778
+ "epoch": 0.1637852593266606,
779
+ "grad_norm": 1.2266082158969924,
780
+ "learning_rate": 1.9301304238533608e-05,
781
+ "loss": 0.4088,
782
+ "step": 540
783
+ },
784
+ {
785
+ "epoch": 0.16530178950561117,
786
+ "grad_norm": 1.040073592097893,
787
+ "learning_rate": 1.9282775584872485e-05,
788
+ "loss": 0.4384,
789
+ "step": 545
790
+ },
791
+ {
792
+ "epoch": 0.16681831968456173,
793
+ "grad_norm": 1.0997516079813272,
794
+ "learning_rate": 1.926401358760489e-05,
795
+ "loss": 0.457,
796
+ "step": 550
797
+ },
798
+ {
799
+ "epoch": 0.16833484986351227,
800
+ "grad_norm": 1.2258344173220719,
801
+ "learning_rate": 1.924501871835616e-05,
802
+ "loss": 0.3966,
803
+ "step": 555
804
+ },
805
+ {
806
+ "epoch": 0.16985138004246284,
807
+ "grad_norm": 0.981904205499286,
808
+ "learning_rate": 1.9225791454605392e-05,
809
+ "loss": 0.3813,
810
+ "step": 560
811
+ },
812
+ {
813
+ "epoch": 0.1713679102214134,
814
+ "grad_norm": 1.084662004004182,
815
+ "learning_rate": 1.9206332279673437e-05,
816
+ "loss": 0.4593,
817
+ "step": 565
818
+ },
819
+ {
820
+ "epoch": 0.17288444040036396,
821
+ "grad_norm": 1.0398606480794417,
822
+ "learning_rate": 1.9186641682710774e-05,
823
+ "loss": 0.4093,
824
+ "step": 570
825
+ },
826
+ {
827
+ "epoch": 0.17440097057931453,
828
+ "grad_norm": 1.1672550355106228,
829
+ "learning_rate": 1.9166720158685187e-05,
830
+ "loss": 0.4052,
831
+ "step": 575
832
+ },
833
+ {
834
+ "epoch": 0.1759175007582651,
835
+ "grad_norm": 1.0087878638040124,
836
+ "learning_rate": 1.9146568208369346e-05,
837
+ "loss": 0.366,
838
+ "step": 580
839
+ },
840
+ {
841
+ "epoch": 0.17743403093721566,
842
+ "grad_norm": 1.2216691544147262,
843
+ "learning_rate": 1.91261863383282e-05,
844
+ "loss": 0.451,
845
+ "step": 585
846
+ },
847
+ {
848
+ "epoch": 0.17895056111616622,
849
+ "grad_norm": 1.3584786997202387,
850
+ "learning_rate": 1.9105575060906254e-05,
851
+ "loss": 0.436,
852
+ "step": 590
853
+ },
854
+ {
855
+ "epoch": 0.18046709129511676,
856
+ "grad_norm": 0.9189515390724051,
857
+ "learning_rate": 1.908473489421468e-05,
858
+ "loss": 0.3632,
859
+ "step": 595
860
+ },
861
+ {
862
+ "epoch": 0.18198362147406733,
863
+ "grad_norm": 1.0321679875275023,
864
+ "learning_rate": 1.9063666362118324e-05,
865
+ "loss": 0.3914,
866
+ "step": 600
867
+ },
868
+ {
869
+ "epoch": 0.1835001516530179,
870
+ "grad_norm": 1.039032214603117,
871
+ "learning_rate": 1.9042369994222487e-05,
872
+ "loss": 0.3524,
873
+ "step": 605
874
+ },
875
+ {
876
+ "epoch": 0.18501668183196845,
877
+ "grad_norm": 0.9559021915379782,
878
+ "learning_rate": 1.902084632585965e-05,
879
+ "loss": 0.436,
880
+ "step": 610
881
+ },
882
+ {
883
+ "epoch": 0.18653321201091902,
884
+ "grad_norm": 1.0082278999199734,
885
+ "learning_rate": 1.8999095898076012e-05,
886
+ "loss": 0.3322,
887
+ "step": 615
888
+ },
889
+ {
890
+ "epoch": 0.18804974218986958,
891
+ "grad_norm": 1.1936896778472303,
892
+ "learning_rate": 1.8977119257617878e-05,
893
+ "loss": 0.3529,
894
+ "step": 620
895
+ },
896
+ {
897
+ "epoch": 0.18956627236882015,
898
+ "grad_norm": 0.9718605188148373,
899
+ "learning_rate": 1.8954916956917922e-05,
900
+ "loss": 0.3966,
901
+ "step": 625
902
+ },
903
+ {
904
+ "epoch": 0.1910828025477707,
905
+ "grad_norm": 1.0441882225736692,
906
+ "learning_rate": 1.8932489554081295e-05,
907
+ "loss": 0.3683,
908
+ "step": 630
909
+ },
910
+ {
911
+ "epoch": 0.19259933272672125,
912
+ "grad_norm": 1.0302242941912132,
913
+ "learning_rate": 1.8909837612871615e-05,
914
+ "loss": 0.4015,
915
+ "step": 635
916
+ },
917
+ {
918
+ "epoch": 0.19411586290567182,
919
+ "grad_norm": 1.0668469700274468,
920
+ "learning_rate": 1.8886961702696765e-05,
921
+ "loss": 0.3682,
922
+ "step": 640
923
+ },
924
+ {
925
+ "epoch": 0.19563239308462238,
926
+ "grad_norm": 0.9494238968317048,
927
+ "learning_rate": 1.8863862398594606e-05,
928
+ "loss": 0.4081,
929
+ "step": 645
930
+ },
931
+ {
932
+ "epoch": 0.19714892326357294,
933
+ "grad_norm": 1.1243612355939872,
934
+ "learning_rate": 1.8840540281218506e-05,
935
+ "loss": 0.3883,
936
+ "step": 650
937
+ },
938
+ {
939
+ "epoch": 0.1986654534425235,
940
+ "grad_norm": 1.075535429789229,
941
+ "learning_rate": 1.881699593682275e-05,
942
+ "loss": 0.3991,
943
+ "step": 655
944
+ },
945
+ {
946
+ "epoch": 0.20018198362147407,
947
+ "grad_norm": 1.1424504644700166,
948
+ "learning_rate": 1.8793229957247808e-05,
949
+ "loss": 0.4006,
950
+ "step": 660
951
+ },
952
+ {
953
+ "epoch": 0.20169851380042464,
954
+ "grad_norm": 1.20832013347783,
955
+ "learning_rate": 1.8769242939905446e-05,
956
+ "loss": 0.3781,
957
+ "step": 665
958
+ },
959
+ {
960
+ "epoch": 0.20321504397937518,
961
+ "grad_norm": 0.9949915408461836,
962
+ "learning_rate": 1.874503548776372e-05,
963
+ "loss": 0.4073,
964
+ "step": 670
965
+ },
966
+ {
967
+ "epoch": 0.20473157415832574,
968
+ "grad_norm": 1.0184539078882147,
969
+ "learning_rate": 1.8720608209331813e-05,
970
+ "loss": 0.404,
971
+ "step": 675
972
+ },
973
+ {
974
+ "epoch": 0.2062481043372763,
975
+ "grad_norm": 1.0772016502245503,
976
+ "learning_rate": 1.8695961718644743e-05,
977
+ "loss": 0.4306,
978
+ "step": 680
979
+ },
980
+ {
981
+ "epoch": 0.20776463451622687,
982
+ "grad_norm": 1.0757591761900838,
983
+ "learning_rate": 1.8671096635247914e-05,
984
+ "loss": 0.3891,
985
+ "step": 685
986
+ },
987
+ {
988
+ "epoch": 0.20928116469517744,
989
+ "grad_norm": 1.134500144137449,
990
+ "learning_rate": 1.864601358418157e-05,
991
+ "loss": 0.3858,
992
+ "step": 690
993
+ },
994
+ {
995
+ "epoch": 0.210797694874128,
996
+ "grad_norm": 1.3231464100942203,
997
+ "learning_rate": 1.8620713195965052e-05,
998
+ "loss": 0.3896,
999
+ "step": 695
1000
+ },
1001
+ {
1002
+ "epoch": 0.21231422505307856,
1003
+ "grad_norm": 0.9378926668917021,
1004
+ "learning_rate": 1.8595196106580973e-05,
1005
+ "loss": 0.372,
1006
+ "step": 700
1007
+ },
1008
+ {
1009
+ "epoch": 0.21383075523202913,
1010
+ "grad_norm": 1.156227424640605,
1011
+ "learning_rate": 1.856946295745921e-05,
1012
+ "loss": 0.4282,
1013
+ "step": 705
1014
+ },
1015
+ {
1016
+ "epoch": 0.21534728541097967,
1017
+ "grad_norm": 1.2036566060604865,
1018
+ "learning_rate": 1.8543514395460806e-05,
1019
+ "loss": 0.4055,
1020
+ "step": 710
1021
+ },
1022
+ {
1023
+ "epoch": 0.21686381558993023,
1024
+ "grad_norm": 1.0835648170732943,
1025
+ "learning_rate": 1.8517351072861682e-05,
1026
+ "loss": 0.3864,
1027
+ "step": 715
1028
+ },
1029
+ {
1030
+ "epoch": 0.2183803457688808,
1031
+ "grad_norm": 1.0077498655883261,
1032
+ "learning_rate": 1.8490973647336255e-05,
1033
+ "loss": 0.4268,
1034
+ "step": 720
1035
+ },
1036
+ {
1037
+ "epoch": 0.21989687594783136,
1038
+ "grad_norm": 0.9546564688055839,
1039
+ "learning_rate": 1.8464382781940918e-05,
1040
+ "loss": 0.3632,
1041
+ "step": 725
1042
+ },
1043
+ {
1044
+ "epoch": 0.22141340612678193,
1045
+ "grad_norm": 0.8840597627792552,
1046
+ "learning_rate": 1.843757914509734e-05,
1047
+ "loss": 0.38,
1048
+ "step": 730
1049
+ },
1050
+ {
1051
+ "epoch": 0.2229299363057325,
1052
+ "grad_norm": 1.158885040155564,
1053
+ "learning_rate": 1.8410563410575696e-05,
1054
+ "loss": 0.3923,
1055
+ "step": 735
1056
+ },
1057
+ {
1058
+ "epoch": 0.22444646648468305,
1059
+ "grad_norm": 0.9247547809922593,
1060
+ "learning_rate": 1.838333625747771e-05,
1061
+ "loss": 0.3865,
1062
+ "step": 740
1063
+ },
1064
+ {
1065
+ "epoch": 0.22596299666363362,
1066
+ "grad_norm": 1.0016406439137249,
1067
+ "learning_rate": 1.835589837021959e-05,
1068
+ "loss": 0.3776,
1069
+ "step": 745
1070
+ },
1071
+ {
1072
+ "epoch": 0.22747952684258416,
1073
+ "grad_norm": 1.043626508615065,
1074
+ "learning_rate": 1.8328250438514837e-05,
1075
+ "loss": 0.3608,
1076
+ "step": 750
1077
+ },
1078
+ {
1079
+ "epoch": 0.22899605702153472,
1080
+ "grad_norm": 0.9577907090027657,
1081
+ "learning_rate": 1.830039315735688e-05,
1082
+ "loss": 0.3894,
1083
+ "step": 755
1084
+ },
1085
+ {
1086
+ "epoch": 0.23051258720048529,
1087
+ "grad_norm": 0.9580721684764322,
1088
+ "learning_rate": 1.827232722700163e-05,
1089
+ "loss": 0.3614,
1090
+ "step": 760
1091
+ },
1092
+ {
1093
+ "epoch": 0.23202911737943585,
1094
+ "grad_norm": 0.9477520019373953,
1095
+ "learning_rate": 1.8244053352949866e-05,
1096
+ "loss": 0.3582,
1097
+ "step": 765
1098
+ },
1099
+ {
1100
+ "epoch": 0.23354564755838642,
1101
+ "grad_norm": 0.9783981899125124,
1102
+ "learning_rate": 1.82155722459295e-05,
1103
+ "loss": 0.3686,
1104
+ "step": 770
1105
+ },
1106
+ {
1107
+ "epoch": 0.23506217773733698,
1108
+ "grad_norm": 1.2531749551765516,
1109
+ "learning_rate": 1.8186884621877726e-05,
1110
+ "loss": 0.435,
1111
+ "step": 775
1112
+ },
1113
+ {
1114
+ "epoch": 0.23657870791628755,
1115
+ "grad_norm": 1.0378676377391638,
1116
+ "learning_rate": 1.815799120192299e-05,
1117
+ "loss": 0.3919,
1118
+ "step": 780
1119
+ },
1120
+ {
1121
+ "epoch": 0.23809523809523808,
1122
+ "grad_norm": 0.9221216623517181,
1123
+ "learning_rate": 1.8128892712366916e-05,
1124
+ "loss": 0.3703,
1125
+ "step": 785
1126
+ },
1127
+ {
1128
+ "epoch": 0.23961176827418865,
1129
+ "grad_norm": 0.9471155975175013,
1130
+ "learning_rate": 1.8099589884665986e-05,
1131
+ "loss": 0.4248,
1132
+ "step": 790
1133
+ },
1134
+ {
1135
+ "epoch": 0.2411282984531392,
1136
+ "grad_norm": 1.0217123358982474,
1137
+ "learning_rate": 1.80700834554132e-05,
1138
+ "loss": 0.3706,
1139
+ "step": 795
1140
+ },
1141
+ {
1142
+ "epoch": 0.24264482863208978,
1143
+ "grad_norm": 1.1240824857344656,
1144
+ "learning_rate": 1.804037416631954e-05,
1145
+ "loss": 0.381,
1146
+ "step": 800
1147
+ },
1148
+ {
1149
+ "epoch": 0.24416135881104034,
1150
+ "grad_norm": 0.9053272736194982,
1151
+ "learning_rate": 1.801046276419534e-05,
1152
+ "loss": 0.3612,
1153
+ "step": 805
1154
+ },
1155
+ {
1156
+ "epoch": 0.2456778889899909,
1157
+ "grad_norm": 1.1627818653327249,
1158
+ "learning_rate": 1.7980350000931494e-05,
1159
+ "loss": 0.3961,
1160
+ "step": 810
1161
+ },
1162
+ {
1163
+ "epoch": 0.24719441916894147,
1164
+ "grad_norm": 0.9805469968339321,
1165
+ "learning_rate": 1.7950036633480557e-05,
1166
+ "loss": 0.3902,
1167
+ "step": 815
1168
+ },
1169
+ {
1170
+ "epoch": 0.24871094934789204,
1171
+ "grad_norm": 1.1280037352121748,
1172
+ "learning_rate": 1.7919523423837743e-05,
1173
+ "loss": 0.3845,
1174
+ "step": 820
1175
+ },
1176
+ {
1177
+ "epoch": 0.2502274795268426,
1178
+ "grad_norm": 0.9614220888721607,
1179
+ "learning_rate": 1.788881113902174e-05,
1180
+ "loss": 0.3896,
1181
+ "step": 825
1182
+ },
1183
+ {
1184
+ "epoch": 0.25174400970579314,
1185
+ "grad_norm": 0.9351755310616024,
1186
+ "learning_rate": 1.7857900551055448e-05,
1187
+ "loss": 0.39,
1188
+ "step": 830
1189
+ },
1190
+ {
1191
+ "epoch": 0.25326053988474373,
1192
+ "grad_norm": 1.1017280153749585,
1193
+ "learning_rate": 1.7826792436946562e-05,
1194
+ "loss": 0.377,
1195
+ "step": 835
1196
+ },
1197
+ {
1198
+ "epoch": 0.25477707006369427,
1199
+ "grad_norm": 0.9933284106135498,
1200
+ "learning_rate": 1.779548757866804e-05,
1201
+ "loss": 0.4235,
1202
+ "step": 840
1203
+ },
1204
+ {
1205
+ "epoch": 0.2562936002426448,
1206
+ "grad_norm": 1.0616915093996913,
1207
+ "learning_rate": 1.7763986763138467e-05,
1208
+ "loss": 0.3812,
1209
+ "step": 845
1210
+ },
1211
+ {
1212
+ "epoch": 0.2578101304215954,
1213
+ "grad_norm": 1.1439917497224341,
1214
+ "learning_rate": 1.7732290782202244e-05,
1215
+ "loss": 0.4254,
1216
+ "step": 850
1217
+ },
1218
+ {
1219
+ "epoch": 0.25932666060054593,
1220
+ "grad_norm": 0.8582079212321349,
1221
+ "learning_rate": 1.7700400432609695e-05,
1222
+ "loss": 0.3699,
1223
+ "step": 855
1224
+ },
1225
+ {
1226
+ "epoch": 0.2608431907794965,
1227
+ "grad_norm": 1.2134674630319888,
1228
+ "learning_rate": 1.7668316515997047e-05,
1229
+ "loss": 0.4338,
1230
+ "step": 860
1231
+ },
1232
+ {
1233
+ "epoch": 0.26235972095844706,
1234
+ "grad_norm": 1.163884571927593,
1235
+ "learning_rate": 1.7636039838866278e-05,
1236
+ "loss": 0.3866,
1237
+ "step": 865
1238
+ },
1239
+ {
1240
+ "epoch": 0.26387625113739765,
1241
+ "grad_norm": 1.0199336931679066,
1242
+ "learning_rate": 1.760357121256482e-05,
1243
+ "loss": 0.4029,
1244
+ "step": 870
1245
+ },
1246
+ {
1247
+ "epoch": 0.2653927813163482,
1248
+ "grad_norm": 0.946151313424608,
1249
+ "learning_rate": 1.757091145326521e-05,
1250
+ "loss": 0.3595,
1251
+ "step": 875
1252
+ },
1253
+ {
1254
+ "epoch": 0.26690931149529873,
1255
+ "grad_norm": 1.1932767975805696,
1256
+ "learning_rate": 1.7538061381944524e-05,
1257
+ "loss": 0.43,
1258
+ "step": 880
1259
+ },
1260
+ {
1261
+ "epoch": 0.2684258416742493,
1262
+ "grad_norm": 0.9649699002383715,
1263
+ "learning_rate": 1.7505021824363767e-05,
1264
+ "loss": 0.3986,
1265
+ "step": 885
1266
+ },
1267
+ {
1268
+ "epoch": 0.26994237185319986,
1269
+ "grad_norm": 0.9468284901526932,
1270
+ "learning_rate": 1.7471793611047114e-05,
1271
+ "loss": 0.3551,
1272
+ "step": 890
1273
+ },
1274
+ {
1275
+ "epoch": 0.27145890203215045,
1276
+ "grad_norm": 1.026534302980973,
1277
+ "learning_rate": 1.743837757726103e-05,
1278
+ "loss": 0.3821,
1279
+ "step": 895
1280
+ },
1281
+ {
1282
+ "epoch": 0.272975432211101,
1283
+ "grad_norm": 0.9999149595040162,
1284
+ "learning_rate": 1.7404774562993268e-05,
1285
+ "loss": 0.3755,
1286
+ "step": 900
1287
+ },
1288
+ {
1289
+ "epoch": 0.2744919623900516,
1290
+ "grad_norm": 0.9887626409688292,
1291
+ "learning_rate": 1.7370985412931766e-05,
1292
+ "loss": 0.3929,
1293
+ "step": 905
1294
+ },
1295
+ {
1296
+ "epoch": 0.2760084925690021,
1297
+ "grad_norm": 1.1520336639240936,
1298
+ "learning_rate": 1.7337010976443404e-05,
1299
+ "loss": 0.416,
1300
+ "step": 910
1301
+ },
1302
+ {
1303
+ "epoch": 0.2775250227479527,
1304
+ "grad_norm": 1.2647475168427802,
1305
+ "learning_rate": 1.730285210755265e-05,
1306
+ "loss": 0.4117,
1307
+ "step": 915
1308
+ },
1309
+ {
1310
+ "epoch": 0.27904155292690325,
1311
+ "grad_norm": 1.1162345721181772,
1312
+ "learning_rate": 1.7268509664920115e-05,
1313
+ "loss": 0.4133,
1314
+ "step": 920
1315
+ },
1316
+ {
1317
+ "epoch": 0.2805580831058538,
1318
+ "grad_norm": 1.1083917949237179,
1319
+ "learning_rate": 1.7233984511820937e-05,
1320
+ "loss": 0.3991,
1321
+ "step": 925
1322
+ },
1323
+ {
1324
+ "epoch": 0.2820746132848044,
1325
+ "grad_norm": 1.1415307366312324,
1326
+ "learning_rate": 1.7199277516123098e-05,
1327
+ "loss": 0.4018,
1328
+ "step": 930
1329
+ },
1330
+ {
1331
+ "epoch": 0.2835911434637549,
1332
+ "grad_norm": 0.9966551314508225,
1333
+ "learning_rate": 1.7164389550265607e-05,
1334
+ "loss": 0.3764,
1335
+ "step": 935
1336
+ },
1337
+ {
1338
+ "epoch": 0.2851076736427055,
1339
+ "grad_norm": 1.1395841818497014,
1340
+ "learning_rate": 1.7129321491236578e-05,
1341
+ "loss": 0.4094,
1342
+ "step": 940
1343
+ },
1344
+ {
1345
+ "epoch": 0.28662420382165604,
1346
+ "grad_norm": 0.9290839897008805,
1347
+ "learning_rate": 1.709407422055116e-05,
1348
+ "loss": 0.3757,
1349
+ "step": 945
1350
+ },
1351
+ {
1352
+ "epoch": 0.28814073400060664,
1353
+ "grad_norm": 0.9613783627220345,
1354
+ "learning_rate": 1.70586486242294e-05,
1355
+ "loss": 0.4067,
1356
+ "step": 950
1357
+ },
1358
+ {
1359
+ "epoch": 0.2896572641795572,
1360
+ "grad_norm": 0.9710982407207631,
1361
+ "learning_rate": 1.7023045592773968e-05,
1362
+ "loss": 0.3521,
1363
+ "step": 955
1364
+ },
1365
+ {
1366
+ "epoch": 0.2911737943585077,
1367
+ "grad_norm": 1.095728541816799,
1368
+ "learning_rate": 1.6987266021147763e-05,
1369
+ "loss": 0.3749,
1370
+ "step": 960
1371
+ },
1372
+ {
1373
+ "epoch": 0.2926903245374583,
1374
+ "grad_norm": 1.0605469280918307,
1375
+ "learning_rate": 1.695131080875142e-05,
1376
+ "loss": 0.3562,
1377
+ "step": 965
1378
+ },
1379
+ {
1380
+ "epoch": 0.29420685471640884,
1381
+ "grad_norm": 0.8408891042454621,
1382
+ "learning_rate": 1.691518085940071e-05,
1383
+ "loss": 0.3377,
1384
+ "step": 970
1385
+ },
1386
+ {
1387
+ "epoch": 0.29572338489535943,
1388
+ "grad_norm": 0.9467466462817451,
1389
+ "learning_rate": 1.6878877081303805e-05,
1390
+ "loss": 0.3498,
1391
+ "step": 975
1392
+ },
1393
+ {
1394
+ "epoch": 0.29723991507430997,
1395
+ "grad_norm": 1.1449510440698407,
1396
+ "learning_rate": 1.6842400387038464e-05,
1397
+ "loss": 0.3524,
1398
+ "step": 980
1399
+ },
1400
+ {
1401
+ "epoch": 0.29875644525326056,
1402
+ "grad_norm": 0.948414328408764,
1403
+ "learning_rate": 1.6805751693529083e-05,
1404
+ "loss": 0.3784,
1405
+ "step": 985
1406
+ },
1407
+ {
1408
+ "epoch": 0.3002729754322111,
1409
+ "grad_norm": 1.0675528322798367,
1410
+ "learning_rate": 1.676893192202364e-05,
1411
+ "loss": 0.3182,
1412
+ "step": 990
1413
+ },
1414
+ {
1415
+ "epoch": 0.30178950561116163,
1416
+ "grad_norm": 1.129087718059458,
1417
+ "learning_rate": 1.673194199807057e-05,
1418
+ "loss": 0.3825,
1419
+ "step": 995
1420
+ },
1421
+ {
1422
+ "epoch": 0.3033060357901122,
1423
+ "grad_norm": 1.0937265212418354,
1424
+ "learning_rate": 1.6694782851495444e-05,
1425
+ "loss": 0.3919,
1426
+ "step": 1000
1427
+ },
1428
+ {
1429
+ "epoch": 0.3033060357901122,
1430
+ "eval_loss": 0.3952370584011078,
1431
+ "eval_runtime": 181.3096,
1432
+ "eval_samples_per_second": 49.28,
1433
+ "eval_steps_per_second": 24.643,
1434
+ "step": 1000
1435
+ },
1436
+ {
1437
+ "epoch": 0.30482256596906276,
1438
+ "grad_norm": 1.1557787421563186,
1439
+ "learning_rate": 1.6657455416377654e-05,
1440
+ "loss": 0.3941,
1441
+ "step": 1005
1442
+ },
1443
+ {
1444
+ "epoch": 0.30633909614801336,
1445
+ "grad_norm": 1.1101059403333342,
1446
+ "learning_rate": 1.661996063102689e-05,
1447
+ "loss": 0.4257,
1448
+ "step": 1010
1449
+ },
1450
+ {
1451
+ "epoch": 0.3078556263269639,
1452
+ "grad_norm": 0.8880119467039661,
1453
+ "learning_rate": 1.6582299437959577e-05,
1454
+ "loss": 0.4088,
1455
+ "step": 1015
1456
+ },
1457
+ {
1458
+ "epoch": 0.3093721565059145,
1459
+ "grad_norm": 0.9624888079810805,
1460
+ "learning_rate": 1.6544472783875173e-05,
1461
+ "loss": 0.3875,
1462
+ "step": 1020
1463
+ },
1464
+ {
1465
+ "epoch": 0.310888686684865,
1466
+ "grad_norm": 0.9935490083526604,
1467
+ "learning_rate": 1.650648161963237e-05,
1468
+ "loss": 0.3899,
1469
+ "step": 1025
1470
+ },
1471
+ {
1472
+ "epoch": 0.3124052168638156,
1473
+ "grad_norm": 1.1083315990997267,
1474
+ "learning_rate": 1.6468326900225204e-05,
1475
+ "loss": 0.3855,
1476
+ "step": 1030
1477
+ },
1478
+ {
1479
+ "epoch": 0.31392174704276615,
1480
+ "grad_norm": 0.9176932289492212,
1481
+ "learning_rate": 1.6430009584759036e-05,
1482
+ "loss": 0.3805,
1483
+ "step": 1035
1484
+ },
1485
+ {
1486
+ "epoch": 0.3154382772217167,
1487
+ "grad_norm": 1.159934652109041,
1488
+ "learning_rate": 1.6391530636426447e-05,
1489
+ "loss": 0.4393,
1490
+ "step": 1040
1491
+ },
1492
+ {
1493
+ "epoch": 0.3169548074006673,
1494
+ "grad_norm": 0.9549890829191309,
1495
+ "learning_rate": 1.6352891022483025e-05,
1496
+ "loss": 0.3853,
1497
+ "step": 1045
1498
+ },
1499
+ {
1500
+ "epoch": 0.3184713375796178,
1501
+ "grad_norm": 0.8753544532645915,
1502
+ "learning_rate": 1.631409171422306e-05,
1503
+ "loss": 0.3783,
1504
+ "step": 1050
1505
+ },
1506
+ {
1507
+ "epoch": 0.3199878677585684,
1508
+ "grad_norm": 1.145468835413513,
1509
+ "learning_rate": 1.6275133686955107e-05,
1510
+ "loss": 0.4251,
1511
+ "step": 1055
1512
+ },
1513
+ {
1514
+ "epoch": 0.32150439793751895,
1515
+ "grad_norm": 1.0318505639013171,
1516
+ "learning_rate": 1.6236017919977495e-05,
1517
+ "loss": 0.4074,
1518
+ "step": 1060
1519
+ },
1520
+ {
1521
+ "epoch": 0.32302092811646954,
1522
+ "grad_norm": 0.9665743048952797,
1523
+ "learning_rate": 1.61967453965537e-05,
1524
+ "loss": 0.3765,
1525
+ "step": 1065
1526
+ },
1527
+ {
1528
+ "epoch": 0.3245374582954201,
1529
+ "grad_norm": 0.9170197925688426,
1530
+ "learning_rate": 1.615731710388761e-05,
1531
+ "loss": 0.3538,
1532
+ "step": 1070
1533
+ },
1534
+ {
1535
+ "epoch": 0.3260539884743706,
1536
+ "grad_norm": 1.3931816470935723,
1537
+ "learning_rate": 1.6117734033098744e-05,
1538
+ "loss": 0.4235,
1539
+ "step": 1075
1540
+ },
1541
+ {
1542
+ "epoch": 0.3275705186533212,
1543
+ "grad_norm": 1.07985655044954,
1544
+ "learning_rate": 1.6077997179197314e-05,
1545
+ "loss": 0.4324,
1546
+ "step": 1080
1547
+ },
1548
+ {
1549
+ "epoch": 0.32908704883227174,
1550
+ "grad_norm": 0.8912377259192088,
1551
+ "learning_rate": 1.6038107541059216e-05,
1552
+ "loss": 0.3746,
1553
+ "step": 1085
1554
+ },
1555
+ {
1556
+ "epoch": 0.33060357901122234,
1557
+ "grad_norm": 0.9135076769501183,
1558
+ "learning_rate": 1.5998066121400925e-05,
1559
+ "loss": 0.3453,
1560
+ "step": 1090
1561
+ },
1562
+ {
1563
+ "epoch": 0.3321201091901729,
1564
+ "grad_norm": 1.0231218093240508,
1565
+ "learning_rate": 1.5957873926754294e-05,
1566
+ "loss": 0.3952,
1567
+ "step": 1095
1568
+ },
1569
+ {
1570
+ "epoch": 0.33363663936912347,
1571
+ "grad_norm": 1.128624243072491,
1572
+ "learning_rate": 1.5917531967441235e-05,
1573
+ "loss": 0.3548,
1574
+ "step": 1100
1575
+ },
1576
+ {
1577
+ "epoch": 0.335153169548074,
1578
+ "grad_norm": 0.8957025806927242,
1579
+ "learning_rate": 1.587704125754835e-05,
1580
+ "loss": 0.3723,
1581
+ "step": 1105
1582
+ },
1583
+ {
1584
+ "epoch": 0.33666969972702454,
1585
+ "grad_norm": 1.1269942168579994,
1586
+ "learning_rate": 1.583640281490141e-05,
1587
+ "loss": 0.4201,
1588
+ "step": 1110
1589
+ },
1590
+ {
1591
+ "epoch": 0.33818622990597513,
1592
+ "grad_norm": 1.133173361691646,
1593
+ "learning_rate": 1.5795617661039794e-05,
1594
+ "loss": 0.4232,
1595
+ "step": 1115
1596
+ },
1597
+ {
1598
+ "epoch": 0.33970276008492567,
1599
+ "grad_norm": 0.8611310832137484,
1600
+ "learning_rate": 1.5754686821190797e-05,
1601
+ "loss": 0.3969,
1602
+ "step": 1120
1603
+ },
1604
+ {
1605
+ "epoch": 0.34121929026387626,
1606
+ "grad_norm": 0.8942802074964451,
1607
+ "learning_rate": 1.5713611324243858e-05,
1608
+ "loss": 0.3453,
1609
+ "step": 1125
1610
+ },
1611
+ {
1612
+ "epoch": 0.3427358204428268,
1613
+ "grad_norm": 1.0128940099871961,
1614
+ "learning_rate": 1.5672392202724702e-05,
1615
+ "loss": 0.41,
1616
+ "step": 1130
1617
+ },
1618
+ {
1619
+ "epoch": 0.3442523506217774,
1620
+ "grad_norm": 0.8082512004172621,
1621
+ "learning_rate": 1.5631030492769385e-05,
1622
+ "loss": 0.3404,
1623
+ "step": 1135
1624
+ },
1625
+ {
1626
+ "epoch": 0.34576888080072793,
1627
+ "grad_norm": 0.9097763994809697,
1628
+ "learning_rate": 1.5589527234098247e-05,
1629
+ "loss": 0.3542,
1630
+ "step": 1140
1631
+ },
1632
+ {
1633
+ "epoch": 0.3472854109796785,
1634
+ "grad_norm": 0.8970233702221374,
1635
+ "learning_rate": 1.5547883469989767e-05,
1636
+ "loss": 0.3783,
1637
+ "step": 1145
1638
+ },
1639
+ {
1640
+ "epoch": 0.34880194115862906,
1641
+ "grad_norm": 1.0373858768564683,
1642
+ "learning_rate": 1.5506100247254363e-05,
1643
+ "loss": 0.3424,
1644
+ "step": 1150
1645
+ },
1646
+ {
1647
+ "epoch": 0.3503184713375796,
1648
+ "grad_norm": 0.9339798828268197,
1649
+ "learning_rate": 1.5464178616208046e-05,
1650
+ "loss": 0.4106,
1651
+ "step": 1155
1652
+ },
1653
+ {
1654
+ "epoch": 0.3518350015165302,
1655
+ "grad_norm": 0.9599549402427533,
1656
+ "learning_rate": 1.5422119630646043e-05,
1657
+ "loss": 0.35,
1658
+ "step": 1160
1659
+ },
1660
+ {
1661
+ "epoch": 0.3533515316954807,
1662
+ "grad_norm": 1.0567048479540737,
1663
+ "learning_rate": 1.5379924347816296e-05,
1664
+ "loss": 0.3996,
1665
+ "step": 1165
1666
+ },
1667
+ {
1668
+ "epoch": 0.3548680618744313,
1669
+ "grad_norm": 1.1009203123989402,
1670
+ "learning_rate": 1.533759382839288e-05,
1671
+ "loss": 0.4392,
1672
+ "step": 1170
1673
+ },
1674
+ {
1675
+ "epoch": 0.35638459205338185,
1676
+ "grad_norm": 0.9592441175249292,
1677
+ "learning_rate": 1.5295129136449362e-05,
1678
+ "loss": 0.3941,
1679
+ "step": 1175
1680
+ },
1681
+ {
1682
+ "epoch": 0.35790112223233245,
1683
+ "grad_norm": 0.974862722178181,
1684
+ "learning_rate": 1.5252531339432033e-05,
1685
+ "loss": 0.3848,
1686
+ "step": 1180
1687
+ },
1688
+ {
1689
+ "epoch": 0.359417652411283,
1690
+ "grad_norm": 0.800718600370757,
1691
+ "learning_rate": 1.5209801508133077e-05,
1692
+ "loss": 0.3603,
1693
+ "step": 1185
1694
+ },
1695
+ {
1696
+ "epoch": 0.3609341825902335,
1697
+ "grad_norm": 1.044944939187582,
1698
+ "learning_rate": 1.516694071666367e-05,
1699
+ "loss": 0.4112,
1700
+ "step": 1190
1701
+ },
1702
+ {
1703
+ "epoch": 0.3624507127691841,
1704
+ "grad_norm": 0.9861578107971376,
1705
+ "learning_rate": 1.5123950042426958e-05,
1706
+ "loss": 0.4303,
1707
+ "step": 1195
1708
+ },
1709
+ {
1710
+ "epoch": 0.36396724294813465,
1711
+ "grad_norm": 0.9236514908326175,
1712
+ "learning_rate": 1.5080830566090986e-05,
1713
+ "loss": 0.3705,
1714
+ "step": 1200
1715
+ },
1716
+ {
1717
+ "epoch": 0.36548377312708524,
1718
+ "grad_norm": 1.0547458629444628,
1719
+ "learning_rate": 1.5037583371561538e-05,
1720
+ "loss": 0.3558,
1721
+ "step": 1205
1722
+ },
1723
+ {
1724
+ "epoch": 0.3670003033060358,
1725
+ "grad_norm": 1.0149155289164045,
1726
+ "learning_rate": 1.4994209545954884e-05,
1727
+ "loss": 0.3455,
1728
+ "step": 1210
1729
+ },
1730
+ {
1731
+ "epoch": 0.3685168334849864,
1732
+ "grad_norm": 0.975465418286225,
1733
+ "learning_rate": 1.4950710179570442e-05,
1734
+ "loss": 0.3852,
1735
+ "step": 1215
1736
+ },
1737
+ {
1738
+ "epoch": 0.3700333636639369,
1739
+ "grad_norm": 1.0918284001664949,
1740
+ "learning_rate": 1.49070863658634e-05,
1741
+ "loss": 0.4167,
1742
+ "step": 1220
1743
+ },
1744
+ {
1745
+ "epoch": 0.37154989384288745,
1746
+ "grad_norm": 0.963559934243437,
1747
+ "learning_rate": 1.4863339201417195e-05,
1748
+ "loss": 0.4196,
1749
+ "step": 1225
1750
+ },
1751
+ {
1752
+ "epoch": 0.37306642402183804,
1753
+ "grad_norm": 0.8857039743407772,
1754
+ "learning_rate": 1.4819469785915972e-05,
1755
+ "loss": 0.3796,
1756
+ "step": 1230
1757
+ },
1758
+ {
1759
+ "epoch": 0.3745829542007886,
1760
+ "grad_norm": 0.8586023962029455,
1761
+ "learning_rate": 1.4775479222116935e-05,
1762
+ "loss": 0.4458,
1763
+ "step": 1235
1764
+ },
1765
+ {
1766
+ "epoch": 0.37609948437973917,
1767
+ "grad_norm": 1.0620266861330567,
1768
+ "learning_rate": 1.4731368615822623e-05,
1769
+ "loss": 0.3992,
1770
+ "step": 1240
1771
+ },
1772
+ {
1773
+ "epoch": 0.3776160145586897,
1774
+ "grad_norm": 0.9686222054309538,
1775
+ "learning_rate": 1.468713907585311e-05,
1776
+ "loss": 0.3897,
1777
+ "step": 1245
1778
+ },
1779
+ {
1780
+ "epoch": 0.3791325447376403,
1781
+ "grad_norm": 0.9098741177020511,
1782
+ "learning_rate": 1.4642791714018148e-05,
1783
+ "loss": 0.3975,
1784
+ "step": 1250
1785
+ },
1786
+ {
1787
+ "epoch": 0.38064907491659083,
1788
+ "grad_norm": 1.0748028384120583,
1789
+ "learning_rate": 1.45983276450892e-05,
1790
+ "loss": 0.3719,
1791
+ "step": 1255
1792
+ },
1793
+ {
1794
+ "epoch": 0.3821656050955414,
1795
+ "grad_norm": 1.0534448032219224,
1796
+ "learning_rate": 1.4553747986771426e-05,
1797
+ "loss": 0.428,
1798
+ "step": 1260
1799
+ },
1800
+ {
1801
+ "epoch": 0.38368213527449196,
1802
+ "grad_norm": 1.1494681374152054,
1803
+ "learning_rate": 1.4509053859675601e-05,
1804
+ "loss": 0.43,
1805
+ "step": 1265
1806
+ },
1807
+ {
1808
+ "epoch": 0.3851986654534425,
1809
+ "grad_norm": 0.9573470062232139,
1810
+ "learning_rate": 1.4464246387289913e-05,
1811
+ "loss": 0.412,
1812
+ "step": 1270
1813
+ },
1814
+ {
1815
+ "epoch": 0.3867151956323931,
1816
+ "grad_norm": 1.3662334486068808,
1817
+ "learning_rate": 1.4419326695951752e-05,
1818
+ "loss": 0.3896,
1819
+ "step": 1275
1820
+ },
1821
+ {
1822
+ "epoch": 0.38823172581134363,
1823
+ "grad_norm": 1.0831328903659596,
1824
+ "learning_rate": 1.4374295914819385e-05,
1825
+ "loss": 0.3853,
1826
+ "step": 1280
1827
+ },
1828
+ {
1829
+ "epoch": 0.3897482559902942,
1830
+ "grad_norm": 1.1259269630812838,
1831
+ "learning_rate": 1.4329155175843572e-05,
1832
+ "loss": 0.3424,
1833
+ "step": 1285
1834
+ },
1835
+ {
1836
+ "epoch": 0.39126478616924476,
1837
+ "grad_norm": 1.0133166607814084,
1838
+ "learning_rate": 1.4283905613739107e-05,
1839
+ "loss": 0.3958,
1840
+ "step": 1290
1841
+ },
1842
+ {
1843
+ "epoch": 0.39278131634819535,
1844
+ "grad_norm": 0.9526285779039391,
1845
+ "learning_rate": 1.4238548365956308e-05,
1846
+ "loss": 0.3465,
1847
+ "step": 1295
1848
+ },
1849
+ {
1850
+ "epoch": 0.3942978465271459,
1851
+ "grad_norm": 0.8356690841228143,
1852
+ "learning_rate": 1.4193084572652415e-05,
1853
+ "loss": 0.3676,
1854
+ "step": 1300
1855
+ },
1856
+ {
1857
+ "epoch": 0.3958143767060964,
1858
+ "grad_norm": 1.056339975220587,
1859
+ "learning_rate": 1.4147515376662928e-05,
1860
+ "loss": 0.3866,
1861
+ "step": 1305
1862
+ },
1863
+ {
1864
+ "epoch": 0.397330906885047,
1865
+ "grad_norm": 1.093716550361687,
1866
+ "learning_rate": 1.4101841923472885e-05,
1867
+ "loss": 0.3623,
1868
+ "step": 1310
1869
+ },
1870
+ {
1871
+ "epoch": 0.39884743706399756,
1872
+ "grad_norm": 1.1381025026138512,
1873
+ "learning_rate": 1.4056065361188068e-05,
1874
+ "loss": 0.3829,
1875
+ "step": 1315
1876
+ },
1877
+ {
1878
+ "epoch": 0.40036396724294815,
1879
+ "grad_norm": 1.0195350418123983,
1880
+ "learning_rate": 1.4010186840506123e-05,
1881
+ "loss": 0.3487,
1882
+ "step": 1320
1883
+ },
1884
+ {
1885
+ "epoch": 0.4018804974218987,
1886
+ "grad_norm": 0.8852855193645849,
1887
+ "learning_rate": 1.396420751468768e-05,
1888
+ "loss": 0.3533,
1889
+ "step": 1325
1890
+ },
1891
+ {
1892
+ "epoch": 0.4033970276008493,
1893
+ "grad_norm": 0.993588466372261,
1894
+ "learning_rate": 1.3918128539527312e-05,
1895
+ "loss": 0.4471,
1896
+ "step": 1330
1897
+ },
1898
+ {
1899
+ "epoch": 0.4049135577797998,
1900
+ "grad_norm": 0.9519991181443925,
1901
+ "learning_rate": 1.3871951073324508e-05,
1902
+ "loss": 0.3521,
1903
+ "step": 1335
1904
+ },
1905
+ {
1906
+ "epoch": 0.40643008795875035,
1907
+ "grad_norm": 0.9493052976320027,
1908
+ "learning_rate": 1.3825676276854563e-05,
1909
+ "loss": 0.3804,
1910
+ "step": 1340
1911
+ },
1912
+ {
1913
+ "epoch": 0.40794661813770094,
1914
+ "grad_norm": 0.8469034564877794,
1915
+ "learning_rate": 1.377930531333938e-05,
1916
+ "loss": 0.4021,
1917
+ "step": 1345
1918
+ },
1919
+ {
1920
+ "epoch": 0.4094631483166515,
1921
+ "grad_norm": 1.1005168381310284,
1922
+ "learning_rate": 1.3732839348418234e-05,
1923
+ "loss": 0.328,
1924
+ "step": 1350
1925
+ },
1926
+ {
1927
+ "epoch": 0.4109796784956021,
1928
+ "grad_norm": 1.0684696716820943,
1929
+ "learning_rate": 1.3686279550118491e-05,
1930
+ "loss": 0.377,
1931
+ "step": 1355
1932
+ },
1933
+ {
1934
+ "epoch": 0.4124962086745526,
1935
+ "grad_norm": 1.0525285102148023,
1936
+ "learning_rate": 1.3639627088826217e-05,
1937
+ "loss": 0.3807,
1938
+ "step": 1360
1939
+ },
1940
+ {
1941
+ "epoch": 0.4140127388535032,
1942
+ "grad_norm": 0.9870824418498249,
1943
+ "learning_rate": 1.3592883137256776e-05,
1944
+ "loss": 0.3821,
1945
+ "step": 1365
1946
+ },
1947
+ {
1948
+ "epoch": 0.41552926903245374,
1949
+ "grad_norm": 0.7608715586792121,
1950
+ "learning_rate": 1.3546048870425356e-05,
1951
+ "loss": 0.3404,
1952
+ "step": 1370
1953
+ },
1954
+ {
1955
+ "epoch": 0.41704579921140433,
1956
+ "grad_norm": 0.99595610767164,
1957
+ "learning_rate": 1.3499125465617417e-05,
1958
+ "loss": 0.3872,
1959
+ "step": 1375
1960
+ },
1961
+ {
1962
+ "epoch": 0.41856232939035487,
1963
+ "grad_norm": 1.000735855327813,
1964
+ "learning_rate": 1.34521141023591e-05,
1965
+ "loss": 0.3428,
1966
+ "step": 1380
1967
+ },
1968
+ {
1969
+ "epoch": 0.4200788595693054,
1970
+ "grad_norm": 0.9679808555916511,
1971
+ "learning_rate": 1.3405015962387588e-05,
1972
+ "loss": 0.3543,
1973
+ "step": 1385
1974
+ },
1975
+ {
1976
+ "epoch": 0.421595389748256,
1977
+ "grad_norm": 0.9241312294010786,
1978
+ "learning_rate": 1.3357832229621393e-05,
1979
+ "loss": 0.4196,
1980
+ "step": 1390
1981
+ },
1982
+ {
1983
+ "epoch": 0.42311191992720654,
1984
+ "grad_norm": 0.9695965917058144,
1985
+ "learning_rate": 1.3310564090130588e-05,
1986
+ "loss": 0.4028,
1987
+ "step": 1395
1988
+ },
1989
+ {
1990
+ "epoch": 0.42462845010615713,
1991
+ "grad_norm": 0.9665874167024507,
1992
+ "learning_rate": 1.3263212732107014e-05,
1993
+ "loss": 0.4332,
1994
+ "step": 1400
1995
+ },
1996
+ {
1997
+ "epoch": 0.42614498028510767,
1998
+ "grad_norm": 1.1010539769239436,
1999
+ "learning_rate": 1.3215779345834385e-05,
2000
+ "loss": 0.3641,
2001
+ "step": 1405
2002
+ },
2003
+ {
2004
+ "epoch": 0.42766151046405826,
2005
+ "grad_norm": 0.738968771424043,
2006
+ "learning_rate": 1.3168265123658386e-05,
2007
+ "loss": 0.324,
2008
+ "step": 1410
2009
+ },
2010
+ {
2011
+ "epoch": 0.4291780406430088,
2012
+ "grad_norm": 0.8453745632530565,
2013
+ "learning_rate": 1.3120671259956699e-05,
2014
+ "loss": 0.3628,
2015
+ "step": 1415
2016
+ },
2017
+ {
2018
+ "epoch": 0.43069457082195933,
2019
+ "grad_norm": 1.0651118171429264,
2020
+ "learning_rate": 1.3072998951108978e-05,
2021
+ "loss": 0.3874,
2022
+ "step": 1420
2023
+ },
2024
+ {
2025
+ "epoch": 0.4322111010009099,
2026
+ "grad_norm": 0.8241128807768539,
2027
+ "learning_rate": 1.3025249395466758e-05,
2028
+ "loss": 0.3672,
2029
+ "step": 1425
2030
+ },
2031
+ {
2032
+ "epoch": 0.43372763117986046,
2033
+ "grad_norm": 0.8764574352955143,
2034
+ "learning_rate": 1.297742379332337e-05,
2035
+ "loss": 0.3784,
2036
+ "step": 1430
2037
+ },
2038
+ {
2039
+ "epoch": 0.43524416135881105,
2040
+ "grad_norm": 0.9127221397188253,
2041
+ "learning_rate": 1.292952334688373e-05,
2042
+ "loss": 0.3538,
2043
+ "step": 1435
2044
+ },
2045
+ {
2046
+ "epoch": 0.4367606915377616,
2047
+ "grad_norm": 0.9788821462735778,
2048
+ "learning_rate": 1.2881549260234137e-05,
2049
+ "loss": 0.3817,
2050
+ "step": 1440
2051
+ },
2052
+ {
2053
+ "epoch": 0.4382772217167122,
2054
+ "grad_norm": 0.9336141701850212,
2055
+ "learning_rate": 1.2833502739312009e-05,
2056
+ "loss": 0.3911,
2057
+ "step": 1445
2058
+ },
2059
+ {
2060
+ "epoch": 0.4397937518956627,
2061
+ "grad_norm": 1.1869329918853797,
2062
+ "learning_rate": 1.2785384991875565e-05,
2063
+ "loss": 0.3865,
2064
+ "step": 1450
2065
+ },
2066
+ {
2067
+ "epoch": 0.44131028207461326,
2068
+ "grad_norm": 0.8526421566105249,
2069
+ "learning_rate": 1.273719722747345e-05,
2070
+ "loss": 0.3398,
2071
+ "step": 1455
2072
+ },
2073
+ {
2074
+ "epoch": 0.44282681225356385,
2075
+ "grad_norm": 0.8605140943098712,
2076
+ "learning_rate": 1.2688940657414362e-05,
2077
+ "loss": 0.3591,
2078
+ "step": 1460
2079
+ },
2080
+ {
2081
+ "epoch": 0.4443433424325144,
2082
+ "grad_norm": 1.1744022801387977,
2083
+ "learning_rate": 1.264061649473657e-05,
2084
+ "loss": 0.4196,
2085
+ "step": 1465
2086
+ },
2087
+ {
2088
+ "epoch": 0.445859872611465,
2089
+ "grad_norm": 0.9522399346317854,
2090
+ "learning_rate": 1.2592225954177453e-05,
2091
+ "loss": 0.3397,
2092
+ "step": 1470
2093
+ },
2094
+ {
2095
+ "epoch": 0.4473764027904155,
2096
+ "grad_norm": 0.9996838939742224,
2097
+ "learning_rate": 1.2543770252142938e-05,
2098
+ "loss": 0.4086,
2099
+ "step": 1475
2100
+ },
2101
+ {
2102
+ "epoch": 0.4488929329693661,
2103
+ "grad_norm": 1.0879756329716435,
2104
+ "learning_rate": 1.2495250606676927e-05,
2105
+ "loss": 0.3837,
2106
+ "step": 1480
2107
+ },
2108
+ {
2109
+ "epoch": 0.45040946314831665,
2110
+ "grad_norm": 1.174932056472148,
2111
+ "learning_rate": 1.2446668237430697e-05,
2112
+ "loss": 0.4314,
2113
+ "step": 1485
2114
+ },
2115
+ {
2116
+ "epoch": 0.45192599332726724,
2117
+ "grad_norm": 0.8386548630342906,
2118
+ "learning_rate": 1.2398024365632229e-05,
2119
+ "loss": 0.3641,
2120
+ "step": 1490
2121
+ },
2122
+ {
2123
+ "epoch": 0.4534425235062178,
2124
+ "grad_norm": 0.9535711206768751,
2125
+ "learning_rate": 1.2349320214055502e-05,
2126
+ "loss": 0.3634,
2127
+ "step": 1495
2128
+ },
2129
+ {
2130
+ "epoch": 0.4549590536851683,
2131
+ "grad_norm": 0.9958036756892384,
2132
+ "learning_rate": 1.2300557006989768e-05,
2133
+ "loss": 0.386,
2134
+ "step": 1500
2135
+ },
2136
+ {
2137
+ "epoch": 0.4549590536851683,
2138
+ "eval_loss": 0.3838871121406555,
2139
+ "eval_runtime": 177.9377,
2140
+ "eval_samples_per_second": 50.214,
2141
+ "eval_steps_per_second": 25.11,
2142
+ "step": 1500
2143
+ },
2144
+ {
2145
+ "epoch": 0.4564755838641189,
2146
+ "grad_norm": 1.1714023611493132,
2147
+ "learning_rate": 1.2251735970208776e-05,
2148
+ "loss": 0.4244,
2149
+ "step": 1505
2150
+ },
2151
+ {
2152
+ "epoch": 0.45799211404306944,
2153
+ "grad_norm": 0.9816643204121793,
2154
+ "learning_rate": 1.2202858330939946e-05,
2155
+ "loss": 0.3582,
2156
+ "step": 1510
2157
+ },
2158
+ {
2159
+ "epoch": 0.45950864422202004,
2160
+ "grad_norm": 1.2215461118765152,
2161
+ "learning_rate": 1.2153925317833544e-05,
2162
+ "loss": 0.3571,
2163
+ "step": 1515
2164
+ },
2165
+ {
2166
+ "epoch": 0.46102517440097057,
2167
+ "grad_norm": 1.0705463427387603,
2168
+ "learning_rate": 1.2104938160931775e-05,
2169
+ "loss": 0.3729,
2170
+ "step": 1520
2171
+ },
2172
+ {
2173
+ "epoch": 0.46254170457992116,
2174
+ "grad_norm": 1.0397410437607952,
2175
+ "learning_rate": 1.2055898091637867e-05,
2176
+ "loss": 0.3818,
2177
+ "step": 1525
2178
+ },
2179
+ {
2180
+ "epoch": 0.4640582347588717,
2181
+ "grad_norm": 1.1007740207154797,
2182
+ "learning_rate": 1.2006806342685127e-05,
2183
+ "loss": 0.4075,
2184
+ "step": 1530
2185
+ },
2186
+ {
2187
+ "epoch": 0.46557476493782224,
2188
+ "grad_norm": 1.0553911498111368,
2189
+ "learning_rate": 1.195766414810595e-05,
2190
+ "loss": 0.4049,
2191
+ "step": 1535
2192
+ },
2193
+ {
2194
+ "epoch": 0.46709129511677283,
2195
+ "grad_norm": 1.048411929765153,
2196
+ "learning_rate": 1.1908472743200787e-05,
2197
+ "loss": 0.4239,
2198
+ "step": 1540
2199
+ },
2200
+ {
2201
+ "epoch": 0.46860782529572337,
2202
+ "grad_norm": 0.9022404852741357,
2203
+ "learning_rate": 1.1859233364507105e-05,
2204
+ "loss": 0.4135,
2205
+ "step": 1545
2206
+ },
2207
+ {
2208
+ "epoch": 0.47012435547467396,
2209
+ "grad_norm": 0.8399864038325862,
2210
+ "learning_rate": 1.1809947249768312e-05,
2211
+ "loss": 0.3431,
2212
+ "step": 1550
2213
+ },
2214
+ {
2215
+ "epoch": 0.4716408856536245,
2216
+ "grad_norm": 0.9086969723025363,
2217
+ "learning_rate": 1.1760615637902615e-05,
2218
+ "loss": 0.4285,
2219
+ "step": 1555
2220
+ },
2221
+ {
2222
+ "epoch": 0.4731574158325751,
2223
+ "grad_norm": 1.038569849088238,
2224
+ "learning_rate": 1.1711239768971908e-05,
2225
+ "loss": 0.3955,
2226
+ "step": 1560
2227
+ },
2228
+ {
2229
+ "epoch": 0.4746739460115256,
2230
+ "grad_norm": 1.0391274519013625,
2231
+ "learning_rate": 1.1661820884150577e-05,
2232
+ "loss": 0.3598,
2233
+ "step": 1565
2234
+ },
2235
+ {
2236
+ "epoch": 0.47619047619047616,
2237
+ "grad_norm": 1.1140866338724735,
2238
+ "learning_rate": 1.1612360225694317e-05,
2239
+ "loss": 0.419,
2240
+ "step": 1570
2241
+ },
2242
+ {
2243
+ "epoch": 0.47770700636942676,
2244
+ "grad_norm": 1.1753789072266334,
2245
+ "learning_rate": 1.1562859036908895e-05,
2246
+ "loss": 0.4039,
2247
+ "step": 1575
2248
+ },
2249
+ {
2250
+ "epoch": 0.4792235365483773,
2251
+ "grad_norm": 0.8669494063557552,
2252
+ "learning_rate": 1.1513318562118902e-05,
2253
+ "loss": 0.2988,
2254
+ "step": 1580
2255
+ },
2256
+ {
2257
+ "epoch": 0.4807400667273279,
2258
+ "grad_norm": 1.1084096884130417,
2259
+ "learning_rate": 1.1463740046636471e-05,
2260
+ "loss": 0.3788,
2261
+ "step": 1585
2262
+ },
2263
+ {
2264
+ "epoch": 0.4822565969062784,
2265
+ "grad_norm": 0.9636937678770835,
2266
+ "learning_rate": 1.1414124736729966e-05,
2267
+ "loss": 0.3891,
2268
+ "step": 1590
2269
+ },
2270
+ {
2271
+ "epoch": 0.483773127085229,
2272
+ "grad_norm": 1.0650960432078391,
2273
+ "learning_rate": 1.1364473879592674e-05,
2274
+ "loss": 0.3815,
2275
+ "step": 1595
2276
+ },
2277
+ {
2278
+ "epoch": 0.48528965726417955,
2279
+ "grad_norm": 0.8673741729500686,
2280
+ "learning_rate": 1.1314788723311438e-05,
2281
+ "loss": 0.402,
2282
+ "step": 1600
2283
+ },
2284
+ {
2285
+ "epoch": 0.48680618744313015,
2286
+ "grad_norm": 0.8885954157154354,
2287
+ "learning_rate": 1.1265070516835286e-05,
2288
+ "loss": 0.3701,
2289
+ "step": 1605
2290
+ },
2291
+ {
2292
+ "epoch": 0.4883227176220807,
2293
+ "grad_norm": 0.9978426248829919,
2294
+ "learning_rate": 1.1215320509944038e-05,
2295
+ "loss": 0.3451,
2296
+ "step": 1610
2297
+ },
2298
+ {
2299
+ "epoch": 0.4898392478010312,
2300
+ "grad_norm": 1.0183575162282423,
2301
+ "learning_rate": 1.1165539953216893e-05,
2302
+ "loss": 0.3681,
2303
+ "step": 1615
2304
+ },
2305
+ {
2306
+ "epoch": 0.4913557779799818,
2307
+ "grad_norm": 1.0579314109624234,
2308
+ "learning_rate": 1.1115730098000982e-05,
2309
+ "loss": 0.3972,
2310
+ "step": 1620
2311
+ },
2312
+ {
2313
+ "epoch": 0.49287230815893235,
2314
+ "grad_norm": 1.032202979176664,
2315
+ "learning_rate": 1.1065892196379928e-05,
2316
+ "loss": 0.4024,
2317
+ "step": 1625
2318
+ },
2319
+ {
2320
+ "epoch": 0.49438883833788294,
2321
+ "grad_norm": 1.1044778878734602,
2322
+ "learning_rate": 1.101602750114236e-05,
2323
+ "loss": 0.4044,
2324
+ "step": 1630
2325
+ },
2326
+ {
2327
+ "epoch": 0.4959053685168335,
2328
+ "grad_norm": 0.8668001202243194,
2329
+ "learning_rate": 1.0966137265750427e-05,
2330
+ "loss": 0.3988,
2331
+ "step": 1635
2332
+ },
2333
+ {
2334
+ "epoch": 0.49742189869578407,
2335
+ "grad_norm": 1.0209675003151268,
2336
+ "learning_rate": 1.0916222744308285e-05,
2337
+ "loss": 0.3875,
2338
+ "step": 1640
2339
+ },
2340
+ {
2341
+ "epoch": 0.4989384288747346,
2342
+ "grad_norm": 1.098694287995219,
2343
+ "learning_rate": 1.0866285191530572e-05,
2344
+ "loss": 0.3787,
2345
+ "step": 1645
2346
+ },
2347
+ {
2348
+ "epoch": 0.5004549590536852,
2349
+ "grad_norm": 1.0520492533718504,
2350
+ "learning_rate": 1.0816325862710884e-05,
2351
+ "loss": 0.367,
2352
+ "step": 1650
2353
+ },
2354
+ {
2355
+ "epoch": 0.5019714892326357,
2356
+ "grad_norm": 0.8661615469589424,
2357
+ "learning_rate": 1.0766346013690193e-05,
2358
+ "loss": 0.3818,
2359
+ "step": 1655
2360
+ },
2361
+ {
2362
+ "epoch": 0.5034880194115863,
2363
+ "grad_norm": 1.0448544478772042,
2364
+ "learning_rate": 1.0716346900825298e-05,
2365
+ "loss": 0.4066,
2366
+ "step": 1660
2367
+ },
2368
+ {
2369
+ "epoch": 0.5050045495905369,
2370
+ "grad_norm": 0.9437379396406751,
2371
+ "learning_rate": 1.066632978095724e-05,
2372
+ "loss": 0.3612,
2373
+ "step": 1665
2374
+ },
2375
+ {
2376
+ "epoch": 0.5065210797694875,
2377
+ "grad_norm": 1.0577638338010937,
2378
+ "learning_rate": 1.0616295911379706e-05,
2379
+ "loss": 0.3931,
2380
+ "step": 1670
2381
+ },
2382
+ {
2383
+ "epoch": 0.5080376099484379,
2384
+ "grad_norm": 1.0224841150046953,
2385
+ "learning_rate": 1.0566246549807424e-05,
2386
+ "loss": 0.3568,
2387
+ "step": 1675
2388
+ },
2389
+ {
2390
+ "epoch": 0.5095541401273885,
2391
+ "grad_norm": 1.0709827796628466,
2392
+ "learning_rate": 1.0516182954344548e-05,
2393
+ "loss": 0.3785,
2394
+ "step": 1680
2395
+ },
2396
+ {
2397
+ "epoch": 0.5110706703063391,
2398
+ "grad_norm": 1.0408023532900905,
2399
+ "learning_rate": 1.0466106383453033e-05,
2400
+ "loss": 0.4197,
2401
+ "step": 1685
2402
+ },
2403
+ {
2404
+ "epoch": 0.5125872004852896,
2405
+ "grad_norm": 0.9978865184694876,
2406
+ "learning_rate": 1.0416018095921002e-05,
2407
+ "loss": 0.361,
2408
+ "step": 1690
2409
+ },
2410
+ {
2411
+ "epoch": 0.5141037306642402,
2412
+ "grad_norm": 0.9482830099197627,
2413
+ "learning_rate": 1.0365919350831105e-05,
2414
+ "loss": 0.3846,
2415
+ "step": 1695
2416
+ },
2417
+ {
2418
+ "epoch": 0.5156202608431908,
2419
+ "grad_norm": 1.098252143243866,
2420
+ "learning_rate": 1.031581140752886e-05,
2421
+ "loss": 0.384,
2422
+ "step": 1700
2423
+ },
2424
+ {
2425
+ "epoch": 0.5171367910221414,
2426
+ "grad_norm": 0.8273439653306288,
2427
+ "learning_rate": 1.0265695525591003e-05,
2428
+ "loss": 0.3457,
2429
+ "step": 1705
2430
+ },
2431
+ {
2432
+ "epoch": 0.5186533212010919,
2433
+ "grad_norm": 0.8472856026513179,
2434
+ "learning_rate": 1.0215572964793838e-05,
2435
+ "loss": 0.3804,
2436
+ "step": 1710
2437
+ },
2438
+ {
2439
+ "epoch": 0.5201698513800425,
2440
+ "grad_norm": 0.9358580008284806,
2441
+ "learning_rate": 1.0165444985081543e-05,
2442
+ "loss": 0.3508,
2443
+ "step": 1715
2444
+ },
2445
+ {
2446
+ "epoch": 0.521686381558993,
2447
+ "grad_norm": 0.7254575613461925,
2448
+ "learning_rate": 1.0115312846534518e-05,
2449
+ "loss": 0.373,
2450
+ "step": 1720
2451
+ },
2452
+ {
2453
+ "epoch": 0.5232029117379435,
2454
+ "grad_norm": 1.0896517684977018,
2455
+ "learning_rate": 1.0065177809337703e-05,
2456
+ "loss": 0.3391,
2457
+ "step": 1725
2458
+ },
2459
+ {
2460
+ "epoch": 0.5247194419168941,
2461
+ "grad_norm": 0.9244440412828214,
2462
+ "learning_rate": 1.0015041133748908e-05,
2463
+ "loss": 0.3588,
2464
+ "step": 1730
2465
+ },
2466
+ {
2467
+ "epoch": 0.5262359720958447,
2468
+ "grad_norm": 1.0159645169016112,
2469
+ "learning_rate": 9.964904080067119e-06,
2470
+ "loss": 0.3627,
2471
+ "step": 1735
2472
+ },
2473
+ {
2474
+ "epoch": 0.5277525022747953,
2475
+ "grad_norm": 0.8376669598715493,
2476
+ "learning_rate": 9.914767908600835e-06,
2477
+ "loss": 0.3369,
2478
+ "step": 1740
2479
+ },
2480
+ {
2481
+ "epoch": 0.5292690324537458,
2482
+ "grad_norm": 0.7978725648817389,
2483
+ "learning_rate": 9.864633879636371e-06,
2484
+ "loss": 0.326,
2485
+ "step": 1745
2486
+ },
2487
+ {
2488
+ "epoch": 0.5307855626326964,
2489
+ "grad_norm": 0.9497158806314906,
2490
+ "learning_rate": 9.814503253406188e-06,
2491
+ "loss": 0.3761,
2492
+ "step": 1750
2493
+ },
2494
+ {
2495
+ "epoch": 0.532302092811647,
2496
+ "grad_norm": 0.8666942612740441,
2497
+ "learning_rate": 9.764377290057217e-06,
2498
+ "loss": 0.3237,
2499
+ "step": 1755
2500
+ },
2501
+ {
2502
+ "epoch": 0.5338186229905975,
2503
+ "grad_norm": 1.015724277407853,
2504
+ "learning_rate": 9.714257249619166e-06,
2505
+ "loss": 0.3851,
2506
+ "step": 1760
2507
+ },
2508
+ {
2509
+ "epoch": 0.535335153169548,
2510
+ "grad_norm": 0.8370259755101838,
2511
+ "learning_rate": 9.664144391972867e-06,
2512
+ "loss": 0.3499,
2513
+ "step": 1765
2514
+ },
2515
+ {
2516
+ "epoch": 0.5368516833484986,
2517
+ "grad_norm": 0.8982590989333401,
2518
+ "learning_rate": 9.614039976818591e-06,
2519
+ "loss": 0.3871,
2520
+ "step": 1770
2521
+ },
2522
+ {
2523
+ "epoch": 0.5383682135274492,
2524
+ "grad_norm": 0.8871582462703803,
2525
+ "learning_rate": 9.56394526364439e-06,
2526
+ "loss": 0.3852,
2527
+ "step": 1775
2528
+ },
2529
+ {
2530
+ "epoch": 0.5398847437063997,
2531
+ "grad_norm": 0.9733274161552575,
2532
+ "learning_rate": 9.513861511694432e-06,
2533
+ "loss": 0.3754,
2534
+ "step": 1780
2535
+ },
2536
+ {
2537
+ "epoch": 0.5414012738853503,
2538
+ "grad_norm": 0.9575249396664576,
2539
+ "learning_rate": 9.46378997993735e-06,
2540
+ "loss": 0.3571,
2541
+ "step": 1785
2542
+ },
2543
+ {
2544
+ "epoch": 0.5429178040643009,
2545
+ "grad_norm": 0.9607725237981556,
2546
+ "learning_rate": 9.413731927034607e-06,
2547
+ "loss": 0.3747,
2548
+ "step": 1790
2549
+ },
2550
+ {
2551
+ "epoch": 0.5444343342432515,
2552
+ "grad_norm": 1.165305127561796,
2553
+ "learning_rate": 9.363688611308825e-06,
2554
+ "loss": 0.39,
2555
+ "step": 1795
2556
+ },
2557
+ {
2558
+ "epoch": 0.545950864422202,
2559
+ "grad_norm": 1.2378284714450407,
2560
+ "learning_rate": 9.313661290712182e-06,
2561
+ "loss": 0.4023,
2562
+ "step": 1800
2563
+ },
2564
+ {
2565
+ "epoch": 0.5474673946011526,
2566
+ "grad_norm": 0.9390807801563313,
2567
+ "learning_rate": 9.26365122279479e-06,
2568
+ "loss": 0.3655,
2569
+ "step": 1805
2570
+ },
2571
+ {
2572
+ "epoch": 0.5489839247801032,
2573
+ "grad_norm": 0.8512608333919068,
2574
+ "learning_rate": 9.213659664673063e-06,
2575
+ "loss": 0.3833,
2576
+ "step": 1810
2577
+ },
2578
+ {
2579
+ "epoch": 0.5505004549590536,
2580
+ "grad_norm": 0.9579623704318949,
2581
+ "learning_rate": 9.163687872998134e-06,
2582
+ "loss": 0.3565,
2583
+ "step": 1815
2584
+ },
2585
+ {
2586
+ "epoch": 0.5520169851380042,
2587
+ "grad_norm": 0.9096584536659199,
2588
+ "learning_rate": 9.113737103924266e-06,
2589
+ "loss": 0.3946,
2590
+ "step": 1820
2591
+ },
2592
+ {
2593
+ "epoch": 0.5535335153169548,
2594
+ "grad_norm": 0.8686976809327038,
2595
+ "learning_rate": 9.063808613077265e-06,
2596
+ "loss": 0.3416,
2597
+ "step": 1825
2598
+ },
2599
+ {
2600
+ "epoch": 0.5550500454959054,
2601
+ "grad_norm": 1.04912717783663,
2602
+ "learning_rate": 9.013903655522931e-06,
2603
+ "loss": 0.4476,
2604
+ "step": 1830
2605
+ },
2606
+ {
2607
+ "epoch": 0.5565665756748559,
2608
+ "grad_norm": 1.1364230419987023,
2609
+ "learning_rate": 8.964023485735491e-06,
2610
+ "loss": 0.3816,
2611
+ "step": 1835
2612
+ },
2613
+ {
2614
+ "epoch": 0.5580831058538065,
2615
+ "grad_norm": 0.8842214833551543,
2616
+ "learning_rate": 8.914169357566082e-06,
2617
+ "loss": 0.3291,
2618
+ "step": 1840
2619
+ },
2620
+ {
2621
+ "epoch": 0.5595996360327571,
2622
+ "grad_norm": 0.9895945394957149,
2623
+ "learning_rate": 8.864342524211228e-06,
2624
+ "loss": 0.3881,
2625
+ "step": 1845
2626
+ },
2627
+ {
2628
+ "epoch": 0.5611161662117076,
2629
+ "grad_norm": 0.9031975360786588,
2630
+ "learning_rate": 8.814544238181327e-06,
2631
+ "loss": 0.407,
2632
+ "step": 1850
2633
+ },
2634
+ {
2635
+ "epoch": 0.5626326963906582,
2636
+ "grad_norm": 1.011452507718733,
2637
+ "learning_rate": 8.764775751269184e-06,
2638
+ "loss": 0.3784,
2639
+ "step": 1855
2640
+ },
2641
+ {
2642
+ "epoch": 0.5641492265696088,
2643
+ "grad_norm": 1.0249078931801974,
2644
+ "learning_rate": 8.715038314518532e-06,
2645
+ "loss": 0.387,
2646
+ "step": 1860
2647
+ },
2648
+ {
2649
+ "epoch": 0.5656657567485593,
2650
+ "grad_norm": 1.0281856208543434,
2651
+ "learning_rate": 8.66533317819259e-06,
2652
+ "loss": 0.349,
2653
+ "step": 1865
2654
+ },
2655
+ {
2656
+ "epoch": 0.5671822869275098,
2657
+ "grad_norm": 0.8713289058377761,
2658
+ "learning_rate": 8.615661591742626e-06,
2659
+ "loss": 0.3775,
2660
+ "step": 1870
2661
+ },
2662
+ {
2663
+ "epoch": 0.5686988171064604,
2664
+ "grad_norm": 0.9270856112232262,
2665
+ "learning_rate": 8.566024803776567e-06,
2666
+ "loss": 0.403,
2667
+ "step": 1875
2668
+ },
2669
+ {
2670
+ "epoch": 0.570215347285411,
2671
+ "grad_norm": 0.9868872288291042,
2672
+ "learning_rate": 8.516424062027587e-06,
2673
+ "loss": 0.3474,
2674
+ "step": 1880
2675
+ },
2676
+ {
2677
+ "epoch": 0.5717318774643615,
2678
+ "grad_norm": 0.9141693333613865,
2679
+ "learning_rate": 8.466860613322773e-06,
2680
+ "loss": 0.3596,
2681
+ "step": 1885
2682
+ },
2683
+ {
2684
+ "epoch": 0.5732484076433121,
2685
+ "grad_norm": 1.13028177709786,
2686
+ "learning_rate": 8.417335703551753e-06,
2687
+ "loss": 0.39,
2688
+ "step": 1890
2689
+ },
2690
+ {
2691
+ "epoch": 0.5747649378222627,
2692
+ "grad_norm": 0.963375132739776,
2693
+ "learning_rate": 8.3678505776354e-06,
2694
+ "loss": 0.3582,
2695
+ "step": 1895
2696
+ },
2697
+ {
2698
+ "epoch": 0.5762814680012133,
2699
+ "grad_norm": 0.8566314951014108,
2700
+ "learning_rate": 8.318406479494526e-06,
2701
+ "loss": 0.3499,
2702
+ "step": 1900
2703
+ },
2704
+ {
2705
+ "epoch": 0.5777979981801638,
2706
+ "grad_norm": 0.871827865692706,
2707
+ "learning_rate": 8.269004652018615e-06,
2708
+ "loss": 0.3378,
2709
+ "step": 1905
2710
+ },
2711
+ {
2712
+ "epoch": 0.5793145283591143,
2713
+ "grad_norm": 0.8864124825854094,
2714
+ "learning_rate": 8.219646337034587e-06,
2715
+ "loss": 0.3663,
2716
+ "step": 1910
2717
+ },
2718
+ {
2719
+ "epoch": 0.5808310585380649,
2720
+ "grad_norm": 1.0414992512722574,
2721
+ "learning_rate": 8.170332775275572e-06,
2722
+ "loss": 0.3895,
2723
+ "step": 1915
2724
+ },
2725
+ {
2726
+ "epoch": 0.5823475887170154,
2727
+ "grad_norm": 1.1285331789955597,
2728
+ "learning_rate": 8.12106520634973e-06,
2729
+ "loss": 0.391,
2730
+ "step": 1920
2731
+ },
2732
+ {
2733
+ "epoch": 0.583864118895966,
2734
+ "grad_norm": 1.1338277664903524,
2735
+ "learning_rate": 8.071844868709086e-06,
2736
+ "loss": 0.4213,
2737
+ "step": 1925
2738
+ },
2739
+ {
2740
+ "epoch": 0.5853806490749166,
2741
+ "grad_norm": 1.0011965373177365,
2742
+ "learning_rate": 8.022672999618394e-06,
2743
+ "loss": 0.3739,
2744
+ "step": 1930
2745
+ },
2746
+ {
2747
+ "epoch": 0.5868971792538672,
2748
+ "grad_norm": 0.8353717953496863,
2749
+ "learning_rate": 7.973550835124055e-06,
2750
+ "loss": 0.3886,
2751
+ "step": 1935
2752
+ },
2753
+ {
2754
+ "epoch": 0.5884137094328177,
2755
+ "grad_norm": 1.0274242422703623,
2756
+ "learning_rate": 7.924479610023016e-06,
2757
+ "loss": 0.3554,
2758
+ "step": 1940
2759
+ },
2760
+ {
2761
+ "epoch": 0.5899302396117683,
2762
+ "grad_norm": 0.9984155763187357,
2763
+ "learning_rate": 7.875460557831755e-06,
2764
+ "loss": 0.3556,
2765
+ "step": 1945
2766
+ },
2767
+ {
2768
+ "epoch": 0.5914467697907189,
2769
+ "grad_norm": 0.8717630327395798,
2770
+ "learning_rate": 7.82649491075527e-06,
2771
+ "loss": 0.3411,
2772
+ "step": 1950
2773
+ },
2774
+ {
2775
+ "epoch": 0.5929632999696693,
2776
+ "grad_norm": 1.0351492468740082,
2777
+ "learning_rate": 7.777583899656092e-06,
2778
+ "loss": 0.3599,
2779
+ "step": 1955
2780
+ },
2781
+ {
2782
+ "epoch": 0.5944798301486199,
2783
+ "grad_norm": 0.9357066998157261,
2784
+ "learning_rate": 7.728728754023354e-06,
2785
+ "loss": 0.3373,
2786
+ "step": 1960
2787
+ },
2788
+ {
2789
+ "epoch": 0.5959963603275705,
2790
+ "grad_norm": 0.8738803213245617,
2791
+ "learning_rate": 7.679930701941888e-06,
2792
+ "loss": 0.3475,
2793
+ "step": 1965
2794
+ },
2795
+ {
2796
+ "epoch": 0.5975128905065211,
2797
+ "grad_norm": 0.9262823621070619,
2798
+ "learning_rate": 7.631190970061349e-06,
2799
+ "loss": 0.3519,
2800
+ "step": 1970
2801
+ },
2802
+ {
2803
+ "epoch": 0.5990294206854716,
2804
+ "grad_norm": 0.9004476821654784,
2805
+ "learning_rate": 7.5825107835653814e-06,
2806
+ "loss": 0.326,
2807
+ "step": 1975
2808
+ },
2809
+ {
2810
+ "epoch": 0.6005459508644222,
2811
+ "grad_norm": 1.0358901873299116,
2812
+ "learning_rate": 7.533891366140815e-06,
2813
+ "loss": 0.3521,
2814
+ "step": 1980
2815
+ },
2816
+ {
2817
+ "epoch": 0.6020624810433728,
2818
+ "grad_norm": 0.8420033070489784,
2819
+ "learning_rate": 7.485333939946926e-06,
2820
+ "loss": 0.3379,
2821
+ "step": 1985
2822
+ },
2823
+ {
2824
+ "epoch": 0.6035790112223233,
2825
+ "grad_norm": 1.0494757959050431,
2826
+ "learning_rate": 7.4368397255846845e-06,
2827
+ "loss": 0.3945,
2828
+ "step": 1990
2829
+ },
2830
+ {
2831
+ "epoch": 0.6050955414012739,
2832
+ "grad_norm": 0.9524555408133961,
2833
+ "learning_rate": 7.388409942066099e-06,
2834
+ "loss": 0.3596,
2835
+ "step": 1995
2836
+ },
2837
+ {
2838
+ "epoch": 0.6066120715802245,
2839
+ "grad_norm": 1.0996405384525345,
2840
+ "learning_rate": 7.340045806783559e-06,
2841
+ "loss": 0.409,
2842
+ "step": 2000
2843
+ },
2844
+ {
2845
+ "epoch": 0.6066120715802245,
2846
+ "eval_loss": 0.3754875957965851,
2847
+ "eval_runtime": 173.9951,
2848
+ "eval_samples_per_second": 51.352,
2849
+ "eval_steps_per_second": 25.679,
2850
+ "step": 2000
2851
+ }
2852
+ ],
2853
+ "logging_steps": 5,
2854
+ "max_steps": 3297,
2855
+ "num_input_tokens_seen": 0,
2856
+ "num_train_epochs": 1,
2857
+ "save_steps": 500,
2858
+ "stateful_callbacks": {
2859
+ "EarlyStoppingCallback": {
2860
+ "args": {
2861
+ "early_stopping_patience": 3,
2862
+ "early_stopping_threshold": 0.0
2863
+ },
2864
+ "attributes": {
2865
+ "early_stopping_patience_counter": 0
2866
+ }
2867
+ },
2868
+ "TrainerControl": {
2869
+ "args": {
2870
+ "should_epoch_stop": false,
2871
+ "should_evaluate": false,
2872
+ "should_log": false,
2873
+ "should_save": true,
2874
+ "should_training_stop": false
2875
+ },
2876
+ "attributes": {}
2877
+ }
2878
+ },
2879
+ "total_flos": 97417710010368.0,
2880
+ "train_batch_size": 1,
2881
+ "trial_name": null,
2882
+ "trial_params": null
2883
+ }
checkpoint-2000/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-2000/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info(f"Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info(f"Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
checkpoint-2500/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
checkpoint-2500/chat_template.jinja ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- else %}
85
+ {{- '<think>\n\n' }}
86
+ {%- endif %}
87
+ {%- endif %}
checkpoint-2500/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 151645,
8
+ "head_dim": 128,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 6144,
13
+ "layer_types": [
14
+ "full_attention",
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention"
42
+ ],
43
+ "max_position_embeddings": 40960,
44
+ "max_window_layers": 28,
45
+ "model_type": "qwen3",
46
+ "num_attention_heads": 16,
47
+ "num_hidden_layers": 28,
48
+ "num_key_value_heads": 8,
49
+ "rms_norm_eps": 1e-06,
50
+ "rope_scaling": null,
51
+ "rope_theta": 1000000,
52
+ "sliding_window": null,
53
+ "tie_word_embeddings": true,
54
+ "torch_dtype": "bfloat16",
55
+ "transformers_version": "4.53.1",
56
+ "use_cache": false,
57
+ "use_sliding_window": false,
58
+ "vocab_size": 151936
59
+ }
checkpoint-2500/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.53.1"
13
+ }
checkpoint-2500/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step2500
checkpoint-2500/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-2500/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
checkpoint-2500/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
checkpoint-2500/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-2500/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info(f"Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info(f"Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
checkpoint-3000/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
checkpoint-3000/chat_template.jinja ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- else %}
85
+ {{- '<think>\n\n' }}
86
+ {%- endif %}
87
+ {%- endif %}
checkpoint-3000/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 151645,
8
+ "head_dim": 128,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 6144,
13
+ "layer_types": [
14
+ "full_attention",
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention"
42
+ ],
43
+ "max_position_embeddings": 40960,
44
+ "max_window_layers": 28,
45
+ "model_type": "qwen3",
46
+ "num_attention_heads": 16,
47
+ "num_hidden_layers": 28,
48
+ "num_key_value_heads": 8,
49
+ "rms_norm_eps": 1e-06,
50
+ "rope_scaling": null,
51
+ "rope_theta": 1000000,
52
+ "sliding_window": null,
53
+ "tie_word_embeddings": true,
54
+ "torch_dtype": "bfloat16",
55
+ "transformers_version": "4.53.1",
56
+ "use_cache": false,
57
+ "use_sliding_window": false,
58
+ "vocab_size": 151936
59
+ }
checkpoint-3000/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step3000
checkpoint-3000/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-3000/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
checkpoint-3000/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
checkpoint-3000/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-3000/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-3000/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info(f"Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info(f"Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
checkpoint-3297/added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
checkpoint-3297/chat_template.jinja ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
+ {%- elif message.role == "assistant" %}
29
+ {%- set content = message.content %}
30
+ {%- set reasoning_content = '' %}
31
+ {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
+ {%- set reasoning_content = message.reasoning_content %}
33
+ {%- else %}
34
+ {%- if '</think>' in message.content %}
35
+ {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
+ {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
+ {%- endif %}
38
+ {%- endif %}
39
+ {%- if loop.index0 > ns.last_query_index %}
40
+ {%- if loop.last or (not loop.last and reasoning_content) %}
41
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
+ {%- else %}
43
+ {{- '<|im_start|>' + message.role + '\n' + content }}
44
+ {%- endif %}
45
+ {%- else %}
46
+ {{- '<|im_start|>' + message.role + '\n' + content }}
47
+ {%- endif %}
48
+ {%- if message.tool_calls %}
49
+ {%- for tool_call in message.tool_calls %}
50
+ {%- if (loop.first and content) or (not loop.first) %}
51
+ {{- '\n' }}
52
+ {%- endif %}
53
+ {%- if tool_call.function %}
54
+ {%- set tool_call = tool_call.function %}
55
+ {%- endif %}
56
+ {{- '<tool_call>\n{"name": "' }}
57
+ {{- tool_call.name }}
58
+ {{- '", "arguments": ' }}
59
+ {%- if tool_call.arguments is string %}
60
+ {{- tool_call.arguments }}
61
+ {%- else %}
62
+ {{- tool_call.arguments | tojson }}
63
+ {%- endif %}
64
+ {{- '}\n</tool_call>' }}
65
+ {%- endfor %}
66
+ {%- endif %}
67
+ {{- '<|im_end|>\n' }}
68
+ {%- elif message.role == "tool" %}
69
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
+ {{- '<|im_start|>user' }}
71
+ {%- endif %}
72
+ {{- '\n<tool_response>\n' }}
73
+ {{- message.content }}
74
+ {{- '\n</tool_response>' }}
75
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
+ {{- '<|im_end|>\n' }}
77
+ {%- endif %}
78
+ {%- endif %}
79
+ {%- endfor %}
80
+ {%- if add_generation_prompt %}
81
+ {{- '<|im_start|>assistant\n' }}
82
+ {%- if enable_thinking is defined and enable_thinking is false %}
83
+ {{- '<think>\n\n</think>\n\n' }}
84
+ {%- else %}
85
+ {{- '<think>\n\n' }}
86
+ {%- endif %}
87
+ {%- endif %}
checkpoint-3297/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 151645,
8
+ "head_dim": 128,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 6144,
13
+ "layer_types": [
14
+ "full_attention",
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention"
42
+ ],
43
+ "max_position_embeddings": 40960,
44
+ "max_window_layers": 28,
45
+ "model_type": "qwen3",
46
+ "num_attention_heads": 16,
47
+ "num_hidden_layers": 28,
48
+ "num_key_value_heads": 8,
49
+ "rms_norm_eps": 1e-06,
50
+ "rope_scaling": null,
51
+ "rope_theta": 1000000,
52
+ "sliding_window": null,
53
+ "tie_word_embeddings": true,
54
+ "torch_dtype": "bfloat16",
55
+ "transformers_version": "4.53.1",
56
+ "use_cache": false,
57
+ "use_sliding_window": false,
58
+ "vocab_size": 151936
59
+ }
checkpoint-3297/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.53.1"
13
+ }
checkpoint-3297/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ global_step3297
checkpoint-3297/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-3297/special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
checkpoint-3297/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
checkpoint-3297/trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-3297/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-3297/zero_to_fp32.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import gc
25
+ import json
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+
31
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
32
+ # DeepSpeed data structures it has to be available in the current python environment.
33
+ from deepspeed.utils import logger
34
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
35
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
37
+
38
+
39
+ @dataclass
40
+ class zero_model_state:
41
+ buffers: dict()
42
+ param_shapes: dict()
43
+ shared_params: list
44
+ ds_version: int
45
+ frozen_param_shapes: dict()
46
+ frozen_param_fragments: dict()
47
+
48
+
49
+ debug = 0
50
+
51
+ # load to cpu
52
+ device = torch.device('cpu')
53
+
54
+
55
+ def atoi(text):
56
+ return int(text) if text.isdigit() else text
57
+
58
+
59
+ def natural_keys(text):
60
+ '''
61
+ alist.sort(key=natural_keys) sorts in human order
62
+ http://nedbatchelder.com/blog/200712/human_sorting.html
63
+ (See Toothy's implementation in the comments)
64
+ '''
65
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
66
+
67
+
68
+ def get_model_state_file(checkpoint_dir, zero_stage):
69
+ if not os.path.isdir(checkpoint_dir):
70
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
71
+
72
+ # there should be only one file
73
+ if zero_stage <= 2:
74
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
75
+ elif zero_stage == 3:
76
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
77
+
78
+ if not os.path.exists(file):
79
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
80
+
81
+ return file
82
+
83
+
84
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
85
+ # XXX: need to test that this simple glob rule works for multi-node setup too
86
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
87
+
88
+ if len(ckpt_files) == 0:
89
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
90
+
91
+ return ckpt_files
92
+
93
+
94
+ def get_optim_files(checkpoint_dir):
95
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
96
+
97
+
98
+ def get_model_state_files(checkpoint_dir):
99
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
100
+
101
+
102
+ def parse_model_states(files):
103
+ zero_model_states = []
104
+ for file in files:
105
+ state_dict = torch.load(file, map_location=device, weights_only=False)
106
+
107
+ if BUFFER_NAMES not in state_dict:
108
+ raise ValueError(f"{file} is not a model state checkpoint")
109
+ buffer_names = state_dict[BUFFER_NAMES]
110
+ if debug:
111
+ print("Found buffers:", buffer_names)
112
+
113
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
114
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
115
+ param_shapes = state_dict[PARAM_SHAPES]
116
+
117
+ # collect parameters that are included in param_shapes
118
+ param_names = []
119
+ for s in param_shapes:
120
+ for name in s.keys():
121
+ param_names.append(name)
122
+
123
+ # update with frozen parameters
124
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
125
+ if frozen_param_shapes is not None:
126
+ if debug:
127
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
128
+ param_names += list(frozen_param_shapes.keys())
129
+
130
+ # handle shared params
131
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
132
+
133
+ ds_version = state_dict.get(DS_VERSION, None)
134
+
135
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
136
+
137
+ z_model_state = zero_model_state(buffers=buffers,
138
+ param_shapes=param_shapes,
139
+ shared_params=shared_params,
140
+ ds_version=ds_version,
141
+ frozen_param_shapes=frozen_param_shapes,
142
+ frozen_param_fragments=frozen_param_fragments)
143
+ zero_model_states.append(z_model_state)
144
+
145
+ return zero_model_states
146
+
147
+
148
+ def parse_optim_states(files, ds_checkpoint_dir):
149
+ total_files = len(files)
150
+ state_dicts = []
151
+ for f in tqdm(files, desc='Loading checkpoint shards'):
152
+ state_dict = torch.load(f, map_location=device, mmap=True, weights_only=False)
153
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
154
+ # and also handle the case where it was already removed by another helper script
155
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
156
+ state_dicts.append(state_dict)
157
+
158
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
159
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
160
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
161
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
162
+
163
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
164
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
165
+ # use the max of the partition_count to get the dp world_size.
166
+
167
+ if type(world_size) is list:
168
+ world_size = max(world_size)
169
+
170
+ if world_size != total_files:
171
+ raise ValueError(
172
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
173
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
174
+ )
175
+
176
+ # the groups are named differently in each stage
177
+ if zero_stage <= 2:
178
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
179
+ elif zero_stage == 3:
180
+ fp32_groups_key = FP32_FLAT_GROUPS
181
+ else:
182
+ raise ValueError(f"unknown zero stage {zero_stage}")
183
+
184
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
185
+ return zero_stage, world_size, fp32_flat_groups
186
+
187
+
188
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
189
+ """
190
+ Returns fp32 state_dict reconstructed from ds checkpoint
191
+
192
+ Args:
193
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
194
+
195
+ """
196
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
197
+
198
+ optim_files = get_optim_files(ds_checkpoint_dir)
199
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
200
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
201
+
202
+ model_files = get_model_state_files(ds_checkpoint_dir)
203
+
204
+ zero_model_states = parse_model_states(model_files)
205
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
206
+
207
+ if zero_stage <= 2:
208
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
209
+ exclude_frozen_parameters)
210
+ elif zero_stage == 3:
211
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
212
+ exclude_frozen_parameters)
213
+
214
+
215
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
216
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
217
+ return
218
+
219
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
220
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
221
+
222
+ if debug:
223
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
224
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
225
+
226
+ wanted_params = len(frozen_param_shapes)
227
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
228
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
229
+ print(f'Frozen params: Have {avail_numel} numels to process.')
230
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
231
+
232
+ total_params = 0
233
+ total_numel = 0
234
+ for name, shape in frozen_param_shapes.items():
235
+ total_params += 1
236
+ unpartitioned_numel = shape.numel()
237
+ total_numel += unpartitioned_numel
238
+
239
+ state_dict[name] = frozen_param_fragments[name]
240
+
241
+ if debug:
242
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
243
+
244
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
245
+
246
+
247
+ def _has_callable(obj, fn):
248
+ attr = getattr(obj, fn, None)
249
+ return callable(attr)
250
+
251
+
252
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
253
+ param_shapes = zero_model_states[0].param_shapes
254
+
255
+ # Reconstruction protocol:
256
+ #
257
+ # XXX: document this
258
+
259
+ if debug:
260
+ for i in range(world_size):
261
+ for j in range(len(fp32_flat_groups[0])):
262
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
263
+
264
+ # XXX: memory usage doubles here (zero2)
265
+ num_param_groups = len(fp32_flat_groups[0])
266
+ merged_single_partition_of_fp32_groups = []
267
+ for i in range(num_param_groups):
268
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
269
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
270
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
271
+ avail_numel = sum(
272
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
273
+
274
+ if debug:
275
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
276
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
277
+ # not asserting if there is a mismatch due to possible padding
278
+ print(f"Have {avail_numel} numels to process.")
279
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
280
+
281
+ # params
282
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
283
+ # out-of-core computing solution
284
+ total_numel = 0
285
+ total_params = 0
286
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
287
+ offset = 0
288
+ avail_numel = full_single_fp32_vector.numel()
289
+ for name, shape in shapes.items():
290
+
291
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
292
+ total_numel += unpartitioned_numel
293
+ total_params += 1
294
+
295
+ if debug:
296
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
297
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
298
+ offset += unpartitioned_numel
299
+
300
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
301
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
302
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
303
+ # live optimizer object, so we are checking that the numbers are within the right range
304
+ align_to = 2 * world_size
305
+
306
+ def zero2_align(x):
307
+ return align_to * math.ceil(x / align_to)
308
+
309
+ if debug:
310
+ print(f"original offset={offset}, avail_numel={avail_numel}")
311
+
312
+ offset = zero2_align(offset)
313
+ avail_numel = zero2_align(avail_numel)
314
+
315
+ if debug:
316
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
317
+
318
+ # Sanity check
319
+ if offset != avail_numel:
320
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
321
+
322
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
323
+
324
+
325
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
326
+ exclude_frozen_parameters):
327
+ state_dict = OrderedDict()
328
+
329
+ # buffers
330
+ buffers = zero_model_states[0].buffers
331
+ state_dict.update(buffers)
332
+ if debug:
333
+ print(f"added {len(buffers)} buffers")
334
+
335
+ if not exclude_frozen_parameters:
336
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
337
+
338
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
339
+
340
+ # recover shared parameters
341
+ for pair in zero_model_states[0].shared_params:
342
+ if pair[1] in state_dict:
343
+ state_dict[pair[0]] = state_dict[pair[1]]
344
+
345
+ return state_dict
346
+
347
+
348
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
349
+ remainder = unpartitioned_numel % world_size
350
+ padding_numel = (world_size - remainder) if remainder else 0
351
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
352
+ return partitioned_numel, padding_numel
353
+
354
+
355
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
356
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
357
+ return
358
+
359
+ if debug:
360
+ for i in range(world_size):
361
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
362
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
363
+
364
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
365
+ wanted_params = len(frozen_param_shapes)
366
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
367
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
368
+ print(f'Frozen params: Have {avail_numel} numels to process.')
369
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
370
+
371
+ total_params = 0
372
+ total_numel = 0
373
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
374
+ total_params += 1
375
+ unpartitioned_numel = shape.numel()
376
+ total_numel += unpartitioned_numel
377
+
378
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
379
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
380
+
381
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
382
+
383
+ if debug:
384
+ print(
385
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
386
+ )
387
+
388
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
389
+
390
+
391
+ class GatheredTensor:
392
+ """
393
+ A pseudo tensor that collects partitioned weights.
394
+ It is more memory efficient when there are multiple groups.
395
+ """
396
+
397
+ def __init__(self, flat_groups, flat_groups_offset, offset, partitioned_numel, shape):
398
+ self.flat_groups = flat_groups
399
+ self.flat_groups_offset = flat_groups_offset
400
+ self.offset = offset
401
+ self.partitioned_numel = partitioned_numel
402
+ self.shape = shape
403
+ self.dtype = self.flat_groups[0][0].dtype
404
+
405
+ def contiguous(self):
406
+ """
407
+ Merge partitioned weights from flat_groups into a single tensor.
408
+ """
409
+ end_idx = self.offset + self.partitioned_numel
410
+ world_size = len(self.flat_groups)
411
+ pad_flat_param_chunks = []
412
+
413
+ for rank_i in range(world_size):
414
+ # for each rank, we need to collect weights from related group/groups
415
+ flat_groups_at_rank_i = self.flat_groups[rank_i]
416
+ start_group_id = None
417
+ end_group_id = None
418
+ for group_id in range(len(self.flat_groups_offset)):
419
+ if self.flat_groups_offset[group_id] <= self.offset < self.flat_groups_offset[group_id + 1]:
420
+ start_group_id = group_id
421
+ if self.flat_groups_offset[group_id] < end_idx <= self.flat_groups_offset[group_id + 1]:
422
+ end_group_id = group_id
423
+ break
424
+ # collect weights from related group/groups
425
+ for group_id in range(start_group_id, end_group_id + 1):
426
+ flat_tensor = flat_groups_at_rank_i[group_id]
427
+ start_offset = self.offset - self.flat_groups_offset[group_id]
428
+ end_offset = min(end_idx, self.flat_groups_offset[group_id + 1]) - self.flat_groups_offset[group_id]
429
+ pad_flat_param_chunks.append(flat_tensor[start_offset:end_offset])
430
+
431
+ # collect weights from all ranks
432
+ pad_flat_param = torch.cat(pad_flat_param_chunks, dim=0)
433
+ param = pad_flat_param[:self.shape.numel()].view(self.shape).contiguous()
434
+ return param
435
+
436
+
437
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
438
+ param_shapes = zero_model_states[0].param_shapes
439
+ avail_numel = sum([flat_group.numel() for flat_group in fp32_flat_groups[0]]) * world_size
440
+
441
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
442
+ # param, re-consolidating each param, while dealing with padding if any
443
+
444
+ # merge list of dicts, preserving order
445
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
446
+
447
+ if debug:
448
+ for i in range(world_size):
449
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
450
+
451
+ wanted_params = len(param_shapes)
452
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
453
+ # not asserting if there is a mismatch due to possible padding
454
+ avail_numel = fp32_flat_groups[0].numel() * world_size
455
+ print(f"Trainable params: Have {avail_numel} numels to process.")
456
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
457
+
458
+ # params
459
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
460
+ # out-of-core computing solution
461
+ offset = 0
462
+ total_numel = 0
463
+ total_params = 0
464
+ flat_groups_offset = [0] + list(np.cumsum([flat_tensor.numel() for flat_tensor in fp32_flat_groups[0]]))
465
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering sharded weights'):
466
+ unpartitioned_numel = shape.numel()
467
+ total_numel += unpartitioned_numel
468
+ total_params += 1
469
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
470
+
471
+ if debug:
472
+ print(
473
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
474
+ )
475
+
476
+ # memory efficient tensor
477
+ tensor = GatheredTensor(fp32_flat_groups, flat_groups_offset, offset, partitioned_numel, shape)
478
+ state_dict[name] = tensor
479
+ offset += partitioned_numel
480
+
481
+ offset *= world_size
482
+
483
+ # Sanity check
484
+ if offset != avail_numel:
485
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
486
+
487
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
488
+
489
+
490
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
491
+ exclude_frozen_parameters):
492
+ state_dict = OrderedDict()
493
+
494
+ # buffers
495
+ buffers = zero_model_states[0].buffers
496
+ state_dict.update(buffers)
497
+ if debug:
498
+ print(f"added {len(buffers)} buffers")
499
+
500
+ if not exclude_frozen_parameters:
501
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
502
+
503
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
504
+
505
+ # recover shared parameters
506
+ for pair in zero_model_states[0].shared_params:
507
+ if pair[1] in state_dict:
508
+ state_dict[pair[0]] = state_dict[pair[1]]
509
+
510
+ return state_dict
511
+
512
+
513
+ def to_torch_tensor(state_dict, return_empty_tensor=False):
514
+ """
515
+ Convert state_dict of GatheredTensor to torch tensor
516
+ """
517
+ torch_state_dict = {}
518
+ converted_tensors = {}
519
+ for name, tensor in state_dict.items():
520
+ tensor_id = id(tensor)
521
+ if tensor_id in converted_tensors: # shared tensors
522
+ shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
523
+ torch_state_dict[name] = shared_tensor
524
+ else:
525
+ converted_tensors[tensor_id] = name
526
+ if return_empty_tensor:
527
+ torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
528
+ else:
529
+ torch_state_dict[name] = tensor.contiguous()
530
+ return torch_state_dict
531
+
532
+
533
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
534
+ tag=None,
535
+ exclude_frozen_parameters=False,
536
+ lazy_mode=False):
537
+ """
538
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
539
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
540
+ via a model hub.
541
+
542
+ Args:
543
+ - ``checkpoint_dir``: path to the desired checkpoint folder
544
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
545
+ - ``exclude_frozen_parameters``: exclude frozen parameters
546
+ - ``lazy_mode``: get state_dict in lazy mode. It returns a dict of pesduo tensor instead of torch tensor, which is more memory efficient.
547
+ Convert the pesduo tensor to torch tensor by ``.contiguous()``
548
+
549
+ Returns:
550
+ - pytorch ``state_dict``
551
+
552
+ A typical usage might be ::
553
+
554
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
555
+ # do the training and checkpoint saving
556
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
557
+ model = model.cpu() # move to cpu
558
+ model.load_state_dict(state_dict)
559
+ # submit to model hub or save the model to share with others
560
+
561
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
562
+ application. i.e. you will need to re-initialize the deepspeed engine, since
563
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
564
+
565
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
566
+
567
+ Note: the above usage may not work if your application doesn't have sufficient free CPU memory.
568
+ You may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
569
+ the checkpoint. Or you can load state_dict in lazy mode ::
570
+
571
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
572
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, lazy_mode=True) # not on cpu
573
+ for name, lazy_tensor in state_dict.item():
574
+ tensor = lazy_tensor.contiguous() # to cpu
575
+ print(name, tensor)
576
+ # del tensor to release memory if it no longer in use
577
+ """
578
+ if tag is None:
579
+ latest_path = os.path.join(checkpoint_dir, 'latest')
580
+ if os.path.isfile(latest_path):
581
+ with open(latest_path, 'r') as fd:
582
+ tag = fd.read().strip()
583
+ else:
584
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
585
+
586
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
587
+
588
+ if not os.path.isdir(ds_checkpoint_dir):
589
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
590
+
591
+ state_dict = _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
592
+ if lazy_mode:
593
+ return state_dict
594
+ else:
595
+ return to_torch_tensor(state_dict)
596
+
597
+
598
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
599
+ output_dir,
600
+ max_shard_size="5GB",
601
+ safe_serialization=False,
602
+ tag=None,
603
+ exclude_frozen_parameters=False):
604
+ """
605
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
606
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
607
+
608
+ Args:
609
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
610
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
611
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
612
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
613
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
614
+ - ``exclude_frozen_parameters``: exclude frozen parameters
615
+ """
616
+
617
+ # Dependency pre-check
618
+ if safe_serialization:
619
+ try:
620
+ from safetensors.torch import save_file
621
+ except ImportError:
622
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
623
+ raise
624
+ if max_shard_size is not None:
625
+ try:
626
+ from huggingface_hub import split_torch_state_dict_into_shards
627
+ except ImportError:
628
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
629
+ raise
630
+
631
+ # Convert zero checkpoint to state_dict
632
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
633
+ tag,
634
+ exclude_frozen_parameters,
635
+ lazy_mode=True)
636
+
637
+ # Shard the model if it is too big.
638
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
639
+ if max_shard_size is not None:
640
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
641
+ # an memory-efficient approach for sharding
642
+ empty_state_dict = to_torch_tensor(state_dict, return_empty_tensor=True)
643
+ state_dict_split = split_torch_state_dict_into_shards(empty_state_dict,
644
+ filename_pattern=filename_pattern,
645
+ max_shard_size=max_shard_size)
646
+ else:
647
+ from collections import namedtuple
648
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
649
+ state_dict_split = StateDictSplit(is_sharded=False,
650
+ filename_to_tensors={weights_name: list(state_dict.keys())})
651
+
652
+ # Save the model by shard
653
+ os.makedirs(output_dir, exist_ok=True)
654
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
655
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
656
+ shard_state_dict = {tensor_name: state_dict[tensor_name] for tensor_name in tensors}
657
+ shard_state_dict = to_torch_tensor(shard_state_dict)
658
+ output_path = os.path.join(output_dir, shard_file)
659
+ if safe_serialization:
660
+ save_file(shard_state_dict, output_path, metadata={"format": "pt"})
661
+ else:
662
+ torch.save(shard_state_dict, output_path)
663
+ # release the memory of current shard
664
+ for tensor_name in list(shard_state_dict.keys()):
665
+ del state_dict[tensor_name]
666
+ del shard_state_dict[tensor_name]
667
+ del shard_state_dict
668
+ gc.collect()
669
+
670
+ # Save index if sharded
671
+ if state_dict_split.is_sharded:
672
+ index = {
673
+ "metadata": state_dict_split.metadata,
674
+ "weight_map": state_dict_split.tensor_to_filename,
675
+ }
676
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
677
+ save_index_file = os.path.join(output_dir, save_index_file)
678
+ with open(save_index_file, "w", encoding="utf-8") as f:
679
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
680
+ f.write(content)
681
+
682
+
683
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
684
+ """
685
+ 1. Put the provided model to cpu
686
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
687
+ 3. Load it into the provided model
688
+
689
+ Args:
690
+ - ``model``: the model object to update
691
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
692
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
693
+
694
+ Returns:
695
+ - ``model`: modified model
696
+
697
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
698
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
699
+ conveniently placed for you in the checkpoint folder.
700
+
701
+ A typical usage might be ::
702
+
703
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
704
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
705
+ # submit to model hub or save the model to share with others
706
+
707
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
708
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
709
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
710
+
711
+ """
712
+ logger.info(f"Extracting fp32 weights")
713
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
714
+
715
+ logger.info(f"Overwriting model with fp32 weights")
716
+ model = model.cpu()
717
+ model.load_state_dict(state_dict, strict=False)
718
+
719
+ return model
720
+
721
+
722
+ if __name__ == "__main__":
723
+ parser = argparse.ArgumentParser()
724
+ parser.add_argument("checkpoint_dir",
725
+ type=str,
726
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
727
+ parser.add_argument("output_dir",
728
+ type=str,
729
+ help="directory to the pytorch fp32 state_dict output files"
730
+ "(e.g. path/checkpoint-12-output/)")
731
+ parser.add_argument(
732
+ "--max_shard_size",
733
+ type=str,
734
+ default="5GB",
735
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
736
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
737
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
738
+ "without CPU OOM issues.")
739
+ parser.add_argument(
740
+ "--safe_serialization",
741
+ default=False,
742
+ action='store_true',
743
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
744
+ parser.add_argument("-t",
745
+ "--tag",
746
+ type=str,
747
+ default=None,
748
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
749
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
750
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
751
+ args = parser.parse_args()
752
+
753
+ debug = args.debug
754
+
755
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
756
+ args.output_dir,
757
+ max_shard_size=args.max_shard_size,
758
+ safe_serialization=args.safe_serialization,
759
+ tag=args.tag,
760
+ exclude_frozen_parameters=args.exclude_frozen_parameters)
config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "eos_token_id": 151645,
8
+ "head_dim": 128,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 2048,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 6144,
13
+ "layer_types": [
14
+ "full_attention",
15
+ "full_attention",
16
+ "full_attention",
17
+ "full_attention",
18
+ "full_attention",
19
+ "full_attention",
20
+ "full_attention",
21
+ "full_attention",
22
+ "full_attention",
23
+ "full_attention",
24
+ "full_attention",
25
+ "full_attention",
26
+ "full_attention",
27
+ "full_attention",
28
+ "full_attention",
29
+ "full_attention",
30
+ "full_attention",
31
+ "full_attention",
32
+ "full_attention",
33
+ "full_attention",
34
+ "full_attention",
35
+ "full_attention",
36
+ "full_attention",
37
+ "full_attention",
38
+ "full_attention",
39
+ "full_attention",
40
+ "full_attention",
41
+ "full_attention"
42
+ ],
43
+ "max_position_embeddings": 40960,
44
+ "max_window_layers": 28,
45
+ "model_type": "qwen3",
46
+ "num_attention_heads": 16,
47
+ "num_hidden_layers": 28,
48
+ "num_key_value_heads": 8,
49
+ "rms_norm_eps": 1e-06,
50
+ "rope_scaling": null,
51
+ "rope_theta": 1000000,
52
+ "sliding_window": null,
53
+ "tie_word_embeddings": true,
54
+ "torch_dtype": "bfloat16",
55
+ "transformers_version": "4.53.1",
56
+ "use_cache": false,
57
+ "use_sliding_window": false,
58
+ "vocab_size": 151936
59
+ }
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.53.1"
13
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }