Add model card metadata (#1)
Browse files- Add model card metadata (33ae0651789be167dd5f292ec40b272df4de5904)
Co-authored-by: Niels Rogge <[email protected]>
README.md
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
pipeline_tag: text-generation
|
3 |
+
library_name: transformers
|
4 |
+
license: apache-2.0
|
5 |
+
---
|
6 |
+
|
7 |
+
# RADLADS
|
8 |
+
## Rapid Attention Distillation to Linear Attention Decoders at Scale
|
9 |
+
|
10 |
+
Paper link: https://arxiv.org/abs/2505.03005
|
11 |
+
|
12 |
+
Checkpoints: https://huggingface.co/collections/recursal/radlads-6818ee69e99e729ba8a87102
|
13 |
+
|
14 |
+
RADLADS converts traditional softmax attention transformers to use linear attention variants that feature constant-time inference per token. This is accomplished via a three stage distillation process that maintains quality close to the original teacher model. Conversion can be accomplished with 700 million tokens or less of distillation training.
|
15 |
+
|
16 |
+
<div align="center" >
|
17 |
+
<img src="assets/radlads_process.png" height=63 alt="RADLADS Conversion Process" />
|
18 |
+
</div>
|
19 |
+
|
20 |
+
We provide two new RWKV variants, RAD-RWKV6 and RAD-RWKV7, that provide an efficient destination architecture for transformer conversions. Our method achieves outstanding results, often with many fewer tokens of training than other methods:
|
21 |
+
|
22 |
+
<div align="center" >
|
23 |
+
<img src="assets/radlads_evals.png" height=275 alt="GoldFinch evals" />
|
24 |
+
</div>
|
25 |
+
|
26 |
+
Please see the RADLADS paper at https://arxiv.org/abs/2505.03005 for more details.
|
27 |
+
|
28 |
+
## What's included in this repository
|
29 |
+
|
30 |
+
- Reconfigurable Transformer base model code with support for carried state
|
31 |
+
- Pluggable time and channel mixer component classes for several model architectures
|
32 |
+
- RAD-RWKV6
|
33 |
+
- RAD-RWKV7
|
34 |
+
- Qwen2.5
|
35 |
+
- HuggingFace transformers conversion scripts and model code
|
36 |
+
- simple config system
|
37 |
+
- lightning based trainer
|
38 |
+
- lm_eval_harness support
|
39 |
+
- inference support (limited)
|
40 |
+
|
41 |
+
## setup
|
42 |
+
|
43 |
+
```
|
44 |
+
pip install lightning torch flash-linear-attention triton deepspeed wandb ninja --upgrade
|
45 |
+
```
|
46 |
+
|
47 |
+
You can download the DCLM binidx via
|
48 |
+
|
49 |
+
```bash
|
50 |
+
mkdir -p data
|
51 |
+
wget --continue -O data/dclm-10B.idx https://huggingface.co/datasets/recursal/DCLM-10B-Qwen2-binidx/resolve/main/dclm-10B.idx?download=true
|
52 |
+
wget --continue -O data/dclm-10B.bin https://huggingface.co/datasets/recursal/DCLM-10B-Qwen2-binidx/resolve/main/dclm-10B.bin?download=true
|
53 |
+
```
|
54 |
+
|
55 |
+
You can also convert other datasets or examine the magic primes required for an existing bin/idx dataset using `python3 make_data_hf.py`
|
56 |
+
|
57 |
+
## configuration
|
58 |
+
|
59 |
+
new config system allows you to specify one or more `-c CONFIG_PATH` in yaml or json format
|
60 |
+
later configs will override earlier ones
|
61 |
+
you can also list specific config parameters e.g. `--model.n_layer 12 --train.lr_init: 6e-4`
|
62 |
+
|
63 |
+
see configs.py for specific configuration settings in dataclasses
|
64 |
+
|
65 |
+
`model.tmix` is the first variety of time mixer, becomes the class at path `f'tmix.tmix_{tmix}.TMix_{tmix}'`
|
66 |
+
|
67 |
+
`model.tmix2` is the second variety of time mixer, if any
|
68 |
+
|
69 |
+
`model.cmix` is the first variety of channel mixer
|
70 |
+
|
71 |
+
`model.cmix2` is the second variety of channel mixer, if any
|
72 |
+
|
73 |
+
`model.inv_other_layer_ratio` is the ratio of second variety layers to all layers (e.g. 3 means 2/3 of the first variety and 1/3 of the second variety)
|
74 |
+
|
75 |
+
Inherited from LinearAttentionArena, training is broken up into 'mini-batches' of 40320 samples, where a sample is the context length of the model.
|
76 |
+
`magic_prime` is used to pseudo-randomize the location of these samples within the dataset, and is calculated as below from the LinearAttentionArena documentation:
|
77 |
+
|
78 |
+
```
|
79 |
+
magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case) = 2926181 in this case
|
80 |
+
|
81 |
+
use https://www.dcode.fr/prime-numbers-search
|
82 |
+
```
|
83 |
+
|
84 |
+
You can also examine the magic primes required for an existing bin/idx dataset using `python3 make_data_hf.py`
|
85 |
+
|
86 |
+
## running it
|
87 |
+
|
88 |
+
### Example for Qwen2.5-7B-Instruct
|
89 |
+
|
90 |
+
Download Qwen/Qwen2.5-7B-Instruct from huggingface
|
91 |
+
`huggingface-cli download Qwen/Qwen2.5-7B-Instruct`
|
92 |
+
|
93 |
+
Convert to PTH format
|
94 |
+
`python3 convert_hf_to_pth.py` YOUR_CACHED_HF_QWEN_MODEL_LOCATION out/Qwen2.5-7B-Instruct.pth
|
95 |
+
|
96 |
+
RADLADS Step 0:
|
97 |
+
`RWKV_TORCH_COMPILE=0 RWKV_JIT_ON=0 python3 train.py -c configs/qwen7b.yaml -c configs/qwerky7.yaml -c configs/distill1.yaml --train.load_model out/Qwen2.5-7B-Instruct.pth`
|
98 |
+
|
99 |
+
RADLADS Step 1:
|
100 |
+
`RWKV_TORCH_COMPILE=0 RWKV_JIT_ON=0 python3 train.py -c configs/qwen7b.yaml -c configs/qwerky7.yaml -c configs/qwen7binstructteacher.yaml -c configs/distill2.yaml --train.load_model out/L28-D3584-qwerky7_qwen2-1/rwkv-final.pth`
|
101 |
+
|
102 |
+
RADLADS Step 2:
|
103 |
+
`RWKV_TORCH_COMPILE=0 RWKV_JIT_ON=0 python3 train.py -c configs/qwen7b.yaml -c configs/qwerky7.yaml -c configs/qwen7binstructteacher.yaml -c configs/distill3.yaml --train.load_model out/L28-D3584-qwerky7_qwen2-2/rwkv-final.pth`
|
104 |
+
|
105 |
+
You can convert the resulting PTH files back to safetensors format for use with HF Transformers via
|
106 |
+
`python3 convert_to_safetensors.py out/L28-D3584-qwerky7_qwen2-3/rwkv-final.pth RADRWKV7Qwen2.5-7B/model.safetensors`
|
107 |
+
(note, you can list just a directory and it will emit chunked files instead of a single safetensors but sometimes HF has some issues with this and you have to convert to a single file first, and then from that to the chunks using this same convert_to_safetensors.py tool)
|
108 |
+
|
109 |
+
The HF Transformers model code is provided in the rwkv6qwen2 and rwkv7qwen2 subdirectories. You can put together a working HF model mostly by copy-and-pasting. Full details are beyond the scope of this tutorial, but you can look at the pre-converted models to see how it's done.
|
110 |
+
|
111 |
+
beware, it will continue from any numbered saved checkpoints still in the directory (if running again in the same dir)
|
112 |
+
|
113 |
+
there is also some lm_eval support in run_lm_eval.py, which also uses the same config system
|
114 |
+
|
115 |
+
and dragon_test.py which can be used to run a quick inference test, also with the same system
|
116 |
+
|
117 |
+
|
118 |
+
## Citation
|
119 |
+
|
120 |
+
If you use this code or find our work valuable, please consider citing RADLADS:
|
121 |
+
|
122 |
+
```bibtex
|
123 |
+
@misc{goldstein2025radladsrapidattentiondistillation,
|
124 |
+
title={RADLADS: Rapid Attention Distillation to Linear Attention Decoders at Scale},
|
125 |
+
author={Daniel Goldstein and Eric Alcaide and Janna Lu and Eugene Cheah},
|
126 |
+
year={2025},
|
127 |
+
eprint={2505.03005},
|
128 |
+
archivePrefix={arXiv},
|
129 |
+
primaryClass={cs.CL},
|
130 |
+
url={https://arxiv.org/abs/2505.03005},
|
131 |
+
}
|
132 |
+
```
|
133 |
+
|
134 |
+
Note: 72B models are also governed by the Qwen License Agreement.
|