--- base_model: distilbert/distilgpt2 datasets: - wikimedia/wikipedia library_name: Distily license: apache-2.0 tags: - generated_from_trainer model-index: - name: short_gpt2 results: [] --- # Summary Distilled with [Distily](https://github.com/lapp0/distily) library using teacher model [gpt2](https://huggingface.co/gpt2) on dataset [wikimedia/wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia). # Model Architecture: - **Architecture**: `GPT2LMHeadModel` - **Total Parameters**: 81,912,576 - **Data Type (dtype)**: torch.bfloat16 - **Model Size**: 0.16 GB # Evaluation Metrics Comparison | step | epoch | enwikippl | frwikippl | loss | runtime | samples_per_second | steps_per_second | tinystoriesppl | zhwikippl | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | **teacher eval** | | 43.25 | 61.25 | | | | | 11.6875 | 19.125 | | 0 | 0 | 2018634629120.0 | 122045790683136.0 | 21.0022 | 102.1494 | 97.896 | 12.237 | 9999220736.0 | 43705587204096.0 | | 2500 | 0.0101 | 299008.0 | 6422528.0 | 5.8065 | 101.9861 | 98.053 | 12.257 | 45824.0 | 14483456.0 | | 5000 | 0.0202 | 6880.0 | 96256.0 | 3.3113 | 102.9516 | 97.133 | 12.142 | 4160.0 | 493568.0 | | 7500 | 0.0303 | 1216.0 | 8096.0 | 2.1560 | 103.0236 | 97.065 | 12.133 | 692.0 | 42752.0 | | 10000 | 0.0404 | 608.0 | 3664.0 | 1.7825 | 102.3752 | 97.68 | 12.21 | 388.0 | 888.0 | | 12500 | 0.0505 | 358.0 | 1632.0 | 1.4664 | 102.1871 | 97.86 | 12.232 | 272.0 | 308.0 | | 15000 | 0.0606 | 288.0 | 1176.0 | 1.3488 | 102.6007 | 97.465 | 12.183 | 228.0 | 260.0 | | 17500 | 0.0707 | 255.0 | 1040.0 | 1.2932 | 102.1542 | 97.891 | 12.236 | 199.0 | 215.0 | | 20000 | 0.0808 | 216.0 | 892.0 | 1.1570 | 102.1073 | 97.936 | 12.242 | 173.0 | 149.0 | | 22500 | 0.0909 | 178.0 | 740.0 | 1.0350 | 102.0765 | 97.966 | 12.246 | 146.0 | 141.0 | | 25000 | 0.1010 | 155.0 | 524.0 | 0.9676 | 102.1019 | 97.941 | 12.243 | 122.5 | 139.0 | | 27500 | 0.1111 | 142.0 | 560.0 | 0.9230 | 102.0256 | 98.015 | 12.252 | 114.0 | 130.0 | | 30000 | 0.1212 | 137.0 | 470.0 | 0.8998 | 102.3365 | 97.717 | 12.215 | 108.5 | 138.0 | | 32500 | 0.1313 | 134.0 | 476.0 | 0.8740 | 102.3911 | 97.665 | 12.208 | 104.0 | 140.0 | | 35000 | 0.1414 | 129.0 | 496.0 | 0.8657 | 102.2153 | 97.833 | 12.229 | 102.5 | 141.0 | | 37500 | 0.1515 | 127.0 | 464.0 | 0.8513 | 102.0489 | 97.992 | 12.249 | 97.0 | 117.0 | | 40000 | 0.1616 | 108.0 | 446.0 | 0.7522 | 102.9331 | 97.15 | 12.144 | 93.0 | 104.0 | | 42500 | 0.1717 | 99.5 | 374.0 | 0.6850 | 103.1088 | 96.985 | 12.123 | 82.0 | 116.0 | | 45000 | 0.1818 | 90.5 | 346.0 | 0.6316 | 102.7903 | 97.285 | 12.161 | 73.5 | 113.0 | | 47500 | 0.1919 | 82.5 | 320.0 | 0.5960 | 102.5988 | 97.467 | 12.183 | 71.0 | 101.0 | | 50000 | 0.2020 | 78.5 | 306.0 | 0.5676 | 102.5936 | 97.472 | 12.184 | 72.5 | 106.0 | | 52500 | 0.2121 | 79.5 | 290.0 | 0.5424 | 102.5863 | 97.479 | 12.185 | 64.5 | 92.0 | | 55000 | 0.2222 | 76.0 | 270.0 | 0.5280 | 102.6307 | 97.437 | 12.18 | 65.0 | 87.0 | | 57500 | 0.2323 | 76.5 | 272.0 | 0.5278 | 101.9639 | 98.074 | 12.259 | 64.5 | 102.0 | | 60000 | 0.2424 | 77.5 | 268.0 | 0.5286 | 102.0921 | 97.951 | 12.244 | 62.75 | 99.5 | | 62500 | 0.2525 | 75.5 | 264.0 | 0.5204 | 102.0679 | 97.974 | 12.247 | 63.25 | 83.0 | | 65000 | 0.2626 | 76.0 | 260.0 | 0.5176 | 102.1795 | 97.867 | 12.233 | 61.5 | 90.5 | | 67500 | 0.2727 | 74.5 | 256.0 | 0.5112 | 102.5764 | 97.488 | 12.186 | 62.25 | 93.5 | | 70000 | 0.2828 | 73.5 | 258.0 | 0.5128 | 101.9569 | 98.081 | 12.26 | 62.0 | 79.0 | | 72500 | 0.2929 | 75.0 | 250.0 | 0.5053 | 101.9382 | 98.099 | 12.262 | 64.0 | 96.0 | | 75000 | 0.3030 | 72.5 | 238.0 | 0.5068 | 102.0407 | 98.0 | 12.25 | 61.5 | 88.5 | | 77500 | 0.3131 | 73.5 | 256.0 | 0.5085 | 102.0542 | 97.987 | 12.248 | 64.5 | 86.5 | | 80000 | 0.3232 | 70.5 | 238.0 | 0.4699 | 102.4042 | 97.652 | 12.207 | 54.75 | 98.5 | | 82500 | 0.3333 | 68.0 | 242.0 | 0.4574 | 102.2684 | 97.782 | 12.223 | 55.5 | 160.0 | | 85000 | 0.3434 | 64.5 | 218.0 | 0.4490 | 102.3277 | 97.725 | 12.216 | 52.0 | 77.5 | | 87500 | 0.3535 | 66.5 | 203.0 | 0.4394 | 102.1134 | 97.93 | 12.241 | 51.25 | 67.5 | | 90000 | 0.3636 | 63.75 | 212.0 | 0.4310 | 102.0438 | 97.997 | 12.25 | 51.25 | 88.5 | | 92500 | 0.3737 | 65.5 | 209.0 | 0.4262 | 101.9984 | 98.041 | 12.255 | 49.75 | 103.5 | | 95000 | 0.3838 | 65.0 | 204.0 | 0.4274 | 102.0781 | 97.964 | 12.246 | 46.25 | 83.0 | | 97500 | 0.3939 | 64.5 | 201.0 | 0.4192 | 102.0692 | 97.973 | 12.247 | 50.5 | 94.5 | | 100000 | 0.4040 | 64.5 | 203.0 | 0.4207 | 102.1283 | 97.916 | 12.24 | 49.0 | 88.0 | | 102500 | 0.4141 | 63.0 | 209.0 | 0.4184 | 102.224 | 97.824 | 12.228 | 48.0 | 125.0 | | 105000 | 0.4242 | 62.75 | 193.0 | 0.4166 | 102.1918 | 97.855 | 12.232 | 46.0 | 76.0 | | 107500 | 0.4343 | 62.75 | 197.0 | 0.4128 | 102.1719 | 97.874 | 12.234 | 47.0 | 113.0 | | 110000 | 0.4444 | 64.5 | 191.0 | 0.4118 | 103.0992 | 96.994 | 12.124 | 49.0 | 82.0 | | 112500 | 0.4545 | 65.0 | 213.0 | 0.4128 | 102.7296 | 97.343 | 12.168 | 47.0 | 111.5 | | 115000 | 0.4646 | 68.5 | 207.0 | 0.4301 | 102.178 | 97.868 | 12.234 | 49.0 | 108.0 | | 117500 | 0.4747 | 65.0 | 217.0 | 0.4372 | 102.2302 | 97.818 | 12.227 | 50.25 | 124.0 | | 120000 | 0.4848 | 65.5 | 210.0 | 0.4351 | 102.2952 | 97.756 | 12.22 | 51.0 | 139.0 | | 122500 | 0.4949 | 66.0 | 272.0 | 0.4352 | 102.1941 | 97.853 | 12.232 | 50.5 | 226.0 | | 125000 | 0.5051 | 67.0 | 240.0 | 0.4387 | 101.978 | 98.06 | 12.258 | 49.0 | 71.0 | | 127500 | 0.5152 | 66.5 | 224.0 | 0.4396 | 101.9014 | 98.134 | 12.267 | 49.75 | 100.0 | | 130000 | 0.5253 | 65.5 | 227.0 | 0.4354 | 102.1244 | 97.92 | 12.24 | 50.75 | 146.0 | | 132500 | 0.5354 | 66.0 | 209.0 | 0.4286 | 102.0218 | 98.018 | 12.252 | 52.25 | 101.5 | | 135000 | 0.5455 | 64.5 | 220.0 | 0.4361 | 101.9074 | 98.128 | 12.266 | 51.25 | 181.0 | | 137500 | 0.5556 | 66.5 | 223.0 | 0.4288 | 102.0744 | 97.968 | 12.246 | 49.0 | 103.0 | | 140000 | 0.5657 | 66.5 | 232.0 | 0.4287 | 102.1162 | 97.928 | 12.241 | 49.25 | 127.5 | | 142500 | 0.5758 | 66.5 | 220.0 | 0.4299 | 101.9461 | 98.091 | 12.261 | 49.5 | 88.5 | | 145000 | 0.5859 | 65.5 | 217.0 | 0.4238 | 101.9572 | 98.08 | 12.26 | 48.75 | 177.0 | | 147500 | 0.5960 | 64.0 | 205.0 | 0.4109 | 101.9497 | 98.088 | 12.261 | 48.75 | 128.0 | | 150000 | 0.6061 | 63.5 | 224.0 | 0.4051 | 102.0205 | 98.02 | 12.252 | 48.5 | 117.5 | | 152500 | 0.6162 | 63.25 | 202.0 | 0.4000 | 101.9318 | 98.105 | 12.263 | 47.5 | 160.0 | | 155000 | 0.6263 | 63.75 | 195.0 | 0.4052 | 102.0203 | 98.02 | 12.252 | 48.75 | 100.0 | | 157500 | 0.6364 | 63.75 | 212.0 | 0.4014 | 101.8935 | 98.142 | 12.268 | 49.25 | 113.0 | | 160000 | 0.6465 | 62.75 | 198.0 | 0.3988 | 101.9178 | 98.118 | 12.265 | 44.5 | 132.0 | | 162500 | 0.6566 | 64.5 | 192.0 | 0.3918 | 102.0303 | 98.01 | 12.251 | 45.5 | 100.0 | | 165000 | 0.6667 | 62.5 | 202.0 | 0.3958 | 102.3627 | 97.692 | 12.211 | 47.75 | 88.5 | | 167500 | 0.6768 | 62.5 | 191.0 | 0.3883 | 102.1537 | 97.892 | 12.236 | 44.75 | 80.5 | | 170000 | 0.6869 | 63.5 | 195.0 | 0.3880 | 102.0728 | 97.969 | 12.246 | 51.0 | 91.5 | | 172500 | 0.6970 | 60.75 | 201.0 | 0.3863 | 101.9235 | 98.113 | 12.264 | 47.5 | 90.5 | | 175000 | 0.7071 | 61.5 | 189.0 | 0.3806 | 101.9376 | 98.099 | 12.262 | 46.5 | 82.5 | | 177500 | 0.7172 | 58.75 | 171.0 | 0.3512 | 101.9844 | 98.054 | 12.257 | 42.75 | 66.0 | | 180000 | 0.7273 | 55.5 | 161.0 | 0.3218 | 101.881 | 98.154 | 12.269 | 39.25 | 54.0 | | 182500 | 0.7374 | 54.25 | 149.0 | 0.3148 | 101.9839 | 98.055 | 12.257 | 38.75 | 47.75 | | 185000 | 0.7475 | 53.5 | 160.0 | 0.3133 | 101.9875 | 98.051 | 12.256 | 38.75 | 45.0 | | 187500 | 0.7576 | 54.75 | 160.0 | 0.3114 | 101.9762 | 98.062 | 12.258 | 38.0 | 43.75 | | 190000 | 0.7677 | 53.75 | 147.0 | 0.3075 | 101.9972 | 98.042 | 12.255 | 38.0 | 38.25 | | 192500 | 0.7778 | 54.0 | 157.0 | 0.3057 | 101.9431 | 98.094 | 12.262 | 38.0 | 48.0 | | 195000 | 0.7879 | 53.25 | 149.0 | 0.3058 | 101.9778 | 98.061 | 12.258 | 37.0 | 41.0 | | 197500 | 0.7980 | 54.0 | 152.0 | 0.3032 | 102.0059 | 98.034 | 12.254 | 37.25 | 40.0 | | 200000 | 0.8081 | 53.75 | 151.0 | 0.3033 | 102.0615 | 97.98 | 12.248 | 37.25 | 47.25 | | 202500 | 0.8182 | 53.0 | 146.0 | 0.2957 | 102.0116 | 98.028 | 12.254 | 36.75 | 39.0 | | 205000 | 0.8283 | 52.5 | 139.0 | 0.2903 | 102.1449 | 97.9 | 12.238 | 36.5 | 35.75 | | 207500 | 0.8384 | 52.0 | 142.0 | 0.2894 | 102.0126 | 98.027 | 12.253 | 36.25 | 38.25 | | 210000 | 0.8485 | 52.25 | 142.0 | 0.2883 | 102.0938 | 97.949 | 12.244 | 36.0 | 37.25 | | 212500 | 0.8586 | 52.5 | 141.0 | 0.2874 | 101.9515 | 98.086 | 12.261 | 36.0 | 37.0 | | 215000 | 0.8687 | 52.25 | 140.0 | 0.2873 | 101.9427 | 98.094 | 12.262 | 36.0 | 36.0 | | 217500 | 0.8788 | 51.75 | 141.0 | 0.2863 | 102.0114 | 98.028 | 12.254 | 36.0 | 35.5 | | 220000 | 0.8889 | 52.0 | 141.0 | 0.2854 | 102.0424 | 97.999 | 12.25 | 36.0 | 35.75 | | 222500 | 0.8990 | 52.5 | 143.0 | 0.2853 | 102.0368 | 98.004 | 12.25 | 36.0 | 35.25 | | 225000 | 0.9091 | 52.0 | 142.0 | 0.2849 | 102.115 | 97.929 | 12.241 | 35.75 | 35.0 | | 227500 | 0.9192 | 52.0 | 141.0 | 0.2851 | 102.0455 | 97.996 | 12.249 | 36.0 | 35.25 | | 230000 | 0.9293 | 52.0 | 141.0 | 0.2846 | 102.0273 | 98.013 | 12.252 | 35.75 | 35.25 | | 232500 | 0.9394 | 52.0 | 141.0 | 0.2843 | 101.961 | 98.077 | 12.26 | 35.75 | 35.0 | | 235000 | 0.9495 | 52.0 | 141.0 | 0.2844 | 102.0188 | 98.021 | 12.253 | 35.75 | 35.25 | | 237500 | 0.9596 | 52.0 | 141.0 | 0.2845 | 102.0714 | 97.971 | 12.246 | 35.75 | 35.25 | | 240000 | 0.9697 | 52.0 | 141.0 | 0.2844 | 102.0371 | 98.004 | 12.25 | 35.75 | 35.25 | | 242500 | 0.9798 | 52.0 | 141.0 | 0.2844 | 102.0363 | 98.004 | 12.251 | 35.75 | 35.25 | | 245000 | 0.9899 | 52.0 | 141.0 | 0.2844 | 102.0254 | 98.015 | 12.252 | 35.75 | 35.25 | | 247500 | 1.0 | 52.0 | 141.0 | 0.2846 | 102.5728 | 97.492 | 12.186 | 35.75 | 35.25 | # Resource Usage Comparison - VRAM Use: 7.2012 GB `# Distillation (Teacher -> Student) Architecture Difference: - **Architecture**: `GPT2LMHeadModel` -> `GPT2LMHeadModel` - **Total Parameters**: 124,439,808 -> 81,912,576 - **Data Type (dtype)**: 124439808 -> torch.bfloat16 - **Model Size**: 0.24 GB -> 0.16 GB
Module Diff Details ```diff --- teacher model modules +++ student model modules @@ -4,7 +4,7 @@ (wpe): Embedding(1024, 768) (drop): Dropout(p=0.1, inplace=False) (h): ModuleList( - (0-11): 12 x GPT2Block( + (0-5): 6 x GPT2Block( (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (attn): GPT2FlashAttention2( (c_attn): Conv1D() ```

