diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..52373fe24473b1aa44333d318f578ae6bf04b49b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+tokenizer.json filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..181c3d85e8120399467234f968b6484f2a88891c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,181 @@
+---
+license: other
+library_name: peft
+tags:
+- generated_from_trainer
+base_model: google/gemma-7b
+model-index:
+- name: gemma-python
+ results: []
+---
+
+
+
+[
](https://github.com/OpenAccess-AI-Collective/axolotl)
+See axolotl config
+
+axolotl version: `0.4.0`
+```yaml
+# use google/gemma-7b if you have access
+base_model: google/gemma-7b
+model_type: AutoModelForCausalLM
+tokenizer_type: AutoTokenizer
+
+
+load_in_8bit: false
+load_in_4bit: true
+strict: false
+
+# huggingface repo
+datasets:
+ - path: ./dataset/data1.jsonl
+ type: input_output
+val_set_size: 0.1
+output_dir: ./gemma-python
+
+adapter: qlora
+lora_r: 32
+lora_alpha: 16
+lora_dropout: 0.05
+lora_target_linear: true
+
+sequence_len: 4096
+sample_packing: false
+pad_to_sequence_len: true
+
+wandb_project:
+wandb_entity:
+wandb_watch:
+wandb_name:
+wandb_log_model:
+
+
+gradient_accumulation_steps: 3
+micro_batch_size: 2
+num_epochs: 10
+optimizer: adamw_bnb_8bit
+lr_scheduler: cosine
+learning_rate: 0.0002
+
+train_on_inputs: false
+group_by_length: false
+bf16: auto
+fp16:
+tf32: false
+
+gradient_checkpointing: true
+early_stopping_patience:
+resume_from_checkpoint:
+local_rank:
+logging_steps: 1
+xformers_attention:
+flash_attention: true
+
+warmup_ratio: 0.1
+evals_per_epoch: 4
+eval_table_size:
+eval_max_new_tokens: 128
+saves_per_epoch: 1
+debug:
+deepspeed: deepspeed_configs/zero1.json
+weight_decay: 0.0
+fsdp:
+fsdp_config:
+special_tokens:
+
+```
+
+
+
+# gemma-python
+
+This model is a fine-tuned version of [google/gemma-7b](https://huggingface.co/google/gemma-7b) on the None dataset.
+It achieves the following results on the evaluation set:
+- Loss: 2.1143
+
+## Model description
+
+More information needed
+
+## Intended uses & limitations
+
+More information needed
+
+## Training and evaluation data
+
+More information needed
+
+## Training procedure
+
+### Training hyperparameters
+
+The following hyperparameters were used during training:
+- learning_rate: 0.0002
+- train_batch_size: 2
+- eval_batch_size: 2
+- seed: 42
+- distributed_type: multi-GPU
+- num_devices: 4
+- gradient_accumulation_steps: 3
+- total_train_batch_size: 24
+- total_eval_batch_size: 8
+- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
+- lr_scheduler_type: cosine
+- lr_scheduler_warmup_steps: 2
+- num_epochs: 10
+
+### Training results
+
+| Training Loss | Epoch | Step | Validation Loss |
+|:-------------:|:-----:|:----:|:---------------:|
+| 19.0016 | 0.12 | 1 | 18.6992 |
+| 19.4686 | 0.25 | 2 | 16.2578 |
+| 11.468 | 0.5 | 4 | 8.2891 |
+| 7.5305 | 0.75 | 6 | 5.8847 |
+| 5.7572 | 1.0 | 8 | 4.3635 |
+| 4.3903 | 1.25 | 10 | 3.2849 |
+| 2.9497 | 1.5 | 12 | 2.8539 |
+| 2.8738 | 1.75 | 14 | 2.6203 |
+| 2.7298 | 2.0 | 16 | 2.4534 |
+| 2.4284 | 2.25 | 18 | 2.3077 |
+| 2.394 | 2.5 | 20 | 2.1876 |
+| 2.069 | 2.75 | 22 | 2.1294 |
+| 1.9355 | 3.0 | 24 | 2.1048 |
+| 1.9635 | 3.25 | 26 | 2.0707 |
+| 2.092 | 3.5 | 28 | 2.0596 |
+| 1.9675 | 3.75 | 30 | 2.0287 |
+| 1.9693 | 4.0 | 32 | 2.0220 |
+| 2.0198 | 4.25 | 34 | 2.0124 |
+| 1.9357 | 4.5 | 36 | 1.9946 |
+| 1.8147 | 4.75 | 38 | 1.9979 |
+| 1.9084 | 5.0 | 40 | 1.9751 |
+| 1.6678 | 5.25 | 42 | 2.0049 |
+| 1.7639 | 5.5 | 44 | 1.9885 |
+| 1.7475 | 5.75 | 46 | 1.9777 |
+| 1.4848 | 6.0 | 48 | 1.9939 |
+| 1.3065 | 6.25 | 50 | 2.0264 |
+| 1.4792 | 6.5 | 52 | 2.0125 |
+| 1.4233 | 6.75 | 54 | 2.0204 |
+| 1.2534 | 7.0 | 56 | 2.0318 |
+| 1.2409 | 7.25 | 58 | 2.0445 |
+| 1.4309 | 7.5 | 60 | 2.0641 |
+| 1.1622 | 7.75 | 62 | 2.0633 |
+| 1.228 | 8.0 | 64 | 2.0930 |
+| 1.3076 | 8.25 | 66 | 2.1077 |
+| 1.2323 | 8.5 | 68 | 2.1060 |
+| 1.1635 | 8.75 | 70 | 2.1039 |
+| 1.261 | 9.0 | 72 | 2.1068 |
+| 1.0122 | 9.25 | 74 | 2.1110 |
+| 1.218 | 9.5 | 76 | 2.1180 |
+| 1.1022 | 9.75 | 78 | 2.1226 |
+| 1.2072 | 10.0 | 80 | 2.1143 |
+
+
+### Framework versions
+
+- PEFT 0.9.0
+- Transformers 4.38.2
+- Pytorch 2.2.1
+- Datasets 2.18.0
+- Tokenizers 0.15.0
\ No newline at end of file
diff --git a/adapter_config.json b/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f48c351b3328c029833db4675ebe2c0dbdf14af4
--- /dev/null
+++ b/adapter_config.json
@@ -0,0 +1,33 @@
+{
+ "alpha_pattern": {},
+ "auto_mapping": null,
+ "base_model_name_or_path": "google/gemma-7b",
+ "bias": "none",
+ "fan_in_fan_out": null,
+ "inference_mode": true,
+ "init_lora_weights": true,
+ "layers_pattern": null,
+ "layers_to_transform": null,
+ "loftq_config": {},
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "megatron_config": null,
+ "megatron_core": "megatron.core",
+ "modules_to_save": null,
+ "peft_type": "LORA",
+ "r": 32,
+ "rank_pattern": {},
+ "revision": null,
+ "target_modules": [
+ "o_proj",
+ "up_proj",
+ "k_proj",
+ "q_proj",
+ "v_proj",
+ "gate_proj",
+ "down_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "use_dora": false,
+ "use_rslora": false
+}
\ No newline at end of file
diff --git a/adapter_model.bin b/adapter_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..3ab23ca8d90396282555dce094a51a00414f7dc3
--- /dev/null
+++ b/adapter_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56027d240d1c3a71c42694a8d10be6ce43895d6fcf1f952d2c724a9e87326474
+size 200078074
diff --git a/checkpoint-40/README.md b/checkpoint-40/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7fde103e177d517a68ed416ca36925d7f86b488b
--- /dev/null
+++ b/checkpoint-40/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: google/gemma-7b
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.9.0
\ No newline at end of file
diff --git a/checkpoint-40/adapter_config.json b/checkpoint-40/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f48c351b3328c029833db4675ebe2c0dbdf14af4
--- /dev/null
+++ b/checkpoint-40/adapter_config.json
@@ -0,0 +1,33 @@
+{
+ "alpha_pattern": {},
+ "auto_mapping": null,
+ "base_model_name_or_path": "google/gemma-7b",
+ "bias": "none",
+ "fan_in_fan_out": null,
+ "inference_mode": true,
+ "init_lora_weights": true,
+ "layers_pattern": null,
+ "layers_to_transform": null,
+ "loftq_config": {},
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "megatron_config": null,
+ "megatron_core": "megatron.core",
+ "modules_to_save": null,
+ "peft_type": "LORA",
+ "r": 32,
+ "rank_pattern": {},
+ "revision": null,
+ "target_modules": [
+ "o_proj",
+ "up_proj",
+ "k_proj",
+ "q_proj",
+ "v_proj",
+ "gate_proj",
+ "down_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "use_dora": false,
+ "use_rslora": false
+}
\ No newline at end of file
diff --git a/checkpoint-40/adapter_model.safetensors b/checkpoint-40/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3a9542b2d2dbd1bd7bb60461c1fcad499517d447
--- /dev/null
+++ b/checkpoint-40/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d1aed72bc91825c6eaf1575f0a5c94f50d56417ca5b53f0b81cce29048b6ab70
+size 200068904
diff --git a/checkpoint-40/global_step40/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-40/global_step40/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..d239877dd78b25e75ca47ef30d9c1b709813f67a
--- /dev/null
+++ b/checkpoint-40/global_step40/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c24cda968c111802eab57451e1860053f2e38fd65f26135f266c9ffeac134c45
+size 150126608
diff --git a/checkpoint-40/global_step40/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-40/global_step40/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b5d0100bcf2c82db4e0d3ef1de1fc9b16407f5ce
--- /dev/null
+++ b/checkpoint-40/global_step40/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fad69a773e6fc3a934d7bdb9d232fc948270842ba6fc8efbf370876ffa0f7e03
+size 150126672
diff --git a/checkpoint-40/global_step40/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/checkpoint-40/global_step40/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..15fbfe191160bd50853c7c0eb094dcf1f6fe84cb
--- /dev/null
+++ b/checkpoint-40/global_step40/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cbefc165cb203c96d8e832780b04524237e7b6eecc71d7ff978399d0d0c545ee
+size 150126736
diff --git a/checkpoint-40/global_step40/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/checkpoint-40/global_step40/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..0cf551edb840aa9515255d18270c1f49bdf35232
--- /dev/null
+++ b/checkpoint-40/global_step40/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bfd9defda8117e8665d598fff4c10cf611313f809395d507fbfc7658fbf5a9e
+size 150126736
diff --git a/checkpoint-40/global_step40/mp_rank_00_model_states.pt b/checkpoint-40/global_step40/mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6484cdaf7b576fc74d53c5882609eb8616570416
--- /dev/null
+++ b/checkpoint-40/global_step40/mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:408829abb39ec8d398d4973de44b3946c13a8327253229249eaf1572b03d1b54
+size 1896781478
diff --git a/checkpoint-40/latest b/checkpoint-40/latest
new file mode 100644
index 0000000000000000000000000000000000000000..8631ab8ddebf60eb3e7f5f2c2b1a2da8298a43c3
--- /dev/null
+++ b/checkpoint-40/latest
@@ -0,0 +1 @@
+global_step40
\ No newline at end of file
diff --git a/checkpoint-40/rng_state_0.pth b/checkpoint-40/rng_state_0.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8e69fa8c1c053135df7bb242789e1d6ac005e85c
--- /dev/null
+++ b/checkpoint-40/rng_state_0.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfa8f3ba412a4ede1340e4612f378f735f109cbf5a004a7ef3413d51993099c5
+size 15024
diff --git a/checkpoint-40/rng_state_1.pth b/checkpoint-40/rng_state_1.pth
new file mode 100644
index 0000000000000000000000000000000000000000..5326d604816824d458cd595524bdf9a2c668b53f
--- /dev/null
+++ b/checkpoint-40/rng_state_1.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d540901c9ea15d4cbbe676b69891f7b748ca516ed58e850a2fd4e6d02d301a10
+size 15024
diff --git a/checkpoint-40/rng_state_2.pth b/checkpoint-40/rng_state_2.pth
new file mode 100644
index 0000000000000000000000000000000000000000..23b3867a9a6cc0d52f9a399497dea7bf6ee0b142
--- /dev/null
+++ b/checkpoint-40/rng_state_2.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c0baa6c67b9316790653f049223543efdc12d27422fe3e39b0b8ac11b1af04e
+size 15024
diff --git a/checkpoint-40/rng_state_3.pth b/checkpoint-40/rng_state_3.pth
new file mode 100644
index 0000000000000000000000000000000000000000..ac6af2abe6a014a52fc05809e9bd6f5fa17191dc
--- /dev/null
+++ b/checkpoint-40/rng_state_3.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f6491f57903cffa60cc5ed0ffde720e7ccebee6b0c3dcccdb9c0e1d27509c70
+size 15024
diff --git a/checkpoint-40/scheduler.pt b/checkpoint-40/scheduler.pt
new file mode 100644
index 0000000000000000000000000000000000000000..56773ec53282837b5ede146214bce148357a9921
--- /dev/null
+++ b/checkpoint-40/scheduler.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9fe9aa3a69b7aa00ab6c2c283052e530e526040db3d71112487efe44649fc62
+size 1064
diff --git a/checkpoint-40/trainer_state.json b/checkpoint-40/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..1d2bad40bce944c10c2456a63c1219d9a2ee254a
--- /dev/null
+++ b/checkpoint-40/trainer_state.json
@@ -0,0 +1,469 @@
+{
+ "best_metric": 1.9750508069992065,
+ "best_model_checkpoint": "./gemma-python/checkpoint-40",
+ "epoch": 5.0,
+ "eval_steps": 2,
+ "global_step": 40,
+ "is_hyper_param_search": false,
+ "is_local_process_zero": true,
+ "is_world_process_zero": true,
+ "log_history": [
+ {
+ "epoch": 0.12,
+ "grad_norm": 40.636978402335416,
+ "learning_rate": 0.0001,
+ "loss": 19.0016,
+ "step": 1
+ },
+ {
+ "epoch": 0.12,
+ "eval_loss": 18.6992130279541,
+ "eval_runtime": 2.881,
+ "eval_samples_per_second": 7.289,
+ "eval_steps_per_second": 1.041,
+ "step": 1
+ },
+ {
+ "epoch": 0.25,
+ "grad_norm": 41.61053527062362,
+ "learning_rate": 0.0002,
+ "loss": 19.4686,
+ "step": 2
+ },
+ {
+ "epoch": 0.25,
+ "eval_loss": 16.257802963256836,
+ "eval_runtime": 2.9111,
+ "eval_samples_per_second": 7.214,
+ "eval_steps_per_second": 1.031,
+ "step": 2
+ },
+ {
+ "epoch": 0.38,
+ "grad_norm": 28.704819713850974,
+ "learning_rate": 0.00019991889981715698,
+ "loss": 13.2303,
+ "step": 3
+ },
+ {
+ "epoch": 0.5,
+ "grad_norm": 26.40444243073739,
+ "learning_rate": 0.00019967573081342103,
+ "loss": 11.468,
+ "step": 4
+ },
+ {
+ "epoch": 0.5,
+ "eval_loss": 8.28911018371582,
+ "eval_runtime": 2.9257,
+ "eval_samples_per_second": 7.178,
+ "eval_steps_per_second": 1.025,
+ "step": 4
+ },
+ {
+ "epoch": 0.62,
+ "grad_norm": 12.912981323843146,
+ "learning_rate": 0.0001992708874098054,
+ "loss": 9.3107,
+ "step": 5
+ },
+ {
+ "epoch": 0.75,
+ "grad_norm": 7.943058500648636,
+ "learning_rate": 0.00019870502626379127,
+ "loss": 7.5305,
+ "step": 6
+ },
+ {
+ "epoch": 0.75,
+ "eval_loss": 5.884701728820801,
+ "eval_runtime": 2.9479,
+ "eval_samples_per_second": 7.124,
+ "eval_steps_per_second": 1.018,
+ "step": 6
+ },
+ {
+ "epoch": 0.88,
+ "grad_norm": 6.267657551985817,
+ "learning_rate": 0.00019797906520422677,
+ "loss": 6.6492,
+ "step": 7
+ },
+ {
+ "epoch": 1.0,
+ "grad_norm": 5.0825555341832365,
+ "learning_rate": 0.0001970941817426052,
+ "loss": 5.7572,
+ "step": 8
+ },
+ {
+ "epoch": 1.0,
+ "eval_loss": 4.363473892211914,
+ "eval_runtime": 2.9653,
+ "eval_samples_per_second": 7.082,
+ "eval_steps_per_second": 1.012,
+ "step": 8
+ },
+ {
+ "epoch": 1.12,
+ "grad_norm": 4.88565620317727,
+ "learning_rate": 0.00019605181116313724,
+ "loss": 4.5414,
+ "step": 9
+ },
+ {
+ "epoch": 1.25,
+ "grad_norm": 5.0847008955317605,
+ "learning_rate": 0.00019485364419471454,
+ "loss": 4.3903,
+ "step": 10
+ },
+ {
+ "epoch": 1.25,
+ "eval_loss": 3.284867763519287,
+ "eval_runtime": 2.9746,
+ "eval_samples_per_second": 7.06,
+ "eval_steps_per_second": 1.009,
+ "step": 10
+ },
+ {
+ "epoch": 1.38,
+ "grad_norm": 3.424587898800574,
+ "learning_rate": 0.0001935016242685415,
+ "loss": 3.79,
+ "step": 11
+ },
+ {
+ "epoch": 1.5,
+ "grad_norm": 2.7255824385278506,
+ "learning_rate": 0.00019199794436588243,
+ "loss": 2.9497,
+ "step": 12
+ },
+ {
+ "epoch": 1.5,
+ "eval_loss": 2.853942394256592,
+ "eval_runtime": 2.9866,
+ "eval_samples_per_second": 7.031,
+ "eval_steps_per_second": 1.004,
+ "step": 12
+ },
+ {
+ "epoch": 1.62,
+ "grad_norm": 2.1001906898750624,
+ "learning_rate": 0.00019034504346103823,
+ "loss": 2.7728,
+ "step": 13
+ },
+ {
+ "epoch": 1.75,
+ "grad_norm": 1.9200021565941778,
+ "learning_rate": 0.000188545602565321,
+ "loss": 2.8738,
+ "step": 14
+ },
+ {
+ "epoch": 1.75,
+ "eval_loss": 2.62028431892395,
+ "eval_runtime": 2.9982,
+ "eval_samples_per_second": 7.004,
+ "eval_steps_per_second": 1.001,
+ "step": 14
+ },
+ {
+ "epoch": 1.88,
+ "grad_norm": 1.8837224890225774,
+ "learning_rate": 0.00018660254037844388,
+ "loss": 3.0787,
+ "step": 15
+ },
+ {
+ "epoch": 2.0,
+ "grad_norm": 1.8929687978608318,
+ "learning_rate": 0.0001845190085543795,
+ "loss": 2.7298,
+ "step": 16
+ },
+ {
+ "epoch": 2.0,
+ "eval_loss": 2.453444242477417,
+ "eval_runtime": 2.9964,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 16
+ },
+ {
+ "epoch": 2.12,
+ "grad_norm": 1.3652069569291694,
+ "learning_rate": 0.00018229838658936564,
+ "loss": 2.5967,
+ "step": 17
+ },
+ {
+ "epoch": 2.25,
+ "grad_norm": 2.4263600812149417,
+ "learning_rate": 0.00017994427634035015,
+ "loss": 2.4284,
+ "step": 18
+ },
+ {
+ "epoch": 2.25,
+ "eval_loss": 2.307706832885742,
+ "eval_runtime": 2.9963,
+ "eval_samples_per_second": 7.009,
+ "eval_steps_per_second": 1.001,
+ "step": 18
+ },
+ {
+ "epoch": 2.38,
+ "grad_norm": 2.5673391658400053,
+ "learning_rate": 0.00017746049618276545,
+ "loss": 2.6721,
+ "step": 19
+ },
+ {
+ "epoch": 2.5,
+ "grad_norm": 2.2252437500899656,
+ "learning_rate": 0.00017485107481711012,
+ "loss": 2.394,
+ "step": 20
+ },
+ {
+ "epoch": 2.5,
+ "eval_loss": 2.187636137008667,
+ "eval_runtime": 2.9975,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 20
+ },
+ {
+ "epoch": 2.62,
+ "grad_norm": 2.345233295279928,
+ "learning_rate": 0.00017212024473438147,
+ "loss": 2.3972,
+ "step": 21
+ },
+ {
+ "epoch": 2.75,
+ "grad_norm": 1.1122620317353238,
+ "learning_rate": 0.00016927243535095997,
+ "loss": 2.069,
+ "step": 22
+ },
+ {
+ "epoch": 2.75,
+ "eval_loss": 2.1294100284576416,
+ "eval_runtime": 2.993,
+ "eval_samples_per_second": 7.016,
+ "eval_steps_per_second": 1.002,
+ "step": 22
+ },
+ {
+ "epoch": 2.88,
+ "grad_norm": 2.8270209249093803,
+ "learning_rate": 0.00016631226582407952,
+ "loss": 2.211,
+ "step": 23
+ },
+ {
+ "epoch": 3.0,
+ "grad_norm": 7.323169716541166,
+ "learning_rate": 0.00016324453755953773,
+ "loss": 1.9355,
+ "step": 24
+ },
+ {
+ "epoch": 3.0,
+ "eval_loss": 2.1047682762145996,
+ "eval_runtime": 2.9871,
+ "eval_samples_per_second": 7.03,
+ "eval_steps_per_second": 1.004,
+ "step": 24
+ },
+ {
+ "epoch": 3.12,
+ "grad_norm": 1.9938311808450486,
+ "learning_rate": 0.0001600742264237979,
+ "loss": 2.1962,
+ "step": 25
+ },
+ {
+ "epoch": 3.25,
+ "grad_norm": 3.330986691029466,
+ "learning_rate": 0.00015680647467311557,
+ "loss": 1.9635,
+ "step": 26
+ },
+ {
+ "epoch": 3.25,
+ "eval_loss": 2.0707101821899414,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 26
+ },
+ {
+ "epoch": 3.38,
+ "grad_norm": 2.0371854480792178,
+ "learning_rate": 0.0001534465826127801,
+ "loss": 2.2319,
+ "step": 27
+ },
+ {
+ "epoch": 3.5,
+ "grad_norm": 3.2163831286077653,
+ "learning_rate": 0.00015000000000000001,
+ "loss": 2.092,
+ "step": 28
+ },
+ {
+ "epoch": 3.5,
+ "eval_loss": 2.059619426727295,
+ "eval_runtime": 2.9996,
+ "eval_samples_per_second": 7.001,
+ "eval_steps_per_second": 1.0,
+ "step": 28
+ },
+ {
+ "epoch": 3.62,
+ "grad_norm": 2.853987323853131,
+ "learning_rate": 0.00014647231720437686,
+ "loss": 1.9182,
+ "step": 29
+ },
+ {
+ "epoch": 3.75,
+ "grad_norm": 2.2997509863024352,
+ "learning_rate": 0.00014286925614030542,
+ "loss": 1.9675,
+ "step": 30
+ },
+ {
+ "epoch": 3.75,
+ "eval_loss": 2.0287458896636963,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 30
+ },
+ {
+ "epoch": 3.88,
+ "grad_norm": 2.2770679758385244,
+ "learning_rate": 0.00013919666098600753,
+ "loss": 1.9815,
+ "step": 31
+ },
+ {
+ "epoch": 4.0,
+ "grad_norm": 0.8553765652252152,
+ "learning_rate": 0.00013546048870425356,
+ "loss": 1.9693,
+ "step": 32
+ },
+ {
+ "epoch": 4.0,
+ "eval_loss": 2.022012710571289,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 32
+ },
+ {
+ "epoch": 4.12,
+ "grad_norm": 3.8094922067262336,
+ "learning_rate": 0.00013166679938014726,
+ "loss": 1.6479,
+ "step": 33
+ },
+ {
+ "epoch": 4.25,
+ "grad_norm": 3.5435911597121277,
+ "learning_rate": 0.0001278217463916453,
+ "loss": 2.0198,
+ "step": 34
+ },
+ {
+ "epoch": 4.25,
+ "eval_loss": 2.012432336807251,
+ "eval_runtime": 2.9987,
+ "eval_samples_per_second": 7.003,
+ "eval_steps_per_second": 1.0,
+ "step": 34
+ },
+ {
+ "epoch": 4.38,
+ "grad_norm": 1.4676241516417539,
+ "learning_rate": 0.0001239315664287558,
+ "loss": 1.7496,
+ "step": 35
+ },
+ {
+ "epoch": 4.5,
+ "grad_norm": 1.4772602834377506,
+ "learning_rate": 0.00012000256937760445,
+ "loss": 1.9357,
+ "step": 36
+ },
+ {
+ "epoch": 4.5,
+ "eval_loss": 1.9945744276046753,
+ "eval_runtime": 3.0019,
+ "eval_samples_per_second": 6.995,
+ "eval_steps_per_second": 0.999,
+ "step": 36
+ },
+ {
+ "epoch": 4.62,
+ "grad_norm": 0.8198622785029981,
+ "learning_rate": 0.00011604112808577603,
+ "loss": 1.8365,
+ "step": 37
+ },
+ {
+ "epoch": 4.75,
+ "grad_norm": 2.5267989029749556,
+ "learning_rate": 0.0001120536680255323,
+ "loss": 1.8147,
+ "step": 38
+ },
+ {
+ "epoch": 4.75,
+ "eval_loss": 1.9979486465454102,
+ "eval_runtime": 2.9865,
+ "eval_samples_per_second": 7.032,
+ "eval_steps_per_second": 1.005,
+ "step": 38
+ },
+ {
+ "epoch": 4.88,
+ "grad_norm": 1.2889515222114942,
+ "learning_rate": 0.00010804665687167262,
+ "loss": 1.6703,
+ "step": 39
+ },
+ {
+ "epoch": 5.0,
+ "grad_norm": 1.3474067788797102,
+ "learning_rate": 0.00010402659401094152,
+ "loss": 1.9084,
+ "step": 40
+ },
+ {
+ "epoch": 5.0,
+ "eval_loss": 1.9750508069992065,
+ "eval_runtime": 2.9945,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 40
+ }
+ ],
+ "logging_steps": 1,
+ "max_steps": 80,
+ "num_input_tokens_seen": 0,
+ "num_train_epochs": 10,
+ "save_steps": 8,
+ "total_flos": 1.8523438033403904e+17,
+ "train_batch_size": 2,
+ "trial_name": null,
+ "trial_params": null
+}
diff --git a/checkpoint-40/training_args.bin b/checkpoint-40/training_args.bin
new file mode 100644
index 0000000000000000000000000000000000000000..b11ae566a70dfd7bcafb281eef91bfd37c1b257b
--- /dev/null
+++ b/checkpoint-40/training_args.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbd3cdf0c7e847516177c465407e4f8b9cbcc9b8664e3b64c39191721cf5ef99
+size 6776
diff --git a/checkpoint-40/zero_to_fp32.py b/checkpoint-40/zero_to_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b846633d6eb1e836e34681e44033581f4edb7b
--- /dev/null
+++ b/checkpoint-40/zero_to_fp32.py
@@ -0,0 +1,592 @@
+#!/usr/bin/env python
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
+# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
+# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
+# application.
+#
+# example: python zero_to_fp32.py . pytorch_model.bin
+
+import argparse
+import torch
+import glob
+import math
+import os
+import re
+from collections import OrderedDict
+from dataclasses import dataclass
+
+# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
+# DeepSpeed data structures it has to be available in the current python environment.
+from deepspeed.utils import logger
+from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
+
+
+@dataclass
+class zero_model_state:
+ buffers: dict()
+ param_shapes: dict()
+ shared_params: list
+ ds_version: int
+ frozen_param_shapes: dict()
+ frozen_param_fragments: dict()
+
+
+debug = 0
+
+# load to cpu
+device = torch.device('cpu')
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text
+
+
+def natural_keys(text):
+ '''
+ alist.sort(key=natural_keys) sorts in human order
+ http://nedbatchelder.com/blog/200712/human_sorting.html
+ (See Toothy's implementation in the comments)
+ '''
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_model_state_file(checkpoint_dir, zero_stage):
+ if not os.path.isdir(checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
+
+ # there should be only one file
+ if zero_stage <= 2:
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
+ elif zero_stage == 3:
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
+
+ if not os.path.exists(file):
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
+
+ return file
+
+
+def get_checkpoint_files(checkpoint_dir, glob_pattern):
+ # XXX: need to test that this simple glob rule works for multi-node setup too
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
+
+ if len(ckpt_files) == 0:
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
+
+ return ckpt_files
+
+
+def get_optim_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
+
+
+def get_model_state_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
+
+
+def parse_model_states(files):
+ zero_model_states = []
+ for file in files:
+ state_dict = torch.load(file, map_location=device)
+
+ if BUFFER_NAMES not in state_dict:
+ raise ValueError(f"{file} is not a model state checkpoint")
+ buffer_names = state_dict[BUFFER_NAMES]
+ if debug:
+ print("Found buffers:", buffer_names)
+
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
+ param_shapes = state_dict[PARAM_SHAPES]
+
+ # collect parameters that are included in param_shapes
+ param_names = []
+ for s in param_shapes:
+ for name in s.keys():
+ param_names.append(name)
+
+ # update with frozen parameters
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
+ if frozen_param_shapes is not None:
+ if debug:
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
+ param_names += list(frozen_param_shapes.keys())
+
+ # handle shared params
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
+
+ ds_version = state_dict.get(DS_VERSION, None)
+
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
+
+ z_model_state = zero_model_state(buffers=buffers,
+ param_shapes=param_shapes,
+ shared_params=shared_params,
+ ds_version=ds_version,
+ frozen_param_shapes=frozen_param_shapes,
+ frozen_param_fragments=frozen_param_fragments)
+ zero_model_states.append(z_model_state)
+
+ return zero_model_states
+
+
+def parse_optim_states(files, ds_checkpoint_dir):
+
+ total_files = len(files)
+ state_dicts = []
+ for f in files:
+ state_dict = torch.load(f, map_location=device)
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
+ # and also handle the case where it was already removed by another helper script
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
+ state_dicts.append(state_dict)
+
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
+
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
+ # use the max of the partition_count to get the dp world_size.
+
+ if type(world_size) is list:
+ world_size = max(world_size)
+
+ if world_size != total_files:
+ raise ValueError(
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
+ )
+
+ # the groups are named differently in each stage
+ if zero_stage <= 2:
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
+ elif zero_stage == 3:
+ fp32_groups_key = FP32_FLAT_GROUPS
+ else:
+ raise ValueError(f"unknown zero stage {zero_stage}")
+
+ if zero_stage <= 2:
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
+ elif zero_stage == 3:
+ # if there is more than one param group, there will be multiple flattened tensors - one
+ # flattened tensor per group - for simplicity merge them into a single tensor
+ #
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
+
+ fp32_flat_groups = [
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
+ ]
+
+ return zero_stage, world_size, fp32_flat_groups
+
+
+def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
+ """
+ Returns fp32 state_dict reconstructed from ds checkpoint
+
+ Args:
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
+
+ """
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
+
+ optim_files = get_optim_files(ds_checkpoint_dir)
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
+
+ model_files = get_model_state_files(ds_checkpoint_dir)
+
+ zero_model_states = parse_model_states(model_files)
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
+
+ if zero_stage <= 2:
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+ elif zero_stage == 3:
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+
+
+def _zero2_merge_frozen_params(state_dict, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
+
+ if debug:
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ state_dict[name] = frozen_param_fragments[name]
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _has_callable(obj, fn):
+ attr = getattr(obj, fn, None)
+ return callable(attr)
+
+
+def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+
+ # Reconstruction protocol:
+ #
+ # XXX: document this
+
+ if debug:
+ for i in range(world_size):
+ for j in range(len(fp32_flat_groups[0])):
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
+
+ # XXX: memory usage doubles here (zero2)
+ num_param_groups = len(fp32_flat_groups[0])
+ merged_single_partition_of_fp32_groups = []
+ for i in range(num_param_groups):
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
+ avail_numel = sum(
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
+
+ if debug:
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
+ # not asserting if there is a mismatch due to possible padding
+ print(f"Have {avail_numel} numels to process.")
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ total_numel = 0
+ total_params = 0
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
+ offset = 0
+ avail_numel = full_single_fp32_vector.numel()
+ for name, shape in shapes.items():
+
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
+ offset += unpartitioned_numel
+
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
+ # live optimizer object, so we are checking that the numbers are within the right range
+ align_to = 2 * world_size
+
+ def zero2_align(x):
+ return align_to * math.ceil(x / align_to)
+
+ if debug:
+ print(f"original offset={offset}, avail_numel={avail_numel}")
+
+ offset = zero2_align(offset)
+ avail_numel = zero2_align(avail_numel)
+
+ if debug:
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
+
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def zero3_partitioned_param_info(unpartitioned_numel, world_size):
+ remainder = unpartitioned_numel % world_size
+ padding_numel = (world_size - remainder) if remainder else 0
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
+ return partitioned_numel, padding_numel
+
+
+def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ if debug:
+ for i in range(world_size):
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
+ # param, re-consolidating each param, while dealing with padding if any
+
+ # merge list of dicts, preserving order
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
+
+ if debug:
+ for i in range(world_size):
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
+
+ wanted_params = len(param_shapes)
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
+ # not asserting if there is a mismatch due to possible padding
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ print(f"Trainable params: Have {avail_numel} numels to process.")
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ offset = 0
+ total_numel = 0
+ total_params = 0
+ for name, shape in param_shapes.items():
+
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ # XXX: memory usage doubles here
+ state_dict[name] = torch.cat(
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
+ offset += partitioned_numel
+
+ offset *= world_size
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
+
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
+ via a model hub.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
+
+ Returns:
+ - pytorch ``state_dict``
+
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
+ the checkpoint.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ # do the training and checkpoint saving
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
+ model = model.cpu() # move to cpu
+ model.load_state_dict(state_dict)
+ # submit to model hub or save the model to share with others
+
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
+ application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
+
+ """
+ if tag is None:
+ latest_path = os.path.join(checkpoint_dir, 'latest')
+ if os.path.isfile(latest_path):
+ with open(latest_path, 'r') as fd:
+ tag = fd.read().strip()
+ else:
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
+
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
+
+ if not os.path.isdir(ds_checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
+
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
+
+
+def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+ """
+
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+ print(f"Saving fp32 state dict to {output_file}")
+ torch.save(state_dict, output_file)
+
+
+def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
+ """
+ 1. Put the provided model to cpu
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
+ 3. Load it into the provided model
+
+ Args:
+ - ``model``: the model object to update
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+
+ Returns:
+ - ``model`: modified model
+
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
+ conveniently placed for you in the checkpoint folder.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
+ # submit to model hub or save the model to share with others
+
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ """
+ logger.info(f"Extracting fp32 weights")
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+
+ logger.info(f"Overwriting model with fp32 weights")
+ model = model.cpu()
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint_dir",
+ type=str,
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
+ parser.add_argument(
+ "output_file",
+ type=str,
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
+ parser.add_argument("-t",
+ "--tag",
+ type=str,
+ default=None,
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
+ args = parser.parse_args()
+
+ debug = args.debug
+
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
diff --git a/checkpoint-64/README.md b/checkpoint-64/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7fde103e177d517a68ed416ca36925d7f86b488b
--- /dev/null
+++ b/checkpoint-64/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: google/gemma-7b
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.9.0
\ No newline at end of file
diff --git a/checkpoint-64/adapter_config.json b/checkpoint-64/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f48c351b3328c029833db4675ebe2c0dbdf14af4
--- /dev/null
+++ b/checkpoint-64/adapter_config.json
@@ -0,0 +1,33 @@
+{
+ "alpha_pattern": {},
+ "auto_mapping": null,
+ "base_model_name_or_path": "google/gemma-7b",
+ "bias": "none",
+ "fan_in_fan_out": null,
+ "inference_mode": true,
+ "init_lora_weights": true,
+ "layers_pattern": null,
+ "layers_to_transform": null,
+ "loftq_config": {},
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "megatron_config": null,
+ "megatron_core": "megatron.core",
+ "modules_to_save": null,
+ "peft_type": "LORA",
+ "r": 32,
+ "rank_pattern": {},
+ "revision": null,
+ "target_modules": [
+ "o_proj",
+ "up_proj",
+ "k_proj",
+ "q_proj",
+ "v_proj",
+ "gate_proj",
+ "down_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "use_dora": false,
+ "use_rslora": false
+}
\ No newline at end of file
diff --git a/checkpoint-64/adapter_model.safetensors b/checkpoint-64/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..4f7755b95ed335b6dcce41f681c3d2236ce83cd2
--- /dev/null
+++ b/checkpoint-64/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e3a309ee2731ba474e8a0458bdcea156d55a66ebac666c29ad3fe07d60d64949
+size 200068904
diff --git a/checkpoint-64/global_step64/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-64/global_step64/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1c883c5d4a5af46f074ee0488c99a3359173c71c
--- /dev/null
+++ b/checkpoint-64/global_step64/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:69020befe25ce201249ca288824b860ccdd3a97b0dd6ddd5b05d33cd916509f2
+size 150126608
diff --git a/checkpoint-64/global_step64/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-64/global_step64/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..c3b780f093ba5b385f5c4d0bb66989879849da3f
--- /dev/null
+++ b/checkpoint-64/global_step64/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51e2d6a8023dd55f29d5046f0a69aab1551a7139a1b644f60bc25a2a01b8a2a1
+size 150126672
diff --git a/checkpoint-64/global_step64/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/checkpoint-64/global_step64/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..7e0abafe8d98e8416a22f3240b915137e689eaba
--- /dev/null
+++ b/checkpoint-64/global_step64/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea33fc3eff7e1857897f2d9f6a166c00b0a4bf9e7f116c99878ca648cab421a0
+size 150126736
diff --git a/checkpoint-64/global_step64/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/checkpoint-64/global_step64/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..02012f3453b75669c3ec66c41a4d1b8447385924
--- /dev/null
+++ b/checkpoint-64/global_step64/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b5425cbcd7924377570973efd1db89165c62bcc2fc5a3fae5bad6f9e621f9e18
+size 150126736
diff --git a/checkpoint-64/global_step64/mp_rank_00_model_states.pt b/checkpoint-64/global_step64/mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..762f57cf049677f68834ba890f17d1313fd415d3
--- /dev/null
+++ b/checkpoint-64/global_step64/mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67f8431e72b92e00e8c035cce5cffc86e5601219db4c20919cfec992381dd225
+size 1896781478
diff --git a/checkpoint-64/latest b/checkpoint-64/latest
new file mode 100644
index 0000000000000000000000000000000000000000..4a12e7f9029554e8e5ce68ebe3e97d0b4e734304
--- /dev/null
+++ b/checkpoint-64/latest
@@ -0,0 +1 @@
+global_step64
\ No newline at end of file
diff --git a/checkpoint-64/rng_state_0.pth b/checkpoint-64/rng_state_0.pth
new file mode 100644
index 0000000000000000000000000000000000000000..53cc18a3d9b53cf709360a57377266c2c3c7085d
--- /dev/null
+++ b/checkpoint-64/rng_state_0.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9527b9b5ae29ac374c87db2096874998096d81acfc9d70d4bbaf48795fad788f
+size 15024
diff --git a/checkpoint-64/rng_state_1.pth b/checkpoint-64/rng_state_1.pth
new file mode 100644
index 0000000000000000000000000000000000000000..e29fe0b7d4038bc27dbb61154675ea67e1a42fc2
--- /dev/null
+++ b/checkpoint-64/rng_state_1.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ffb943cf7bf621beffa16bb82712d440c7f94be3b7d0d0d4da79e0f0a2feac0
+size 15024
diff --git a/checkpoint-64/rng_state_2.pth b/checkpoint-64/rng_state_2.pth
new file mode 100644
index 0000000000000000000000000000000000000000..67ad1fbf12e8abad407d317fef786b9e5106efd0
--- /dev/null
+++ b/checkpoint-64/rng_state_2.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81a4e19888f9eb3b62dac0a0242272d66796d01b3f3212243c7701aa65eebf3c
+size 15024
diff --git a/checkpoint-64/rng_state_3.pth b/checkpoint-64/rng_state_3.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f5c4f066ac0b61b6675337bbc5b0ebfda762de9e
--- /dev/null
+++ b/checkpoint-64/rng_state_3.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd08e343faf2a38113c34dca2fd8802aaee14dc414d6e81008bf8ab9d8855859
+size 15024
diff --git a/checkpoint-64/scheduler.pt b/checkpoint-64/scheduler.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b11dacd39929f1f368ae47ecfbac29761cc35e92
--- /dev/null
+++ b/checkpoint-64/scheduler.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81667d34dc03f7a0b89cbca2d657bfead11ac055747074bc43cccaf1feb58bbc
+size 1064
diff --git a/checkpoint-64/trainer_state.json b/checkpoint-64/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..4a75ccd9b9f06bcf750207f70ea4833b2f0d593e
--- /dev/null
+++ b/checkpoint-64/trainer_state.json
@@ -0,0 +1,733 @@
+{
+ "best_metric": 1.9750508069992065,
+ "best_model_checkpoint": "./gemma-python/checkpoint-40",
+ "epoch": 8.0,
+ "eval_steps": 2,
+ "global_step": 64,
+ "is_hyper_param_search": false,
+ "is_local_process_zero": true,
+ "is_world_process_zero": true,
+ "log_history": [
+ {
+ "epoch": 0.12,
+ "grad_norm": 40.636978402335416,
+ "learning_rate": 0.0001,
+ "loss": 19.0016,
+ "step": 1
+ },
+ {
+ "epoch": 0.12,
+ "eval_loss": 18.6992130279541,
+ "eval_runtime": 2.881,
+ "eval_samples_per_second": 7.289,
+ "eval_steps_per_second": 1.041,
+ "step": 1
+ },
+ {
+ "epoch": 0.25,
+ "grad_norm": 41.61053527062362,
+ "learning_rate": 0.0002,
+ "loss": 19.4686,
+ "step": 2
+ },
+ {
+ "epoch": 0.25,
+ "eval_loss": 16.257802963256836,
+ "eval_runtime": 2.9111,
+ "eval_samples_per_second": 7.214,
+ "eval_steps_per_second": 1.031,
+ "step": 2
+ },
+ {
+ "epoch": 0.38,
+ "grad_norm": 28.704819713850974,
+ "learning_rate": 0.00019991889981715698,
+ "loss": 13.2303,
+ "step": 3
+ },
+ {
+ "epoch": 0.5,
+ "grad_norm": 26.40444243073739,
+ "learning_rate": 0.00019967573081342103,
+ "loss": 11.468,
+ "step": 4
+ },
+ {
+ "epoch": 0.5,
+ "eval_loss": 8.28911018371582,
+ "eval_runtime": 2.9257,
+ "eval_samples_per_second": 7.178,
+ "eval_steps_per_second": 1.025,
+ "step": 4
+ },
+ {
+ "epoch": 0.62,
+ "grad_norm": 12.912981323843146,
+ "learning_rate": 0.0001992708874098054,
+ "loss": 9.3107,
+ "step": 5
+ },
+ {
+ "epoch": 0.75,
+ "grad_norm": 7.943058500648636,
+ "learning_rate": 0.00019870502626379127,
+ "loss": 7.5305,
+ "step": 6
+ },
+ {
+ "epoch": 0.75,
+ "eval_loss": 5.884701728820801,
+ "eval_runtime": 2.9479,
+ "eval_samples_per_second": 7.124,
+ "eval_steps_per_second": 1.018,
+ "step": 6
+ },
+ {
+ "epoch": 0.88,
+ "grad_norm": 6.267657551985817,
+ "learning_rate": 0.00019797906520422677,
+ "loss": 6.6492,
+ "step": 7
+ },
+ {
+ "epoch": 1.0,
+ "grad_norm": 5.0825555341832365,
+ "learning_rate": 0.0001970941817426052,
+ "loss": 5.7572,
+ "step": 8
+ },
+ {
+ "epoch": 1.0,
+ "eval_loss": 4.363473892211914,
+ "eval_runtime": 2.9653,
+ "eval_samples_per_second": 7.082,
+ "eval_steps_per_second": 1.012,
+ "step": 8
+ },
+ {
+ "epoch": 1.12,
+ "grad_norm": 4.88565620317727,
+ "learning_rate": 0.00019605181116313724,
+ "loss": 4.5414,
+ "step": 9
+ },
+ {
+ "epoch": 1.25,
+ "grad_norm": 5.0847008955317605,
+ "learning_rate": 0.00019485364419471454,
+ "loss": 4.3903,
+ "step": 10
+ },
+ {
+ "epoch": 1.25,
+ "eval_loss": 3.284867763519287,
+ "eval_runtime": 2.9746,
+ "eval_samples_per_second": 7.06,
+ "eval_steps_per_second": 1.009,
+ "step": 10
+ },
+ {
+ "epoch": 1.38,
+ "grad_norm": 3.424587898800574,
+ "learning_rate": 0.0001935016242685415,
+ "loss": 3.79,
+ "step": 11
+ },
+ {
+ "epoch": 1.5,
+ "grad_norm": 2.7255824385278506,
+ "learning_rate": 0.00019199794436588243,
+ "loss": 2.9497,
+ "step": 12
+ },
+ {
+ "epoch": 1.5,
+ "eval_loss": 2.853942394256592,
+ "eval_runtime": 2.9866,
+ "eval_samples_per_second": 7.031,
+ "eval_steps_per_second": 1.004,
+ "step": 12
+ },
+ {
+ "epoch": 1.62,
+ "grad_norm": 2.1001906898750624,
+ "learning_rate": 0.00019034504346103823,
+ "loss": 2.7728,
+ "step": 13
+ },
+ {
+ "epoch": 1.75,
+ "grad_norm": 1.9200021565941778,
+ "learning_rate": 0.000188545602565321,
+ "loss": 2.8738,
+ "step": 14
+ },
+ {
+ "epoch": 1.75,
+ "eval_loss": 2.62028431892395,
+ "eval_runtime": 2.9982,
+ "eval_samples_per_second": 7.004,
+ "eval_steps_per_second": 1.001,
+ "step": 14
+ },
+ {
+ "epoch": 1.88,
+ "grad_norm": 1.8837224890225774,
+ "learning_rate": 0.00018660254037844388,
+ "loss": 3.0787,
+ "step": 15
+ },
+ {
+ "epoch": 2.0,
+ "grad_norm": 1.8929687978608318,
+ "learning_rate": 0.0001845190085543795,
+ "loss": 2.7298,
+ "step": 16
+ },
+ {
+ "epoch": 2.0,
+ "eval_loss": 2.453444242477417,
+ "eval_runtime": 2.9964,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 16
+ },
+ {
+ "epoch": 2.12,
+ "grad_norm": 1.3652069569291694,
+ "learning_rate": 0.00018229838658936564,
+ "loss": 2.5967,
+ "step": 17
+ },
+ {
+ "epoch": 2.25,
+ "grad_norm": 2.4263600812149417,
+ "learning_rate": 0.00017994427634035015,
+ "loss": 2.4284,
+ "step": 18
+ },
+ {
+ "epoch": 2.25,
+ "eval_loss": 2.307706832885742,
+ "eval_runtime": 2.9963,
+ "eval_samples_per_second": 7.009,
+ "eval_steps_per_second": 1.001,
+ "step": 18
+ },
+ {
+ "epoch": 2.38,
+ "grad_norm": 2.5673391658400053,
+ "learning_rate": 0.00017746049618276545,
+ "loss": 2.6721,
+ "step": 19
+ },
+ {
+ "epoch": 2.5,
+ "grad_norm": 2.2252437500899656,
+ "learning_rate": 0.00017485107481711012,
+ "loss": 2.394,
+ "step": 20
+ },
+ {
+ "epoch": 2.5,
+ "eval_loss": 2.187636137008667,
+ "eval_runtime": 2.9975,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 20
+ },
+ {
+ "epoch": 2.62,
+ "grad_norm": 2.345233295279928,
+ "learning_rate": 0.00017212024473438147,
+ "loss": 2.3972,
+ "step": 21
+ },
+ {
+ "epoch": 2.75,
+ "grad_norm": 1.1122620317353238,
+ "learning_rate": 0.00016927243535095997,
+ "loss": 2.069,
+ "step": 22
+ },
+ {
+ "epoch": 2.75,
+ "eval_loss": 2.1294100284576416,
+ "eval_runtime": 2.993,
+ "eval_samples_per_second": 7.016,
+ "eval_steps_per_second": 1.002,
+ "step": 22
+ },
+ {
+ "epoch": 2.88,
+ "grad_norm": 2.8270209249093803,
+ "learning_rate": 0.00016631226582407952,
+ "loss": 2.211,
+ "step": 23
+ },
+ {
+ "epoch": 3.0,
+ "grad_norm": 7.323169716541166,
+ "learning_rate": 0.00016324453755953773,
+ "loss": 1.9355,
+ "step": 24
+ },
+ {
+ "epoch": 3.0,
+ "eval_loss": 2.1047682762145996,
+ "eval_runtime": 2.9871,
+ "eval_samples_per_second": 7.03,
+ "eval_steps_per_second": 1.004,
+ "step": 24
+ },
+ {
+ "epoch": 3.12,
+ "grad_norm": 1.9938311808450486,
+ "learning_rate": 0.0001600742264237979,
+ "loss": 2.1962,
+ "step": 25
+ },
+ {
+ "epoch": 3.25,
+ "grad_norm": 3.330986691029466,
+ "learning_rate": 0.00015680647467311557,
+ "loss": 1.9635,
+ "step": 26
+ },
+ {
+ "epoch": 3.25,
+ "eval_loss": 2.0707101821899414,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 26
+ },
+ {
+ "epoch": 3.38,
+ "grad_norm": 2.0371854480792178,
+ "learning_rate": 0.0001534465826127801,
+ "loss": 2.2319,
+ "step": 27
+ },
+ {
+ "epoch": 3.5,
+ "grad_norm": 3.2163831286077653,
+ "learning_rate": 0.00015000000000000001,
+ "loss": 2.092,
+ "step": 28
+ },
+ {
+ "epoch": 3.5,
+ "eval_loss": 2.059619426727295,
+ "eval_runtime": 2.9996,
+ "eval_samples_per_second": 7.001,
+ "eval_steps_per_second": 1.0,
+ "step": 28
+ },
+ {
+ "epoch": 3.62,
+ "grad_norm": 2.853987323853131,
+ "learning_rate": 0.00014647231720437686,
+ "loss": 1.9182,
+ "step": 29
+ },
+ {
+ "epoch": 3.75,
+ "grad_norm": 2.2997509863024352,
+ "learning_rate": 0.00014286925614030542,
+ "loss": 1.9675,
+ "step": 30
+ },
+ {
+ "epoch": 3.75,
+ "eval_loss": 2.0287458896636963,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 30
+ },
+ {
+ "epoch": 3.88,
+ "grad_norm": 2.2770679758385244,
+ "learning_rate": 0.00013919666098600753,
+ "loss": 1.9815,
+ "step": 31
+ },
+ {
+ "epoch": 4.0,
+ "grad_norm": 0.8553765652252152,
+ "learning_rate": 0.00013546048870425356,
+ "loss": 1.9693,
+ "step": 32
+ },
+ {
+ "epoch": 4.0,
+ "eval_loss": 2.022012710571289,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 32
+ },
+ {
+ "epoch": 4.12,
+ "grad_norm": 3.8094922067262336,
+ "learning_rate": 0.00013166679938014726,
+ "loss": 1.6479,
+ "step": 33
+ },
+ {
+ "epoch": 4.25,
+ "grad_norm": 3.5435911597121277,
+ "learning_rate": 0.0001278217463916453,
+ "loss": 2.0198,
+ "step": 34
+ },
+ {
+ "epoch": 4.25,
+ "eval_loss": 2.012432336807251,
+ "eval_runtime": 2.9987,
+ "eval_samples_per_second": 7.003,
+ "eval_steps_per_second": 1.0,
+ "step": 34
+ },
+ {
+ "epoch": 4.38,
+ "grad_norm": 1.4676241516417539,
+ "learning_rate": 0.0001239315664287558,
+ "loss": 1.7496,
+ "step": 35
+ },
+ {
+ "epoch": 4.5,
+ "grad_norm": 1.4772602834377506,
+ "learning_rate": 0.00012000256937760445,
+ "loss": 1.9357,
+ "step": 36
+ },
+ {
+ "epoch": 4.5,
+ "eval_loss": 1.9945744276046753,
+ "eval_runtime": 3.0019,
+ "eval_samples_per_second": 6.995,
+ "eval_steps_per_second": 0.999,
+ "step": 36
+ },
+ {
+ "epoch": 4.62,
+ "grad_norm": 0.8198622785029981,
+ "learning_rate": 0.00011604112808577603,
+ "loss": 1.8365,
+ "step": 37
+ },
+ {
+ "epoch": 4.75,
+ "grad_norm": 2.5267989029749556,
+ "learning_rate": 0.0001120536680255323,
+ "loss": 1.8147,
+ "step": 38
+ },
+ {
+ "epoch": 4.75,
+ "eval_loss": 1.9979486465454102,
+ "eval_runtime": 2.9865,
+ "eval_samples_per_second": 7.032,
+ "eval_steps_per_second": 1.005,
+ "step": 38
+ },
+ {
+ "epoch": 4.88,
+ "grad_norm": 1.2889515222114942,
+ "learning_rate": 0.00010804665687167262,
+ "loss": 1.6703,
+ "step": 39
+ },
+ {
+ "epoch": 5.0,
+ "grad_norm": 1.3474067788797102,
+ "learning_rate": 0.00010402659401094152,
+ "loss": 1.9084,
+ "step": 40
+ },
+ {
+ "epoch": 5.0,
+ "eval_loss": 1.9750508069992065,
+ "eval_runtime": 2.9945,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 40
+ },
+ {
+ "epoch": 5.12,
+ "grad_norm": 1.320063776368443,
+ "learning_rate": 0.0001,
+ "loss": 1.6233,
+ "step": 41
+ },
+ {
+ "epoch": 5.25,
+ "grad_norm": 0.7858628087737163,
+ "learning_rate": 9.597340598905852e-05,
+ "loss": 1.6678,
+ "step": 42
+ },
+ {
+ "epoch": 5.25,
+ "eval_loss": 2.004897356033325,
+ "eval_runtime": 2.9946,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 42
+ },
+ {
+ "epoch": 5.38,
+ "grad_norm": 1.149181462350102,
+ "learning_rate": 9.195334312832742e-05,
+ "loss": 1.5673,
+ "step": 43
+ },
+ {
+ "epoch": 5.5,
+ "grad_norm": 1.961547695831496,
+ "learning_rate": 8.79463319744677e-05,
+ "loss": 1.7639,
+ "step": 44
+ },
+ {
+ "epoch": 5.5,
+ "eval_loss": 1.9885122776031494,
+ "eval_runtime": 2.9905,
+ "eval_samples_per_second": 7.022,
+ "eval_steps_per_second": 1.003,
+ "step": 44
+ },
+ {
+ "epoch": 5.62,
+ "grad_norm": 0.794217334050356,
+ "learning_rate": 8.395887191422397e-05,
+ "loss": 1.6191,
+ "step": 45
+ },
+ {
+ "epoch": 5.75,
+ "grad_norm": 1.5568588659062292,
+ "learning_rate": 7.999743062239557e-05,
+ "loss": 1.7475,
+ "step": 46
+ },
+ {
+ "epoch": 5.75,
+ "eval_loss": 1.9777300357818604,
+ "eval_runtime": 2.9821,
+ "eval_samples_per_second": 7.042,
+ "eval_steps_per_second": 1.006,
+ "step": 46
+ },
+ {
+ "epoch": 5.88,
+ "grad_norm": 0.9110203190054421,
+ "learning_rate": 7.606843357124426e-05,
+ "loss": 1.5998,
+ "step": 47
+ },
+ {
+ "epoch": 6.0,
+ "grad_norm": 1.4501990937976796,
+ "learning_rate": 7.217825360835473e-05,
+ "loss": 1.4848,
+ "step": 48
+ },
+ {
+ "epoch": 6.0,
+ "eval_loss": 1.9939006567001343,
+ "eval_runtime": 2.9785,
+ "eval_samples_per_second": 7.05,
+ "eval_steps_per_second": 1.007,
+ "step": 48
+ },
+ {
+ "epoch": 6.12,
+ "grad_norm": 1.3413384555399062,
+ "learning_rate": 6.833320061985277e-05,
+ "loss": 1.5343,
+ "step": 49
+ },
+ {
+ "epoch": 6.25,
+ "grad_norm": 0.9844954583473513,
+ "learning_rate": 6.453951129574644e-05,
+ "loss": 1.3065,
+ "step": 50
+ },
+ {
+ "epoch": 6.25,
+ "eval_loss": 2.0264320373535156,
+ "eval_runtime": 2.9839,
+ "eval_samples_per_second": 7.038,
+ "eval_steps_per_second": 1.005,
+ "step": 50
+ },
+ {
+ "epoch": 6.38,
+ "grad_norm": 1.268663878876962,
+ "learning_rate": 6.080333901399251e-05,
+ "loss": 1.4153,
+ "step": 51
+ },
+ {
+ "epoch": 6.5,
+ "grad_norm": 1.1638516740810099,
+ "learning_rate": 5.713074385969457e-05,
+ "loss": 1.4792,
+ "step": 52
+ },
+ {
+ "epoch": 6.5,
+ "eval_loss": 2.012540817260742,
+ "eval_runtime": 2.9954,
+ "eval_samples_per_second": 7.011,
+ "eval_steps_per_second": 1.002,
+ "step": 52
+ },
+ {
+ "epoch": 6.62,
+ "grad_norm": 0.8956974540095054,
+ "learning_rate": 5.3527682795623146e-05,
+ "loss": 1.5184,
+ "step": 53
+ },
+ {
+ "epoch": 6.75,
+ "grad_norm": 0.8166104294104601,
+ "learning_rate": 5.000000000000002e-05,
+ "loss": 1.4233,
+ "step": 54
+ },
+ {
+ "epoch": 6.75,
+ "eval_loss": 2.0203704833984375,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 54
+ },
+ {
+ "epoch": 6.88,
+ "grad_norm": 1.2567309830006292,
+ "learning_rate": 4.6553417387219886e-05,
+ "loss": 1.5766,
+ "step": 55
+ },
+ {
+ "epoch": 7.0,
+ "grad_norm": 1.202021898168564,
+ "learning_rate": 4.3193525326884435e-05,
+ "loss": 1.2534,
+ "step": 56
+ },
+ {
+ "epoch": 7.0,
+ "eval_loss": 2.0317745208740234,
+ "eval_runtime": 2.9887,
+ "eval_samples_per_second": 7.027,
+ "eval_steps_per_second": 1.004,
+ "step": 56
+ },
+ {
+ "epoch": 7.12,
+ "grad_norm": 1.0179404054971375,
+ "learning_rate": 3.99257735762021e-05,
+ "loss": 1.3538,
+ "step": 57
+ },
+ {
+ "epoch": 7.25,
+ "grad_norm": 0.8024465225797554,
+ "learning_rate": 3.675546244046228e-05,
+ "loss": 1.2409,
+ "step": 58
+ },
+ {
+ "epoch": 7.25,
+ "eval_loss": 2.0444860458374023,
+ "eval_runtime": 2.9957,
+ "eval_samples_per_second": 7.01,
+ "eval_steps_per_second": 1.001,
+ "step": 58
+ },
+ {
+ "epoch": 7.38,
+ "grad_norm": 1.0938821440297672,
+ "learning_rate": 3.36877341759205e-05,
+ "loss": 1.2446,
+ "step": 59
+ },
+ {
+ "epoch": 7.5,
+ "grad_norm": 1.4397725924431397,
+ "learning_rate": 3.072756464904006e-05,
+ "loss": 1.4309,
+ "step": 60
+ },
+ {
+ "epoch": 7.5,
+ "eval_loss": 2.0641307830810547,
+ "eval_runtime": 3.0002,
+ "eval_samples_per_second": 6.999,
+ "eval_steps_per_second": 1.0,
+ "step": 60
+ },
+ {
+ "epoch": 7.62,
+ "grad_norm": 1.084317322881849,
+ "learning_rate": 2.7879755265618555e-05,
+ "loss": 1.4057,
+ "step": 61
+ },
+ {
+ "epoch": 7.75,
+ "grad_norm": 0.8921847488708302,
+ "learning_rate": 2.514892518288988e-05,
+ "loss": 1.1622,
+ "step": 62
+ },
+ {
+ "epoch": 7.75,
+ "eval_loss": 2.0632762908935547,
+ "eval_runtime": 2.9934,
+ "eval_samples_per_second": 7.015,
+ "eval_steps_per_second": 1.002,
+ "step": 62
+ },
+ {
+ "epoch": 7.88,
+ "grad_norm": 1.2733235220422945,
+ "learning_rate": 2.2539503817234553e-05,
+ "loss": 1.2667,
+ "step": 63
+ },
+ {
+ "epoch": 8.0,
+ "grad_norm": 1.01591405423162,
+ "learning_rate": 2.0055723659649904e-05,
+ "loss": 1.228,
+ "step": 64
+ },
+ {
+ "epoch": 8.0,
+ "eval_loss": 2.09301495552063,
+ "eval_runtime": 2.9938,
+ "eval_samples_per_second": 7.014,
+ "eval_steps_per_second": 1.002,
+ "step": 64
+ }
+ ],
+ "logging_steps": 1,
+ "max_steps": 80,
+ "num_input_tokens_seen": 0,
+ "num_train_epochs": 10,
+ "save_steps": 8,
+ "total_flos": 2.9637500853446246e+17,
+ "train_batch_size": 2,
+ "trial_name": null,
+ "trial_params": null
+}
diff --git a/checkpoint-64/training_args.bin b/checkpoint-64/training_args.bin
new file mode 100644
index 0000000000000000000000000000000000000000..b11ae566a70dfd7bcafb281eef91bfd37c1b257b
--- /dev/null
+++ b/checkpoint-64/training_args.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbd3cdf0c7e847516177c465407e4f8b9cbcc9b8664e3b64c39191721cf5ef99
+size 6776
diff --git a/checkpoint-64/zero_to_fp32.py b/checkpoint-64/zero_to_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b846633d6eb1e836e34681e44033581f4edb7b
--- /dev/null
+++ b/checkpoint-64/zero_to_fp32.py
@@ -0,0 +1,592 @@
+#!/usr/bin/env python
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
+# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
+# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
+# application.
+#
+# example: python zero_to_fp32.py . pytorch_model.bin
+
+import argparse
+import torch
+import glob
+import math
+import os
+import re
+from collections import OrderedDict
+from dataclasses import dataclass
+
+# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
+# DeepSpeed data structures it has to be available in the current python environment.
+from deepspeed.utils import logger
+from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
+
+
+@dataclass
+class zero_model_state:
+ buffers: dict()
+ param_shapes: dict()
+ shared_params: list
+ ds_version: int
+ frozen_param_shapes: dict()
+ frozen_param_fragments: dict()
+
+
+debug = 0
+
+# load to cpu
+device = torch.device('cpu')
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text
+
+
+def natural_keys(text):
+ '''
+ alist.sort(key=natural_keys) sorts in human order
+ http://nedbatchelder.com/blog/200712/human_sorting.html
+ (See Toothy's implementation in the comments)
+ '''
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_model_state_file(checkpoint_dir, zero_stage):
+ if not os.path.isdir(checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
+
+ # there should be only one file
+ if zero_stage <= 2:
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
+ elif zero_stage == 3:
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
+
+ if not os.path.exists(file):
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
+
+ return file
+
+
+def get_checkpoint_files(checkpoint_dir, glob_pattern):
+ # XXX: need to test that this simple glob rule works for multi-node setup too
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
+
+ if len(ckpt_files) == 0:
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
+
+ return ckpt_files
+
+
+def get_optim_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
+
+
+def get_model_state_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
+
+
+def parse_model_states(files):
+ zero_model_states = []
+ for file in files:
+ state_dict = torch.load(file, map_location=device)
+
+ if BUFFER_NAMES not in state_dict:
+ raise ValueError(f"{file} is not a model state checkpoint")
+ buffer_names = state_dict[BUFFER_NAMES]
+ if debug:
+ print("Found buffers:", buffer_names)
+
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
+ param_shapes = state_dict[PARAM_SHAPES]
+
+ # collect parameters that are included in param_shapes
+ param_names = []
+ for s in param_shapes:
+ for name in s.keys():
+ param_names.append(name)
+
+ # update with frozen parameters
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
+ if frozen_param_shapes is not None:
+ if debug:
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
+ param_names += list(frozen_param_shapes.keys())
+
+ # handle shared params
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
+
+ ds_version = state_dict.get(DS_VERSION, None)
+
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
+
+ z_model_state = zero_model_state(buffers=buffers,
+ param_shapes=param_shapes,
+ shared_params=shared_params,
+ ds_version=ds_version,
+ frozen_param_shapes=frozen_param_shapes,
+ frozen_param_fragments=frozen_param_fragments)
+ zero_model_states.append(z_model_state)
+
+ return zero_model_states
+
+
+def parse_optim_states(files, ds_checkpoint_dir):
+
+ total_files = len(files)
+ state_dicts = []
+ for f in files:
+ state_dict = torch.load(f, map_location=device)
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
+ # and also handle the case where it was already removed by another helper script
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
+ state_dicts.append(state_dict)
+
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
+
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
+ # use the max of the partition_count to get the dp world_size.
+
+ if type(world_size) is list:
+ world_size = max(world_size)
+
+ if world_size != total_files:
+ raise ValueError(
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
+ )
+
+ # the groups are named differently in each stage
+ if zero_stage <= 2:
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
+ elif zero_stage == 3:
+ fp32_groups_key = FP32_FLAT_GROUPS
+ else:
+ raise ValueError(f"unknown zero stage {zero_stage}")
+
+ if zero_stage <= 2:
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
+ elif zero_stage == 3:
+ # if there is more than one param group, there will be multiple flattened tensors - one
+ # flattened tensor per group - for simplicity merge them into a single tensor
+ #
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
+
+ fp32_flat_groups = [
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
+ ]
+
+ return zero_stage, world_size, fp32_flat_groups
+
+
+def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
+ """
+ Returns fp32 state_dict reconstructed from ds checkpoint
+
+ Args:
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
+
+ """
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
+
+ optim_files = get_optim_files(ds_checkpoint_dir)
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
+
+ model_files = get_model_state_files(ds_checkpoint_dir)
+
+ zero_model_states = parse_model_states(model_files)
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
+
+ if zero_stage <= 2:
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+ elif zero_stage == 3:
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+
+
+def _zero2_merge_frozen_params(state_dict, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
+
+ if debug:
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ state_dict[name] = frozen_param_fragments[name]
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _has_callable(obj, fn):
+ attr = getattr(obj, fn, None)
+ return callable(attr)
+
+
+def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+
+ # Reconstruction protocol:
+ #
+ # XXX: document this
+
+ if debug:
+ for i in range(world_size):
+ for j in range(len(fp32_flat_groups[0])):
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
+
+ # XXX: memory usage doubles here (zero2)
+ num_param_groups = len(fp32_flat_groups[0])
+ merged_single_partition_of_fp32_groups = []
+ for i in range(num_param_groups):
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
+ avail_numel = sum(
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
+
+ if debug:
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
+ # not asserting if there is a mismatch due to possible padding
+ print(f"Have {avail_numel} numels to process.")
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ total_numel = 0
+ total_params = 0
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
+ offset = 0
+ avail_numel = full_single_fp32_vector.numel()
+ for name, shape in shapes.items():
+
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
+ offset += unpartitioned_numel
+
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
+ # live optimizer object, so we are checking that the numbers are within the right range
+ align_to = 2 * world_size
+
+ def zero2_align(x):
+ return align_to * math.ceil(x / align_to)
+
+ if debug:
+ print(f"original offset={offset}, avail_numel={avail_numel}")
+
+ offset = zero2_align(offset)
+ avail_numel = zero2_align(avail_numel)
+
+ if debug:
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
+
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def zero3_partitioned_param_info(unpartitioned_numel, world_size):
+ remainder = unpartitioned_numel % world_size
+ padding_numel = (world_size - remainder) if remainder else 0
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
+ return partitioned_numel, padding_numel
+
+
+def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ if debug:
+ for i in range(world_size):
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
+ # param, re-consolidating each param, while dealing with padding if any
+
+ # merge list of dicts, preserving order
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
+
+ if debug:
+ for i in range(world_size):
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
+
+ wanted_params = len(param_shapes)
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
+ # not asserting if there is a mismatch due to possible padding
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ print(f"Trainable params: Have {avail_numel} numels to process.")
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ offset = 0
+ total_numel = 0
+ total_params = 0
+ for name, shape in param_shapes.items():
+
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ # XXX: memory usage doubles here
+ state_dict[name] = torch.cat(
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
+ offset += partitioned_numel
+
+ offset *= world_size
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
+
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
+ via a model hub.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
+
+ Returns:
+ - pytorch ``state_dict``
+
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
+ the checkpoint.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ # do the training and checkpoint saving
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
+ model = model.cpu() # move to cpu
+ model.load_state_dict(state_dict)
+ # submit to model hub or save the model to share with others
+
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
+ application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
+
+ """
+ if tag is None:
+ latest_path = os.path.join(checkpoint_dir, 'latest')
+ if os.path.isfile(latest_path):
+ with open(latest_path, 'r') as fd:
+ tag = fd.read().strip()
+ else:
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
+
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
+
+ if not os.path.isdir(ds_checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
+
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
+
+
+def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+ """
+
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+ print(f"Saving fp32 state dict to {output_file}")
+ torch.save(state_dict, output_file)
+
+
+def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
+ """
+ 1. Put the provided model to cpu
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
+ 3. Load it into the provided model
+
+ Args:
+ - ``model``: the model object to update
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+
+ Returns:
+ - ``model`: modified model
+
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
+ conveniently placed for you in the checkpoint folder.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
+ # submit to model hub or save the model to share with others
+
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ """
+ logger.info(f"Extracting fp32 weights")
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+
+ logger.info(f"Overwriting model with fp32 weights")
+ model = model.cpu()
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint_dir",
+ type=str,
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
+ parser.add_argument(
+ "output_file",
+ type=str,
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
+ parser.add_argument("-t",
+ "--tag",
+ type=str,
+ default=None,
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
+ args = parser.parse_args()
+
+ debug = args.debug
+
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
diff --git a/checkpoint-72/README.md b/checkpoint-72/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7fde103e177d517a68ed416ca36925d7f86b488b
--- /dev/null
+++ b/checkpoint-72/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: google/gemma-7b
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.9.0
\ No newline at end of file
diff --git a/checkpoint-72/adapter_config.json b/checkpoint-72/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f48c351b3328c029833db4675ebe2c0dbdf14af4
--- /dev/null
+++ b/checkpoint-72/adapter_config.json
@@ -0,0 +1,33 @@
+{
+ "alpha_pattern": {},
+ "auto_mapping": null,
+ "base_model_name_or_path": "google/gemma-7b",
+ "bias": "none",
+ "fan_in_fan_out": null,
+ "inference_mode": true,
+ "init_lora_weights": true,
+ "layers_pattern": null,
+ "layers_to_transform": null,
+ "loftq_config": {},
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "megatron_config": null,
+ "megatron_core": "megatron.core",
+ "modules_to_save": null,
+ "peft_type": "LORA",
+ "r": 32,
+ "rank_pattern": {},
+ "revision": null,
+ "target_modules": [
+ "o_proj",
+ "up_proj",
+ "k_proj",
+ "q_proj",
+ "v_proj",
+ "gate_proj",
+ "down_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "use_dora": false,
+ "use_rslora": false
+}
\ No newline at end of file
diff --git a/checkpoint-72/adapter_model.safetensors b/checkpoint-72/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..a95c4b468a8fc3b50b643ab945a25f5e0cd3a8d6
--- /dev/null
+++ b/checkpoint-72/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44954d71c44b0b3a77c82c4e61cb154f0620626427991da9edec48e2006b123a
+size 200068904
diff --git a/checkpoint-72/global_step72/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-72/global_step72/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..ccea8de14564fdc97e033ed71c6157a693808039
--- /dev/null
+++ b/checkpoint-72/global_step72/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:81fd6d1854d4065202f5faaf517a529c9babc13cbcdb54cfe6e80f75c3e68591
+size 150126608
diff --git a/checkpoint-72/global_step72/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-72/global_step72/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..832e7a03111da4703a5e8f45c30063e487f857b4
--- /dev/null
+++ b/checkpoint-72/global_step72/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ddde495ae8fee1b1af3e50c77935ce1d2760a49a5fe0e32d48f88f54c6d4bad
+size 150126672
diff --git a/checkpoint-72/global_step72/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/checkpoint-72/global_step72/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..d6def409c541b81a1215a73805e967d830e25183
--- /dev/null
+++ b/checkpoint-72/global_step72/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfc74a5b5403785a3a238055defca1048acbeb8766ed3e1a7aa600ebd4408864
+size 150126736
diff --git a/checkpoint-72/global_step72/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/checkpoint-72/global_step72/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..23bb9a5e8da6f1747e5f7e1d557b9abd32da8577
--- /dev/null
+++ b/checkpoint-72/global_step72/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f4fcc5ebe0bd0b28c13562346884763ed07438fa1a6c030fce64c64395b3671
+size 150126736
diff --git a/checkpoint-72/global_step72/mp_rank_00_model_states.pt b/checkpoint-72/global_step72/mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..a79b92d60ff360b575e81c20b76c9f8abfd19b9c
--- /dev/null
+++ b/checkpoint-72/global_step72/mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20a1e69a9ab961b6a4dd5f7cb8f604b17f89050e4fa1bf8665a541186272ede7
+size 1896781478
diff --git a/checkpoint-72/latest b/checkpoint-72/latest
new file mode 100644
index 0000000000000000000000000000000000000000..f3ff0f3ef57eac4f36c543b2d7ef78ca727041bd
--- /dev/null
+++ b/checkpoint-72/latest
@@ -0,0 +1 @@
+global_step72
\ No newline at end of file
diff --git a/checkpoint-72/rng_state_0.pth b/checkpoint-72/rng_state_0.pth
new file mode 100644
index 0000000000000000000000000000000000000000..b05f736c66b10af1a30dc8d103a89e1007c01865
--- /dev/null
+++ b/checkpoint-72/rng_state_0.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:668b3c267070e4954cf0fb4816322e2dc903d37f5a7200afaeffc177308dba71
+size 15024
diff --git a/checkpoint-72/rng_state_1.pth b/checkpoint-72/rng_state_1.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c37727104121e357796d46d0efa342ebcc19a181
--- /dev/null
+++ b/checkpoint-72/rng_state_1.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac91d0dff53f6ec4da985ead857c455f6a2b5328e0d8d6d1b5c52db53d8b6dba
+size 15024
diff --git a/checkpoint-72/rng_state_2.pth b/checkpoint-72/rng_state_2.pth
new file mode 100644
index 0000000000000000000000000000000000000000..368f2b6cb77c0612a01a501a76d78c3621239155
--- /dev/null
+++ b/checkpoint-72/rng_state_2.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29312515ecdeb078acf20cc2c79eee00742f44ea4a3b75ff7112ede39ce6c19d
+size 15024
diff --git a/checkpoint-72/rng_state_3.pth b/checkpoint-72/rng_state_3.pth
new file mode 100644
index 0000000000000000000000000000000000000000..45cbcf617d8bcf811694d86982b9573ccc3eecc6
--- /dev/null
+++ b/checkpoint-72/rng_state_3.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0994265605963e9e0370391c3fde861e6e7648bea2aa71f4878efa25ceef7e0c
+size 15024
diff --git a/checkpoint-72/scheduler.pt b/checkpoint-72/scheduler.pt
new file mode 100644
index 0000000000000000000000000000000000000000..48c08cc1451421b5917471aacb562b58f0d35f1c
--- /dev/null
+++ b/checkpoint-72/scheduler.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:703522edce8dae3d48ce371eedc384a15870706940aaf5f704a573378f193df2
+size 1064
diff --git a/checkpoint-72/trainer_state.json b/checkpoint-72/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..30d807c785c1e4409b10fa1939aa9e45ef093fd4
--- /dev/null
+++ b/checkpoint-72/trainer_state.json
@@ -0,0 +1,821 @@
+{
+ "best_metric": 1.9750508069992065,
+ "best_model_checkpoint": "./gemma-python/checkpoint-40",
+ "epoch": 9.0,
+ "eval_steps": 2,
+ "global_step": 72,
+ "is_hyper_param_search": false,
+ "is_local_process_zero": true,
+ "is_world_process_zero": true,
+ "log_history": [
+ {
+ "epoch": 0.12,
+ "grad_norm": 40.636978402335416,
+ "learning_rate": 0.0001,
+ "loss": 19.0016,
+ "step": 1
+ },
+ {
+ "epoch": 0.12,
+ "eval_loss": 18.6992130279541,
+ "eval_runtime": 2.881,
+ "eval_samples_per_second": 7.289,
+ "eval_steps_per_second": 1.041,
+ "step": 1
+ },
+ {
+ "epoch": 0.25,
+ "grad_norm": 41.61053527062362,
+ "learning_rate": 0.0002,
+ "loss": 19.4686,
+ "step": 2
+ },
+ {
+ "epoch": 0.25,
+ "eval_loss": 16.257802963256836,
+ "eval_runtime": 2.9111,
+ "eval_samples_per_second": 7.214,
+ "eval_steps_per_second": 1.031,
+ "step": 2
+ },
+ {
+ "epoch": 0.38,
+ "grad_norm": 28.704819713850974,
+ "learning_rate": 0.00019991889981715698,
+ "loss": 13.2303,
+ "step": 3
+ },
+ {
+ "epoch": 0.5,
+ "grad_norm": 26.40444243073739,
+ "learning_rate": 0.00019967573081342103,
+ "loss": 11.468,
+ "step": 4
+ },
+ {
+ "epoch": 0.5,
+ "eval_loss": 8.28911018371582,
+ "eval_runtime": 2.9257,
+ "eval_samples_per_second": 7.178,
+ "eval_steps_per_second": 1.025,
+ "step": 4
+ },
+ {
+ "epoch": 0.62,
+ "grad_norm": 12.912981323843146,
+ "learning_rate": 0.0001992708874098054,
+ "loss": 9.3107,
+ "step": 5
+ },
+ {
+ "epoch": 0.75,
+ "grad_norm": 7.943058500648636,
+ "learning_rate": 0.00019870502626379127,
+ "loss": 7.5305,
+ "step": 6
+ },
+ {
+ "epoch": 0.75,
+ "eval_loss": 5.884701728820801,
+ "eval_runtime": 2.9479,
+ "eval_samples_per_second": 7.124,
+ "eval_steps_per_second": 1.018,
+ "step": 6
+ },
+ {
+ "epoch": 0.88,
+ "grad_norm": 6.267657551985817,
+ "learning_rate": 0.00019797906520422677,
+ "loss": 6.6492,
+ "step": 7
+ },
+ {
+ "epoch": 1.0,
+ "grad_norm": 5.0825555341832365,
+ "learning_rate": 0.0001970941817426052,
+ "loss": 5.7572,
+ "step": 8
+ },
+ {
+ "epoch": 1.0,
+ "eval_loss": 4.363473892211914,
+ "eval_runtime": 2.9653,
+ "eval_samples_per_second": 7.082,
+ "eval_steps_per_second": 1.012,
+ "step": 8
+ },
+ {
+ "epoch": 1.12,
+ "grad_norm": 4.88565620317727,
+ "learning_rate": 0.00019605181116313724,
+ "loss": 4.5414,
+ "step": 9
+ },
+ {
+ "epoch": 1.25,
+ "grad_norm": 5.0847008955317605,
+ "learning_rate": 0.00019485364419471454,
+ "loss": 4.3903,
+ "step": 10
+ },
+ {
+ "epoch": 1.25,
+ "eval_loss": 3.284867763519287,
+ "eval_runtime": 2.9746,
+ "eval_samples_per_second": 7.06,
+ "eval_steps_per_second": 1.009,
+ "step": 10
+ },
+ {
+ "epoch": 1.38,
+ "grad_norm": 3.424587898800574,
+ "learning_rate": 0.0001935016242685415,
+ "loss": 3.79,
+ "step": 11
+ },
+ {
+ "epoch": 1.5,
+ "grad_norm": 2.7255824385278506,
+ "learning_rate": 0.00019199794436588243,
+ "loss": 2.9497,
+ "step": 12
+ },
+ {
+ "epoch": 1.5,
+ "eval_loss": 2.853942394256592,
+ "eval_runtime": 2.9866,
+ "eval_samples_per_second": 7.031,
+ "eval_steps_per_second": 1.004,
+ "step": 12
+ },
+ {
+ "epoch": 1.62,
+ "grad_norm": 2.1001906898750624,
+ "learning_rate": 0.00019034504346103823,
+ "loss": 2.7728,
+ "step": 13
+ },
+ {
+ "epoch": 1.75,
+ "grad_norm": 1.9200021565941778,
+ "learning_rate": 0.000188545602565321,
+ "loss": 2.8738,
+ "step": 14
+ },
+ {
+ "epoch": 1.75,
+ "eval_loss": 2.62028431892395,
+ "eval_runtime": 2.9982,
+ "eval_samples_per_second": 7.004,
+ "eval_steps_per_second": 1.001,
+ "step": 14
+ },
+ {
+ "epoch": 1.88,
+ "grad_norm": 1.8837224890225774,
+ "learning_rate": 0.00018660254037844388,
+ "loss": 3.0787,
+ "step": 15
+ },
+ {
+ "epoch": 2.0,
+ "grad_norm": 1.8929687978608318,
+ "learning_rate": 0.0001845190085543795,
+ "loss": 2.7298,
+ "step": 16
+ },
+ {
+ "epoch": 2.0,
+ "eval_loss": 2.453444242477417,
+ "eval_runtime": 2.9964,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 16
+ },
+ {
+ "epoch": 2.12,
+ "grad_norm": 1.3652069569291694,
+ "learning_rate": 0.00018229838658936564,
+ "loss": 2.5967,
+ "step": 17
+ },
+ {
+ "epoch": 2.25,
+ "grad_norm": 2.4263600812149417,
+ "learning_rate": 0.00017994427634035015,
+ "loss": 2.4284,
+ "step": 18
+ },
+ {
+ "epoch": 2.25,
+ "eval_loss": 2.307706832885742,
+ "eval_runtime": 2.9963,
+ "eval_samples_per_second": 7.009,
+ "eval_steps_per_second": 1.001,
+ "step": 18
+ },
+ {
+ "epoch": 2.38,
+ "grad_norm": 2.5673391658400053,
+ "learning_rate": 0.00017746049618276545,
+ "loss": 2.6721,
+ "step": 19
+ },
+ {
+ "epoch": 2.5,
+ "grad_norm": 2.2252437500899656,
+ "learning_rate": 0.00017485107481711012,
+ "loss": 2.394,
+ "step": 20
+ },
+ {
+ "epoch": 2.5,
+ "eval_loss": 2.187636137008667,
+ "eval_runtime": 2.9975,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 20
+ },
+ {
+ "epoch": 2.62,
+ "grad_norm": 2.345233295279928,
+ "learning_rate": 0.00017212024473438147,
+ "loss": 2.3972,
+ "step": 21
+ },
+ {
+ "epoch": 2.75,
+ "grad_norm": 1.1122620317353238,
+ "learning_rate": 0.00016927243535095997,
+ "loss": 2.069,
+ "step": 22
+ },
+ {
+ "epoch": 2.75,
+ "eval_loss": 2.1294100284576416,
+ "eval_runtime": 2.993,
+ "eval_samples_per_second": 7.016,
+ "eval_steps_per_second": 1.002,
+ "step": 22
+ },
+ {
+ "epoch": 2.88,
+ "grad_norm": 2.8270209249093803,
+ "learning_rate": 0.00016631226582407952,
+ "loss": 2.211,
+ "step": 23
+ },
+ {
+ "epoch": 3.0,
+ "grad_norm": 7.323169716541166,
+ "learning_rate": 0.00016324453755953773,
+ "loss": 1.9355,
+ "step": 24
+ },
+ {
+ "epoch": 3.0,
+ "eval_loss": 2.1047682762145996,
+ "eval_runtime": 2.9871,
+ "eval_samples_per_second": 7.03,
+ "eval_steps_per_second": 1.004,
+ "step": 24
+ },
+ {
+ "epoch": 3.12,
+ "grad_norm": 1.9938311808450486,
+ "learning_rate": 0.0001600742264237979,
+ "loss": 2.1962,
+ "step": 25
+ },
+ {
+ "epoch": 3.25,
+ "grad_norm": 3.330986691029466,
+ "learning_rate": 0.00015680647467311557,
+ "loss": 1.9635,
+ "step": 26
+ },
+ {
+ "epoch": 3.25,
+ "eval_loss": 2.0707101821899414,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 26
+ },
+ {
+ "epoch": 3.38,
+ "grad_norm": 2.0371854480792178,
+ "learning_rate": 0.0001534465826127801,
+ "loss": 2.2319,
+ "step": 27
+ },
+ {
+ "epoch": 3.5,
+ "grad_norm": 3.2163831286077653,
+ "learning_rate": 0.00015000000000000001,
+ "loss": 2.092,
+ "step": 28
+ },
+ {
+ "epoch": 3.5,
+ "eval_loss": 2.059619426727295,
+ "eval_runtime": 2.9996,
+ "eval_samples_per_second": 7.001,
+ "eval_steps_per_second": 1.0,
+ "step": 28
+ },
+ {
+ "epoch": 3.62,
+ "grad_norm": 2.853987323853131,
+ "learning_rate": 0.00014647231720437686,
+ "loss": 1.9182,
+ "step": 29
+ },
+ {
+ "epoch": 3.75,
+ "grad_norm": 2.2997509863024352,
+ "learning_rate": 0.00014286925614030542,
+ "loss": 1.9675,
+ "step": 30
+ },
+ {
+ "epoch": 3.75,
+ "eval_loss": 2.0287458896636963,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 30
+ },
+ {
+ "epoch": 3.88,
+ "grad_norm": 2.2770679758385244,
+ "learning_rate": 0.00013919666098600753,
+ "loss": 1.9815,
+ "step": 31
+ },
+ {
+ "epoch": 4.0,
+ "grad_norm": 0.8553765652252152,
+ "learning_rate": 0.00013546048870425356,
+ "loss": 1.9693,
+ "step": 32
+ },
+ {
+ "epoch": 4.0,
+ "eval_loss": 2.022012710571289,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 32
+ },
+ {
+ "epoch": 4.12,
+ "grad_norm": 3.8094922067262336,
+ "learning_rate": 0.00013166679938014726,
+ "loss": 1.6479,
+ "step": 33
+ },
+ {
+ "epoch": 4.25,
+ "grad_norm": 3.5435911597121277,
+ "learning_rate": 0.0001278217463916453,
+ "loss": 2.0198,
+ "step": 34
+ },
+ {
+ "epoch": 4.25,
+ "eval_loss": 2.012432336807251,
+ "eval_runtime": 2.9987,
+ "eval_samples_per_second": 7.003,
+ "eval_steps_per_second": 1.0,
+ "step": 34
+ },
+ {
+ "epoch": 4.38,
+ "grad_norm": 1.4676241516417539,
+ "learning_rate": 0.0001239315664287558,
+ "loss": 1.7496,
+ "step": 35
+ },
+ {
+ "epoch": 4.5,
+ "grad_norm": 1.4772602834377506,
+ "learning_rate": 0.00012000256937760445,
+ "loss": 1.9357,
+ "step": 36
+ },
+ {
+ "epoch": 4.5,
+ "eval_loss": 1.9945744276046753,
+ "eval_runtime": 3.0019,
+ "eval_samples_per_second": 6.995,
+ "eval_steps_per_second": 0.999,
+ "step": 36
+ },
+ {
+ "epoch": 4.62,
+ "grad_norm": 0.8198622785029981,
+ "learning_rate": 0.00011604112808577603,
+ "loss": 1.8365,
+ "step": 37
+ },
+ {
+ "epoch": 4.75,
+ "grad_norm": 2.5267989029749556,
+ "learning_rate": 0.0001120536680255323,
+ "loss": 1.8147,
+ "step": 38
+ },
+ {
+ "epoch": 4.75,
+ "eval_loss": 1.9979486465454102,
+ "eval_runtime": 2.9865,
+ "eval_samples_per_second": 7.032,
+ "eval_steps_per_second": 1.005,
+ "step": 38
+ },
+ {
+ "epoch": 4.88,
+ "grad_norm": 1.2889515222114942,
+ "learning_rate": 0.00010804665687167262,
+ "loss": 1.6703,
+ "step": 39
+ },
+ {
+ "epoch": 5.0,
+ "grad_norm": 1.3474067788797102,
+ "learning_rate": 0.00010402659401094152,
+ "loss": 1.9084,
+ "step": 40
+ },
+ {
+ "epoch": 5.0,
+ "eval_loss": 1.9750508069992065,
+ "eval_runtime": 2.9945,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 40
+ },
+ {
+ "epoch": 5.12,
+ "grad_norm": 1.320063776368443,
+ "learning_rate": 0.0001,
+ "loss": 1.6233,
+ "step": 41
+ },
+ {
+ "epoch": 5.25,
+ "grad_norm": 0.7858628087737163,
+ "learning_rate": 9.597340598905852e-05,
+ "loss": 1.6678,
+ "step": 42
+ },
+ {
+ "epoch": 5.25,
+ "eval_loss": 2.004897356033325,
+ "eval_runtime": 2.9946,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 42
+ },
+ {
+ "epoch": 5.38,
+ "grad_norm": 1.149181462350102,
+ "learning_rate": 9.195334312832742e-05,
+ "loss": 1.5673,
+ "step": 43
+ },
+ {
+ "epoch": 5.5,
+ "grad_norm": 1.961547695831496,
+ "learning_rate": 8.79463319744677e-05,
+ "loss": 1.7639,
+ "step": 44
+ },
+ {
+ "epoch": 5.5,
+ "eval_loss": 1.9885122776031494,
+ "eval_runtime": 2.9905,
+ "eval_samples_per_second": 7.022,
+ "eval_steps_per_second": 1.003,
+ "step": 44
+ },
+ {
+ "epoch": 5.62,
+ "grad_norm": 0.794217334050356,
+ "learning_rate": 8.395887191422397e-05,
+ "loss": 1.6191,
+ "step": 45
+ },
+ {
+ "epoch": 5.75,
+ "grad_norm": 1.5568588659062292,
+ "learning_rate": 7.999743062239557e-05,
+ "loss": 1.7475,
+ "step": 46
+ },
+ {
+ "epoch": 5.75,
+ "eval_loss": 1.9777300357818604,
+ "eval_runtime": 2.9821,
+ "eval_samples_per_second": 7.042,
+ "eval_steps_per_second": 1.006,
+ "step": 46
+ },
+ {
+ "epoch": 5.88,
+ "grad_norm": 0.9110203190054421,
+ "learning_rate": 7.606843357124426e-05,
+ "loss": 1.5998,
+ "step": 47
+ },
+ {
+ "epoch": 6.0,
+ "grad_norm": 1.4501990937976796,
+ "learning_rate": 7.217825360835473e-05,
+ "loss": 1.4848,
+ "step": 48
+ },
+ {
+ "epoch": 6.0,
+ "eval_loss": 1.9939006567001343,
+ "eval_runtime": 2.9785,
+ "eval_samples_per_second": 7.05,
+ "eval_steps_per_second": 1.007,
+ "step": 48
+ },
+ {
+ "epoch": 6.12,
+ "grad_norm": 1.3413384555399062,
+ "learning_rate": 6.833320061985277e-05,
+ "loss": 1.5343,
+ "step": 49
+ },
+ {
+ "epoch": 6.25,
+ "grad_norm": 0.9844954583473513,
+ "learning_rate": 6.453951129574644e-05,
+ "loss": 1.3065,
+ "step": 50
+ },
+ {
+ "epoch": 6.25,
+ "eval_loss": 2.0264320373535156,
+ "eval_runtime": 2.9839,
+ "eval_samples_per_second": 7.038,
+ "eval_steps_per_second": 1.005,
+ "step": 50
+ },
+ {
+ "epoch": 6.38,
+ "grad_norm": 1.268663878876962,
+ "learning_rate": 6.080333901399251e-05,
+ "loss": 1.4153,
+ "step": 51
+ },
+ {
+ "epoch": 6.5,
+ "grad_norm": 1.1638516740810099,
+ "learning_rate": 5.713074385969457e-05,
+ "loss": 1.4792,
+ "step": 52
+ },
+ {
+ "epoch": 6.5,
+ "eval_loss": 2.012540817260742,
+ "eval_runtime": 2.9954,
+ "eval_samples_per_second": 7.011,
+ "eval_steps_per_second": 1.002,
+ "step": 52
+ },
+ {
+ "epoch": 6.62,
+ "grad_norm": 0.8956974540095054,
+ "learning_rate": 5.3527682795623146e-05,
+ "loss": 1.5184,
+ "step": 53
+ },
+ {
+ "epoch": 6.75,
+ "grad_norm": 0.8166104294104601,
+ "learning_rate": 5.000000000000002e-05,
+ "loss": 1.4233,
+ "step": 54
+ },
+ {
+ "epoch": 6.75,
+ "eval_loss": 2.0203704833984375,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 54
+ },
+ {
+ "epoch": 6.88,
+ "grad_norm": 1.2567309830006292,
+ "learning_rate": 4.6553417387219886e-05,
+ "loss": 1.5766,
+ "step": 55
+ },
+ {
+ "epoch": 7.0,
+ "grad_norm": 1.202021898168564,
+ "learning_rate": 4.3193525326884435e-05,
+ "loss": 1.2534,
+ "step": 56
+ },
+ {
+ "epoch": 7.0,
+ "eval_loss": 2.0317745208740234,
+ "eval_runtime": 2.9887,
+ "eval_samples_per_second": 7.027,
+ "eval_steps_per_second": 1.004,
+ "step": 56
+ },
+ {
+ "epoch": 7.12,
+ "grad_norm": 1.0179404054971375,
+ "learning_rate": 3.99257735762021e-05,
+ "loss": 1.3538,
+ "step": 57
+ },
+ {
+ "epoch": 7.25,
+ "grad_norm": 0.8024465225797554,
+ "learning_rate": 3.675546244046228e-05,
+ "loss": 1.2409,
+ "step": 58
+ },
+ {
+ "epoch": 7.25,
+ "eval_loss": 2.0444860458374023,
+ "eval_runtime": 2.9957,
+ "eval_samples_per_second": 7.01,
+ "eval_steps_per_second": 1.001,
+ "step": 58
+ },
+ {
+ "epoch": 7.38,
+ "grad_norm": 1.0938821440297672,
+ "learning_rate": 3.36877341759205e-05,
+ "loss": 1.2446,
+ "step": 59
+ },
+ {
+ "epoch": 7.5,
+ "grad_norm": 1.4397725924431397,
+ "learning_rate": 3.072756464904006e-05,
+ "loss": 1.4309,
+ "step": 60
+ },
+ {
+ "epoch": 7.5,
+ "eval_loss": 2.0641307830810547,
+ "eval_runtime": 3.0002,
+ "eval_samples_per_second": 6.999,
+ "eval_steps_per_second": 1.0,
+ "step": 60
+ },
+ {
+ "epoch": 7.62,
+ "grad_norm": 1.084317322881849,
+ "learning_rate": 2.7879755265618555e-05,
+ "loss": 1.4057,
+ "step": 61
+ },
+ {
+ "epoch": 7.75,
+ "grad_norm": 0.8921847488708302,
+ "learning_rate": 2.514892518288988e-05,
+ "loss": 1.1622,
+ "step": 62
+ },
+ {
+ "epoch": 7.75,
+ "eval_loss": 2.0632762908935547,
+ "eval_runtime": 2.9934,
+ "eval_samples_per_second": 7.015,
+ "eval_steps_per_second": 1.002,
+ "step": 62
+ },
+ {
+ "epoch": 7.88,
+ "grad_norm": 1.2733235220422945,
+ "learning_rate": 2.2539503817234553e-05,
+ "loss": 1.2667,
+ "step": 63
+ },
+ {
+ "epoch": 8.0,
+ "grad_norm": 1.01591405423162,
+ "learning_rate": 2.0055723659649904e-05,
+ "loss": 1.228,
+ "step": 64
+ },
+ {
+ "epoch": 8.0,
+ "eval_loss": 2.09301495552063,
+ "eval_runtime": 2.9938,
+ "eval_samples_per_second": 7.014,
+ "eval_steps_per_second": 1.002,
+ "step": 64
+ },
+ {
+ "epoch": 8.12,
+ "grad_norm": 0.9494450303367244,
+ "learning_rate": 1.7701613410634365e-05,
+ "loss": 1.1147,
+ "step": 65
+ },
+ {
+ "epoch": 8.25,
+ "grad_norm": 0.8254286577206483,
+ "learning_rate": 1.5480991445620542e-05,
+ "loss": 1.3076,
+ "step": 66
+ },
+ {
+ "epoch": 8.25,
+ "eval_loss": 2.1076860427856445,
+ "eval_runtime": 2.9974,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 66
+ },
+ {
+ "epoch": 8.38,
+ "grad_norm": 0.9874923331530434,
+ "learning_rate": 1.339745962155613e-05,
+ "loss": 1.1572,
+ "step": 67
+ },
+ {
+ "epoch": 8.5,
+ "grad_norm": 0.8701092754993289,
+ "learning_rate": 1.1454397434679021e-05,
+ "loss": 1.2323,
+ "step": 68
+ },
+ {
+ "epoch": 8.5,
+ "eval_loss": 2.1060104370117188,
+ "eval_runtime": 2.9923,
+ "eval_samples_per_second": 7.018,
+ "eval_steps_per_second": 1.003,
+ "step": 68
+ },
+ {
+ "epoch": 8.62,
+ "grad_norm": 0.9048894666644874,
+ "learning_rate": 9.65495653896179e-06,
+ "loss": 1.1888,
+ "step": 69
+ },
+ {
+ "epoch": 8.75,
+ "grad_norm": 0.8899151834513122,
+ "learning_rate": 8.002055634117578e-06,
+ "loss": 1.1635,
+ "step": 70
+ },
+ {
+ "epoch": 8.75,
+ "eval_loss": 2.1039013862609863,
+ "eval_runtime": 2.9883,
+ "eval_samples_per_second": 7.027,
+ "eval_steps_per_second": 1.004,
+ "step": 70
+ },
+ {
+ "epoch": 8.88,
+ "grad_norm": 0.9759646607551775,
+ "learning_rate": 6.498375731458528e-06,
+ "loss": 1.0924,
+ "step": 71
+ },
+ {
+ "epoch": 9.0,
+ "grad_norm": 1.067529387326401,
+ "learning_rate": 5.146355805285452e-06,
+ "loss": 1.261,
+ "step": 72
+ },
+ {
+ "epoch": 9.0,
+ "eval_loss": 2.1068060398101807,
+ "eval_runtime": 2.9995,
+ "eval_samples_per_second": 7.001,
+ "eval_steps_per_second": 1.0,
+ "step": 72
+ }
+ ],
+ "logging_steps": 1,
+ "max_steps": 80,
+ "num_input_tokens_seen": 0,
+ "num_train_epochs": 10,
+ "save_steps": 8,
+ "total_flos": 3.334218846012703e+17,
+ "train_batch_size": 2,
+ "trial_name": null,
+ "trial_params": null
+}
diff --git a/checkpoint-72/training_args.bin b/checkpoint-72/training_args.bin
new file mode 100644
index 0000000000000000000000000000000000000000..b11ae566a70dfd7bcafb281eef91bfd37c1b257b
--- /dev/null
+++ b/checkpoint-72/training_args.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbd3cdf0c7e847516177c465407e4f8b9cbcc9b8664e3b64c39191721cf5ef99
+size 6776
diff --git a/checkpoint-72/zero_to_fp32.py b/checkpoint-72/zero_to_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b846633d6eb1e836e34681e44033581f4edb7b
--- /dev/null
+++ b/checkpoint-72/zero_to_fp32.py
@@ -0,0 +1,592 @@
+#!/usr/bin/env python
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
+# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
+# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
+# application.
+#
+# example: python zero_to_fp32.py . pytorch_model.bin
+
+import argparse
+import torch
+import glob
+import math
+import os
+import re
+from collections import OrderedDict
+from dataclasses import dataclass
+
+# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
+# DeepSpeed data structures it has to be available in the current python environment.
+from deepspeed.utils import logger
+from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
+
+
+@dataclass
+class zero_model_state:
+ buffers: dict()
+ param_shapes: dict()
+ shared_params: list
+ ds_version: int
+ frozen_param_shapes: dict()
+ frozen_param_fragments: dict()
+
+
+debug = 0
+
+# load to cpu
+device = torch.device('cpu')
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text
+
+
+def natural_keys(text):
+ '''
+ alist.sort(key=natural_keys) sorts in human order
+ http://nedbatchelder.com/blog/200712/human_sorting.html
+ (See Toothy's implementation in the comments)
+ '''
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_model_state_file(checkpoint_dir, zero_stage):
+ if not os.path.isdir(checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
+
+ # there should be only one file
+ if zero_stage <= 2:
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
+ elif zero_stage == 3:
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
+
+ if not os.path.exists(file):
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
+
+ return file
+
+
+def get_checkpoint_files(checkpoint_dir, glob_pattern):
+ # XXX: need to test that this simple glob rule works for multi-node setup too
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
+
+ if len(ckpt_files) == 0:
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
+
+ return ckpt_files
+
+
+def get_optim_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
+
+
+def get_model_state_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
+
+
+def parse_model_states(files):
+ zero_model_states = []
+ for file in files:
+ state_dict = torch.load(file, map_location=device)
+
+ if BUFFER_NAMES not in state_dict:
+ raise ValueError(f"{file} is not a model state checkpoint")
+ buffer_names = state_dict[BUFFER_NAMES]
+ if debug:
+ print("Found buffers:", buffer_names)
+
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
+ param_shapes = state_dict[PARAM_SHAPES]
+
+ # collect parameters that are included in param_shapes
+ param_names = []
+ for s in param_shapes:
+ for name in s.keys():
+ param_names.append(name)
+
+ # update with frozen parameters
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
+ if frozen_param_shapes is not None:
+ if debug:
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
+ param_names += list(frozen_param_shapes.keys())
+
+ # handle shared params
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
+
+ ds_version = state_dict.get(DS_VERSION, None)
+
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
+
+ z_model_state = zero_model_state(buffers=buffers,
+ param_shapes=param_shapes,
+ shared_params=shared_params,
+ ds_version=ds_version,
+ frozen_param_shapes=frozen_param_shapes,
+ frozen_param_fragments=frozen_param_fragments)
+ zero_model_states.append(z_model_state)
+
+ return zero_model_states
+
+
+def parse_optim_states(files, ds_checkpoint_dir):
+
+ total_files = len(files)
+ state_dicts = []
+ for f in files:
+ state_dict = torch.load(f, map_location=device)
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
+ # and also handle the case where it was already removed by another helper script
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
+ state_dicts.append(state_dict)
+
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
+
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
+ # use the max of the partition_count to get the dp world_size.
+
+ if type(world_size) is list:
+ world_size = max(world_size)
+
+ if world_size != total_files:
+ raise ValueError(
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
+ )
+
+ # the groups are named differently in each stage
+ if zero_stage <= 2:
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
+ elif zero_stage == 3:
+ fp32_groups_key = FP32_FLAT_GROUPS
+ else:
+ raise ValueError(f"unknown zero stage {zero_stage}")
+
+ if zero_stage <= 2:
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
+ elif zero_stage == 3:
+ # if there is more than one param group, there will be multiple flattened tensors - one
+ # flattened tensor per group - for simplicity merge them into a single tensor
+ #
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
+
+ fp32_flat_groups = [
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
+ ]
+
+ return zero_stage, world_size, fp32_flat_groups
+
+
+def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
+ """
+ Returns fp32 state_dict reconstructed from ds checkpoint
+
+ Args:
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
+
+ """
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
+
+ optim_files = get_optim_files(ds_checkpoint_dir)
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
+
+ model_files = get_model_state_files(ds_checkpoint_dir)
+
+ zero_model_states = parse_model_states(model_files)
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
+
+ if zero_stage <= 2:
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+ elif zero_stage == 3:
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+
+
+def _zero2_merge_frozen_params(state_dict, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
+
+ if debug:
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ state_dict[name] = frozen_param_fragments[name]
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _has_callable(obj, fn):
+ attr = getattr(obj, fn, None)
+ return callable(attr)
+
+
+def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+
+ # Reconstruction protocol:
+ #
+ # XXX: document this
+
+ if debug:
+ for i in range(world_size):
+ for j in range(len(fp32_flat_groups[0])):
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
+
+ # XXX: memory usage doubles here (zero2)
+ num_param_groups = len(fp32_flat_groups[0])
+ merged_single_partition_of_fp32_groups = []
+ for i in range(num_param_groups):
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
+ avail_numel = sum(
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
+
+ if debug:
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
+ # not asserting if there is a mismatch due to possible padding
+ print(f"Have {avail_numel} numels to process.")
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ total_numel = 0
+ total_params = 0
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
+ offset = 0
+ avail_numel = full_single_fp32_vector.numel()
+ for name, shape in shapes.items():
+
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
+ offset += unpartitioned_numel
+
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
+ # live optimizer object, so we are checking that the numbers are within the right range
+ align_to = 2 * world_size
+
+ def zero2_align(x):
+ return align_to * math.ceil(x / align_to)
+
+ if debug:
+ print(f"original offset={offset}, avail_numel={avail_numel}")
+
+ offset = zero2_align(offset)
+ avail_numel = zero2_align(avail_numel)
+
+ if debug:
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
+
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def zero3_partitioned_param_info(unpartitioned_numel, world_size):
+ remainder = unpartitioned_numel % world_size
+ padding_numel = (world_size - remainder) if remainder else 0
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
+ return partitioned_numel, padding_numel
+
+
+def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ if debug:
+ for i in range(world_size):
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
+ # param, re-consolidating each param, while dealing with padding if any
+
+ # merge list of dicts, preserving order
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
+
+ if debug:
+ for i in range(world_size):
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
+
+ wanted_params = len(param_shapes)
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
+ # not asserting if there is a mismatch due to possible padding
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ print(f"Trainable params: Have {avail_numel} numels to process.")
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ offset = 0
+ total_numel = 0
+ total_params = 0
+ for name, shape in param_shapes.items():
+
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ # XXX: memory usage doubles here
+ state_dict[name] = torch.cat(
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
+ offset += partitioned_numel
+
+ offset *= world_size
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
+
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
+ via a model hub.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
+
+ Returns:
+ - pytorch ``state_dict``
+
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
+ the checkpoint.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ # do the training and checkpoint saving
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
+ model = model.cpu() # move to cpu
+ model.load_state_dict(state_dict)
+ # submit to model hub or save the model to share with others
+
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
+ application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
+
+ """
+ if tag is None:
+ latest_path = os.path.join(checkpoint_dir, 'latest')
+ if os.path.isfile(latest_path):
+ with open(latest_path, 'r') as fd:
+ tag = fd.read().strip()
+ else:
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
+
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
+
+ if not os.path.isdir(ds_checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
+
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
+
+
+def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+ """
+
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+ print(f"Saving fp32 state dict to {output_file}")
+ torch.save(state_dict, output_file)
+
+
+def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
+ """
+ 1. Put the provided model to cpu
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
+ 3. Load it into the provided model
+
+ Args:
+ - ``model``: the model object to update
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+
+ Returns:
+ - ``model`: modified model
+
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
+ conveniently placed for you in the checkpoint folder.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
+ # submit to model hub or save the model to share with others
+
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ """
+ logger.info(f"Extracting fp32 weights")
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+
+ logger.info(f"Overwriting model with fp32 weights")
+ model = model.cpu()
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint_dir",
+ type=str,
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
+ parser.add_argument(
+ "output_file",
+ type=str,
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
+ parser.add_argument("-t",
+ "--tag",
+ type=str,
+ default=None,
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
+ args = parser.parse_args()
+
+ debug = args.debug
+
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
diff --git a/checkpoint-80/README.md b/checkpoint-80/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7fde103e177d517a68ed416ca36925d7f86b488b
--- /dev/null
+++ b/checkpoint-80/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: google/gemma-7b
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.9.0
\ No newline at end of file
diff --git a/checkpoint-80/adapter_config.json b/checkpoint-80/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f48c351b3328c029833db4675ebe2c0dbdf14af4
--- /dev/null
+++ b/checkpoint-80/adapter_config.json
@@ -0,0 +1,33 @@
+{
+ "alpha_pattern": {},
+ "auto_mapping": null,
+ "base_model_name_or_path": "google/gemma-7b",
+ "bias": "none",
+ "fan_in_fan_out": null,
+ "inference_mode": true,
+ "init_lora_weights": true,
+ "layers_pattern": null,
+ "layers_to_transform": null,
+ "loftq_config": {},
+ "lora_alpha": 16,
+ "lora_dropout": 0.05,
+ "megatron_config": null,
+ "megatron_core": "megatron.core",
+ "modules_to_save": null,
+ "peft_type": "LORA",
+ "r": 32,
+ "rank_pattern": {},
+ "revision": null,
+ "target_modules": [
+ "o_proj",
+ "up_proj",
+ "k_proj",
+ "q_proj",
+ "v_proj",
+ "gate_proj",
+ "down_proj"
+ ],
+ "task_type": "CAUSAL_LM",
+ "use_dora": false,
+ "use_rslora": false
+}
\ No newline at end of file
diff --git a/checkpoint-80/adapter_model.safetensors b/checkpoint-80/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..94a1cdc601a5e9cc800bbe510f60ebd9083d307e
--- /dev/null
+++ b/checkpoint-80/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:946e4076344c368f83172717f91c6280df3fee0451c37c4cf680ac9b84b04900
+size 200068904
diff --git a/checkpoint-80/global_step80/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-80/global_step80/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..b8d17f398d8209a7f665bcec768a34b0b1122c12
--- /dev/null
+++ b/checkpoint-80/global_step80/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7dac6bc3e440e953fe02e7306dcedc946bed8e4c9d31a7016e38d7caf965c868
+size 150126608
diff --git a/checkpoint-80/global_step80/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt b/checkpoint-80/global_step80/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..cfab92819ea540bb8dfab6abea69365b3958f109
--- /dev/null
+++ b/checkpoint-80/global_step80/bf16_zero_pp_rank_1_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e29d531cd6d45b37f0525ad85945bd65ebb34d3ea55cc9658396f93af1763486
+size 150126672
diff --git a/checkpoint-80/global_step80/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt b/checkpoint-80/global_step80/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..64ba277855adb80273c37c56d0f1938919b69b1b
--- /dev/null
+++ b/checkpoint-80/global_step80/bf16_zero_pp_rank_2_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce873409e6044da5bb63746e090c833024ba5164391fe69df00d515d71df3634
+size 150126736
diff --git a/checkpoint-80/global_step80/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt b/checkpoint-80/global_step80/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..aac8e949c67ed74aeff83ce59c5b1eb9e25780a8
--- /dev/null
+++ b/checkpoint-80/global_step80/bf16_zero_pp_rank_3_mp_rank_00_optim_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af0705c1926f18bbb29f9848beb4578dc2d75a1e5f7b80c0be0ed713b33ed22c
+size 150126736
diff --git a/checkpoint-80/global_step80/mp_rank_00_model_states.pt b/checkpoint-80/global_step80/mp_rank_00_model_states.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1345dc0cb621d878e2b1f125751d54b446a7770d
--- /dev/null
+++ b/checkpoint-80/global_step80/mp_rank_00_model_states.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b25941654e69a3eea28102baf15293cec60095bbb24029f61e762704b68930bb
+size 1896781478
diff --git a/checkpoint-80/latest b/checkpoint-80/latest
new file mode 100644
index 0000000000000000000000000000000000000000..75eab498d0366633484ab40334e4b8fb92b16dad
--- /dev/null
+++ b/checkpoint-80/latest
@@ -0,0 +1 @@
+global_step80
\ No newline at end of file
diff --git a/checkpoint-80/rng_state_0.pth b/checkpoint-80/rng_state_0.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7dff3c76bb39c77b38dda0c45e8e446fb406e760
--- /dev/null
+++ b/checkpoint-80/rng_state_0.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a00391921765e2e3c27e90f8d86bcf37562aaafd9c931e17ba8c2c475cdf77f0
+size 15024
diff --git a/checkpoint-80/rng_state_1.pth b/checkpoint-80/rng_state_1.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3a0c88de92b59a7c629936b3f7df27d08f7fca80
--- /dev/null
+++ b/checkpoint-80/rng_state_1.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f2543d60458415f5de6c675e9abf84987fbf199169a666a95aa3649ca7c1f43
+size 15024
diff --git a/checkpoint-80/rng_state_2.pth b/checkpoint-80/rng_state_2.pth
new file mode 100644
index 0000000000000000000000000000000000000000..6b4a8ea6b4decdd05801a783cffe8b7b0d73fd0c
--- /dev/null
+++ b/checkpoint-80/rng_state_2.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d831768bcffbcde48d6e577ce291504ea9145a38ba635c787be93292c9b4a10
+size 15024
diff --git a/checkpoint-80/rng_state_3.pth b/checkpoint-80/rng_state_3.pth
new file mode 100644
index 0000000000000000000000000000000000000000..1856e4f82690127b83f8cbeb9dbe36147bc00ec8
--- /dev/null
+++ b/checkpoint-80/rng_state_3.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17c60617b54eaa2ee560dd80b4f0cb8c672c6312d73bd70cc675c8885da24a87
+size 15024
diff --git a/checkpoint-80/scheduler.pt b/checkpoint-80/scheduler.pt
new file mode 100644
index 0000000000000000000000000000000000000000..1fae5b481d5db36d9acebe45fab76f7773a9405c
--- /dev/null
+++ b/checkpoint-80/scheduler.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1a6c7e77749f0168af945b20e781b95b56d3c3430951ad3aea179b9bba0a92d
+size 1064
diff --git a/checkpoint-80/trainer_state.json b/checkpoint-80/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..01d16fd96efe5a82bc58ba0e6f50f99bb461d5f2
--- /dev/null
+++ b/checkpoint-80/trainer_state.json
@@ -0,0 +1,909 @@
+{
+ "best_metric": 1.9750508069992065,
+ "best_model_checkpoint": "./gemma-python/checkpoint-40",
+ "epoch": 10.0,
+ "eval_steps": 2,
+ "global_step": 80,
+ "is_hyper_param_search": false,
+ "is_local_process_zero": true,
+ "is_world_process_zero": true,
+ "log_history": [
+ {
+ "epoch": 0.12,
+ "grad_norm": 40.636978402335416,
+ "learning_rate": 0.0001,
+ "loss": 19.0016,
+ "step": 1
+ },
+ {
+ "epoch": 0.12,
+ "eval_loss": 18.6992130279541,
+ "eval_runtime": 2.881,
+ "eval_samples_per_second": 7.289,
+ "eval_steps_per_second": 1.041,
+ "step": 1
+ },
+ {
+ "epoch": 0.25,
+ "grad_norm": 41.61053527062362,
+ "learning_rate": 0.0002,
+ "loss": 19.4686,
+ "step": 2
+ },
+ {
+ "epoch": 0.25,
+ "eval_loss": 16.257802963256836,
+ "eval_runtime": 2.9111,
+ "eval_samples_per_second": 7.214,
+ "eval_steps_per_second": 1.031,
+ "step": 2
+ },
+ {
+ "epoch": 0.38,
+ "grad_norm": 28.704819713850974,
+ "learning_rate": 0.00019991889981715698,
+ "loss": 13.2303,
+ "step": 3
+ },
+ {
+ "epoch": 0.5,
+ "grad_norm": 26.40444243073739,
+ "learning_rate": 0.00019967573081342103,
+ "loss": 11.468,
+ "step": 4
+ },
+ {
+ "epoch": 0.5,
+ "eval_loss": 8.28911018371582,
+ "eval_runtime": 2.9257,
+ "eval_samples_per_second": 7.178,
+ "eval_steps_per_second": 1.025,
+ "step": 4
+ },
+ {
+ "epoch": 0.62,
+ "grad_norm": 12.912981323843146,
+ "learning_rate": 0.0001992708874098054,
+ "loss": 9.3107,
+ "step": 5
+ },
+ {
+ "epoch": 0.75,
+ "grad_norm": 7.943058500648636,
+ "learning_rate": 0.00019870502626379127,
+ "loss": 7.5305,
+ "step": 6
+ },
+ {
+ "epoch": 0.75,
+ "eval_loss": 5.884701728820801,
+ "eval_runtime": 2.9479,
+ "eval_samples_per_second": 7.124,
+ "eval_steps_per_second": 1.018,
+ "step": 6
+ },
+ {
+ "epoch": 0.88,
+ "grad_norm": 6.267657551985817,
+ "learning_rate": 0.00019797906520422677,
+ "loss": 6.6492,
+ "step": 7
+ },
+ {
+ "epoch": 1.0,
+ "grad_norm": 5.0825555341832365,
+ "learning_rate": 0.0001970941817426052,
+ "loss": 5.7572,
+ "step": 8
+ },
+ {
+ "epoch": 1.0,
+ "eval_loss": 4.363473892211914,
+ "eval_runtime": 2.9653,
+ "eval_samples_per_second": 7.082,
+ "eval_steps_per_second": 1.012,
+ "step": 8
+ },
+ {
+ "epoch": 1.12,
+ "grad_norm": 4.88565620317727,
+ "learning_rate": 0.00019605181116313724,
+ "loss": 4.5414,
+ "step": 9
+ },
+ {
+ "epoch": 1.25,
+ "grad_norm": 5.0847008955317605,
+ "learning_rate": 0.00019485364419471454,
+ "loss": 4.3903,
+ "step": 10
+ },
+ {
+ "epoch": 1.25,
+ "eval_loss": 3.284867763519287,
+ "eval_runtime": 2.9746,
+ "eval_samples_per_second": 7.06,
+ "eval_steps_per_second": 1.009,
+ "step": 10
+ },
+ {
+ "epoch": 1.38,
+ "grad_norm": 3.424587898800574,
+ "learning_rate": 0.0001935016242685415,
+ "loss": 3.79,
+ "step": 11
+ },
+ {
+ "epoch": 1.5,
+ "grad_norm": 2.7255824385278506,
+ "learning_rate": 0.00019199794436588243,
+ "loss": 2.9497,
+ "step": 12
+ },
+ {
+ "epoch": 1.5,
+ "eval_loss": 2.853942394256592,
+ "eval_runtime": 2.9866,
+ "eval_samples_per_second": 7.031,
+ "eval_steps_per_second": 1.004,
+ "step": 12
+ },
+ {
+ "epoch": 1.62,
+ "grad_norm": 2.1001906898750624,
+ "learning_rate": 0.00019034504346103823,
+ "loss": 2.7728,
+ "step": 13
+ },
+ {
+ "epoch": 1.75,
+ "grad_norm": 1.9200021565941778,
+ "learning_rate": 0.000188545602565321,
+ "loss": 2.8738,
+ "step": 14
+ },
+ {
+ "epoch": 1.75,
+ "eval_loss": 2.62028431892395,
+ "eval_runtime": 2.9982,
+ "eval_samples_per_second": 7.004,
+ "eval_steps_per_second": 1.001,
+ "step": 14
+ },
+ {
+ "epoch": 1.88,
+ "grad_norm": 1.8837224890225774,
+ "learning_rate": 0.00018660254037844388,
+ "loss": 3.0787,
+ "step": 15
+ },
+ {
+ "epoch": 2.0,
+ "grad_norm": 1.8929687978608318,
+ "learning_rate": 0.0001845190085543795,
+ "loss": 2.7298,
+ "step": 16
+ },
+ {
+ "epoch": 2.0,
+ "eval_loss": 2.453444242477417,
+ "eval_runtime": 2.9964,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 16
+ },
+ {
+ "epoch": 2.12,
+ "grad_norm": 1.3652069569291694,
+ "learning_rate": 0.00018229838658936564,
+ "loss": 2.5967,
+ "step": 17
+ },
+ {
+ "epoch": 2.25,
+ "grad_norm": 2.4263600812149417,
+ "learning_rate": 0.00017994427634035015,
+ "loss": 2.4284,
+ "step": 18
+ },
+ {
+ "epoch": 2.25,
+ "eval_loss": 2.307706832885742,
+ "eval_runtime": 2.9963,
+ "eval_samples_per_second": 7.009,
+ "eval_steps_per_second": 1.001,
+ "step": 18
+ },
+ {
+ "epoch": 2.38,
+ "grad_norm": 2.5673391658400053,
+ "learning_rate": 0.00017746049618276545,
+ "loss": 2.6721,
+ "step": 19
+ },
+ {
+ "epoch": 2.5,
+ "grad_norm": 2.2252437500899656,
+ "learning_rate": 0.00017485107481711012,
+ "loss": 2.394,
+ "step": 20
+ },
+ {
+ "epoch": 2.5,
+ "eval_loss": 2.187636137008667,
+ "eval_runtime": 2.9975,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 20
+ },
+ {
+ "epoch": 2.62,
+ "grad_norm": 2.345233295279928,
+ "learning_rate": 0.00017212024473438147,
+ "loss": 2.3972,
+ "step": 21
+ },
+ {
+ "epoch": 2.75,
+ "grad_norm": 1.1122620317353238,
+ "learning_rate": 0.00016927243535095997,
+ "loss": 2.069,
+ "step": 22
+ },
+ {
+ "epoch": 2.75,
+ "eval_loss": 2.1294100284576416,
+ "eval_runtime": 2.993,
+ "eval_samples_per_second": 7.016,
+ "eval_steps_per_second": 1.002,
+ "step": 22
+ },
+ {
+ "epoch": 2.88,
+ "grad_norm": 2.8270209249093803,
+ "learning_rate": 0.00016631226582407952,
+ "loss": 2.211,
+ "step": 23
+ },
+ {
+ "epoch": 3.0,
+ "grad_norm": 7.323169716541166,
+ "learning_rate": 0.00016324453755953773,
+ "loss": 1.9355,
+ "step": 24
+ },
+ {
+ "epoch": 3.0,
+ "eval_loss": 2.1047682762145996,
+ "eval_runtime": 2.9871,
+ "eval_samples_per_second": 7.03,
+ "eval_steps_per_second": 1.004,
+ "step": 24
+ },
+ {
+ "epoch": 3.12,
+ "grad_norm": 1.9938311808450486,
+ "learning_rate": 0.0001600742264237979,
+ "loss": 2.1962,
+ "step": 25
+ },
+ {
+ "epoch": 3.25,
+ "grad_norm": 3.330986691029466,
+ "learning_rate": 0.00015680647467311557,
+ "loss": 1.9635,
+ "step": 26
+ },
+ {
+ "epoch": 3.25,
+ "eval_loss": 2.0707101821899414,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 26
+ },
+ {
+ "epoch": 3.38,
+ "grad_norm": 2.0371854480792178,
+ "learning_rate": 0.0001534465826127801,
+ "loss": 2.2319,
+ "step": 27
+ },
+ {
+ "epoch": 3.5,
+ "grad_norm": 3.2163831286077653,
+ "learning_rate": 0.00015000000000000001,
+ "loss": 2.092,
+ "step": 28
+ },
+ {
+ "epoch": 3.5,
+ "eval_loss": 2.059619426727295,
+ "eval_runtime": 2.9996,
+ "eval_samples_per_second": 7.001,
+ "eval_steps_per_second": 1.0,
+ "step": 28
+ },
+ {
+ "epoch": 3.62,
+ "grad_norm": 2.853987323853131,
+ "learning_rate": 0.00014647231720437686,
+ "loss": 1.9182,
+ "step": 29
+ },
+ {
+ "epoch": 3.75,
+ "grad_norm": 2.2997509863024352,
+ "learning_rate": 0.00014286925614030542,
+ "loss": 1.9675,
+ "step": 30
+ },
+ {
+ "epoch": 3.75,
+ "eval_loss": 2.0287458896636963,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 30
+ },
+ {
+ "epoch": 3.88,
+ "grad_norm": 2.2770679758385244,
+ "learning_rate": 0.00013919666098600753,
+ "loss": 1.9815,
+ "step": 31
+ },
+ {
+ "epoch": 4.0,
+ "grad_norm": 0.8553765652252152,
+ "learning_rate": 0.00013546048870425356,
+ "loss": 1.9693,
+ "step": 32
+ },
+ {
+ "epoch": 4.0,
+ "eval_loss": 2.022012710571289,
+ "eval_runtime": 2.9895,
+ "eval_samples_per_second": 7.025,
+ "eval_steps_per_second": 1.004,
+ "step": 32
+ },
+ {
+ "epoch": 4.12,
+ "grad_norm": 3.8094922067262336,
+ "learning_rate": 0.00013166679938014726,
+ "loss": 1.6479,
+ "step": 33
+ },
+ {
+ "epoch": 4.25,
+ "grad_norm": 3.5435911597121277,
+ "learning_rate": 0.0001278217463916453,
+ "loss": 2.0198,
+ "step": 34
+ },
+ {
+ "epoch": 4.25,
+ "eval_loss": 2.012432336807251,
+ "eval_runtime": 2.9987,
+ "eval_samples_per_second": 7.003,
+ "eval_steps_per_second": 1.0,
+ "step": 34
+ },
+ {
+ "epoch": 4.38,
+ "grad_norm": 1.4676241516417539,
+ "learning_rate": 0.0001239315664287558,
+ "loss": 1.7496,
+ "step": 35
+ },
+ {
+ "epoch": 4.5,
+ "grad_norm": 1.4772602834377506,
+ "learning_rate": 0.00012000256937760445,
+ "loss": 1.9357,
+ "step": 36
+ },
+ {
+ "epoch": 4.5,
+ "eval_loss": 1.9945744276046753,
+ "eval_runtime": 3.0019,
+ "eval_samples_per_second": 6.995,
+ "eval_steps_per_second": 0.999,
+ "step": 36
+ },
+ {
+ "epoch": 4.62,
+ "grad_norm": 0.8198622785029981,
+ "learning_rate": 0.00011604112808577603,
+ "loss": 1.8365,
+ "step": 37
+ },
+ {
+ "epoch": 4.75,
+ "grad_norm": 2.5267989029749556,
+ "learning_rate": 0.0001120536680255323,
+ "loss": 1.8147,
+ "step": 38
+ },
+ {
+ "epoch": 4.75,
+ "eval_loss": 1.9979486465454102,
+ "eval_runtime": 2.9865,
+ "eval_samples_per_second": 7.032,
+ "eval_steps_per_second": 1.005,
+ "step": 38
+ },
+ {
+ "epoch": 4.88,
+ "grad_norm": 1.2889515222114942,
+ "learning_rate": 0.00010804665687167262,
+ "loss": 1.6703,
+ "step": 39
+ },
+ {
+ "epoch": 5.0,
+ "grad_norm": 1.3474067788797102,
+ "learning_rate": 0.00010402659401094152,
+ "loss": 1.9084,
+ "step": 40
+ },
+ {
+ "epoch": 5.0,
+ "eval_loss": 1.9750508069992065,
+ "eval_runtime": 2.9945,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 40
+ },
+ {
+ "epoch": 5.12,
+ "grad_norm": 1.320063776368443,
+ "learning_rate": 0.0001,
+ "loss": 1.6233,
+ "step": 41
+ },
+ {
+ "epoch": 5.25,
+ "grad_norm": 0.7858628087737163,
+ "learning_rate": 9.597340598905852e-05,
+ "loss": 1.6678,
+ "step": 42
+ },
+ {
+ "epoch": 5.25,
+ "eval_loss": 2.004897356033325,
+ "eval_runtime": 2.9946,
+ "eval_samples_per_second": 7.013,
+ "eval_steps_per_second": 1.002,
+ "step": 42
+ },
+ {
+ "epoch": 5.38,
+ "grad_norm": 1.149181462350102,
+ "learning_rate": 9.195334312832742e-05,
+ "loss": 1.5673,
+ "step": 43
+ },
+ {
+ "epoch": 5.5,
+ "grad_norm": 1.961547695831496,
+ "learning_rate": 8.79463319744677e-05,
+ "loss": 1.7639,
+ "step": 44
+ },
+ {
+ "epoch": 5.5,
+ "eval_loss": 1.9885122776031494,
+ "eval_runtime": 2.9905,
+ "eval_samples_per_second": 7.022,
+ "eval_steps_per_second": 1.003,
+ "step": 44
+ },
+ {
+ "epoch": 5.62,
+ "grad_norm": 0.794217334050356,
+ "learning_rate": 8.395887191422397e-05,
+ "loss": 1.6191,
+ "step": 45
+ },
+ {
+ "epoch": 5.75,
+ "grad_norm": 1.5568588659062292,
+ "learning_rate": 7.999743062239557e-05,
+ "loss": 1.7475,
+ "step": 46
+ },
+ {
+ "epoch": 5.75,
+ "eval_loss": 1.9777300357818604,
+ "eval_runtime": 2.9821,
+ "eval_samples_per_second": 7.042,
+ "eval_steps_per_second": 1.006,
+ "step": 46
+ },
+ {
+ "epoch": 5.88,
+ "grad_norm": 0.9110203190054421,
+ "learning_rate": 7.606843357124426e-05,
+ "loss": 1.5998,
+ "step": 47
+ },
+ {
+ "epoch": 6.0,
+ "grad_norm": 1.4501990937976796,
+ "learning_rate": 7.217825360835473e-05,
+ "loss": 1.4848,
+ "step": 48
+ },
+ {
+ "epoch": 6.0,
+ "eval_loss": 1.9939006567001343,
+ "eval_runtime": 2.9785,
+ "eval_samples_per_second": 7.05,
+ "eval_steps_per_second": 1.007,
+ "step": 48
+ },
+ {
+ "epoch": 6.12,
+ "grad_norm": 1.3413384555399062,
+ "learning_rate": 6.833320061985277e-05,
+ "loss": 1.5343,
+ "step": 49
+ },
+ {
+ "epoch": 6.25,
+ "grad_norm": 0.9844954583473513,
+ "learning_rate": 6.453951129574644e-05,
+ "loss": 1.3065,
+ "step": 50
+ },
+ {
+ "epoch": 6.25,
+ "eval_loss": 2.0264320373535156,
+ "eval_runtime": 2.9839,
+ "eval_samples_per_second": 7.038,
+ "eval_steps_per_second": 1.005,
+ "step": 50
+ },
+ {
+ "epoch": 6.38,
+ "grad_norm": 1.268663878876962,
+ "learning_rate": 6.080333901399251e-05,
+ "loss": 1.4153,
+ "step": 51
+ },
+ {
+ "epoch": 6.5,
+ "grad_norm": 1.1638516740810099,
+ "learning_rate": 5.713074385969457e-05,
+ "loss": 1.4792,
+ "step": 52
+ },
+ {
+ "epoch": 6.5,
+ "eval_loss": 2.012540817260742,
+ "eval_runtime": 2.9954,
+ "eval_samples_per_second": 7.011,
+ "eval_steps_per_second": 1.002,
+ "step": 52
+ },
+ {
+ "epoch": 6.62,
+ "grad_norm": 0.8956974540095054,
+ "learning_rate": 5.3527682795623146e-05,
+ "loss": 1.5184,
+ "step": 53
+ },
+ {
+ "epoch": 6.75,
+ "grad_norm": 0.8166104294104601,
+ "learning_rate": 5.000000000000002e-05,
+ "loss": 1.4233,
+ "step": 54
+ },
+ {
+ "epoch": 6.75,
+ "eval_loss": 2.0203704833984375,
+ "eval_runtime": 2.9966,
+ "eval_samples_per_second": 7.008,
+ "eval_steps_per_second": 1.001,
+ "step": 54
+ },
+ {
+ "epoch": 6.88,
+ "grad_norm": 1.2567309830006292,
+ "learning_rate": 4.6553417387219886e-05,
+ "loss": 1.5766,
+ "step": 55
+ },
+ {
+ "epoch": 7.0,
+ "grad_norm": 1.202021898168564,
+ "learning_rate": 4.3193525326884435e-05,
+ "loss": 1.2534,
+ "step": 56
+ },
+ {
+ "epoch": 7.0,
+ "eval_loss": 2.0317745208740234,
+ "eval_runtime": 2.9887,
+ "eval_samples_per_second": 7.027,
+ "eval_steps_per_second": 1.004,
+ "step": 56
+ },
+ {
+ "epoch": 7.12,
+ "grad_norm": 1.0179404054971375,
+ "learning_rate": 3.99257735762021e-05,
+ "loss": 1.3538,
+ "step": 57
+ },
+ {
+ "epoch": 7.25,
+ "grad_norm": 0.8024465225797554,
+ "learning_rate": 3.675546244046228e-05,
+ "loss": 1.2409,
+ "step": 58
+ },
+ {
+ "epoch": 7.25,
+ "eval_loss": 2.0444860458374023,
+ "eval_runtime": 2.9957,
+ "eval_samples_per_second": 7.01,
+ "eval_steps_per_second": 1.001,
+ "step": 58
+ },
+ {
+ "epoch": 7.38,
+ "grad_norm": 1.0938821440297672,
+ "learning_rate": 3.36877341759205e-05,
+ "loss": 1.2446,
+ "step": 59
+ },
+ {
+ "epoch": 7.5,
+ "grad_norm": 1.4397725924431397,
+ "learning_rate": 3.072756464904006e-05,
+ "loss": 1.4309,
+ "step": 60
+ },
+ {
+ "epoch": 7.5,
+ "eval_loss": 2.0641307830810547,
+ "eval_runtime": 3.0002,
+ "eval_samples_per_second": 6.999,
+ "eval_steps_per_second": 1.0,
+ "step": 60
+ },
+ {
+ "epoch": 7.62,
+ "grad_norm": 1.084317322881849,
+ "learning_rate": 2.7879755265618555e-05,
+ "loss": 1.4057,
+ "step": 61
+ },
+ {
+ "epoch": 7.75,
+ "grad_norm": 0.8921847488708302,
+ "learning_rate": 2.514892518288988e-05,
+ "loss": 1.1622,
+ "step": 62
+ },
+ {
+ "epoch": 7.75,
+ "eval_loss": 2.0632762908935547,
+ "eval_runtime": 2.9934,
+ "eval_samples_per_second": 7.015,
+ "eval_steps_per_second": 1.002,
+ "step": 62
+ },
+ {
+ "epoch": 7.88,
+ "grad_norm": 1.2733235220422945,
+ "learning_rate": 2.2539503817234553e-05,
+ "loss": 1.2667,
+ "step": 63
+ },
+ {
+ "epoch": 8.0,
+ "grad_norm": 1.01591405423162,
+ "learning_rate": 2.0055723659649904e-05,
+ "loss": 1.228,
+ "step": 64
+ },
+ {
+ "epoch": 8.0,
+ "eval_loss": 2.09301495552063,
+ "eval_runtime": 2.9938,
+ "eval_samples_per_second": 7.014,
+ "eval_steps_per_second": 1.002,
+ "step": 64
+ },
+ {
+ "epoch": 8.12,
+ "grad_norm": 0.9494450303367244,
+ "learning_rate": 1.7701613410634365e-05,
+ "loss": 1.1147,
+ "step": 65
+ },
+ {
+ "epoch": 8.25,
+ "grad_norm": 0.8254286577206483,
+ "learning_rate": 1.5480991445620542e-05,
+ "loss": 1.3076,
+ "step": 66
+ },
+ {
+ "epoch": 8.25,
+ "eval_loss": 2.1076860427856445,
+ "eval_runtime": 2.9974,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 66
+ },
+ {
+ "epoch": 8.38,
+ "grad_norm": 0.9874923331530434,
+ "learning_rate": 1.339745962155613e-05,
+ "loss": 1.1572,
+ "step": 67
+ },
+ {
+ "epoch": 8.5,
+ "grad_norm": 0.8701092754993289,
+ "learning_rate": 1.1454397434679021e-05,
+ "loss": 1.2323,
+ "step": 68
+ },
+ {
+ "epoch": 8.5,
+ "eval_loss": 2.1060104370117188,
+ "eval_runtime": 2.9923,
+ "eval_samples_per_second": 7.018,
+ "eval_steps_per_second": 1.003,
+ "step": 68
+ },
+ {
+ "epoch": 8.62,
+ "grad_norm": 0.9048894666644874,
+ "learning_rate": 9.65495653896179e-06,
+ "loss": 1.1888,
+ "step": 69
+ },
+ {
+ "epoch": 8.75,
+ "grad_norm": 0.8899151834513122,
+ "learning_rate": 8.002055634117578e-06,
+ "loss": 1.1635,
+ "step": 70
+ },
+ {
+ "epoch": 8.75,
+ "eval_loss": 2.1039013862609863,
+ "eval_runtime": 2.9883,
+ "eval_samples_per_second": 7.027,
+ "eval_steps_per_second": 1.004,
+ "step": 70
+ },
+ {
+ "epoch": 8.88,
+ "grad_norm": 0.9759646607551775,
+ "learning_rate": 6.498375731458528e-06,
+ "loss": 1.0924,
+ "step": 71
+ },
+ {
+ "epoch": 9.0,
+ "grad_norm": 1.067529387326401,
+ "learning_rate": 5.146355805285452e-06,
+ "loss": 1.261,
+ "step": 72
+ },
+ {
+ "epoch": 9.0,
+ "eval_loss": 2.1068060398101807,
+ "eval_runtime": 2.9995,
+ "eval_samples_per_second": 7.001,
+ "eval_steps_per_second": 1.0,
+ "step": 72
+ },
+ {
+ "epoch": 9.12,
+ "grad_norm": 0.6912625265797223,
+ "learning_rate": 3.948188836862776e-06,
+ "loss": 1.2225,
+ "step": 73
+ },
+ {
+ "epoch": 9.25,
+ "grad_norm": 0.8624797224342494,
+ "learning_rate": 2.905818257394799e-06,
+ "loss": 1.0122,
+ "step": 74
+ },
+ {
+ "epoch": 9.25,
+ "eval_loss": 2.1110289096832275,
+ "eval_runtime": 2.9973,
+ "eval_samples_per_second": 7.006,
+ "eval_steps_per_second": 1.001,
+ "step": 74
+ },
+ {
+ "epoch": 9.38,
+ "grad_norm": 0.7882237650664056,
+ "learning_rate": 2.0209347957732328e-06,
+ "loss": 1.0959,
+ "step": 75
+ },
+ {
+ "epoch": 9.5,
+ "grad_norm": 0.8572353081855683,
+ "learning_rate": 1.2949737362087156e-06,
+ "loss": 1.218,
+ "step": 76
+ },
+ {
+ "epoch": 9.5,
+ "eval_loss": 2.117999315261841,
+ "eval_runtime": 2.9874,
+ "eval_samples_per_second": 7.03,
+ "eval_steps_per_second": 1.004,
+ "step": 76
+ },
+ {
+ "epoch": 9.62,
+ "grad_norm": 0.8712624014542376,
+ "learning_rate": 7.291125901946027e-07,
+ "loss": 1.2579,
+ "step": 77
+ },
+ {
+ "epoch": 9.75,
+ "grad_norm": 0.7643303727279644,
+ "learning_rate": 3.2426918657900704e-07,
+ "loss": 1.1022,
+ "step": 78
+ },
+ {
+ "epoch": 9.75,
+ "eval_loss": 2.1226046085357666,
+ "eval_runtime": 2.9845,
+ "eval_samples_per_second": 7.036,
+ "eval_steps_per_second": 1.005,
+ "step": 78
+ },
+ {
+ "epoch": 9.88,
+ "grad_norm": 0.7335554379014946,
+ "learning_rate": 8.110018284304133e-08,
+ "loss": 1.1355,
+ "step": 79
+ },
+ {
+ "epoch": 10.0,
+ "grad_norm": 0.7141910762762422,
+ "learning_rate": 0.0,
+ "loss": 1.2072,
+ "step": 80
+ },
+ {
+ "epoch": 10.0,
+ "eval_loss": 2.1142799854278564,
+ "eval_runtime": 2.9807,
+ "eval_samples_per_second": 7.045,
+ "eval_steps_per_second": 1.006,
+ "step": 80
+ }
+ ],
+ "logging_steps": 1,
+ "max_steps": 80,
+ "num_input_tokens_seen": 0,
+ "num_train_epochs": 10,
+ "save_steps": 8,
+ "total_flos": 3.704687606680781e+17,
+ "train_batch_size": 2,
+ "trial_name": null,
+ "trial_params": null
+}
diff --git a/checkpoint-80/training_args.bin b/checkpoint-80/training_args.bin
new file mode 100644
index 0000000000000000000000000000000000000000..b11ae566a70dfd7bcafb281eef91bfd37c1b257b
--- /dev/null
+++ b/checkpoint-80/training_args.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbd3cdf0c7e847516177c465407e4f8b9cbcc9b8664e3b64c39191721cf5ef99
+size 6776
diff --git a/checkpoint-80/zero_to_fp32.py b/checkpoint-80/zero_to_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..49b846633d6eb1e836e34681e44033581f4edb7b
--- /dev/null
+++ b/checkpoint-80/zero_to_fp32.py
@@ -0,0 +1,592 @@
+#!/usr/bin/env python
+
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
+# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
+# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
+# application.
+#
+# example: python zero_to_fp32.py . pytorch_model.bin
+
+import argparse
+import torch
+import glob
+import math
+import os
+import re
+from collections import OrderedDict
+from dataclasses import dataclass
+
+# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
+# DeepSpeed data structures it has to be available in the current python environment.
+from deepspeed.utils import logger
+from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
+
+
+@dataclass
+class zero_model_state:
+ buffers: dict()
+ param_shapes: dict()
+ shared_params: list
+ ds_version: int
+ frozen_param_shapes: dict()
+ frozen_param_fragments: dict()
+
+
+debug = 0
+
+# load to cpu
+device = torch.device('cpu')
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text
+
+
+def natural_keys(text):
+ '''
+ alist.sort(key=natural_keys) sorts in human order
+ http://nedbatchelder.com/blog/200712/human_sorting.html
+ (See Toothy's implementation in the comments)
+ '''
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_model_state_file(checkpoint_dir, zero_stage):
+ if not os.path.isdir(checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
+
+ # there should be only one file
+ if zero_stage <= 2:
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
+ elif zero_stage == 3:
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
+
+ if not os.path.exists(file):
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
+
+ return file
+
+
+def get_checkpoint_files(checkpoint_dir, glob_pattern):
+ # XXX: need to test that this simple glob rule works for multi-node setup too
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
+
+ if len(ckpt_files) == 0:
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
+
+ return ckpt_files
+
+
+def get_optim_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
+
+
+def get_model_state_files(checkpoint_dir):
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
+
+
+def parse_model_states(files):
+ zero_model_states = []
+ for file in files:
+ state_dict = torch.load(file, map_location=device)
+
+ if BUFFER_NAMES not in state_dict:
+ raise ValueError(f"{file} is not a model state checkpoint")
+ buffer_names = state_dict[BUFFER_NAMES]
+ if debug:
+ print("Found buffers:", buffer_names)
+
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
+ param_shapes = state_dict[PARAM_SHAPES]
+
+ # collect parameters that are included in param_shapes
+ param_names = []
+ for s in param_shapes:
+ for name in s.keys():
+ param_names.append(name)
+
+ # update with frozen parameters
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
+ if frozen_param_shapes is not None:
+ if debug:
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
+ param_names += list(frozen_param_shapes.keys())
+
+ # handle shared params
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
+
+ ds_version = state_dict.get(DS_VERSION, None)
+
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
+
+ z_model_state = zero_model_state(buffers=buffers,
+ param_shapes=param_shapes,
+ shared_params=shared_params,
+ ds_version=ds_version,
+ frozen_param_shapes=frozen_param_shapes,
+ frozen_param_fragments=frozen_param_fragments)
+ zero_model_states.append(z_model_state)
+
+ return zero_model_states
+
+
+def parse_optim_states(files, ds_checkpoint_dir):
+
+ total_files = len(files)
+ state_dicts = []
+ for f in files:
+ state_dict = torch.load(f, map_location=device)
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
+ # and also handle the case where it was already removed by another helper script
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
+ state_dicts.append(state_dict)
+
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
+
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
+ # use the max of the partition_count to get the dp world_size.
+
+ if type(world_size) is list:
+ world_size = max(world_size)
+
+ if world_size != total_files:
+ raise ValueError(
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
+ )
+
+ # the groups are named differently in each stage
+ if zero_stage <= 2:
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
+ elif zero_stage == 3:
+ fp32_groups_key = FP32_FLAT_GROUPS
+ else:
+ raise ValueError(f"unknown zero stage {zero_stage}")
+
+ if zero_stage <= 2:
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
+ elif zero_stage == 3:
+ # if there is more than one param group, there will be multiple flattened tensors - one
+ # flattened tensor per group - for simplicity merge them into a single tensor
+ #
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
+
+ fp32_flat_groups = [
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
+ ]
+
+ return zero_stage, world_size, fp32_flat_groups
+
+
+def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
+ """
+ Returns fp32 state_dict reconstructed from ds checkpoint
+
+ Args:
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
+
+ """
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
+
+ optim_files = get_optim_files(ds_checkpoint_dir)
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
+
+ model_files = get_model_state_files(ds_checkpoint_dir)
+
+ zero_model_states = parse_model_states(model_files)
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
+
+ if zero_stage <= 2:
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+ elif zero_stage == 3:
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
+
+
+def _zero2_merge_frozen_params(state_dict, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
+
+ if debug:
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ state_dict[name] = frozen_param_fragments[name]
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _has_callable(obj, fn):
+ attr = getattr(obj, fn, None)
+ return callable(attr)
+
+
+def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+
+ # Reconstruction protocol:
+ #
+ # XXX: document this
+
+ if debug:
+ for i in range(world_size):
+ for j in range(len(fp32_flat_groups[0])):
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
+
+ # XXX: memory usage doubles here (zero2)
+ num_param_groups = len(fp32_flat_groups[0])
+ merged_single_partition_of_fp32_groups = []
+ for i in range(num_param_groups):
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
+ avail_numel = sum(
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
+
+ if debug:
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
+ # not asserting if there is a mismatch due to possible padding
+ print(f"Have {avail_numel} numels to process.")
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ total_numel = 0
+ total_params = 0
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
+ offset = 0
+ avail_numel = full_single_fp32_vector.numel()
+ for name, shape in shapes.items():
+
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ if debug:
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
+ offset += unpartitioned_numel
+
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
+ # live optimizer object, so we are checking that the numbers are within the right range
+ align_to = 2 * world_size
+
+ def zero2_align(x):
+ return align_to * math.ceil(x / align_to)
+
+ if debug:
+ print(f"original offset={offset}, avail_numel={avail_numel}")
+
+ offset = zero2_align(offset)
+ avail_numel = zero2_align(avail_numel)
+
+ if debug:
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
+
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def zero3_partitioned_param_info(unpartitioned_numel, world_size):
+ remainder = unpartitioned_numel % world_size
+ padding_numel = (world_size - remainder) if remainder else 0
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
+ return partitioned_numel, padding_numel
+
+
+def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
+ return
+
+ if debug:
+ for i in range(world_size):
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
+
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
+ wanted_params = len(frozen_param_shapes)
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
+ print(f'Frozen params: Have {avail_numel} numels to process.')
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
+
+ total_params = 0
+ total_numel = 0
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
+ total_params += 1
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
+ param_shapes = zero_model_states[0].param_shapes
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
+ # param, re-consolidating each param, while dealing with padding if any
+
+ # merge list of dicts, preserving order
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
+
+ if debug:
+ for i in range(world_size):
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
+
+ wanted_params = len(param_shapes)
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
+ # not asserting if there is a mismatch due to possible padding
+ avail_numel = fp32_flat_groups[0].numel() * world_size
+ print(f"Trainable params: Have {avail_numel} numels to process.")
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
+
+ # params
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
+ # out-of-core computing solution
+ offset = 0
+ total_numel = 0
+ total_params = 0
+ for name, shape in param_shapes.items():
+
+ unpartitioned_numel = shape.numel()
+ total_numel += unpartitioned_numel
+ total_params += 1
+
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
+
+ if debug:
+ print(
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
+ )
+
+ # XXX: memory usage doubles here
+ state_dict[name] = torch.cat(
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
+ offset += partitioned_numel
+
+ offset *= world_size
+
+ # Sanity check
+ if offset != avail_numel:
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
+
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
+
+
+def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
+ state_dict = OrderedDict()
+
+ # buffers
+ buffers = zero_model_states[0].buffers
+ state_dict.update(buffers)
+ if debug:
+ print(f"added {len(buffers)} buffers")
+
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
+
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
+
+ # recover shared parameters
+ for pair in zero_model_states[0].shared_params:
+ if pair[1] in state_dict:
+ state_dict[pair[0]] = state_dict[pair[1]]
+
+ return state_dict
+
+
+def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
+ via a model hub.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
+
+ Returns:
+ - pytorch ``state_dict``
+
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
+ the checkpoint.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
+ # do the training and checkpoint saving
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
+ model = model.cpu() # move to cpu
+ model.load_state_dict(state_dict)
+ # submit to model hub or save the model to share with others
+
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
+ application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
+
+ """
+ if tag is None:
+ latest_path = os.path.join(checkpoint_dir, 'latest')
+ if os.path.isfile(latest_path):
+ with open(latest_path, 'r') as fd:
+ tag = fd.read().strip()
+ else:
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
+
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
+
+ if not os.path.isdir(ds_checkpoint_dir):
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
+
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
+
+
+def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
+ """
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
+
+ Args:
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+ """
+
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+ print(f"Saving fp32 state dict to {output_file}")
+ torch.save(state_dict, output_file)
+
+
+def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
+ """
+ 1. Put the provided model to cpu
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
+ 3. Load it into the provided model
+
+ Args:
+ - ``model``: the model object to update
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
+
+ Returns:
+ - ``model`: modified model
+
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
+ conveniently placed for you in the checkpoint folder.
+
+ A typical usage might be ::
+
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
+ # submit to model hub or save the model to share with others
+
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
+
+ """
+ logger.info(f"Extracting fp32 weights")
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
+
+ logger.info(f"Overwriting model with fp32 weights")
+ model = model.cpu()
+ model.load_state_dict(state_dict, strict=False)
+
+ return model
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint_dir",
+ type=str,
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
+ parser.add_argument(
+ "output_file",
+ type=str,
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
+ parser.add_argument("-t",
+ "--tag",
+ type=str,
+ default=None,
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
+ args = parser.parse_args()
+
+ debug = args.debug
+
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag)
diff --git a/config.json b/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..d139138b8cfb3bf903c116881736a35363cfd077
--- /dev/null
+++ b/config.json
@@ -0,0 +1,42 @@
+{
+ "_name_or_path": "google/gemma-7b",
+ "architectures": [
+ "GemmaForCausalLM"
+ ],
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 2,
+ "eos_token_id": 1,
+ "head_dim": 256,
+ "hidden_act": "gelu",
+ "hidden_size": 3072,
+ "initializer_range": 0.02,
+ "intermediate_size": 24576,
+ "max_position_embeddings": 8192,
+ "model_type": "gemma",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 28,
+ "num_key_value_heads": 16,
+ "pad_token_id": 0,
+ "quantization_config": {
+ "_load_in_4bit": true,
+ "_load_in_8bit": false,
+ "bnb_4bit_compute_dtype": "bfloat16",
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_use_double_quant": true,
+ "llm_int8_enable_fp32_cpu_offload": false,
+ "llm_int8_has_fp16_weight": false,
+ "llm_int8_skip_modules": null,
+ "llm_int8_threshold": 6.0,
+ "load_in_4bit": true,
+ "load_in_8bit": false,
+ "quant_method": "bitsandbytes"
+ },
+ "rms_norm_eps": 1e-06,
+ "rope_scaling": null,
+ "rope_theta": 10000.0,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.38.2",
+ "use_cache": false,
+ "vocab_size": 256000
+}
diff --git a/runs/Mar12_21-34-13_myVm/events.out.tfevents.1710279255.myVm.207035.0 b/runs/Mar12_21-34-13_myVm/events.out.tfevents.1710279255.myVm.207035.0
new file mode 100644
index 0000000000000000000000000000000000000000..caef95f4cbb48abb569c30204f4ce3038023fb27
--- /dev/null
+++ b/runs/Mar12_21-34-13_myVm/events.out.tfevents.1710279255.myVm.207035.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a5e1845943e7b86810f8048c5615cc49427e1d2a25c1af84c3926843a55160a
+size 33616
diff --git a/special_tokens_map.json b/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..f6119589e367b2de0fc8cbd2f1217667532e3174
--- /dev/null
+++ b/special_tokens_map.json
@@ -0,0 +1,30 @@
+{
+ "bos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+}
diff --git a/tokenizer.json b/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..7d526fa4a2fc483e768773df90a944b54094e367
--- /dev/null
+++ b/tokenizer.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0d908b4f9326e0998815690e325b6abbd378978553e10627924dd825db7e243
+size 17477553
diff --git a/tokenizer.model b/tokenizer.model
new file mode 100644
index 0000000000000000000000000000000000000000..796efe9ab515c15e146ce7588e6d7b9b8134dbf8
--- /dev/null
+++ b/tokenizer.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2
+size 4241003
diff --git a/tokenizer_config.json b/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c22f33ac79a1ea3d99aa4e2c78d2379d98747f72
--- /dev/null
+++ b/tokenizer_config.json
@@ -0,0 +1,49 @@
+{
+ "add_bos_token": true,
+ "add_eos_token": false,
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "3": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ }
+ },
+ "bos_token": "",
+ "clean_up_tokenization_spaces": false,
+ "eos_token": "",
+ "legacy": null,
+ "model_max_length": 1000000000000000019884624838656,
+ "pad_token": "",
+ "sp_model_kwargs": {},
+ "spaces_between_special_tokens": false,
+ "tokenizer_class": "GemmaTokenizer",
+ "unk_token": "",
+ "use_default_system_prompt": false
+}