# Train Dataset Trained on 521,350,663 tokens from the [wikimedia/wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia) dataset. - Num Samples: `990,000` - Subset: `20231101.en` - Split: `train` # Training Objective ``` DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl)) ``` # Hyperparameters The following hyperparameters were used during training:
Expand - learning_rate: `0.0001` - train_batch_size: `4` - eval_batch_size: `8` - seed: `42` - optimizer: `Adam with betas=(0.9,0.999) and epsilon=1e-08` - lr_scheduler_type: `cosine` - lr_scheduler_warmup_ratio: `0.5` - num_epochs: `1.0` - distillation_objective: `DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl))` - train_embeddings: `True` - lr_scheduler: `` - student_model_name_or_path: `None` - student_config_name_or_path: `distilbert/distilgpt2` - student_model_config: `None` - reinitialize_weights: `None` - copy_teacher_modules: `[('lm_head', False)]` - student_model_as_bitnet: `False` - student_model_compile: `False` - dropout: `None` - teacher_model_name_or_path: `gpt2` - teacher_load_in_8bit: `False` - teacher_load_in_4bit: `False` - teacher_model_compile: `False` - dataset_uri: `wikimedia/wikipedia` - dataset_subset: `20231101.en` - dataset_split: `train` - dataset_column_name: `text` - dataset_sample_size: `1000000` - dataset_test_size: `0.01` - gradient_accumulation_steps: `1` - weight_decay: `0.0` - max_grad_norm: `1.0` - warmup_ratio: `0.5` - warmup_steps: `0` - gradient_checkpointing: `True`

# Framework Versions - Distily 0.2.0 - Transformers 4.44.0 - Pytorch 2.3.0 - Datasets 2.21.0