Erland commited on
Commit
ec1fbcf
·
verified ·
1 Parent(s): f63b81d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +471 -0
  3. config.json +34 -0
  4. download_checkpoint.py +35 -0
  5. fla/layers/__pycache__/linear_attn.cpython-311.pyc +0 -0
  6. fla/layers/bitattn.py +192 -0
  7. fla/layers/lightnet.py +210 -0
  8. fla/layers/nsa.py +138 -0
  9. fla/layers/rwkv7.py +221 -0
  10. fla/models/delta_net/__pycache__/__init__.cpython-311.pyc +0 -0
  11. fla/models/delta_net/__pycache__/modeling_delta_net.cpython-311.pyc +0 -0
  12. fla/models/delta_net/configuration_delta_net.py +91 -0
  13. fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-311.pyc +0 -0
  14. fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-311.pyc +0 -0
  15. fla/models/gated_deltanet/configuration_gated_deltanet.py +83 -0
  16. fla/models/gated_deltanet/modeling_gated_deltanet.py +412 -0
  17. fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-311.pyc +0 -0
  18. fla/models/gla/configuration_gla.py +95 -0
  19. fla/models/gla/modeling_gla.py +417 -0
  20. fla/models/gsa/__init__.py +13 -0
  21. fla/models/gsa/__pycache__/modeling_gsa.cpython-311.pyc +0 -0
  22. fla/models/hgrn/__init__.py +13 -0
  23. fla/models/hgrn/__pycache__/configuration_hgrn.cpython-311.pyc +0 -0
  24. fla/models/hgrn/__pycache__/modeling_hgrn.cpython-311.pyc +0 -0
  25. fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-311.pyc +0 -0
  26. fla/models/lightnet/__init__.py +13 -0
  27. fla/models/lightnet/modeling_lightnet.py +410 -0
  28. fla/models/linear_attn/__init__.py +12 -0
  29. fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-311.pyc +0 -0
  30. fla/models/mamba/__pycache__/__init__.cpython-311.pyc +0 -0
  31. fla/models/mamba/configuration_mamba.py +166 -0
  32. fla/models/mamba2/__init__.py +13 -0
  33. fla/models/mamba2/__pycache__/__init__.cpython-311.pyc +0 -0
  34. fla/models/retnet/__pycache__/__init__.cpython-311.pyc +0 -0
  35. fla/models/samba/__pycache__/__init__.cpython-311.pyc +0 -0
  36. flame/config_manager.py +940 -0
  37. flame/data.py +570 -0
  38. logs/none_75lcom2m/attempt_0/3/stderr.log +17 -0
  39. logs/none_75lcom2m/attempt_0/3/stdout.log +0 -0
  40. logs/none_vngrbiu1/attempt_0/0/stderr.log +0 -0
  41. logs/none_vngrbiu1/attempt_0/0/stdout.log +0 -0
  42. logs/none_vngrbiu1/attempt_0/1/stderr.log +0 -0
  43. logs/none_vngrbiu1/attempt_0/1/stdout.log +0 -0
  44. logs/none_vngrbiu1/attempt_0/2/stdout.log +0 -0
  45. logs/none_vngrbiu1/attempt_0/3/stderr.log +0 -0
  46. logs/none_vngrbiu1/attempt_0/3/stdout.log +0 -0
  47. profile_trace/iteration_1024/rank0_trace.json +0 -0
  48. profile_trace/iteration_1024/rank1_trace.json +0 -0
  49. profile_trace/iteration_1024/rank2_trace.json +0 -0
  50. profile_trace/iteration_1024/rank3_trace.json +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-2025 Songlin Yang, Yu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🔥 Flame: Flash Linear Attention Made Easy
4
+
5
+ </div>
6
+
7
+ Welcome to 🔥 `flame`, a minimal and efficient framework built on `torchtitan` for training Flash Linear Attention (FLA) models (and more broadly, arbitrary autoregressive language models) with blazing efficiency.
8
+
9
+ **Feature Highlights:**
10
+
11
+ - 🚀 Minimal, easy-to-use, extensible training framework
12
+ - 🤗 Seamless integration with `fla` and `transformers`
13
+ - 🔄 Zero-cost data preprocessing: online tokenization, dataset shuffling, and multiple datasets support
14
+ - 🔮 4D parallelism (coming soon)
15
+
16
+ ## Setup
17
+
18
+ To get started, clone the `flame` repository and install the required dependencies:
19
+
20
+ ```bash
21
+ git clone https://github.com/fla-org/flame.git
22
+ cd flame
23
+ pip install .
24
+ ```
25
+
26
+ `flame` manages minimal dependencies, only including `fla` and `torchtitan` as submodules.
27
+ After installation, initialize and update the submodules:
28
+ ```sh
29
+ git submodule update --init --recursive
30
+ ```
31
+
32
+ ## Dataset Preparation
33
+ To download the dataset to your local disk, create a new Python file with the following content and execute it:
34
+
35
+ ```py
36
+ from datasets import load_dataset
37
+
38
+ # load fineweb-edu with parallel processing
39
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="default", num_proc=64, cache_dir="/your/cache/path")
40
+
41
+ # or load a subset with roughly 100B tokens, suitable for small- or medium-sized experiments
42
+ dataset = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", num_proc=64, cache_dir="/your/cache/path")
43
+ ```
44
+
45
+ ## Training Recipes
46
+
47
+ Here's an example of training a 340M FLA Transformer model with a LLaMA-like architecture from scratch on a 100BT subset of the Fineweb-edu corpus in streaming mode.
48
+
49
+ > [!WARNING]
50
+ > If the dataset is not downloaded beforehand, the streaming mode will attempt to fetch it from a remote server and download it on-the-fly, which can be highly unstable during training due to network issues.
51
+ > For stable training, ensure the dataset is downloaded locally (see [**Dataset Preparation**](#dataset-preparation)). Otherwise, we assume you are only testing the new corpus.
52
+
53
+ ```sh
54
+ bash train.sh \
55
+ --job.config_file flame/models/fla.toml \
56
+ --job.dump_folder exp/transformer-340M-4K-10B/batch1.seqlen65536.context4096.warmup1024.update1.steps20480.lr3e-4.cosine \
57
+ --model.config configs/transformer_340M.json \
58
+ --model.tokenizer_path fla-hub/transformer-1.3B-100B \
59
+ --optimizer.name AdamW \
60
+ --optimizer.eps 1e-15 \
61
+ --optimizer.lr 3e-4 \
62
+ --lr_scheduler.warmup_steps 1024 \
63
+ --lr_scheduler.lr_min 0.1 \
64
+ --lr_scheduler.decay_type cosine \
65
+ --training.batch_size 1 \
66
+ --training.seq_len 65536 \
67
+ --training.context_len 4096 \
68
+ --training.varlen \
69
+ --training.gradient_accumulation_steps 1 \
70
+ --training.steps 20480 \
71
+ --training.max_norm 1.0 \
72
+ --training.skip_nan_inf \
73
+ --training.dataset HuggingFaceFW/fineweb-edu \
74
+ --training.dataset_name sample-100BT \
75
+ --training.dataset_split train \
76
+ --training.streaming \
77
+ --training.num_workers 32 \
78
+ --training.prefetch_factor 2 \
79
+ --training.seed 42 \
80
+ --training.compile \
81
+ --checkpoint.interval 2048 \
82
+ --checkpoint.load_step -1 \
83
+ --checkpoint.keep_latest_k 2 \
84
+ --metrics.log_freq 1
85
+ ```
86
+
87
+ You can specify the number of GPUs by setting the environment variable `NGPU`, which defaults to 8.
88
+ **For single-GPU debugging, set `NGPU=1`.**
89
+
90
+ We provide several [config files](https://github.com/fla-org/flame/tree/main/configs) for different models.
91
+ By default, the learning rate is set to 3e-4 with a cosine scheduler. Other schedulers, such as WSD (wsd), are also supported.
92
+
93
+ **Key parameters:**
94
+ - `--lr_scheduler.decay_ratio`: The proportion of the steps allocated to the decay phase. The learning rate will remain stable after the warmup period and only start decaying during the last `decay_ratio` portion of the total training steps, which is known as the Warmup-Stable-Decay (WSD) schedule.
95
+ - `--lr_scheduler.warmup_steps`: The number of steps for the learning rate warmup phase.
96
+ - `--training.steps`: Total number of training steps.
97
+ - `--training.batch_size`: Batch size per device, must be 1 if `--training.varlen` is set.
98
+ - `--training.seq_len`: The length of each sequence in the batch, which is concatenated from multiple samples.
99
+ - `--training.context_len`: The max allowed length of a sample. For non-varlen mode, this is equivalent to `seq_len`.
100
+ - `--training.varlen`: Whether to conduct variable-length sequence training.
101
+ - `--training.gradient_accumulation_steps`: Number of gradient accumulation steps.
102
+
103
+ > [!WARNING]
104
+ > The total number of tokens processed per batch, referred to as `global_batch_size`, is calculated as batch_size × gradient_accumulation_steps × num_gpus.
105
+ > Each step processes `global_batch_size * seq_len` tokens.
106
+ > Monitor the value of `global_batch_size`, `warmup_steps`, and `steps` carefully when modifying any of the hyperparameters!
107
+
108
+ For a detailed explanation of all parameters, run:
109
+
110
+ ```sh
111
+ bash train.sh -h
112
+ ```
113
+
114
+ <details>
115
+ <summary>Usage</summary>
116
+
117
+ ```py
118
+ options:
119
+ -h, --help show this help message and exit
120
+ --job.config_file JOB.CONFIG_FILE
121
+ Job config file
122
+ --job.dump_folder JOB.DUMP_FOLDER
123
+ Folder to dump job outputs
124
+ --job.description JOB.DESCRIPTION
125
+ Description of the job
126
+ --job.use_for_integration_test
127
+ Add this config to the integration test suite
128
+ --job.print_args Print the args to terminal
129
+ --model.config MODEL.CONFIG
130
+ Path to the model config
131
+ --model.norm_type MODEL.NORM_TYPE
132
+ Type of layer normalization to use [layernorm,
133
+ np_layernorm, rmsnorm, fused_rmsnorm]
134
+ --model.tokenizer_path MODEL.TOKENIZER_PATH
135
+ Tokenizer path
136
+ --profiling.enable_profiling
137
+ Whether to enable pytorch profiler
138
+ --profiling.save_traces_folder PROFILING.SAVE_TRACES_FOLDER
139
+ Trace files location
140
+ --profiling.profile_freq PROFILING.PROFILE_FREQ
141
+ How often to collect profiler traces, in iterations
142
+ --profiling.enable_memory_snapshot
143
+ Whether to dump memory snapshot
144
+ --profiling.save_memory_snapshot_folder PROFILING.SAVE_MEMORY_SNAPSHOT_FOLDER
145
+ Memeory snapshot files location
146
+ --optimizer.name OPTIMIZER.NAME
147
+ Optimizer to use
148
+ --optimizer.eps OPTIMIZER.EPS
149
+ Epsilon value for the optimizer.
150
+ --optimizer.fused Whether the fused implementation(CUDA only) is used.
151
+ --optimizer.scheduler {wsd,cosine,linear}
152
+ Scheduler to use. Currently supported: wsd, cosine,
153
+ and linear.
154
+ --optimizer.lr OPTIMIZER.LR
155
+ Learning rate to use
156
+ --optimizer.min_lr_ratio OPTIMIZER.MIN_LR_RATIO
157
+ Min lr ratio for lr scheduler
158
+ --optimizer.early_step_in_backward
159
+ Whether to apply optimizer in the backward. Caution,
160
+ optimizer_in_backward is not compatible with gradients
161
+ clipping, users should not call
162
+ register_post_accumulate_grad_hook after the optimizer
163
+ is built.
164
+ --training.batch_size TRAINING.BATCH_SIZE
165
+ Batch size
166
+ --training.seq_len TRAINING.SEQ_LEN
167
+ Sequence length
168
+ --training.context_len TRAINING.CONTEXT_LEN
169
+ Max length allowed for each sequence
170
+ --training.varlen Whether to take sequences of variable length as input
171
+ --training.warmup_steps TRAINING.WARMUP_STEPS
172
+ Steps for lr scheduler warmup, normally 1/5 of
173
+ --training.steps
174
+ --training.gradient_accumulation_steps TRAINING.GRADIENT_ACCUMULATION_STEPS
175
+ Number of steps to accumulate gradients before
176
+ updating parameters
177
+ --training.steps TRAINING.STEPS
178
+ How many train steps to run
179
+ --training.max_norm TRAINING.MAX_NORM
180
+ Max norm for gradient clipping
181
+ --training.skip_nan_inf
182
+ Skip batch updates when NaN or INF gradients are
183
+ encountered during training
184
+ --training.dataset TRAINING.DATASET
185
+ Dataset to use, with comma separated values
186
+ --training.dataset_name TRAINING.DATASET_NAME
187
+ The name of the dataset config, with comma separated
188
+ values if provided
189
+ --training.dataset_split TRAINING.DATASET_SPLIT
190
+ Dataset split to use, with comma separated values if
191
+ provided
192
+ --training.data_dir TRAINING.DATA_DIR
193
+ Data dirs to use, with comma separated values if
194
+ provided
195
+ --training.data_files TRAINING.DATA_FILES
196
+ Data files to use, with comma separated values if
197
+ provided
198
+ --training.data_probs TRAINING.DATA_PROBS
199
+ Data sampling probabilities, with comma separated
200
+ values if provided
201
+ --training.streaming Whether to load dataset in streaming mode, used for
202
+ huge dataset
203
+ --training.num_workers TRAINING.NUM_WORKERS
204
+ Number of subprocesses to use for data loading. 0
205
+ means that the data will be loaded in the main
206
+ process.
207
+ --training.prefetch_factor TRAINING.PREFETCH_FACTOR
208
+ Number of batches loaded in advance by each worker.2
209
+ means there will be a total of 2 * num_workers batches
210
+ prefetched across all workers.
211
+ --training.data_parallel_replicate_degree TRAINING.DATA_PARALLEL_REPLICATE_DEGREE
212
+ The `data_parallel_replicate_degree` argument
213
+ specifies the degree of data parallelism for weight
214
+ replication. When this value is greater than 1,
215
+ weights will be replicated across
216
+ `data_parallel_replicate_degree` ranks. If
217
+ `data_parallel_shard_degree` is also greater than 1,
218
+ the parallelism method used is HSDP (Hybrid Sharded
219
+ Data Parallelism). Otherwise, the parallelism method
220
+ used is DDP (Distributed Data Parallelism). 1 means
221
+ disabled.
222
+ --training.data_parallel_shard_degree TRAINING.DATA_PARALLEL_SHARD_DEGREE
223
+ The `data_parallel_shard_degree` argument specifies
224
+ the degree of data parallelism for weight sharding.
225
+ When this value is greater than 1, weights will be
226
+ sharded across `data_parallel_shard_degree` ranks. If
227
+ `data_parallel_replicate_degree` is also greater than
228
+ 1, the parallelism method used is HSDP (Hybrid Sharded
229
+ Data Parallelism). Otherwise, the parallelism method
230
+ used is FSDP (Fully Sharded Data Parallelism). -1
231
+ means leftover ranks will be used (After
232
+ DP_REPLICATE/SP/PP). Note that only
233
+ `data_parallel_shard_degree` can be negative. 1 means
234
+ disabled.
235
+ --training.enable_cpu_offload
236
+ Whether to apply CPU offloading of parameters,
237
+ gradients, and optimizer states in FSDP
238
+ --training.tensor_parallel_degree TRAINING.TENSOR_PARALLEL_DEGREE
239
+ Tensor Parallelism degree. 1 means disabled.
240
+ --training.disable_loss_parallel
241
+ Whether to apply loss parallel when sequence parallel
242
+ is enabled
243
+ --training.mixed_precision_param {bfloat16,float32}
244
+ torch dtype to use for parameters when applying mixed
245
+ precision via FSDP. This feature only takes effect
246
+ when data_parallel_shard_degree > 1
247
+ --training.mixed_precision_reduce {float32}
248
+ torch dtype to use for reductions when applying mixed
249
+ precision via FSDP. This feature only takes effect
250
+ when data_parallel_shard_degree > 1
251
+ --training.compile Whether to compile the model
252
+ --training.gc_freq TRAINING.GC_FREQ
253
+ Python garbage control scheduling interval, in steps
254
+ --training.seed TRAINING.SEED
255
+ Choose the base RNG seed used for training
256
+ --training.deterministic
257
+ Use deterministic algorithms wherever possible, may be
258
+ slower
259
+ --metrics.log_freq METRICS.LOG_FREQ
260
+ How often to log metrics to TensorBoard, in iterations
261
+ --metrics.enable_tensorboard
262
+ Whether to log metrics to TensorBoard
263
+ --metrics.disable_color_printing
264
+ Whether to disable color printing in logs
265
+ --metrics.save_tb_folder METRICS.SAVE_TB_FOLDER
266
+ Folder to dump TensorBoard states
267
+ --metrics.rank_0_only
268
+ Whether to save TensorBoard metrics only for rank 0 or
269
+ for all ranks. When pipeline_parallel_degree is > 1,
270
+ this option uses the 0th rank of the last stage
271
+ pipeline group, which is the only stage that computes
272
+ loss metrics.
273
+ --metrics.enable_wandb
274
+ Whether to log metrics to Weights & Biases
275
+ --experimental.enable_async_tensor_parallel
276
+ Whether to apply async tensor parallel (currently only
277
+ effective when compile is enabled)
278
+ --experimental.pipeline_parallel_degree EXPERIMENTAL.PIPELINE_PARALLEL_DEGREE
279
+ Pipeline Parallelism degree, or number of ranks. 1
280
+ means disabled. If using looped schedules, this still
281
+ specifies the number of physical ranks, not the number
282
+ of stages. Stages per rank are inferred from split
283
+ points degree, and schedule.
284
+ --experimental.pipeline_parallel_split_points EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS [EXPERIMENTAL.PIPELINE_PARALLEL_SPLIT_POINTS ...]
285
+ Specify comma-separated names of modules to use as the
286
+ beginning of a split point. e.g. "layers.0,layers.2"
287
+ will cause the model to be split into 3 stages, the
288
+ first containing all the layers up to layers.0, the
289
+ second containing layers.0 and up to layers.2, the
290
+ third containing layers.2 and all the remaining
291
+ layers. Note: fully-automated splitting may be enabled
292
+ in the future, but currently the split points must be
293
+ specified manually.
294
+ --experimental.pipeline_parallel_schedule EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE
295
+ Specify the Pipeline Parallel schedule to use. The
296
+ supported schedules are: https://github.com/pytorch/py
297
+ torch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/to
298
+ rch/distributed/pipelining/schedules.py#L2161. The
299
+ schedule must be compatible with the split points and
300
+ stages_per_rank. Looped schedules (e.g.
301
+ Interleaved1F1B) require specifying
302
+ pipeline_parallel_degree = number of ranks, and
303
+ split_points = number of stages - 1
304
+ --experimental.pipeline_parallel_schedule_csv EXPERIMENTAL.PIPELINE_PARALLEL_SCHEDULE_CSV
305
+ Specify the path to the pipeline parallel schedule csv
306
+ file to use. The pipeline_parallel_schedule argument
307
+ must be either PipelineScheduleSingle,
308
+ PipelineScheduleMulti, or _PipelineScheduleRuntime.
309
+ --experimental.pipeline_parallel_microbatches EXPERIMENTAL.PIPELINE_PARALLEL_MICROBATCHES
310
+ How many microbatches to split the global training
311
+ batch into when using pipeline parallelism. The global
312
+ training batch size must be evenly divisible by the
313
+ number of microbatches. The default value will be the
314
+ number of pipeline stages, if unspecified.
315
+ --experimental.enable_compiled_autograd
316
+ Enable CompiledAutograd to compile the backward.
317
+ --experimental.context_parallel_degree EXPERIMENTAL.CONTEXT_PARALLEL_DEGREE
318
+ Context parallelism degree. 1 means disabled.
319
+ --experimental.context_parallel_rotate_method EXPERIMENTAL.CONTEXT_PARALLEL_ROTATE_METHOD
320
+ The collective to use in context parallel SDPA for kv
321
+ shards exchange. 'allgather' means to all-gather all
322
+ kv shards on ranks after the first sub-SDPA
323
+ computation, 'alltoall' means to all-to-all shuffle
324
+ the kv shards. The default value is 'allgather'.
325
+ --checkpoint.enable_checkpoint
326
+ Whether to enable checkpoint
327
+ --checkpoint.folder CHECKPOINT.FOLDER
328
+ The folder to store the checkpoints. When
329
+ enable_checkpoint is set to true, checkpoints will be
330
+ in {--job.dump_folder}/{--checkpoint.folder}.
331
+ --checkpoint.interval_type CHECKPOINT.INTERVAL_TYPE
332
+ Checkpointing interval unit of measurement ['step',
333
+ 'seconds']
334
+ --checkpoint.interval CHECKPOINT.INTERVAL
335
+ Checkpointing interval, in steps or seconds depending
336
+ on --checkpoint.interval_type
337
+ --checkpoint.model_weights_only
338
+ When model_weights_only=True, only model weights will
339
+ be saved at the end of training. With this,
340
+ checkpoints can be loaded using `torch.load(...,
341
+ weights_only=True)` after conversion. When
342
+ model_weights_only=False, the full checkpoint will be
343
+ saved. A full checkpoint includes model, optimizer and
344
+ train_state, which can be used to resume training. The
345
+ default value is false.
346
+ --checkpoint.export_dtype {float16,bfloat16,float32}
347
+ Converts to the specified precision when training
348
+ completes and model_weights_only=true. Currently
349
+ supports float32, float16, and bfloat16. The default
350
+ value is float32.
351
+ --checkpoint.create_seed_checkpoint
352
+ Initializes the full model without applying
353
+ parallelisms, and then saves it as a seed checkpoint.
354
+ Note: requires user to call train.py without
355
+ specifying any parallelisms, e.g. NGPU=1. Could be
356
+ implemented as a separate script, but this way shares
357
+ more code.
358
+ --checkpoint.async_mode CHECKPOINT.ASYNC_MODE
359
+ Which async checkpoint mode to use. Currently there
360
+ are 3 different modes. 1. "disabled": synchronized
361
+ checkpointing will be used. 2. "async":
362
+ torch.distributed.checkpoint.async_save will be used.
363
+ 1. "async_with_pinned_mem": this option utilizes a
364
+ dedicated pinned memory space and creates a separate
365
+ process for faster GPU->CPU transfer performance and
366
+ eliminating GIL contention. The cost is increased CPU
367
+ memory usage. If insufficient CPU memory is available,
368
+ performance may degrade due to memory paging. For most
369
+ users, "async" should suffice as the performance
370
+ overhead is typically small (on the order of tens of
371
+ seconds) compared to checkpointing frequency. This
372
+ mode can be employed to pursue near-zero checkpointing
373
+ times (e.g., < 1 second) given appropriate hardware
374
+ support such as ample CPU memory and fast PCIe.
375
+ "disabled" is the default mode.
376
+ --checkpoint.keep_latest_k CHECKPOINT.KEEP_LATEST_K
377
+ Keeps only the latest k checkpoints, and purging older
378
+ ones. If 0, keep all checkpoints. 0 is the default
379
+ value.
380
+ --checkpoint.load_step CHECKPOINT.LOAD_STEP
381
+ Load the checkpoint at the specified step. If -1, load
382
+ the latest checkpoint.
383
+ --float8.enable_float8_linear
384
+ If true, swaps `torch.nn.Linear` with `Float8Linear`.
385
+ This feature requires you to install 'torchao' which
386
+ can be found here: https://github.com/pytorch/ao
387
+ --float8.enable_fsdp_float8_all_gather
388
+ Whether enable float8 all-gather in FSDP
389
+ --float8.precompute_float8_dynamic_scale_for_fsdp
390
+ Whether precompute float8 scales dynamically for FSDP
391
+ --float8.scaling_type_input {dynamic,delayed}
392
+ float8 scaling for input, dynamic (default) or delayed
393
+ --float8.scaling_type_weight FLOAT8.SCALING_TYPE_WEIGHT
394
+ float8 scaling for input, dynamic (default) or delayed
395
+ --float8.scaling_type_grad_output FLOAT8.SCALING_TYPE_GRAD_OUTPUT
396
+ float8 scaling for input, dynamic (default) or delayed
397
+ --comm.init_timeout_seconds COMM.INIT_TIMEOUT_SECONDS
398
+ Timeout for communication operations, during
399
+ initialization and first train step.
400
+ --comm.train_timeout_seconds COMM.TRAIN_TIMEOUT_SECONDS
401
+ Timeout for communication operations after the first
402
+ train step -- usually a tighter bound than during
403
+ initialization.
404
+ --comm.trace_buf_size COMM.TRACE_BUF_SIZE
405
+ Flight recorder ring buffer size, >0 means recording
406
+ by default, 0 means disabled
407
+ --memory_estimation.enabled
408
+ Whether to estimate memory usage for FSDP
409
+ --memory_estimation.disable_fake_mode
410
+ Whether to estimate memory under FakeTensorMode
411
+ ```
412
+ </details>
413
+
414
+ ### Training with `torch.compile`
415
+
416
+ Starting from `torch 2.0`, `torch.compile` has been introduced as a new feature to seamlessly accelerate training processes.
417
+ In `flame`, one can simply enable `torch.compile` by adding `--training.compile` flag to your training script.
418
+
419
+ However, `fla` has integrated numerous fused kernels for acceleration, which may potentially conflict with `torch.compile`.
420
+ We are actively working on resolving these issues to make compilation transparent to users.
421
+ In the meantime, please ensure you are using the latest dependencies.
422
+
423
+ Specifically, **we recommend using `torch>=2.6` and `triton>=3.0`**.
424
+
425
+ ### Training with multiple datasets
426
+
427
+ If you wish to train a model with all-round capabilities (e.g., code, math, and multilingual ability), it's necessary to train on multiple datasets.
428
+ `flame` allows training with multiple datasets easily.
429
+ For example, you can specify the following arguments to train on 6 datasets with different proportions:
430
+
431
+ ```sh
432
+ --training.dataset HuggingFaceFW/fineweb-edu,opencsg/Fineweb-Edu-Chinese-V2.1,OpenCoder-LLM/opc-fineweb-code-corpus,math-ai/AutoMathText,EleutherAI/proof-pile-2,OpenCoder-LLM/opc-fineweb-math-corpus \
433
+ --training.data_probs 0.6,0.15,0.15,0.014,0.058,0.028 \
434
+ ```
435
+
436
+ ### ~Finalizing training~
437
+
438
+ > [!NOTE]
439
+ > We have done this conversion automatically in the training script since our latest updates.
440
+
441
+ Once training is complete, you may want to convert the distributed checkpoints (DCPs) into the 🤗 format for broader use.
442
+ To facilitate this, we provide a straightforward conversion script:
443
+
444
+ ```sh
445
+ python -m flame.utils.convert_dcp_to_hf --path <path_to_model> --step <step> --config <path_to_config> --tokenizer <path_to_tokenizer>
446
+ ```
447
+ After this, your model will be in the 🤗 format, ready to be shared or deployed.
448
+ You can then easily publish your model using the `huggingface_hub` for wider accessibility.
449
+
450
+ ### Continual training
451
+
452
+ If you wish to build upon a strong pre-trained model (in 🤗 format) and continue training, we also offer a script to convert the 🤗 format model back into DCP format.
453
+ This allows you to seamlessly resume training with `flame`.
454
+ ```sh
455
+ python -m flame.utils.convert_hf_to_dcp --model <path_to_hf> --checkpoint <path_to_dcp/checkpoint/step-0>
456
+ ```
457
+ Here, `<path_to_dcp>` is the directory where your distributed checkpoints will be stored.
458
+ The checkpoint is intentionally saved at `<step-0>` within the checkpoint folder to ensure it is loadable by `flame` during the initial training step, similar to how a seed checkpoint is handled.
459
+
460
+ Once the conversion is complete, you can proceed with training using `flame` as usual, continuing from where the pretrained model left off.
461
+
462
+ ## Multi-node training
463
+
464
+ If you have access to multi-node GPUs, consider leveraging them for optimal performance.
465
+ This process is straightforward and well-documented in the PyTorch [docs](https://pytorch.org/docs/stable/elastic/run.html).
466
+
467
+ To set up multi-node training:
468
+ * Set the environment variables `MASTER_ADDR=<ip>` and `MASTER_PORT=<port>` before running the training script across all nodes.
469
+ * If you're using a job scheduler like Slurm, it will handle these variables for you.
470
+
471
+ `torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MTPTransformerForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "bos_token_id": 1,
7
+ "elementwise_affine": true,
8
+ "eos_token_id": 2,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": false,
11
+ "fuse_swiglu": true,
12
+ "hidden_act": "swish",
13
+ "hidden_ratio": 4,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": null,
17
+ "max_position_embeddings": 4096,
18
+ "model_type": "mtp_transformer",
19
+ "n_future_tokens": 4,
20
+ "norm_eps": 1e-06,
21
+ "num_heads": 12,
22
+ "num_hidden_layers": 14,
23
+ "num_kv_heads": null,
24
+ "qk_norm": false,
25
+ "qkv_bias": false,
26
+ "rope_theta": 10000.0,
27
+ "tie_word_embeddings": true,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.51.3",
30
+ "use_cache": true,
31
+ "use_custom_backward": false,
32
+ "vocab_size": 32000,
33
+ "window_size": null
34
+ }
download_checkpoint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from huggingface_hub import HfApi, HfFolder, snapshot_download
4
+
5
+ def main(args):
6
+ api = HfApi()
7
+ token = HfFolder.get_token()
8
+ experiment_checkpoint_folder = os.path.join(args.experiment_checkpoint_folder, "checkpoint")
9
+ os.makedirs(
10
+ experiment_checkpoint_folder,
11
+ exist_ok=True
12
+ )
13
+
14
+ snapshot_download(
15
+ repo_id=args.repo_id,
16
+ token=token,
17
+ local_dir=experiment_checkpoint_folder,
18
+ )
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser(description="Download a checkpoint from Hugging Face Hub.")
22
+ parser.add_argument(
23
+ "--repo_id",
24
+ type=str,
25
+ required=True,
26
+ help="The repository ID on Hugging Face Hub.",
27
+ )
28
+ parser.add_argument(
29
+ "--experiment_checkpoint_folder",
30
+ type=str,
31
+ required=True,
32
+ help="The local directory to save the downloaded checkpoint.",
33
+ )
34
+ args = parser.parse_args()
35
+ main(args)
fla/layers/__pycache__/linear_attn.cpython-311.pyc ADDED
Binary file (7.96 kB). View file
 
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/lightnet.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022)
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from einops import rearrange
14
+
15
+ from fla.modules import FusedRMSNormGated, ShortConvolution
16
+ from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear
17
+ from fla.ops.gla import chunk_gla, fused_recurrent_gla
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class LightNetAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ num_heads: Optional[int] = None,
32
+ expand_ratio: Optional[int] = 128,
33
+ use_short_conv: bool = False,
34
+ conv_size: int = 4,
35
+ conv_bias: bool = False,
36
+ gate_low_rank_dim: int = 128,
37
+ elementwise_affine: Optional[bool] = True,
38
+ norm_eps: float = 1e-5,
39
+ layer_idx: int = None
40
+ ) -> LightNetAttention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ self.hidden_size = hidden_size
45
+
46
+ if expand_ratio is None and num_heads is not None:
47
+ expand_ratio = hidden_size // num_heads
48
+ elif expand_ratio is not None and num_heads is None:
49
+ num_heads = hidden_size // expand_ratio
50
+ elif expand_ratio is None and num_heads is None:
51
+ raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.")
52
+ self.num_heads = num_heads
53
+ self.expand_ratio = expand_ratio
54
+
55
+ self.use_short_conv = use_short_conv
56
+ self.conv_size = conv_size
57
+ self.conv_bias = conv_bias
58
+
59
+ self.key_dim = int(self.num_heads * self.expand_ratio)
60
+ self.value_dim = hidden_size
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.layer_idx = layer_idx
63
+
64
+ assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
65
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
66
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
67
+
68
+ self.head_f_dim = self.expand_ratio
69
+ self.head_i_dim = self.hidden_size // num_heads
70
+
71
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
72
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
73
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
74
+
75
+ if use_short_conv:
76
+ self.conv_size = conv_size
77
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
78
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None)
79
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None)
80
+
81
+ self.g_proj = nn.Sequential(
82
+ nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
83
+ nn.Linear(gate_low_rank_dim, hidden_size, bias=False)
84
+ )
85
+ self.g_norm = FusedRMSNormGated(
86
+ hidden_size=hidden_size,
87
+ elementwise_affine=elementwise_affine,
88
+ eps=norm_eps
89
+ )
90
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
91
+
92
+ def forward(
93
+ self,
94
+ hidden_states: torch.Tensor,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ past_key_values: Optional[Cache] = None,
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ **kwargs: Unpack[Dict]
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
101
+ if attention_mask is not None:
102
+ assert len(attention_mask.shape) == 2, (
103
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
104
+ "for padding purposes (0 indicating padding). "
105
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
106
+ )
107
+
108
+ # launching the triton kernel for just one token will actually be slower
109
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
110
+
111
+ last_state = None
112
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
113
+ last_state = past_key_values[self.layer_idx]
114
+
115
+ cu_seqlens = kwargs.get('cu_seqlens', None)
116
+ if self.use_short_conv:
117
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
118
+ if last_state is not None:
119
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
120
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
121
+ q, conv_state_q = self.q_conv1d(
122
+ x=self.q_proj(hidden_states),
123
+ mask=conv_mask,
124
+ cache=conv_state_q,
125
+ output_final_state=use_cache,
126
+ cu_seqlens=cu_seqlens
127
+ )
128
+ k, conv_state_k = self.k_conv1d(
129
+ x=self.k_proj(hidden_states),
130
+ mask=conv_mask,
131
+ cache=conv_state_k,
132
+ output_final_state=use_cache,
133
+ cu_seqlens=cu_seqlens
134
+ )
135
+ v, conv_state_v = self.v_conv1d(
136
+ x=self.v_proj(hidden_states),
137
+ mask=conv_mask,
138
+ cache=conv_state_v,
139
+ output_final_state=use_cache,
140
+ cu_seqlens=cu_seqlens
141
+ )
142
+ else:
143
+ q = self.q_proj(hidden_states)
144
+ k = self.k_proj(hidden_states)
145
+ v = self.v_proj(hidden_states)
146
+
147
+ # dealing with left-padding
148
+ if attention_mask is not None:
149
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
150
+
151
+ q = F.silu(q)
152
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k))
153
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim)
154
+ # TODO: this 2 steps took huge amount of time, which should be optimized
155
+ z = k.float().logcumsumexp(1)
156
+
157
+ if cu_seqlens is not None:
158
+ raise NotImplementedError("LightNet does not support variable-length sequences for now.")
159
+ k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype)
160
+
161
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
162
+ if mode == 'fused_recurrent':
163
+ o, recurrent_state = fused_recurrent_gla(
164
+ q=q,
165
+ k=k,
166
+ v=v,
167
+ gk=g,
168
+ initial_state=recurrent_state,
169
+ output_final_state=use_cache,
170
+ cu_seqlens=cu_seqlens,
171
+ head_first=False
172
+ )
173
+ elif mode == 'chunk':
174
+ o, recurrent_state = chunk_gla(
175
+ q=q,
176
+ k=k,
177
+ v=v,
178
+ g=g,
179
+ initial_state=recurrent_state,
180
+ output_final_state=use_cache,
181
+ cu_seqlens=cu_seqlens,
182
+ head_first=False
183
+ )
184
+ else:
185
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
186
+
187
+ if past_key_values is not None:
188
+ past_key_values.update(
189
+ recurrent_state=recurrent_state,
190
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
191
+ layer_idx=self.layer_idx,
192
+ offset=q.shape[1]
193
+ )
194
+
195
+ o = rms_norm_swish_gate_linear(
196
+ rearrange(o, 'b t h d -> b t (h d)'),
197
+ self.g_proj(hidden_states),
198
+ self.g_norm.weight,
199
+ self.g_norm.bias,
200
+ self.o_proj.weight,
201
+ self.o_proj.bias
202
+ )
203
+ return o, None, past_key_values
204
+
205
+ def state_size(self, **kwargs) -> int:
206
+ state_size = self.key_dim * self.head_i_dim
207
+ for module in self.children():
208
+ if isinstance(module, ShortConvolution):
209
+ state_size += module.state_size
210
+ return state_size
fla/layers/nsa.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from transformers.utils import logging
12
+
13
+ from fla.modules import RotaryEmbedding
14
+ from fla.ops.nsa.parallel import parallel_nsa
15
+
16
+ if TYPE_CHECKING:
17
+ from fla.models.utils import Cache
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class NativeSparseAttention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ hidden_size: int = 2048,
27
+ num_heads: int = 64,
28
+ num_kv_heads: Optional[int] = 4,
29
+ head_dim: int = 64,
30
+ qkv_bias: bool = False,
31
+ block_size: Optional[int] = 64,
32
+ block_counts: Optional[Union[torch.LongTensor, int]] = 16,
33
+ window_size: Optional[int] = 512,
34
+ rope_theta: Optional[float] = 10000.,
35
+ max_position_embeddings: Optional[int] = None,
36
+ layer_idx: int = None
37
+ ):
38
+ super().__init__()
39
+
40
+ self.hidden_size = hidden_size
41
+ self.num_heads = num_heads
42
+ if num_kv_heads is None:
43
+ self.num_kv_heads = self.num_heads
44
+ else:
45
+ self.num_kv_heads = num_kv_heads
46
+ self.num_kv_groups = num_heads // self.num_kv_heads
47
+ self.head_dim = head_dim
48
+ self.kv_dim = self.num_kv_heads * self.head_dim
49
+ self.qkv_bias = qkv_bias
50
+
51
+ self.block_size = block_size
52
+ self.block_counts = block_counts
53
+ self.window_size = window_size
54
+ self.rope_theta = rope_theta
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.layer_idx = layer_idx
57
+
58
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
61
+ self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
62
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
63
+
64
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
65
+
66
+ def forward(
67
+ self,
68
+ hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.LongTensor] = None,
70
+ past_key_values: Optional[Cache] = None,
71
+ output_attentions: bool = False,
72
+ use_cache: bool = False,
73
+ **kwargs,
74
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
75
+ if attention_mask is not None:
76
+ assert len(attention_mask.shape) == 2, (
77
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
78
+ "for padding purposes (0 indicating padding). "
79
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
80
+ )
81
+
82
+ batch_size, seq_len, _ = hidden_states.size()
83
+
84
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
85
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
86
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
87
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
88
+ g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
89
+
90
+ cu_seqlens = kwargs.get('cu_seqlens', None)
91
+
92
+ seqlen_offset, max_seqlen = 0, seq_len
93
+ if past_key_values is not None:
94
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
95
+ max_seqlen = q.shape[1] + seqlen_offset
96
+
97
+ if attention_mask is not None:
98
+ # to deliminate the offsets of padding tokens
99
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
100
+ max_seqlen = q.shape[1] + max(seqlen_offset)
101
+
102
+ if self.max_position_embeddings is not None:
103
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
104
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
105
+
106
+ if past_key_values is not None:
107
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
108
+ k_cached, v_cached = past_key_values.update(
109
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
110
+ layer_idx=self.layer_idx,
111
+ offset=seq_len,
112
+ cache_kwargs=dict(window_size=self.window_size)
113
+ )['attn_state']
114
+ if cache_has_content:
115
+ k, v = k_cached, v_cached
116
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
117
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
118
+
119
+ o = parallel_nsa(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ g_cmp=g_cmp,
124
+ g_slc=g_slc,
125
+ g_swa=g_swa,
126
+ block_size=self.block_size,
127
+ block_counts=self.block_counts,
128
+ window_size=self.window_size,
129
+ cu_seqlens=cu_seqlens,
130
+ head_first=False
131
+ )
132
+ o = o.reshape(batch_size, seq_len, -1)
133
+ o = self.o_proj(o)
134
+
135
+ if not output_attentions:
136
+ attentions = None
137
+
138
+ return o, attentions, past_key_values
fla/layers/rwkv7.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.layers.rwkv6 import LoRA
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV7Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ head_dim: Optional[int] = 64,
29
+ num_heads: Optional[int] = None,
30
+ decay_low_rank_dim: int = 64,
31
+ gate_low_rank_dim: int = 128,
32
+ a_low_rank_dim: int = 64,
33
+ v_low_rank_dim: int = 16,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None,
37
+ fuse_norm: bool = False,
38
+ value_dim: int = None,
39
+ **kwargs
40
+ ) -> RWKV7Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
45
+ self.hidden_size = hidden_size
46
+
47
+ self.key_dim = hidden_size
48
+ self.value_dim = value_dim if value_dim is not None else hidden_size
49
+ if head_dim is None and num_heads is None:
50
+ raise ValueError("Either `head_dim` or `num_heads` must be specified.")
51
+ elif head_dim is not None:
52
+ self.head_dim = head_dim
53
+ self.num_heads = int(hidden_size // head_dim)
54
+ elif num_heads is not None:
55
+ self.head_dim = int(hidden_size // num_heads)
56
+ self.num_heads = num_heads
57
+ self.head_v_dim = int(self.value_dim // self.num_heads)
58
+
59
+ self.decay_low_rank_dim = decay_low_rank_dim
60
+ self.gate_low_rank_dim = gate_low_rank_dim
61
+ self.a_low_rank_dim = a_low_rank_dim
62
+ self.v_low_rank_dim = v_low_rank_dim
63
+ self.layer_idx = layer_idx
64
+ self.fuse_norm = fuse_norm
65
+
66
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
67
+
68
+ self.x_x = nn.Parameter(torch.zeros(6, hidden_size))
69
+
70
+ self.k_k = nn.Parameter(torch.zeros(self.key_dim))
71
+ self.k_a = nn.Parameter(torch.zeros(self.key_dim))
72
+ self.r_k = nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
73
+
74
+ self.r_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
75
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
76
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
77
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
78
+
79
+ self.w_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=decay_low_rank_dim, activation='tanh')
80
+ if self.layer_idx != 0:
81
+ self.v_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=v_low_rank_dim, activation=None)
82
+ self.a_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=a_low_rank_dim, activation=None)
83
+ self.g_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False)
84
+
85
+ if self.fuse_norm:
86
+ self.g_norm = GroupNorm(
87
+ num_groups=self.num_heads,
88
+ hidden_size=self.value_dim,
89
+ elementwise_affine=elementwise_affine,
90
+ eps=self.head_dim*norm_eps,
91
+ bias=True,
92
+ )
93
+ else:
94
+ self.g_norm = nn.GroupNorm(
95
+ num_groups=self.num_heads,
96
+ num_channels=self.value_dim,
97
+ eps=self.head_dim*norm_eps,
98
+ affine=elementwise_affine
99
+ )
100
+
101
+ self.apply(self._initialize_weights)
102
+
103
+ def _initialize_weights(self, module: nn.Module):
104
+ if getattr(module, "_is_hf_initialized", False):
105
+ return
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
108
+ if module.bias is not None:
109
+ nn.init.zeros_(module.bias)
110
+ if isinstance(module, nn.Parameter):
111
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
112
+ module._is_hf_initialized = True
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ past_key_values: Optional[Cache] = None,
119
+ use_cache: Optional[bool] = False,
120
+ output_attentions: Optional[bool] = False,
121
+ v_first: torch.Tensor = None,
122
+ **kwargs
123
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
124
+ if attention_mask is not None:
125
+ assert len(attention_mask.shape) == 2, (
126
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
127
+ "for padding purposes (0 indicating padding). "
128
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
129
+ )
130
+
131
+ batch_size, seq_len, _ = hidden_states.shape
132
+
133
+ if self.training:
134
+ # if training, use chunk mode no matter how short the sequence is
135
+ mode = 'chunk'
136
+ else:
137
+ # launching the triton kernel for just one token will actually be slower
138
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
139
+
140
+ last_state = None
141
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
142
+ last_state = past_key_values[self.layer_idx]
143
+
144
+ if attention_mask is not None:
145
+ hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None])
146
+ if hidden_states.shape[1] == 1 and last_state is not None:
147
+ shifted = last_state['conv_state'].unsqueeze(1)
148
+ else:
149
+ shifted = self.time_shift(hidden_states)
150
+ if last_state is not None:
151
+ shifted[:, 0] = last_state['conv_state']
152
+
153
+ # [batch_size, seq_len, hidden_size]
154
+ delta = shifted - hidden_states
155
+ xr, xw, xk, xv, xa, xg = hidden_states.addcmul(delta, self.x_x.view(6, 1, 1, -1)).unbind(0)
156
+
157
+ r = self.r_proj(xr)
158
+ # -math.exp(-0.5) = -0.6065306597126334
159
+ # I think .to(torch.float) is unnecessary here, since we calculate lora in bloat16
160
+ # when we apply sigmoid, bf16 input will not have numerical issue
161
+ # FIXME: check if we can remove .to(torch.float)
162
+ w = -0.6065306597126334 * self.w_lora(xw).to(torch.float).sigmoid()
163
+
164
+ k = self.k_proj(xk)
165
+ v = self.v_proj(xv)
166
+
167
+ if self.layer_idx == 0:
168
+ v_first = v
169
+ else:
170
+ v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid())
171
+ a = self.a_lora(xa).sigmoid()
172
+ g = self.g_lora(xg)
173
+
174
+ if self.fuse_norm:
175
+ kk = l2_norm(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim))
176
+ else:
177
+ kk = F.normalize(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0)
178
+
179
+ k = k.addcmul(k * (a - 1), self.k_a)
180
+
181
+ # dealing with left-padding
182
+ if attention_mask is not None:
183
+ v = v * attention_mask[:, -v.shape[-2]:, None]
184
+ r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
185
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
186
+
187
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
188
+
189
+ rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7
190
+ cu_seqlens = kwargs.get('cu_seqlens', None)
191
+ o, recurrent_state = rwkv7_fn(
192
+ r=r,
193
+ w=w,
194
+ k=k,
195
+ v=v,
196
+ a=-kk,
197
+ b=kk * a,
198
+ scale=1.,
199
+ initial_state=recurrent_state,
200
+ output_final_state=use_cache,
201
+ cu_seqlens=cu_seqlens,
202
+ head_first=False
203
+ )
204
+
205
+ if past_key_values is not None:
206
+ past_key_values.update(
207
+ recurrent_state=recurrent_state,
208
+ conv_state=hidden_states[:, -1],
209
+ layer_idx=self.layer_idx,
210
+ offset=r.shape[1]
211
+ )
212
+
213
+ if self.fuse_norm:
214
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)'))
215
+ else:
216
+ o = self.g_norm(rearrange(o, 'b t h d -> (b t) (h d)')).view(batch_size, seq_len, -1)
217
+
218
+ o = o + ((r * k * self.r_k).sum(-1, keepdim=True) * v).view(batch_size, seq_len, -1)
219
+ o = self.o_proj(o * g)
220
+
221
+ return o, None, past_key_values, v_first
fla/models/delta_net/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (758 Bytes). View file
 
fla/models/delta_net/__pycache__/modeling_delta_net.cpython-311.pyc ADDED
Binary file (19.3 kB). View file
 
fla/models/delta_net/configuration_delta_net.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DeltaNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'delta_net'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ use_gate: bool = False,
20
+ use_short_conv: bool = True,
21
+ conv_size: int = 4,
22
+ use_beta: bool = True,
23
+ use_output_norm: bool = True,
24
+ num_heads: int = 16,
25
+ qk_norm: str = 'l2',
26
+ qk_activation: str = 'silu',
27
+ max_position_embeddings: int = 2048,
28
+ hidden_ratio: Optional[int] = 4,
29
+ intermediate_size: Optional[int] = None,
30
+ hidden_act: str = "swish",
31
+ num_hidden_layers: int = 24,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.use_gate = use_gate
51
+ self.use_short_conv = use_short_conv
52
+ self.conv_size = conv_size
53
+ self.use_beta = use_beta
54
+ self.use_output_norm = use_output_norm
55
+ self.num_heads = num_heads
56
+ self.qk_norm = qk_norm
57
+ self.qk_activation = qk_activation
58
+ self.max_position_embeddings = max_position_embeddings
59
+
60
+ self.hidden_ratio = hidden_ratio
61
+ self.intermediate_size = intermediate_size
62
+ self.hidden_act = hidden_act
63
+ self.num_hidden_layers = num_hidden_layers
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/forgetting_transformer/__pycache__/configuration_forgetting_transformer.cpython-311.pyc ADDED
Binary file (2.77 kB). View file
 
fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-311.pyc ADDED
Binary file (3.73 kB). View file
 
fla/models/gated_deltanet/configuration_gated_deltanet.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaNetConfig(PretrainedConfig):
9
+ model_type = 'gated_deltanet'
10
+ keys_to_ignore_at_inference = ['past_key_values']
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_eps: float = 1e-6,
28
+ attn: Optional[Dict] = None,
29
+ use_cache: bool = True,
30
+ pad_token_id: int = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ initializer_range: float = 0.006,
35
+ fuse_norm: bool = True,
36
+ fuse_swiglu: bool = True,
37
+ fuse_cross_entropy: bool = True,
38
+ vocab_size: int = 32000,
39
+ **kwargs
40
+ ):
41
+ self.attn_mode = attn_mode
42
+ self.hidden_size = hidden_size
43
+ self.expand_v = expand_v
44
+ self.use_gate = use_gate
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.head_dim = head_dim
48
+ self.num_heads = num_heads
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_act = hidden_act
54
+ self.num_hidden_layers = num_hidden_layers
55
+ self.norm_eps = norm_eps
56
+ self.attn = attn
57
+ self.use_cache = use_cache
58
+ self.initializer_range = initializer_range
59
+
60
+ self.fuse_norm = fuse_norm
61
+ self.fuse_swiglu = fuse_swiglu
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ if attn is not None:
66
+ if not isinstance(attn, Dict):
67
+ raise ValueError("attn must be a dictionary")
68
+ if 'layers' not in attn:
69
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
70
+ if 'num_heads' not in attn:
71
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
72
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
73
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
fla/models/gated_deltanet/modeling_gated_deltanet.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gated_deltanet import GatedDeltaNet
20
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GatedDeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetBlock(nn.Module):
34
+ def __init__(self, config: GatedDeltaNetConfig, layer_idx: int):
35
+ super().__init__()
36
+
37
+ self.config = config
38
+ self.layer_idx = layer_idx
39
+
40
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
41
+ if config.attn is not None and layer_idx in config.attn['layers']:
42
+ self.attn = Attention(
43
+ hidden_size=config.hidden_size,
44
+ num_heads=config.attn['num_heads'],
45
+ num_kv_heads=config.attn['num_kv_heads'],
46
+ qkv_bias=config.attn['qkv_bias'],
47
+ window_size=config.attn['window_size'],
48
+ rope_theta=config.attn['rope_theta'],
49
+ max_position_embeddings=config.max_position_embeddings,
50
+ layer_idx=layer_idx
51
+ )
52
+ else:
53
+ self.attn = GatedDeltaNet(
54
+ mode=config.attn_mode,
55
+ hidden_size=config.hidden_size,
56
+ expand_v=config.expand_v,
57
+ head_dim=config.head_dim,
58
+ num_heads=config.num_heads,
59
+ use_gate=config.use_gate,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ norm_eps=config.norm_eps,
63
+ layer_idx=layer_idx
64
+ )
65
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
66
+ self.mlp = GatedDeltaNetMLP(
67
+ hidden_size=config.hidden_size,
68
+ hidden_ratio=config.hidden_ratio,
69
+ intermediate_size=config.intermediate_size,
70
+ hidden_act=config.hidden_act,
71
+ fuse_swiglu=config.fuse_swiglu
72
+ )
73
+
74
+ def forward(
75
+ self,
76
+ hidden_states: torch.Tensor,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
79
+ use_cache: Optional[bool] = False,
80
+ output_attentions: Optional[bool] = False,
81
+ **kwargs: Unpack[Dict]
82
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ **kwargs
92
+ )
93
+ if self.config.fuse_norm:
94
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
95
+ else:
96
+ hidden_states = residual + hidden_states
97
+ residual = hidden_states
98
+ hidden_states = self.mlp_norm(hidden_states)
99
+ hidden_states = self.mlp(hidden_states, **kwargs)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states, attentions, past_key_values)
103
+
104
+ return outputs
105
+
106
+
107
+ class GatedDeltaNetPreTrainedModel(PreTrainedModel):
108
+
109
+ config_class = GatedDeltaNetConfig
110
+ base_model_prefix = 'model'
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ['GatedDeltaNetBlock']
113
+ _supports_cache_class = True
114
+
115
+ def __init__(self, *inputs, **kwargs):
116
+ super().__init__(*inputs, **kwargs)
117
+
118
+ def _init_weights(
119
+ self,
120
+ module: nn.Module,
121
+ prenorm_residual_strategy: Optional[str] = 'rescale',
122
+ num_residuals_per_layer: int = 2,
123
+ ):
124
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
125
+ # Slightly different from the TF version which uses truncated_normal for initialization
126
+ # cf https://github.com/pytorch/pytorch/pull/5617
127
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
128
+ if module.bias is not None:
129
+ nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.Embedding):
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ elif hasattr(module, 'reset_parameters'):
133
+ module.reset_parameters()
134
+
135
+ if prenorm_residual_strategy is not None:
136
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
137
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
138
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
139
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
140
+ #
141
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
142
+ p = None
143
+ if hasattr(module, 'o_proj'):
144
+ p = module.o_proj.weight
145
+ elif hasattr(module, 'down_proj'):
146
+ p = module.down_proj.weight
147
+ if p is not None:
148
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
149
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
150
+ # We need to reinit p since this code could be called multiple times
151
+ # Having just p *= scale would repeatedly scale it down
152
+ if prenorm_residual_strategy == 'rescale':
153
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
154
+ with torch.no_grad():
155
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
156
+ elif prenorm_residual_strategy == 'zero':
157
+ nn.init.zeros_(p)
158
+ else:
159
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
160
+
161
+
162
+ class GatedDeltaNetModel(GatedDeltaNetPreTrainedModel):
163
+
164
+ def __init__(self, config: GatedDeltaNetConfig):
165
+ super().__init__(config)
166
+ self.padding_idx = config.pad_token_id
167
+ self.vocab_size = config.vocab_size
168
+
169
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
170
+ self.layers = nn.ModuleList([GatedDeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
171
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
172
+
173
+ self.gradient_checkpointing = False
174
+
175
+ self.post_init()
176
+
177
+ def get_input_embeddings(self):
178
+ return self.embeddings
179
+
180
+ def set_input_embeddings(self, value):
181
+ self.embeddings = value
182
+
183
+ def forward(
184
+ self,
185
+ input_ids: Optional[torch.LongTensor] = None,
186
+ attention_mask: Optional[torch.Tensor] = None, # noqa
187
+ inputs_embeds: Optional[torch.FloatTensor] = None,
188
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
189
+ use_cache: Optional[bool] = None,
190
+ output_attentions: Optional[bool] = None,
191
+ output_hidden_states: Optional[bool] = None,
192
+ return_dict: Optional[bool] = None,
193
+ **kwargs: Unpack[Dict]
194
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
195
+ if output_attentions:
196
+ warnings.warn("`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.")
197
+ output_attentions = False
198
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
199
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
200
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
201
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
202
+
203
+ # retrieve input_ids and inputs_embeds
204
+ if input_ids is not None and inputs_embeds is not None:
205
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
206
+ if input_ids is None and inputs_embeds is None:
207
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
208
+
209
+ if inputs_embeds is None:
210
+ inputs_embeds = self.embeddings(input_ids)
211
+ hidden_states = inputs_embeds
212
+
213
+ if use_cache and not isinstance(past_key_values, Cache):
214
+ past_key_values = Cache.from_legacy_cache(past_key_values)
215
+
216
+ if self.gradient_checkpointing and self.training and use_cache:
217
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
218
+ use_cache = False
219
+
220
+ all_hidden_states = () if output_hidden_states else None
221
+ all_attns = () if output_attentions else None
222
+ for layer in self.layers:
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ if self.gradient_checkpointing and self.training:
227
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
228
+ layer.__call__,
229
+ hidden_states,
230
+ attention_mask,
231
+ past_key_values,
232
+ use_cache,
233
+ output_attentions,
234
+ **kwargs
235
+ )
236
+ else:
237
+ hidden_states, attentions, past_key_values = layer(
238
+ hidden_states,
239
+ attention_mask=attention_mask,
240
+ past_key_values=past_key_values,
241
+ use_cache=use_cache,
242
+ output_attentions=output_attentions,
243
+ **kwargs
244
+ )
245
+
246
+ if output_attentions:
247
+ all_attns += (attentions,)
248
+
249
+ hidden_states = self.norm(hidden_states)
250
+
251
+ # add hidden states from the last decoder layer
252
+ if output_hidden_states:
253
+ all_hidden_states += (hidden_states,)
254
+
255
+ if not return_dict:
256
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
257
+ return BaseModelOutputWithPast(
258
+ last_hidden_state=hidden_states,
259
+ past_key_values=past_key_values,
260
+ hidden_states=all_hidden_states,
261
+ attentions=all_attns
262
+ )
263
+
264
+
265
+ class GatedDeltaNetForCausalLM(GatedDeltaNetPreTrainedModel, GenerationMixin):
266
+
267
+ _tied_weights_keys = ["lm_head.weight"]
268
+
269
+ def __init__(self, config):
270
+ super().__init__(config)
271
+ self.model = GatedDeltaNetModel(config)
272
+ self.vocab_size = config.vocab_size
273
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
274
+ self.criterion = None
275
+
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.model.embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.model.embeddings = value
284
+
285
+ def get_output_embeddings(self):
286
+ return self.lm_head
287
+
288
+ def set_output_embeddings(self, new_embeddings):
289
+ self.lm_head = new_embeddings
290
+
291
+ def set_decoder(self, decoder):
292
+ self.model = decoder
293
+
294
+ def get_decoder(self):
295
+ return self.model
296
+
297
+ def generate(self, *args, **kwargs):
298
+ try:
299
+ return super().generate(*args, **kwargs)
300
+ except AttributeError as exception:
301
+ if 'past_key_values' in str(exception):
302
+ raise AttributeError(
303
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
304
+ f"which is not supported for {self.__class__.__name__}. "
305
+ f"Try another generation strategy instead. "
306
+ f"For the available generation strategies, check this doc: "
307
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
308
+ )
309
+ else:
310
+ raise exception
311
+
312
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
313
+ def prepare_inputs_for_generation(
314
+ self,
315
+ input_ids: torch.LongTensor = None,
316
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ logits_to_keep: Optional[int] = None,
321
+ **kwargs
322
+ ):
323
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
324
+ if past_key_values is not None and len(past_key_values) > 0:
325
+ input_ids = input_ids[:, -1:]
326
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
327
+ if inputs_embeds is not None and len(past_key_values) == 0:
328
+ model_inputs = {'inputs_embeds': inputs_embeds}
329
+ else:
330
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
331
+ # recompiles graphs as the stride of the inputs is a guard.
332
+ # Ref: https://github.com/huggingface/transformers/pull/29114
333
+ # TODO: use `next_tokens` directly instead.
334
+ model_inputs = {'input_ids': input_ids.contiguous()}
335
+
336
+ if logits_to_keep is not None:
337
+ model_inputs['logits_to_keep'] = logits_to_keep
338
+
339
+ model_inputs.update({
340
+ 'past_key_values': past_key_values,
341
+ 'use_cache': use_cache,
342
+ 'attention_mask': attention_mask,
343
+ })
344
+ return model_inputs
345
+
346
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
347
+ def forward(
348
+ self,
349
+ input_ids: torch.LongTensor = None,
350
+ attention_mask: Optional[torch.Tensor] = None,
351
+ inputs_embeds: Optional[torch.Tensor] = None,
352
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
353
+ labels: Optional[torch.LongTensor] = None,
354
+ use_cache: Optional[bool] = None,
355
+ output_attentions: Optional[bool] = None,
356
+ output_hidden_states: Optional[bool] = None,
357
+ return_dict: Optional[bool] = None,
358
+ logits_to_keep: Optional[int] = 0,
359
+ **kwargs: Unpack[Dict]
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
362
+ output_hidden_states = (
363
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
364
+ )
365
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
366
+
367
+ outputs = self.model(
368
+ input_ids=input_ids,
369
+ attention_mask=attention_mask,
370
+ inputs_embeds=inputs_embeds,
371
+ past_key_values=past_key_values,
372
+ use_cache=use_cache,
373
+ output_attentions=output_attentions,
374
+ output_hidden_states=output_hidden_states,
375
+ return_dict=return_dict,
376
+ **kwargs
377
+ )
378
+
379
+ hidden_states = outputs[0]
380
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
381
+
382
+ loss, logits = None, None
383
+ if not fuse_linear_and_cross_entropy or labels is None:
384
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
385
+ if labels is not None:
386
+ if getattr(self, 'criterion', None) is None:
387
+ if fuse_linear_and_cross_entropy:
388
+ criterion = FusedLinearCrossEntropyLoss()
389
+ elif self.config.fuse_cross_entropy:
390
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
391
+ else:
392
+ criterion = nn.CrossEntropyLoss()
393
+ else:
394
+ criterion = self.criterion
395
+ labels = labels.to(hidden_states.device)
396
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
397
+ if fuse_linear_and_cross_entropy:
398
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
399
+ else:
400
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
401
+
402
+ if not return_dict:
403
+ output = (logits,) + outputs[1:]
404
+ return (loss,) + output if loss is not None else output
405
+
406
+ return CausalLMOutputWithPast(
407
+ loss=loss,
408
+ logits=logits,
409
+ past_key_values=outputs.past_key_values,
410
+ hidden_states=outputs.hidden_states,
411
+ attentions=outputs.attentions,
412
+ )
fla/models/gated_deltaproduct/__pycache__/modeling_gated_deltaproduct.cpython-311.pyc ADDED
Binary file (21.5 kB). View file
 
fla/models/gla/configuration_gla.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GLAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gla'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ expand_k: int = 0.5,
17
+ expand_v: int = 1,
18
+ hidden_ratio: Optional[int] = 4,
19
+ intermediate_size: Optional[int] = None,
20
+ num_hidden_layers: int = 24,
21
+ num_heads: int = 4,
22
+ num_kv_heads: Optional[int] = None,
23
+ feature_map: Optional[str] = None,
24
+ attn_mode: str = "chunk",
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ use_output_gate: bool = True,
28
+ clamp_min: Optional[float] = None,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ use_gk: bool = True,
34
+ use_gv: bool = False,
35
+ attn: Optional[Dict] = None,
36
+ use_cache: bool = True,
37
+ pad_token_id: int = None,
38
+ bos_token_id: int = 1,
39
+ eos_token_id: int = 2,
40
+ tie_word_embeddings: bool = False,
41
+ initializer_range: float = 0.006,
42
+ fuse_norm: bool = True,
43
+ fuse_swiglu: bool = True,
44
+ fuse_cross_entropy: bool = True,
45
+ vocab_size: int = 32000,
46
+ **kwargs
47
+ ):
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.attn_mode = attn_mode
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.clamp_min = clamp_min
62
+ self.hidden_act = hidden_act
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.elementwise_affine = elementwise_affine
65
+ self.norm_eps = norm_eps
66
+ self.use_gk = use_gk
67
+ self.use_gv = use_gv
68
+ self.attn = attn
69
+ self.use_cache = use_cache
70
+ self.initializer_range = initializer_range
71
+
72
+ self.fuse_norm = fuse_norm
73
+ self.fuse_swiglu = fuse_swiglu
74
+ self.fuse_cross_entropy = fuse_cross_entropy
75
+ self.vocab_size = vocab_size
76
+
77
+ if attn is not None:
78
+ if not isinstance(attn, Dict):
79
+ raise ValueError("attn must be a dictionary")
80
+ if 'layers' not in attn:
81
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
82
+ if 'num_heads' not in attn:
83
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
84
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
85
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
86
+ attn['window_size'] = attn.get('window_size', None)
87
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
88
+
89
+ super().__init__(
90
+ pad_token_id=pad_token_id,
91
+ bos_token_id=bos_token_id,
92
+ eos_token_id=eos_token_id,
93
+ tie_word_embeddings=tie_word_embeddings,
94
+ **kwargs,
95
+ )
fla/models/gla/modeling_gla.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gla import GatedLinearAttention
20
+ from fla.models.gla.configuration_gla import GLAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GLAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GLABlock(nn.Module):
33
+ def __init__(self, config: GLAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedLinearAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ num_kv_heads=config.num_kv_heads,
59
+ feature_map=config.feature_map,
60
+ use_short_conv=config.use_short_conv,
61
+ conv_size=config.conv_size,
62
+ use_output_gate=config.use_output_gate,
63
+ gate_fn=config.hidden_act,
64
+ elementwise_affine=config.elementwise_affine,
65
+ norm_eps=config.norm_eps,
66
+ clamp_min=config.clamp_min,
67
+ fuse_norm=config.fuse_norm,
68
+ layer_idx=layer_idx
69
+ )
70
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
71
+ self.mlp = GLAMLP(
72
+ hidden_size=config.hidden_size,
73
+ hidden_ratio=config.hidden_ratio,
74
+ intermediate_size=config.intermediate_size,
75
+ hidden_act=config.hidden_act,
76
+ fuse_swiglu=config.fuse_swiglu
77
+ )
78
+
79
+ def forward(
80
+ self,
81
+ hidden_states: torch.Tensor,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
84
+ use_cache: Optional[bool] = False,
85
+ output_attentions: Optional[bool] = False,
86
+ **kwargs: Unpack[Dict]
87
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
88
+ residual = hidden_states
89
+ hidden_states = self.attn_norm(hidden_states)
90
+ hidden_states, attentions, past_key_values = self.attn(
91
+ hidden_states=hidden_states,
92
+ attention_mask=attention_mask,
93
+ past_key_values=past_key_values,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ **kwargs
97
+ )
98
+ if self.config.fuse_norm:
99
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
100
+ else:
101
+ hidden_states = residual + hidden_states
102
+ residual = hidden_states
103
+ hidden_states = self.mlp_norm(hidden_states)
104
+ hidden_states = self.mlp(hidden_states, **kwargs)
105
+ hidden_states = residual + hidden_states
106
+
107
+ outputs = (hidden_states, attentions, past_key_values)
108
+
109
+ return outputs
110
+
111
+
112
+ class GLAPreTrainedModel(PreTrainedModel):
113
+
114
+ config_class = GLAConfig
115
+ base_model_prefix = 'model'
116
+ supports_gradient_checkpointing = True
117
+ _no_split_modules = ['GLABlock']
118
+ _supports_cache_class = True
119
+
120
+ def __init__(self, *inputs, **kwargs):
121
+ super().__init__(*inputs, **kwargs)
122
+
123
+ def _init_weights(
124
+ self,
125
+ module: nn.Module,
126
+ prenorm_residual_strategy: Optional[str] = 'rescale',
127
+ num_residuals_per_layer: int = 2,
128
+ ):
129
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
130
+ # Slightly different from the TF version which uses truncated_normal for initialization
131
+ # cf https://github.com/pytorch/pytorch/pull/5617
132
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
133
+ if module.bias is not None:
134
+ nn.init.zeros_(module.bias)
135
+ elif isinstance(module, nn.Embedding):
136
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
137
+ elif hasattr(module, 'reset_parameters'):
138
+ module.reset_parameters()
139
+
140
+ if prenorm_residual_strategy is not None:
141
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
142
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
143
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
144
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
145
+ #
146
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
147
+ p = None
148
+ if hasattr(module, 'o_proj'):
149
+ p = module.o_proj.weight
150
+ elif hasattr(module, 'down_proj'):
151
+ p = module.down_proj.weight
152
+ if p is not None:
153
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
154
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
155
+ # We need to reinit p since this code could be called multiple times
156
+ # Having just p *= scale would repeatedly scale it down
157
+ if prenorm_residual_strategy == 'rescale':
158
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
159
+ with torch.no_grad():
160
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
161
+ elif prenorm_residual_strategy == 'zero':
162
+ nn.init.zeros_(p)
163
+ else:
164
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
165
+
166
+
167
+ class GLAModel(GLAPreTrainedModel):
168
+
169
+ def __init__(self, config: GLAConfig):
170
+ super().__init__(config)
171
+ self.padding_idx = config.pad_token_id
172
+ self.vocab_size = config.vocab_size
173
+
174
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
175
+ self.layers = nn.ModuleList([GLABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
176
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
177
+
178
+ self.gradient_checkpointing = False
179
+
180
+ self.post_init()
181
+
182
+ def get_input_embeddings(self):
183
+ return self.embeddings
184
+
185
+ def set_input_embeddings(self, value):
186
+ self.embeddings = value
187
+
188
+ def forward(
189
+ self,
190
+ input_ids: Optional[torch.LongTensor] = None,
191
+ attention_mask: Optional[torch.Tensor] = None, # noqa
192
+ inputs_embeds: Optional[torch.FloatTensor] = None,
193
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
194
+ use_cache: Optional[bool] = None,
195
+ output_attentions: Optional[bool] = None,
196
+ output_hidden_states: Optional[bool] = None,
197
+ return_dict: Optional[bool] = None,
198
+ **kwargs: Unpack[Dict]
199
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
200
+ if output_attentions:
201
+ warnings.warn("`GLAModel` does not `output_attentions` now, setting it to `False`.")
202
+ output_attentions = False
203
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
204
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ # retrieve input_ids and inputs_embeds
209
+ if input_ids is not None and inputs_embeds is not None:
210
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
211
+ if input_ids is None and inputs_embeds is None:
212
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
213
+
214
+ if inputs_embeds is None:
215
+ inputs_embeds = self.embeddings(input_ids)
216
+ hidden_states = inputs_embeds
217
+
218
+ if use_cache and not isinstance(past_key_values, Cache):
219
+ past_key_values = Cache.from_legacy_cache(past_key_values)
220
+
221
+ if self.gradient_checkpointing and self.training and use_cache:
222
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
223
+ use_cache = False
224
+
225
+ all_hidden_states = () if output_hidden_states else None
226
+ all_attns = () if output_attentions else None
227
+ for layer in self.layers:
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ if self.gradient_checkpointing and self.training:
232
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
233
+ layer.__call__,
234
+ hidden_states,
235
+ attention_mask,
236
+ past_key_values,
237
+ use_cache,
238
+ output_attentions,
239
+ **kwargs
240
+ )
241
+ else:
242
+ hidden_states, attentions, past_key_values = layer(
243
+ hidden_states,
244
+ attention_mask=attention_mask,
245
+ past_key_values=past_key_values,
246
+ use_cache=use_cache,
247
+ output_attentions=output_attentions,
248
+ **kwargs
249
+ )
250
+
251
+ if output_attentions:
252
+ all_attns += (attentions,)
253
+
254
+ hidden_states = self.norm(hidden_states)
255
+
256
+ # add hidden states from the last decoder layer
257
+ if output_hidden_states:
258
+ all_hidden_states += (hidden_states,)
259
+
260
+ if not return_dict:
261
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
262
+ return BaseModelOutputWithPast(
263
+ last_hidden_state=hidden_states,
264
+ past_key_values=past_key_values,
265
+ hidden_states=all_hidden_states,
266
+ attentions=all_attns
267
+ )
268
+
269
+
270
+ class GLAForCausalLM(GLAPreTrainedModel, GenerationMixin):
271
+
272
+ _tied_weights_keys = ["lm_head.weight"]
273
+
274
+ def __init__(self, config):
275
+ super().__init__(config)
276
+ self.model = GLAModel(config)
277
+ self.vocab_size = config.vocab_size
278
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
279
+ self.criterion = None
280
+
281
+ # Initialize weights and apply final processing
282
+ self.post_init()
283
+
284
+ def get_input_embeddings(self):
285
+ return self.model.embeddings
286
+
287
+ def set_input_embeddings(self, value):
288
+ self.model.embeddings = value
289
+
290
+ def get_output_embeddings(self):
291
+ return self.lm_head
292
+
293
+ def set_output_embeddings(self, new_embeddings):
294
+ self.lm_head = new_embeddings
295
+
296
+ def set_decoder(self, decoder):
297
+ self.model = decoder
298
+
299
+ def get_decoder(self):
300
+ return self.model
301
+
302
+ def generate(self, *args, **kwargs):
303
+ try:
304
+ return super().generate(*args, **kwargs)
305
+ except AttributeError as exception:
306
+ if 'past_key_values' in str(exception):
307
+ raise AttributeError(
308
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
309
+ f"which is not supported for {self.__class__.__name__}. "
310
+ f"Try another generation strategy instead. "
311
+ f"For the available generation strategies, check this doc: "
312
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
313
+ )
314
+ else:
315
+ raise exception
316
+
317
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
318
+ def prepare_inputs_for_generation(
319
+ self,
320
+ input_ids: torch.LongTensor = None,
321
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
322
+ attention_mask: Optional[torch.Tensor] = None,
323
+ inputs_embeds: Optional[torch.Tensor] = None,
324
+ use_cache: bool = True,
325
+ logits_to_keep: Optional[int] = None,
326
+ **kwargs
327
+ ):
328
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
329
+ if past_key_values is not None and len(past_key_values) > 0:
330
+ input_ids = input_ids[:, -1:]
331
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
332
+ if inputs_embeds is not None and len(past_key_values) == 0:
333
+ model_inputs = {'inputs_embeds': inputs_embeds}
334
+ else:
335
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
336
+ # recompiles graphs as the stride of the inputs is a guard.
337
+ # Ref: https://github.com/huggingface/transformers/pull/29114
338
+ # TODO: use `next_tokens` directly instead.
339
+ model_inputs = {'input_ids': input_ids.contiguous()}
340
+
341
+ if logits_to_keep is not None:
342
+ model_inputs['logits_to_keep'] = logits_to_keep
343
+
344
+ model_inputs.update({
345
+ 'past_key_values': past_key_values,
346
+ 'use_cache': use_cache,
347
+ 'attention_mask': attention_mask,
348
+ })
349
+ return model_inputs
350
+
351
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
352
+ def forward(
353
+ self,
354
+ input_ids: torch.LongTensor = None,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
358
+ labels: Optional[torch.LongTensor] = None,
359
+ use_cache: Optional[bool] = None,
360
+ output_attentions: Optional[bool] = None,
361
+ output_hidden_states: Optional[bool] = None,
362
+ return_dict: Optional[bool] = None,
363
+ logits_to_keep: Optional[int] = 0,
364
+ **kwargs: Unpack[Dict]
365
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
366
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
367
+ output_hidden_states = (
368
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
369
+ )
370
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
371
+
372
+ outputs = self.model(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ inputs_embeds=inputs_embeds,
376
+ past_key_values=past_key_values,
377
+ use_cache=use_cache,
378
+ output_attentions=output_attentions,
379
+ output_hidden_states=output_hidden_states,
380
+ return_dict=return_dict,
381
+ **kwargs
382
+ )
383
+
384
+ hidden_states = outputs[0]
385
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
386
+
387
+ loss, logits = None, None
388
+ if not fuse_linear_and_cross_entropy or labels is None:
389
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
390
+ if labels is not None:
391
+ if getattr(self, 'criterion', None) is None:
392
+ if fuse_linear_and_cross_entropy:
393
+ criterion = FusedLinearCrossEntropyLoss()
394
+ elif self.config.fuse_cross_entropy:
395
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
396
+ else:
397
+ criterion = nn.CrossEntropyLoss()
398
+ else:
399
+ criterion = self.criterion
400
+ labels = labels.to(hidden_states.device)
401
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
402
+ if fuse_linear_and_cross_entropy:
403
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
404
+ else:
405
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
406
+
407
+ if not return_dict:
408
+ output = (logits,) + outputs[1:]
409
+ return (loss,) + output if loss is not None else output
410
+
411
+ return CausalLMOutputWithPast(
412
+ loss=loss,
413
+ logits=logits,
414
+ past_key_values=outputs.past_key_values,
415
+ hidden_states=outputs.hidden_states,
416
+ attentions=outputs.attentions,
417
+ )
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/__pycache__/modeling_gsa.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
fla/models/hgrn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
6
+ from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
7
+
8
+ AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
9
+ AutoModel.register(HGRNConfig, HGRNModel)
10
+ AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
fla/models/hgrn/__pycache__/configuration_hgrn.cpython-311.pyc ADDED
Binary file (3.67 kB). View file
 
fla/models/hgrn/__pycache__/modeling_hgrn.cpython-311.pyc ADDED
Binary file (19.7 kB). View file
 
fla/models/hgrn2/__pycache__/configuration_hgrn2.cpython-311.pyc ADDED
Binary file (3.97 kB). View file
 
fla/models/lightnet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
6
+ from fla.models.lightnet.modeling_lightnet import LightNetForCausalLM, LightNetModel
7
+
8
+ AutoConfig.register(LightNetConfig.model_type, LightNetConfig)
9
+ AutoModel.register(LightNetConfig, LightNetModel)
10
+ AutoModelForCausalLM.register(LightNetConfig, LightNetForCausalLM)
11
+
12
+
13
+ __all__ = ['LightNetConfig', 'LightNetForCausalLM', 'LightNetModel']
fla/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla/models/linear_attn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.linear_attn.configuration_linear_attn import LinearAttentionConfig
6
+ from fla.models.linear_attn.modeling_linear_attn import LinearAttentionForCausalLM, LinearAttentionModel
7
+
8
+ AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
9
+ AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
10
+ AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
11
+
12
+ __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
fla/models/linear_attn/__pycache__/modeling_linear_attn.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
fla/models/mamba/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (777 Bytes). View file
 
fla/models/mamba/configuration_mamba.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MAMBA configuration"""
16
+
17
+ import math
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+
21
+
22
+ class MambaConfig(PretrainedConfig):
23
+ """
24
+ This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
25
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
26
+ defaults will yield a similar configuration to that of the MAMBA
27
+ [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*):
35
+ Vocabulary size of the Mamba model.
36
+ hidden_size (`int`, *optional*):
37
+ Dimensionality of the embeddings and hidden states. Default: 2048.
38
+ state_size (`int`, *optional*):
39
+ Shape of the state space latents. Default: 16.
40
+ num_hidden_layers (`int`, *optional*):
41
+ Number of hidden layers in the model. Default: 48.
42
+ layer_norm_epsilon (`float`, *optional*):
43
+ The epsilon to use in the layer normalization layers. Default: 1e-5.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id. Default: 0.
46
+ bos_token_id (`int`, *optional*):
47
+ The id of the beginning of sentence token in the vocabulary. Default: 0.
48
+ eos_token_id (`int`, *optional*):
49
+ The id of the end of sentence token in the vocabulary. Default: 0.
50
+ expand (`int`, *optional*):
51
+ Expanding factor used to determine the intermediate size. Default: 2.
52
+ conv_kernel (`int`, *optional*):
53
+ Size of the convolution kernel. Default: 4.
54
+ use_bias (`bool`, *optional*):
55
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block. Default: `False`.
56
+ use_conv_bias (`bool`, *optional*):
57
+ Whether or not to use bias in the convolution layer of the mixer block. Default: `True`.
58
+ hidden_act (`str`, *optional*):
59
+ The non-linear activation function (function or string) in the decoder. Default: `"silu"`.
60
+ initializer_range (`float`, *optional*):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Default: 0.1.
62
+ residual_in_fp32 (`bool`, *optional*):
63
+ Whether or not residuals should be in `float32`.
64
+ If set to `False` residuals will keep the same `dtype` as the rest of the model. Default: `True`.
65
+ time_step_rank (`Union[int,str]`, *optional*):
66
+ Rank of the the discretization projection matrix.
67
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`. Default: `"auto"`.
68
+ time_step_scale (`float`, *optional*):
69
+ Scale used used to scale `dt_proj.bias`. Default: 1.0.
70
+ time_step_min (`float`, *optional*):
71
+ Minimum `time_step` used to bound `dt_proj.bias`. Default: 0.001.
72
+ time_step_max (`float`, *optional*):
73
+ Maximum `time_step` used to bound `dt_proj.bias`. Default: 0.1.
74
+ time_step_init_scheme (`float`, *optional*):
75
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`. Default: `"random"`.
76
+ time_step_floor (`float`, *optional*):
77
+ Minimum clamping value of the `dt_proj.bias` layer initialization. Default: 0.0001.
78
+ window_size (`int`, *optional*):
79
+ The window size used for sliding window attention. Default: 2048.
80
+ rescale_prenorm_residual (`bool`, *optional*):
81
+ Whether or not to rescale `out_proj` weights when initializing. Default: `False`.
82
+ use_cache (`bool`, *optional*):
83
+ Whether or not the cache should be used. Default: `True`.
84
+
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import MambaConfig, MambaModel
90
+
91
+ >>> # Initializing a Mamba configuration
92
+ >>> configuration = MambaConfig()
93
+
94
+ >>> # Initializing a model (with random weights) from the configuration
95
+ >>> model = MambaModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "mamba"
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size: int = 32000,
106
+ hidden_size: int = 2048,
107
+ state_size: int = 16,
108
+ num_hidden_layers: int = 48,
109
+ layer_norm_epsilon=1e-5,
110
+ pad_token_id: int = 0,
111
+ bos_token_id: int = 1,
112
+ eos_token_id: int = 2,
113
+ expand: int = 2,
114
+ conv_kernel: int = 4,
115
+ use_bias: bool = False,
116
+ use_conv_bias: bool = True,
117
+ hidden_act: str = "silu",
118
+ initializer_range: str = 0.1,
119
+ residual_in_fp32: bool = False,
120
+ time_step_rank: str = "auto",
121
+ time_step_scale: float = 1.0,
122
+ time_step_min: float = 0.001,
123
+ time_step_max: float = 0.1,
124
+ time_step_init_scheme: str = "random",
125
+ time_step_floor: float = 1e-4,
126
+ rescale_prenorm_residual: bool = False,
127
+ use_cache: bool = True,
128
+ fuse_norm: bool = True,
129
+ fuse_cross_entropy: bool = True,
130
+ tie_word_embeddings: bool = False,
131
+ **kwargs,
132
+ ):
133
+ self.vocab_size = vocab_size
134
+ self.hidden_size = hidden_size
135
+ self.state_size = state_size
136
+ self.num_hidden_layers = num_hidden_layers
137
+ self.layer_norm_epsilon = layer_norm_epsilon
138
+ self.conv_kernel = conv_kernel
139
+ self.expand = expand
140
+ self.intermediate_size = int(expand * self.hidden_size)
141
+ self.bos_token_id = bos_token_id
142
+ self.eos_token_id = eos_token_id
143
+ self.pad_token_id = pad_token_id
144
+ self.use_bias = use_bias
145
+ self.use_conv_bias = use_conv_bias
146
+ self.hidden_act = hidden_act
147
+ self.initializer_range = initializer_range
148
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
149
+ self.time_step_scale = time_step_scale
150
+ self.time_step_min = time_step_min
151
+ self.time_step_max = time_step_max
152
+ self.time_step_init_scheme = time_step_init_scheme
153
+ self.time_step_floor = time_step_floor
154
+ self.rescale_prenorm_residual = rescale_prenorm_residual
155
+ self.residual_in_fp32 = residual_in_fp32
156
+ self.use_cache = use_cache
157
+ self.fuse_norm = fuse_norm
158
+ self.fuse_cross_entropy = fuse_cross_entropy
159
+
160
+ super().__init__(
161
+ bos_token_id=bos_token_id,
162
+ eos_token_id=eos_token_id,
163
+ pad_token_id=pad_token_id,
164
+ tie_word_embeddings=tie_word_embeddings,
165
+ **kwargs
166
+ )
fla/models/mamba2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.mamba2.configuration_mamba2 import Mamba2Config
6
+ from fla.models.mamba2.modeling_mamba2 import Mamba2ForCausalLM, Mamba2Model
7
+
8
+ AutoConfig.register(Mamba2Config.model_type, Mamba2Config, True)
9
+ AutoModel.register(Mamba2Config, Mamba2Model, True)
10
+ AutoModelForCausalLM.register(Mamba2Config, Mamba2ForCausalLM, True)
11
+
12
+
13
+ __all__ = ['Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model']
fla/models/mamba2/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (752 Bytes). View file
 
fla/models/retnet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (739 Bytes). View file
 
fla/models/samba/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (777 Bytes). View file
 
flame/config_manager.py ADDED
@@ -0,0 +1,940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import sys
9
+ from collections import defaultdict
10
+ from typing import Tuple
11
+
12
+ import torch
13
+
14
+ try:
15
+ import tomllib
16
+ except ModuleNotFoundError:
17
+ import tomli as tomllib
18
+
19
+ from torchtitan.tools.logging import logger
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def string_list(raw_arg):
29
+ """Comma-separated string list argument."""
30
+ return [s.strip() for s in raw_arg.split(",") if s.strip()]
31
+
32
+
33
+ def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34
+ section, name = fullargname.split(".")
35
+ # Split string list which are still raw strings.
36
+ if (
37
+ section in args_dict
38
+ and name in args_dict[section]
39
+ and isinstance(args_dict[section][name], str)
40
+ ):
41
+ sec = args_dict[section]
42
+ sec[name] = string_list(sec[name])
43
+
44
+
45
+ class JobConfig:
46
+ """
47
+ A helper class to manage the train configuration.
48
+ Semantics:
49
+ - Default config is loaded from a toml file. If no toml file is provided,
50
+ then the default config is loaded from argparse defaults.
51
+ - if toml file has missing keys, they are filled with argparse defaults.
52
+ - if additional explicit cmd args are provided in addition to the toml
53
+ file, they will override the toml config and the argparse defaults
54
+
55
+ precedence order: cmdline > toml > argparse default
56
+
57
+ Arg parsing semantics:
58
+
59
+ Each argument starts with <prefix>_ which is the section name in the toml file
60
+ followed by name of the option in the toml file. For ex,
61
+ model.name translates to:
62
+ [model]
63
+ name
64
+ in the toml file
65
+ """
66
+
67
+ def __init__(self):
68
+ self.args_dict = None
69
+ # main parser
70
+ self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")
71
+
72
+ self.parser.add_argument(
73
+ "--job.config_file",
74
+ type=str,
75
+ default=None,
76
+ help="Job config file",
77
+ )
78
+
79
+ # job level configs
80
+ self.parser.add_argument(
81
+ "--job.dump_folder",
82
+ type=str,
83
+ default="./torchtitan/outputs",
84
+ help="Folder to dump job outputs",
85
+ )
86
+ self.parser.add_argument(
87
+ "--job.description",
88
+ type=str,
89
+ default="default job",
90
+ help="Description of the job",
91
+ )
92
+ self.parser.add_argument(
93
+ "--job.use_for_integration_test",
94
+ action="store_true",
95
+ help="Add this config to the integration test suite",
96
+ )
97
+ self.parser.add_argument(
98
+ "--job.print_args",
99
+ action="store_true",
100
+ help="Print the args to terminal",
101
+ )
102
+
103
+ # model configs
104
+ self.parser.add_argument(
105
+ "--model.name",
106
+ type=str,
107
+ default="fla",
108
+ help="Which model to train",
109
+ )
110
+ self.parser.add_argument(
111
+ "--model.config",
112
+ type=str,
113
+ default="fla-hub/transformer-1.3B-100B",
114
+ help="Path to the model config",
115
+ )
116
+ self.parser.add_argument(
117
+ "--model.tokenizer_path",
118
+ type=str,
119
+ default="fla-hub/transformer-1.3B-100B",
120
+ help="Tokenizer path",
121
+ )
122
+ self.parser.add_argument(
123
+ "--model.converters",
124
+ type=string_list,
125
+ nargs="+",
126
+ default=[],
127
+ help="""
128
+ Comma separated list of converters to apply to the model.
129
+ For instance, the `float8` converter swaps `torch.nn.Linear`
130
+ with `Float8Linear`. This feature requires you to install 'torchao'
131
+ which can be found here: https://github.com/pytorch/ao
132
+ """,
133
+ )
134
+ self.parser.add_argument(
135
+ "--model.print_after_conversion",
136
+ action="store_true",
137
+ help="""
138
+ If true, model definition will be printed to stdout after all model
139
+ converters have been applied.
140
+ """,
141
+ )
142
+
143
+ # profiling configs
144
+ self.parser.add_argument(
145
+ "--profiling.enable_profiling",
146
+ action="store_true",
147
+ help="Whether to enable pytorch profiler",
148
+ )
149
+ self.parser.add_argument(
150
+ "--profiling.save_traces_folder",
151
+ type=str,
152
+ default="profile_traces",
153
+ help="Trace files location",
154
+ )
155
+ self.parser.add_argument(
156
+ "--profiling.profile_freq",
157
+ type=int,
158
+ default=10,
159
+ help="How often to collect profiler traces, in iterations",
160
+ )
161
+ self.parser.add_argument(
162
+ "--profiling.enable_memory_snapshot",
163
+ action="store_true",
164
+ help="Whether to dump memory snapshot",
165
+ )
166
+ self.parser.add_argument(
167
+ "--profiling.save_memory_snapshot_folder",
168
+ type=str,
169
+ default="memory_snapshot",
170
+ help="Memeory snapshot files location",
171
+ )
172
+
173
+ # optimizer configs
174
+ self.parser.add_argument(
175
+ "--optimizer.name", type=str, default="AdamW", help="Optimizer to use"
176
+ )
177
+ self.parser.add_argument(
178
+ "--optimizer.eps",
179
+ type=float,
180
+ default=1e-8,
181
+ help="Epsilon value for the optimizer.",
182
+ )
183
+ self.parser.add_argument(
184
+ "--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185
+ )
186
+ self.parser.add_argument(
187
+ "--optimizer.implementation",
188
+ type=str,
189
+ default="fused",
190
+ choices=["for-loop", "foreach", "fused"],
191
+ help="""
192
+ Specify which optimizer implementation to use:
193
+ - 'fused': Use fused implementation (CUDA only) for best performance.
194
+ - 'foreach': Use some horizontal fusion of tensors for better performance.
195
+ - 'for-loop': Use the default implementation for the optimizer (slowest).
196
+ - more info: https://pytorch.org/docs/stable/optim.html
197
+ """,
198
+ )
199
+ self.parser.add_argument(
200
+ "--optimizer.early_step_in_backward",
201
+ action="store_true",
202
+ help="""
203
+ Whether to apply optimizer in the backward. Caution, optimizer_in_backward
204
+ is not compatible with gradients clipping, users should not call
205
+ register_post_accumulate_grad_hook after the optimizer is built.""",
206
+ )
207
+
208
+ # lr scheduler configs
209
+ self.parser.add_argument(
210
+ "--lr_scheduler.warmup_steps",
211
+ type=int,
212
+ default=200,
213
+ help="Steps for lr scheduler warmup, normally 1/5 of --training.steps",
214
+ )
215
+ self.parser.add_argument(
216
+ "--lr_scheduler.decay_ratio",
217
+ type=float,
218
+ default=None,
219
+ help="""
220
+ Controls the proportion of the training steps allocated to the learning rate decay phase.
221
+
222
+ If `None`, the learning rate will begin decaying immediately after the warmup period.
223
+ Otherwise, the learning rate will remain stable after the warmup period and
224
+ only start decaying during the last `decay_ratio` portion of the total training steps.
225
+
226
+ This is known as the Warmup-Stable-Decay (WSD) schedule, as described in https://arxiv.org/abs/2404.06395.
227
+ """,
228
+ )
229
+ self.parser.add_argument(
230
+ "--lr_scheduler.decay_type",
231
+ type=str,
232
+ default="linear",
233
+ choices=["linear", "sqrt", "cosine"],
234
+ help="""
235
+ Learning rate decay type to use during training:
236
+ - 'linear': linearly decays learning rate from initial to final value
237
+ - 'sqrt': decays learning rate following a 1 minus square root curve
238
+ - 'cosine': smoothly decays learning rate following a cosine curve
239
+ """,
240
+ )
241
+ self.parser.add_argument(
242
+ "--lr_scheduler.lr_min",
243
+ type=float,
244
+ default=0.0,
245
+ help="""
246
+ Min lr ratio for lr scheduler.
247
+
248
+ If provided, the range of decay factor is scaled from 1 to `lr_min`
249
+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
250
+ """,
251
+ )
252
+
253
+ # training configs
254
+ self.parser.add_argument(
255
+ "--training.batch_size", type=int, default=8, help="Batch size"
256
+ )
257
+ self.parser.add_argument(
258
+ "--training.seq_len", type=int, default=2048, help="Sequence length"
259
+ )
260
+ self.parser.add_argument(
261
+ "--training.context_len",
262
+ type=int,
263
+ default=2048,
264
+ help="Max length allowed for each sequence",
265
+ )
266
+ self.parser.add_argument(
267
+ "--training.varlen",
268
+ action="store_true",
269
+ help="Whether to take sequences of variable length as input",
270
+ )
271
+ self.parser.add_argument(
272
+ "--training.gradient_accumulation_steps",
273
+ type=int,
274
+ default=1,
275
+ help="Number of steps to accumulate gradients before updating parameters",
276
+ )
277
+ self.parser.add_argument(
278
+ "--training.steps",
279
+ type=int,
280
+ default=10000,
281
+ help="How many train steps to run",
282
+ )
283
+ self.parser.add_argument(
284
+ "--training.max_norm",
285
+ type=float,
286
+ default=1.0,
287
+ help="Max norm for gradient clipping",
288
+ )
289
+ self.parser.add_argument(
290
+ "--training.skip_nan_inf",
291
+ action="store_true",
292
+ help="Skip batch updates when NaN or INF gradients are encountered during training",
293
+ )
294
+ self.parser.add_argument(
295
+ "--training.dataset",
296
+ default="HuggingFaceFW/fineweb-edu",
297
+ help="Dataset to use, with comma separated values",
298
+ )
299
+ self.parser.add_argument(
300
+ "--training.dataset_name",
301
+ default=None,
302
+ help="The name of the dataset config, with comma separated values if provided",
303
+ )
304
+ self.parser.add_argument(
305
+ "--training.dataset_split",
306
+ default=None,
307
+ help="Dataset split to use, with comma separated values if provided",
308
+ )
309
+ self.parser.add_argument(
310
+ "--training.data_dir",
311
+ default=None,
312
+ help="Data dirs to use, with comma separated values if provided",
313
+ )
314
+ self.parser.add_argument(
315
+ "--training.data_files",
316
+ default=None,
317
+ help="Data files to use, with comma separated values if provided",
318
+ )
319
+ self.parser.add_argument(
320
+ "--training.data_probs",
321
+ default=None,
322
+ help="Data sampling probabilities, with comma separated values if provided",
323
+ )
324
+ self.parser.add_argument(
325
+ "--training.streaming",
326
+ action="store_true",
327
+ help="Whether to load dataset in streaming mode, used for huge dataset",
328
+ )
329
+ self.parser.add_argument(
330
+ "--training.num_workers",
331
+ type=int,
332
+ default=32,
333
+ help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
334
+ )
335
+ self.parser.add_argument(
336
+ "--training.prefetch_factor",
337
+ type=int,
338
+ default=2,
339
+ help="Number of batches loaded in advance by each worker."
340
+ "2 means there will be a total of 2 * num_workers batches prefetched across all workers.",
341
+ )
342
+ self.parser.add_argument(
343
+ "--training.data_parallel_replicate_degree",
344
+ type=int,
345
+ default=1,
346
+ help="""
347
+ The `data_parallel_replicate_degree` argument specifies the degree of
348
+ data parallelism for weight replication. When this value is greater
349
+ than 1, weights will be replicated across `data_parallel_replicate_degree`
350
+ ranks. If `data_parallel_shard_degree` is also greater than 1, the parallelism
351
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
352
+ parallelism method used is DDP (Distributed Data Parallelism).
353
+ 1 means disabled.""",
354
+ )
355
+ self.parser.add_argument(
356
+ "--training.data_parallel_shard_degree",
357
+ type=int,
358
+ default=-1,
359
+ help="""
360
+ The `data_parallel_shard_degree` argument specifies the degree of data
361
+ parallelism for weight sharding. When this value is greater than 1, weights
362
+ will be sharded across `data_parallel_shard_degree` ranks. If
363
+ `data_parallel_replicate_degree` is also greater than 1, the parallelism
364
+ method used is HSDP (Hybrid Sharded Data Parallelism). Otherwise, the
365
+ parallelism method used is FSDP (Fully Sharded Data Parallelism).
366
+
367
+ -1 means leftover ranks will be used (After DP_REPLICATE/SP/PP). Note that
368
+ only `data_parallel_shard_degree` can be negative. 1 means disabled.""",
369
+ )
370
+ self.parser.add_argument(
371
+ "--training.enable_cpu_offload",
372
+ action="store_true",
373
+ help="""
374
+ Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
375
+ )
376
+ self.parser.add_argument(
377
+ "--training.tensor_parallel_degree",
378
+ type=int,
379
+ default=1,
380
+ help="Tensor Parallelism degree. 1 means disabled.",
381
+ )
382
+ self.parser.add_argument(
383
+ "--training.disable_loss_parallel",
384
+ action="store_true",
385
+ help="Whether to apply loss parallel when sequence parallel is enabled",
386
+ )
387
+ self.parser.add_argument(
388
+ "--training.fsdp_reshard_after_forward",
389
+ type=str,
390
+ default="default",
391
+ choices=["default", "always", "never"],
392
+ help="""
393
+ `reshard_after_forward` specifies the policy for applying `reshard_after_forward`
394
+ within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward,
395
+ trading off memory and communication. See torch's `fully_shard` API for more documentation
396
+ on `reshard_after_forward`.
397
+ The supported policies include "default", "always" and "never":
398
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal
399
+ scenarios.
400
+ - "always" will enable `reshard_after_forward` for all forward passes.
401
+ - "never" will disable `reshard_after_forward` for all forward passes.
402
+ """,
403
+ )
404
+ self.parser.add_argument(
405
+ "--training.mixed_precision_param",
406
+ type=str,
407
+ default="bfloat16",
408
+ choices=["bfloat16", "float32"],
409
+ help="""
410
+ torch dtype to use for parameters when applying mixed precision via FSDP.
411
+ This feature only takes effect when data_parallel_shard_degree > 1
412
+ """,
413
+ )
414
+ self.parser.add_argument(
415
+ "--training.mixed_precision_reduce",
416
+ type=str,
417
+ default="float32",
418
+ choices=["float32"],
419
+ help="""
420
+ torch dtype to use for reductions when applying mixed precision via FSDP.
421
+ This feature only takes effect when data_parallel_shard_degree > 1
422
+ """,
423
+ )
424
+ self.parser.add_argument(
425
+ "--training.compile",
426
+ action="store_true",
427
+ help="Whether to compile the model",
428
+ )
429
+ self.parser.add_argument(
430
+ "--training.gc_freq",
431
+ type=int,
432
+ default=50,
433
+ help="Python garbage control scheduling interval, in steps",
434
+ )
435
+ self.parser.add_argument(
436
+ "--training.seed",
437
+ type=int,
438
+ default=42,
439
+ help="Choose the base RNG seed used for training",
440
+ )
441
+ self.parser.add_argument(
442
+ "--training.deterministic",
443
+ action="store_true",
444
+ help="Use deterministic algorithms wherever possible, may be slower",
445
+ )
446
+ # metrics configs
447
+ self.parser.add_argument(
448
+ "--metrics.log_freq",
449
+ type=int,
450
+ default=10,
451
+ help="How often to log metrics to TensorBoard, in iterations",
452
+ )
453
+ self.parser.add_argument(
454
+ "--metrics.enable_tensorboard",
455
+ action="store_true",
456
+ help="Whether to log metrics to TensorBoard",
457
+ )
458
+ self.parser.add_argument(
459
+ "--metrics.disable_color_printing",
460
+ action="store_true",
461
+ help="Whether to disable color printing in logs",
462
+ )
463
+ self.parser.add_argument(
464
+ "--metrics.save_tb_folder",
465
+ type=str,
466
+ default="tb",
467
+ help="Folder to dump TensorBoard states",
468
+ )
469
+ self.parser.add_argument(
470
+ "--metrics.save_for_all_ranks",
471
+ action="store_true",
472
+ default=False,
473
+ help="""
474
+ Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
475
+ When this option is False and pipeline_parallel_degree is > 1, the metrics
476
+ component uses the 0th rank of the last stage pipeline group, which is the
477
+ only stage that computes loss metrics.
478
+ """,
479
+ )
480
+ self.parser.add_argument(
481
+ "--metrics.enable_wandb",
482
+ action="store_true",
483
+ help="Whether to log metrics to Weights & Biases",
484
+ )
485
+
486
+ self.parser.add_argument(
487
+ "--experimental.enable_async_tensor_parallel",
488
+ action="store_true",
489
+ help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
490
+ )
491
+ self.parser.add_argument(
492
+ "--experimental.pipeline_parallel_degree",
493
+ type=int,
494
+ default=1,
495
+ help="""
496
+ Pipeline Parallelism degree, or number of ranks. 1 means disabled.
497
+ If using looped schedules, this still specifies the number of physical ranks, not the number
498
+ of stages. Stages per rank are inferred from split points degree, and schedule.""",
499
+ )
500
+ self.parser.add_argument(
501
+ "--experimental.pipeline_parallel_split_points",
502
+ type=string_list,
503
+ nargs="+",
504
+ default=[],
505
+ help="""
506
+ Specify comma-separated names of modules to use as the beginning of a split point.
507
+
508
+ e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
509
+ the first containing all the layers up to layers.0,
510
+ the second containing layers.0 and up to layers.2,
511
+ the third containing layers.2 and all the remaining layers.
512
+
513
+ Note: fully-automated splitting may be enabled in the future,
514
+ but currently the split points must be specified manually.""",
515
+ )
516
+ self.parser.add_argument(
517
+ "--experimental.pipeline_parallel_schedule",
518
+ type=str,
519
+ default="1F1B",
520
+ help="""
521
+ Specify the Pipeline Parallel schedule to use. The supported schedules are:
522
+ https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/pipelining/schedules.py#L2161.
523
+ The schedule must be compatible with the split points and stages_per_rank.
524
+
525
+ Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
526
+ and split_points = number of stages - 1
527
+ """,
528
+ )
529
+ self.parser.add_argument(
530
+ "--experimental.pipeline_parallel_schedule_csv",
531
+ type=str,
532
+ default="",
533
+ help="""
534
+ Specify the path to the pipeline parallel schedule csv file to use.
535
+ The pipeline_parallel_schedule argument must be either
536
+ PipelineScheduleSingle, PipelineScheduleMulti, or _PipelineScheduleRuntime.
537
+ """,
538
+ )
539
+
540
+ self.parser.add_argument(
541
+ "--experimental.pipeline_parallel_microbatches",
542
+ type=int,
543
+ default=None,
544
+ help="""
545
+ How many microbatches to split the global training batch into when using pipeline parallelism.
546
+
547
+ The global training batch size must be evenly divisible by the number of microbatches.
548
+
549
+ The default value will be the number of pipeline stages, if unspecified.
550
+ """,
551
+ )
552
+ self.parser.add_argument(
553
+ "--experimental.enable_compiled_autograd",
554
+ action="store_true",
555
+ help="Enable CompiledAutograd to compile the backward.",
556
+ )
557
+ self.parser.add_argument(
558
+ "--experimental.context_parallel_degree",
559
+ type=int,
560
+ default=1,
561
+ help="Context parallelism degree. 1 means disabled.",
562
+ )
563
+ self.parser.add_argument(
564
+ "--experimental.context_parallel_rotate_method",
565
+ type=str,
566
+ default="allgather",
567
+ help="""
568
+ The collective to use in context parallel SDPA for kv shards exchange.
569
+
570
+ 'allgather' means to all-gather all kv shards on ranks after the first sub-SDPA computation,
571
+
572
+ 'alltoall' means to all-to-all shuffle the kv shards.
573
+
574
+ The default value is 'allgather'.
575
+ """,
576
+ )
577
+ # I'm not particularly fond of this. Users can choose to write their own wrapper
578
+ # module and import TorchTitan training loop and execute it, which look cleaner.
579
+ # One reason to provide this option is to allow users to use the existing run script.
580
+ # While the script is pretty trivial now, we may add more logic when integrating
581
+ # with TorchFT.
582
+ # This option is subject to change and may be deleted in the future.
583
+ self.parser.add_argument(
584
+ "--experimental.custom_model_path",
585
+ type=str,
586
+ default="",
587
+ help="""
588
+ The --custom_model_path option allows to specify a custom path to a model module
589
+ that is not natively implemented within TorchTitan.
590
+ Acceptable values are the file system path to the module (e.g., my_models/model_x)
591
+ dotted import module (e.g., some_package.model_x).
592
+ """,
593
+ )
594
+ # checkpointing configs
595
+ self.parser.add_argument(
596
+ "--checkpoint.enable_checkpoint",
597
+ action="store_true",
598
+ help="Whether to enable checkpoint",
599
+ )
600
+ self.parser.add_argument(
601
+ "--checkpoint.folder",
602
+ type=str,
603
+ default="checkpoint",
604
+ help="""
605
+ The folder to store the checkpoints.
606
+ When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
607
+ """,
608
+ )
609
+ self.parser.add_argument(
610
+ "--checkpoint.interval",
611
+ type=int,
612
+ default=500,
613
+ help="Checkpointing interval in steps.",
614
+ )
615
+ self.parser.add_argument(
616
+ "--checkpoint.model_weights_only",
617
+ action="store_true",
618
+ help="""
619
+ When model_weights_only=True, only model weights will be saved at the end of training.
620
+ With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
621
+ When model_weights_only=False, the full checkpoint will be saved.
622
+ A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
623
+ The default value is false.
624
+ """,
625
+ )
626
+ self.parser.add_argument(
627
+ "--checkpoint.export_dtype",
628
+ type=str,
629
+ default="float32",
630
+ choices=["float16", "bfloat16", "float32"],
631
+ help="""
632
+ Converts to the specified precision when training completes and model_weights_only=true.
633
+ Currently supports float32, float16, and bfloat16.
634
+ The default value is float32.
635
+ """,
636
+ )
637
+ self.parser.add_argument(
638
+ "--checkpoint.create_seed_checkpoint",
639
+ action="store_true",
640
+ help="""
641
+ Initializes the full model without applying parallelisms, and then saves it as a seed checkpoint.
642
+ Note: requires user to call train.py without specifying any parallelisms, e.g. NGPU=1.
643
+ Could be implemented as a separate script, but this way shares more code.
644
+ """,
645
+ )
646
+ self.parser.add_argument(
647
+ "--checkpoint.async_mode",
648
+ type=str,
649
+ default="disabled",
650
+ help="""
651
+ Which async checkpoint mode to use. Currently there are 3 different modes.
652
+ 1. "disabled": synchronized checkpointing will be used.
653
+ 2. "async": torch.distributed.checkpoint.async_save will be used.
654
+ 3. "async_with_pinned_mem": this option utilizes a dedicated pinned memory
655
+ space and creates a separate process for faster GPU->CPU transfer
656
+ performance and eliminating GIL contention. The cost is increased CPU
657
+ memory usage. If insufficient CPU memory is available, performance may
658
+ degrade due to memory paging. For most users, "async" should suffice as
659
+ the performance overhead is typically small (on the order of tens of
660
+ seconds) compared to checkpointing frequency. This mode can be employed
661
+ to pursue near-zero checkpointing times (e.g., < 1 second) given
662
+ appropriate hardware support such as ample CPU memory and fast PCIe.
663
+
664
+ "disabled" is the default mode.
665
+ """,
666
+ )
667
+ self.parser.add_argument(
668
+ "--checkpoint.keep_latest_k",
669
+ type=int,
670
+ default=0,
671
+ help="""
672
+ Keeps only the latest k checkpoints, and purging older ones. If 0, keep all checkpoints.
673
+ 0 is the default value. k cannot be 1 as the last one may be in the process of being
674
+ saved. As a result, the metadata of the last one may not be ready yet.
675
+ """,
676
+ )
677
+ self.parser.add_argument(
678
+ "--checkpoint.load_step",
679
+ type=int,
680
+ default=-1,
681
+ help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
682
+ )
683
+ self.parser.add_argument(
684
+ "--checkpoint.exclude_from_loading",
685
+ type=string_list,
686
+ nargs="*",
687
+ default=[],
688
+ help="""
689
+ Exclude specific keys from being loaded from the checkpoint.
690
+ Provide a comma-separated list of keys to exclude, e.g. 'optimizer,lr_scheduler,dataloader'.
691
+ This will load the model only, excluding the specified keys.
692
+ """,
693
+ )
694
+ self.parser.add_argument(
695
+ "--checkpoint.convert_to_hf_on_save",
696
+ action="store_true",
697
+ help="""
698
+ If true, automatically convert the saved DCP checkpoint to Hugging Face format
699
+ in a parallel directory (e.g., step-1000-hf) after each save.
700
+ """,
701
+ )
702
+ self.parser.add_argument(
703
+ "--checkpoint.hf_upload_enabled",
704
+ action="store_true",
705
+ help="Enable uploading converted Hugging Face checkpoints to the Hub.",
706
+ )
707
+ self.parser.add_argument(
708
+ "--checkpoint.hf_repo_base_name",
709
+ type=str,
710
+ default=None,
711
+ help="Hugging Face Hub repository ID to upload checkpoints to (e.g., 'username/repo').",
712
+ )
713
+ self.parser.add_argument(
714
+ "--checkpoint.hf_upload_format",
715
+ type=str,
716
+ default="dcp",
717
+ choices=["dcp", "hf"],
718
+ help="""
719
+ Format to upload to Hugging Face Hub. 'dcp' for DCP format, 'hf' for Hugging Face format.
720
+ Note: 'hf' is only supported for models with a single pipeline stage.
721
+ """,
722
+ )
723
+ # activation checkpointing configs
724
+ self.parser.add_argument(
725
+ "--activation_checkpoint.mode",
726
+ type=str,
727
+ default="selective",
728
+ help="Type of activation checkpointing to use ['none', 'full', 'selective']",
729
+ )
730
+ self.parser.add_argument(
731
+ "--activation_checkpoint.selective_ac_option",
732
+ type=str,
733
+ default="2", # 2 = checkpoint every other layer
734
+ help="""
735
+ Selective activation checkpointing options ['int', 'op'].
736
+ 'int' (e.g., 2) for every nth layer, or 'op' for op level ac.
737
+ """,
738
+ )
739
+
740
+ self.parser.add_argument(
741
+ "--activation_offload.mode",
742
+ type=str,
743
+ default="none",
744
+ help="""
745
+ if we are using activation offload or not. Options are ['none', 'full'].
746
+ """,
747
+ )
748
+
749
+ # float8 configs
750
+ self.parser.add_argument(
751
+ "--float8.enable_fsdp_float8_all_gather",
752
+ action="store_true",
753
+ help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
754
+ )
755
+ self.parser.add_argument(
756
+ "--float8.precompute_float8_dynamic_scale_for_fsdp",
757
+ action="store_true",
758
+ help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
759
+ )
760
+ self.parser.add_argument(
761
+ "--float8.force_recompute_fp8_weight_in_bwd",
762
+ action="store_true",
763
+ help="""
764
+ Whether to force the recomputation of FP8 weights during backward pass.
765
+ When using FSDP with tensorwise scaling, it is recommended to enable
766
+ `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
767
+ for backward computation.
768
+ """,
769
+ )
770
+ self.parser.add_argument(
771
+ "--float8.recipe_name",
772
+ type=str,
773
+ default=None,
774
+ choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
775
+ help="""
776
+ If specified, creates float8 config from recipe name, valid choices are
777
+ `tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
778
+ """,
779
+ )
780
+
781
+ # communications library settings
782
+ self.parser.add_argument(
783
+ "--comm.init_timeout_seconds",
784
+ type=int,
785
+ default=300,
786
+ help="Timeout for communication operations, during initialization and first train step.",
787
+ )
788
+ self.parser.add_argument(
789
+ "--comm.train_timeout_seconds",
790
+ type=int,
791
+ default=100,
792
+ help=(
793
+ "Timeout for communication operations after the first train step -- "
794
+ "usually a tighter bound than during initialization."
795
+ ),
796
+ )
797
+ self.parser.add_argument(
798
+ "--comm.trace_buf_size",
799
+ type=int,
800
+ default=20000,
801
+ help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
802
+ )
803
+
804
+ # memory estimation settings
805
+ self.parser.add_argument(
806
+ "--memory_estimation.enabled",
807
+ help="Whether to estimate memory usage for FSDP",
808
+ action="store_true",
809
+ )
810
+
811
+ self.parser.add_argument(
812
+ "--memory_estimation.disable_fake_mode",
813
+ help="Whether to estimate memory under FakeTensorMode",
814
+ action="store_true",
815
+ )
816
+
817
+ self.parser.add_argument(
818
+ "--fault_tolerance.enable",
819
+ action="store_true",
820
+ help="""
821
+ Enable TorchFT integration. When TorchFT is enabled, HSDP will be used.
822
+ And --fault_tolerance.data_parallel_replicate_degree should be 1 and
823
+ --fault_tolerance.group_size will be used to control the maximum
824
+ replicate group size as the replicate group size is dynamic.
825
+
826
+ Note that this is still an experimental feature.
827
+ """,
828
+ )
829
+
830
+ self.parser.add_argument(
831
+ "--fault_tolerance.replica_id",
832
+ type=int,
833
+ default=0,
834
+ help="The TorchFT replica ID of this run.",
835
+ )
836
+
837
+ self.parser.add_argument(
838
+ "--fault_tolerance.group_size",
839
+ type=int,
840
+ default=0,
841
+ help="""
842
+ The number of TorchFT replicate groups. This number will be used for
843
+ dataloader to split the dataset across the replicate groups and FSDP
844
+ dimension
845
+ """,
846
+ )
847
+
848
+ self.parser.add_argument(
849
+ "--fault_tolerance.min_replica_size",
850
+ type=int,
851
+ default=1,
852
+ help="The minimum number of FT replica for each step.",
853
+ )
854
+
855
+ def to_dict(self):
856
+ return self.args_dict
857
+
858
+ def parse_args(self, args_list: list = sys.argv[1:]):
859
+ args, cmd_args = self.parse_args_from_command_line(args_list)
860
+ config_file = getattr(args, "job.config_file", None)
861
+ # build up a two level dict
862
+ args_dict = self._args_to_two_level_dict(args)
863
+ if config_file is not None:
864
+ try:
865
+ with open(config_file, "rb") as f:
866
+ for k, v in tomllib.load(f).items():
867
+ # to prevent overwrite of non-specified keys
868
+ args_dict[k] |= v
869
+ except (FileNotFoundError, tomllib.TOMLDecodeError) as e:
870
+ logger.exception(
871
+ f"Error while loading the configuration file: {config_file}"
872
+ )
873
+ logger.exception(f"Error details: {str(e)}")
874
+ raise e
875
+
876
+ # Checking string-list arguments are properly split into a list
877
+ # if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
878
+ string_list_argnames = self._get_string_list_argument_names()
879
+ for n in string_list_argnames:
880
+ check_string_list_argument(args_dict, n)
881
+
882
+ # override args dict with cmd_args
883
+ cmd_args_dict = self._args_to_two_level_dict(cmd_args)
884
+ for section, section_args in cmd_args_dict.items():
885
+ for k, v in section_args.items():
886
+ args_dict[section][k] = v
887
+
888
+ self.args_dict = args_dict
889
+
890
+ for k, v in args_dict.items():
891
+ class_type = type(k.title(), (), v)
892
+ setattr(self, k, class_type())
893
+ self._validate_config()
894
+
895
+ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
896
+ args_dict = defaultdict(defaultdict)
897
+ for k, v in vars(args).items():
898
+ first_level_key, second_level_key = k.split(".", 1)
899
+ args_dict[first_level_key][second_level_key] = v
900
+ return args_dict
901
+
902
+ def _validate_config(self) -> None:
903
+ # TODO: Add more mandatory validations
904
+ assert self.model.config
905
+ assert self.model.tokenizer_path
906
+
907
+ def _get_string_list_argument_names(self) -> list[str]:
908
+ """Get the parser argument names of type `string_list`."""
909
+ string_list_args = [
910
+ v.dest for v in self.parser._actions if v.type is string_list
911
+ ]
912
+ return string_list_args
913
+
914
+ def parse_args_from_command_line(
915
+ self, args_list
916
+ ) -> Tuple[argparse.Namespace, argparse.Namespace]:
917
+ """
918
+ Parse command line arguments and return the parsed args and the command line only args
919
+ """
920
+ args = self.parser.parse_args(args_list)
921
+ string_list_argnames = set(self._get_string_list_argument_names())
922
+
923
+ # aux parser to parse the command line only args, with no defaults from main parser
924
+ aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
925
+ for arg, val in vars(args).items():
926
+ if isinstance(val, bool):
927
+ aux_parser.add_argument(
928
+ "--" + arg, action="store_true" if val else "store_false"
929
+ )
930
+ elif arg in string_list_argnames:
931
+ # without this special case, type inference breaks here,
932
+ # since the inferred type is just 'list' and it ends up flattening
933
+ # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
934
+ aux_parser.add_argument("--" + arg, type=string_list)
935
+ else:
936
+ aux_parser.add_argument("--" + arg, type=type(val))
937
+
938
+ cmd_args, _ = aux_parser.parse_known_args(args_list)
939
+
940
+ return args, cmd_args
flame/data.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import pickle
7
+ from copy import deepcopy
8
+ from dataclasses import dataclass
9
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
10
+
11
+ import datasets
12
+ import numpy as np
13
+ import torch
14
+ from datasets import Dataset, IterableDataset
15
+ from datasets.iterable_dataset import ShufflingConfig
16
+ from torch.distributed.checkpoint.stateful import Stateful
17
+ from torchdata.stateful_dataloader import StatefulDataLoader
18
+ from transformers import PreTrainedTokenizer
19
+
20
+ from torchtitan.tools.logging import logger
21
+
22
+
23
+ class BufferShuffledIterableDataset(IterableDataset):
24
+ def __init__(
25
+ self,
26
+ dataset: Dataset,
27
+ tokenizer: PreTrainedTokenizer,
28
+ seq_len: int = 2048,
29
+ rank: int = 0,
30
+ world_size: int = 1,
31
+ buffer_size: int = 1024,
32
+ ) -> BufferShuffledIterableDataset:
33
+ self.dataset = dataset
34
+ self.tokenizer = tokenizer
35
+
36
+ self.data = dataset.shard(world_size, rank)
37
+ self.seq_len = seq_len
38
+
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.buffer_size = buffer_size
42
+
43
+ if tokenizer.vocab_size < torch.iinfo(torch.int16).max:
44
+ self.dtype = torch.int16
45
+ elif tokenizer.vocab_size < torch.iinfo(torch.int32).max:
46
+ self.dtype = torch.int32
47
+ else:
48
+ self.dtype = torch.int64
49
+ self.states = None
50
+ self.buffer = torch.tensor([], dtype=self.dtype)
51
+ self.tokens = []
52
+ self.rand_id = 0
53
+ self.token_id = 0
54
+ self.rng_state = None
55
+ self._epoch = 0
56
+
57
+ def __iter__(self):
58
+ g = torch.Generator()
59
+ g.manual_seed(self._epoch + self.rank)
60
+ if self.rng_state is not None:
61
+ g.set_state(self.rng_state)
62
+
63
+ rand_it = self.randint(0, self.buffer_size, g=g)
64
+ if self.states is not None:
65
+ self.data.load_state_dict(self.states)
66
+
67
+ # max number of tokens allowed in the chunk buffer
68
+ n_tokens = self.buffer_size * self.seq_len
69
+
70
+ while True:
71
+ for sample in self.tokenize(self.data):
72
+ # keep appending the samples to the token buffer
73
+ self.tokens += sample
74
+ # if the token buffer is full, start sampling
75
+ # NOTE: we first convert the token ids to a tensor of shape [n_chunks, seq_len] for efficiency
76
+ if len(self.buffer) == 0 and len(self.tokens) >= n_tokens:
77
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=self.dtype).view(self.buffer_size, -1)
78
+ self.tokens = self.tokens[n_tokens:]
79
+ if len(self.buffer) == self.buffer_size:
80
+ yield from self.sample(rand_it)
81
+
82
+ n_chunks = len(self.tokens) // self.seq_len
83
+ # handle the left tokens in the buffer
84
+ if n_chunks > 0:
85
+ n_tokens = n_chunks * self.seq_len
86
+ indices = torch.randperm(n_chunks, generator=g).tolist()
87
+ self.buffer = torch.tensor(self.tokens[:n_tokens], dtype=torch.long).view(n_chunks, -1)
88
+ self.tokens = self.tokens[n_tokens:]
89
+ for i in indices:
90
+ yield {'input_ids': self.buffer[i]}
91
+
92
+ def tokenize(self, data, batch_size: int = 64):
93
+ texts, states = [], []
94
+ for sample in data:
95
+ texts.append(sample['text'])
96
+ states.append(self.data.state_dict())
97
+ if len(texts) == batch_size:
98
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
99
+ self.states = s
100
+ yield tokenized
101
+ texts, states = [], []
102
+ if len(texts) > 0:
103
+ for s, tokenized in zip(states, self.tokenizer(texts, return_attention_mask=False)['input_ids']):
104
+ self.states = s
105
+ yield tokenized
106
+
107
+ def sample(self, indices):
108
+ n_tokens = (len(self.tokens) // self.seq_len) * self.seq_len
109
+ while self.token_id < n_tokens:
110
+ i = next(indices)
111
+ start, end = self.token_id, self.token_id + self.seq_len
112
+ self.token_id += self.seq_len
113
+ yield {'input_ids': self.buffer[i].to(torch.long)}
114
+ self.buffer[i] = torch.tensor(self.tokens[start:end], dtype=self.dtype)
115
+ self.token_id = 0
116
+ self.tokens = self.tokens[n_tokens:]
117
+
118
+ def randint(self, low: int, high: int, buffer_size: int = 1024, g: torch.Generator = torch.Generator()) -> Iterable[int]:
119
+ indices = torch.empty(buffer_size, dtype=torch.long)
120
+ while True:
121
+ # record the generator states before sampling
122
+ self.rng_state = g.get_state()
123
+ indices = torch.randint(low, high, (buffer_size,), out=indices, generator=g)
124
+ for i in indices[self.rand_id:].tolist():
125
+ self.rand_id += 1
126
+ yield i
127
+ self.rand_id = 0
128
+
129
+ def set_epoch(self, epoch):
130
+ self._epoch = epoch
131
+ if hasattr(self.dataset, 'set_epoch'):
132
+ self.dataset.set_epoch(epoch)
133
+
134
+ def state_dict(self):
135
+ return {
136
+ 'states': self.states,
137
+ 'buffer': self.buffer.clone(),
138
+ 'tokens': deepcopy(self.tokens),
139
+ 'rand_id': self.rand_id,
140
+ 'token_id': self.token_id,
141
+ 'rng_state': self.rng_state,
142
+ 'epoch': self._epoch,
143
+ }
144
+
145
+ def load_state_dict(self, state_dict):
146
+ self.states = state_dict['states']
147
+ self.buffer = state_dict['buffer'].clone()
148
+ self.tokens = deepcopy(state_dict['tokens'])
149
+ self.rand_id = state_dict['rand_id']
150
+ self.token_id = state_dict['token_id']
151
+ self.rng_state = state_dict['rng_state'].clone() if state_dict['rng_state'] is not None else None
152
+ self._epoch = state_dict['epoch']
153
+
154
+
155
+ class OnlineTokenizedIterableDataset(IterableDataset):
156
+ def __init__(
157
+ self, dataset: Dataset, tokenizer: PreTrainedTokenizer, seq_len: int = 2048, rank: int = 0, world_size: int = 1
158
+ ) -> OnlineTokenizedIterableDataset:
159
+ self.dataset = dataset
160
+ self.tokenizer = tokenizer
161
+
162
+ self.data = dataset.shard(world_size, rank)
163
+ self.seq_len = seq_len
164
+ self.rank = rank
165
+ self.world_size = world_size
166
+
167
+ self.states = None
168
+ self.tokens = []
169
+
170
+ def __iter__(self):
171
+ if self.states is not None:
172
+ self.data.load_state_dict(self.states)
173
+
174
+ while True:
175
+ for sample in self.tokenize(self.data):
176
+ # keep appending the samples to the token buffer
177
+ self.tokens += sample
178
+
179
+ while len(self.tokens) >= self.seq_len:
180
+ input_ids = torch.tensor(self.tokens[:self.seq_len], dtype=torch.long)
181
+ self.tokens = self.tokens[self.seq_len:]
182
+ yield {'input_ids': input_ids}
183
+
184
+ def tokenize(self, data, buffer_size: int = 64):
185
+ buffer, states = [], []
186
+ for sample in data:
187
+ if sample.get('text', None) is not None:
188
+ buffer.append(sample['text'])
189
+ elif sample.get('content', None) is not None:
190
+ buffer.append(sample['content'])
191
+ else:
192
+ raise ValueError(f"No 'text' or 'content' field found in sample:\n{sample}")
193
+ states.append(self.data.state_dict())
194
+ if len(buffer) == buffer_size:
195
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
196
+ self.states = s
197
+ yield tokenized
198
+ buffer, states = [], []
199
+ if len(buffer) > 0:
200
+ for s, tokenized in zip(states, self.tokenizer(buffer, return_attention_mask=False)['input_ids']):
201
+ self.states = s
202
+ yield tokenized
203
+
204
+ def state_dict(self):
205
+ return {'states': self.states, 'tokens': deepcopy(self.tokens)}
206
+
207
+ def load_state_dict(self, state_dict):
208
+ self.states = state_dict['states']
209
+ self.tokens = deepcopy(state_dict['tokens'])
210
+
211
+
212
+ class BufferShuffledExamplesIterable(datasets.iterable_dataset.BufferShuffledExamplesIterable):
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args, **kwargs)
215
+
216
+ def _init_state_dict(self) -> dict:
217
+ self._state_dict = self.ex_iterable._init_state_dict()
218
+ self._state_dict['mem_buffer'] = ([],)
219
+ self._state_dict['bit_generator_state'] = self.generator.bit_generator.state
220
+ self._state_dict['bit_generator_index_offset'] = 0
221
+ self._state_dict['bit_generator_index_offset_shuffle'] = 0
222
+ return self._state_dict
223
+
224
+ def __iter__(self):
225
+ buffer_size = self.buffer_size
226
+ rng = deepcopy(self.generator)
227
+ # this is the shuffle buffer that we keep in memory
228
+ mem_buffer = self._state_dict['mem_buffer'][0]
229
+ # this is an infinite iterator that randomly samples the index of the source to pick examples from
230
+ index_offset = self._state_dict['bit_generator_index_offset'] if self._state_dict else 0
231
+ if self._state_dict:
232
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
233
+ indices_iterator = self._iter_random_indices(rng, buffer_size, random_batch_size=buffer_size)
234
+ # skip already consumed ones
235
+ for _ in range(index_offset):
236
+ i = next(indices_iterator)
237
+
238
+ for x in self.ex_iterable:
239
+ if len(mem_buffer) < buffer_size: # if the buffer is not full, keep filling the buffer
240
+ mem_buffer.append(x)
241
+ else: # otherwise, pick an example from it
242
+ i = next(indices_iterator)
243
+ index_offset = (index_offset + 1) % buffer_size
244
+ if self._state_dict:
245
+ self._state_dict['bit_generator_index_offset'] = index_offset
246
+ if index_offset == 0:
247
+ self._state_dict['bit_generator_state'] = rng.bit_generator.state
248
+ selected = mem_buffer[i]
249
+ mem_buffer[i] = x # replace the picked example by a new one
250
+ yield selected
251
+
252
+ index_offset = self._state_dict['bit_generator_index_offset_shuffle'] if self._state_dict else 0
253
+ if self._state_dict:
254
+ rng.bit_generator.state = self._state_dict['bit_generator_state']
255
+
256
+ # when we run out of examples, we shuffle the remaining examples in the buffer and yield them
257
+ for i in rng.permutation(len(mem_buffer))[index_offset:].tolist():
258
+ index_offset = index_offset + 1
259
+ if self._state_dict:
260
+ self._state_dict['bit_generator_index_offset_shuffle'] = index_offset
261
+ yield mem_buffer[i]
262
+
263
+ def shuffle_data_sources(self, generator: np.random.Generator) -> BufferShuffledExamplesIterable:
264
+ """Shuffle the wrapped examples iterable as well as the shuffling buffer."""
265
+ return BufferShuffledExamplesIterable(
266
+ self.ex_iterable.shuffle_data_sources(generator), buffer_size=self.buffer_size, generator=generator
267
+ )
268
+
269
+ def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> BufferShuffledExamplesIterable:
270
+ """Keep only the requested shard."""
271
+ return BufferShuffledExamplesIterable(
272
+ self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
273
+ buffer_size=self.buffer_size,
274
+ generator=self.generator,
275
+ )
276
+
277
+ def load_state_dict(self, state_dict: dict) -> dict:
278
+ def _inner_load_state_dict(state, new_state):
279
+ if new_state is not None and isinstance(state, dict):
280
+ for key in new_state:
281
+ state[key] = _inner_load_state_dict(state[key], new_state[key])
282
+ return state
283
+ elif new_state is not None and isinstance(state, list):
284
+ for i in range(len(state)):
285
+ state[i] = _inner_load_state_dict(state[i], new_state[i])
286
+ return state
287
+ return new_state
288
+
289
+ return _inner_load_state_dict(self._state_dict, state_dict)
290
+
291
+
292
+ def shuffle(
293
+ dataset: IterableDataset,
294
+ seed: int = 42,
295
+ generator: np.random.Generator = None,
296
+ buffer_size: int = 1024,
297
+ ):
298
+ generator = np.random.default_rng(seed) if generator is None else deepcopy(generator)
299
+ return IterableDataset(
300
+ ex_iterable=BufferShuffledExamplesIterable(dataset._ex_iterable, buffer_size=buffer_size, generator=generator),
301
+ info=dataset._info.copy(),
302
+ split=dataset._split,
303
+ formatting=dataset._formatting,
304
+ shuffling=ShufflingConfig(generator=generator, _original_seed=seed),
305
+ distributed=copy.deepcopy(dataset._distributed),
306
+ token_per_repo_id=dataset._token_per_repo_id,
307
+ )
308
+
309
+
310
+ @dataclass
311
+ class DataCollatorForLanguageModeling:
312
+ """
313
+ Data collator used for language modeling. Inputs are dynamically padded if `varlen=False`.
314
+ If `varlen=True`, sequences are expected to be concatenated, and labels match inputs.
315
+
316
+ Args:
317
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
318
+ The tokenizer used for encoding the data.
319
+ context_len (`int`, optional):
320
+ When `varlen=True`, sequences longer than this length within a document
321
+ (as determined by `cu_seqlens`) will be further chunked.
322
+ varlen (`bool`):
323
+ Whether to handle variable length concatenated sequences (`True`) or padded batches (`False`).
324
+
325
+ Returns:
326
+ A dictionary with the following keys:
327
+ - `input_ids`: Tensor of input IDs. Shape `[batch_size, seq_len]` if `varlen=False`, `[1, total_len]` if `varlen=True`.
328
+ - `labels`: Tensor of labels. Shape matches `input_ids`. Padding positions are masked with -100 if `varlen=False`.
329
+ - `attention_mask`: Tensor indicating non-padding tokens (only if `varlen=False`). Shape matches `input_ids`.
330
+ - `cu_seqlens`: Tensor of cumulative sequence lengths (only if `varlen=True`). Shape `[1, num_sequences + 1]`.
331
+
332
+ NOTE: When `varlen=True`, the `batch_size` must be 1.
333
+ """
334
+
335
+ tokenizer: PreTrainedTokenizer
336
+ context_len: Optional[int] = None
337
+ varlen: bool = False
338
+
339
+ def __call__(self, examples: List[Union[List[int], Dict[str, Any]]]) -> Dict[str, Any]:
340
+ if not isinstance(examples[0], Dict):
341
+ examples = [{'input_ids': example} for example in examples]
342
+
343
+ def tensorize(example: Dict[str, Any]) -> Dict[str, Any]:
344
+ tensorized = {}
345
+ for key in ['input_ids', 'cu_seqlens']:
346
+ if key not in example:
347
+ continue
348
+ if isinstance(example[key], List):
349
+ tensorized[key] = torch.tensor(example[key], dtype=torch.long)
350
+ elif isinstance(example[key], np.ndarray):
351
+ tensorized[key] = torch.from_numpy(example[key])
352
+ else:
353
+ tensorized[key] = example[key]
354
+ return tensorized
355
+
356
+ examples = list(map(tensorize, examples))
357
+
358
+ if not self.varlen:
359
+ # --- Handling for varlen=False (Batch Padding) ---
360
+ length_of_first = examples[0]['input_ids'].size(0)
361
+ needs_padding = not all(example['input_ids'].size(0) == length_of_first for example in examples)
362
+
363
+ if needs_padding:
364
+ # Check for pad token if padding is actually required
365
+ if self.tokenizer.pad_token_id is None:
366
+ raise ValueError(
367
+ f'You are attempting to pad samples but the tokenizer you are using '
368
+ f'({self.tokenizer.__class__.__name__}) does not have a pad token.'
369
+ )
370
+ # Pad using the tokenizer, ensuring attention_mask is returned
371
+ batch = self.tokenizer.pad(examples, return_tensors='pt', return_attention_mask=True)
372
+ else:
373
+ # No padding needed, stack directly and create a full attention mask
374
+ input_ids = torch.stack([example['input_ids'] for example in examples], dim=0)
375
+ batch = {
376
+ 'input_ids': input_ids,
377
+ # Create attention mask of all ones
378
+ 'attention_mask': torch.ones_like(input_ids),
379
+ }
380
+
381
+ # Create labels by cloning input_ids
382
+ labels = batch['input_ids'].clone()
383
+ # Mask labels only where attention_mask is 0 (padding positions)
384
+ if 'attention_mask' in batch:
385
+ labels[batch['attention_mask'] == 0] = -100
386
+ batch['labels'] = labels
387
+
388
+ else:
389
+ # --- Handling for varlen=True (Concatenated Sequences) ---
390
+ if len(examples) > 1:
391
+ raise ValueError('The batch size must be 1 for inputs with variable lengths (varlen=True).')
392
+
393
+ batch = {'input_ids': torch.cat([example['input_ids'] for example in examples], dim=0).unsqueeze(0)}
394
+
395
+ # --- cu_seqlens calculation logic remains the same ---
396
+ if 'cu_seqlens' in examples[0]:
397
+ batch['cu_seqlens'] = (
398
+ torch.cat([example['cu_seqlens'] for example in examples], dim=0).unsqueeze(0).to(dtype=torch.int32)
399
+ ) # Ensure int32
400
+ else:
401
+ # determine boundaries by bos/eos positions
402
+ # Check for bos_token_id first
403
+ if self.tokenizer.bos_token_id is not None:
404
+ cu_seqlens = []
405
+ # Handle case where the sequence doesn't start with BOS
406
+ if batch['input_ids'][0, 0] != self.tokenizer.bos_token_id:
407
+ cu_seqlens.append(torch.tensor([0], device=batch['input_ids'].device)) # Match device
408
+ # Find all BOS token positions
409
+ bos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.bos_token_id))[1]
410
+ # Ensure bos_positions is on the correct device if empty
411
+ if bos_positions.numel() == 0 and len(cu_seqlens) > 0:
412
+ cu_seqlens.append(bos_positions.to(cu_seqlens[0].device))
413
+ elif bos_positions.numel() > 0:
414
+ cu_seqlens.append(bos_positions)
415
+ # Add the end of the entire batch
416
+ cu_seqlens.append(
417
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
418
+ ) # Match device and use size(1)
419
+ # Filter out empty tensors before cat
420
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
421
+ if not cu_seqlens: # Handle case where input is empty or has no BOS
422
+ batch['cu_seqlens'] = torch.tensor(
423
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
424
+ )
425
+ else:
426
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
427
+
428
+ # Else, check for eos_token_id
429
+ elif self.tokenizer.eos_token_id is not None:
430
+ cu_seqlens = [torch.tensor([0], device=batch['input_ids'].device)] # Match device
431
+ # Find positions *after* EOS tokens
432
+ eos_positions = torch.where(batch['input_ids'].eq(self.tokenizer.eos_token_id))[1] + 1
433
+ # Ensure eos_positions is on the correct device if empty
434
+ if eos_positions.numel() > 0:
435
+ cu_seqlens.append(eos_positions)
436
+ # Handle case where the sequence doesn't end with EOS
437
+ if batch['input_ids'][0, -1] != self.tokenizer.eos_token_id:
438
+ # Only add the final length if the last found EOS wasn't already the end
439
+ if eos_positions.numel() == 0 or eos_positions[-1] != batch['input_ids'].size(1):
440
+ cu_seqlens.append(
441
+ torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
442
+ ) # Match device and use size(1)
443
+ # Filter out empty tensors before cat
444
+ cu_seqlens = [t for t in cu_seqlens if t.numel() > 0]
445
+ if not cu_seqlens: # Handle case where input is empty or has no EOS
446
+ batch['cu_seqlens'] = torch.tensor(
447
+ [0, batch['input_ids'].size(1)], dtype=torch.int32, device=batch['input_ids'].device
448
+ )
449
+ else:
450
+ batch['cu_seqlens'] = torch.cat(cu_seqlens, dim=0).to(dtype=torch.int32)
451
+ # Else, neither BOS nor EOS is usable
452
+ else:
453
+ raise ValueError(
454
+ 'For varlen=True without precomputed cu_seqlens, the tokenizer must have either a bos_token_id '
455
+ 'or an eos_token_id defined to act as sequence separators.'
456
+ )
457
+
458
+ # --- cu_seqlens validation checks remain the same ---
459
+ if batch['cu_seqlens'].numel() < 2:
460
+ raise ValueError(f'Calculated cu_seqlens must have at least start and end: {batch["cu_seqlens"]}')
461
+ if not torch.all(batch['cu_seqlens'][1:] >= batch['cu_seqlens'][:-1]):
462
+ raise ValueError(f'Calculated cu_seqlens are not monotonically increasing: {batch["cu_seqlens"]}')
463
+ if batch['cu_seqlens'][0] != 0:
464
+ raise ValueError(f'Calculated cu_seqlens do not start at 0: {batch["cu_seqlens"]}')
465
+ if batch['cu_seqlens'][-1] != batch['input_ids'].size(1):
466
+ # Allow empty sequence case where cu_seqlens=[0, 0] and input_ids.size(1)=0
467
+ if not (batch['cu_seqlens'].tolist() == [0, 0] and batch['input_ids'].size(1) == 0):
468
+ raise ValueError(
469
+ f'Calculated cu_seqlens do not end at total length {batch["input_ids"].size(1)}: '
470
+ f'{batch["cu_seqlens"]}'
471
+ )
472
+
473
+ # --- context_len splitting logic remains the same ---
474
+ if self.context_len is not None:
475
+ # This logic splits sequences based on context_len *after* initial boundaries are found
476
+ bos = batch['cu_seqlens'][:-1].tolist()
477
+ eos = batch['cu_seqlens'][1:].tolist()
478
+ # Handle empty sequences between boundaries
479
+ split_boundaries = []
480
+ for i, j in zip(bos, eos):
481
+ if i < j: # Only process non-empty sequences
482
+ split_boundaries.append(torch.arange(i, j, self.context_len, device=batch['input_ids'].device))
483
+ # Add the final end point if it wasn't included by arange
484
+ final_end_point = torch.tensor([batch['input_ids'].size(1)], device=batch['input_ids'].device)
485
+ # Concatenate all boundaries
486
+ if not split_boundaries: # Handle case of completely empty input
487
+ batch['cu_seqlens'] = torch.tensor([0, 0], dtype=torch.int32, device=batch['input_ids'].device)
488
+ else:
489
+ batch['cu_seqlens'] = torch.cat(split_boundaries + [final_end_point]).to(dtype=torch.int32)
490
+ # Ensure uniqueness and sort, as arange might duplicate the endpoint
491
+ batch['cu_seqlens'] = torch.unique(batch['cu_seqlens'])
492
+
493
+ # Create labels directly from input_ids, NO padding mask needed for varlen
494
+ labels = batch['input_ids'].clone()
495
+ batch['labels'] = labels
496
+
497
+ return batch
498
+
499
+
500
+ class ParallelAwareDataLoader(StatefulDataLoader, Stateful):
501
+ """
502
+ A wrapper around the StatefulDataLoader that ensures that the state is stored only once per DP rank.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ rank: int,
508
+ dataset: IterableDataset,
509
+ batch_size: int,
510
+ collate_fn: Callable,
511
+ num_workers: int = 0,
512
+ pin_memory: bool = False,
513
+ prefetch_factor: int = 2,
514
+ persistent_workers: bool = False,
515
+ snapshot_every_n_steps: Optional[int] = 1,
516
+ ):
517
+ super().__init__(
518
+ dataset=dataset,
519
+ batch_size=batch_size,
520
+ collate_fn=collate_fn,
521
+ num_workers=num_workers,
522
+ pin_memory=pin_memory,
523
+ prefetch_factor=prefetch_factor,
524
+ persistent_workers=persistent_workers,
525
+ snapshot_every_n_steps=snapshot_every_n_steps,
526
+ )
527
+ self.rank = rank
528
+
529
+ def state_dict(self) -> Dict[str, Any]:
530
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
531
+ return {f'rank_{self.rank}': pickle.dumps(super().state_dict())}
532
+
533
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
534
+ # State being empty is valid
535
+ if not state_dict:
536
+ return
537
+
538
+ if f'rank_{self.rank}' not in state_dict:
539
+ logger.warning(f'DataLoader state is empty for dp rank {self.rank}, expected key rank_{self.rank}')
540
+ return
541
+ super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
542
+
543
+
544
+ def build_dataloader(
545
+ dataset: IterableDataset,
546
+ tokenizer: PreTrainedTokenizer,
547
+ rank: int,
548
+ world_size: int,
549
+ batch_size: int,
550
+ seq_len: int,
551
+ context_len: Optional[int] = None,
552
+ varlen: bool = False,
553
+ num_workers: int = 0,
554
+ pin_memory: bool = False,
555
+ persistent_workers: bool = False,
556
+ snapshot_every_n_steps: Optional[int] = 1,
557
+ ):
558
+ dataset = OnlineTokenizedIterableDataset(
559
+ dataset=dataset, tokenizer=tokenizer, seq_len=seq_len, rank=rank, world_size=world_size
560
+ )
561
+ return ParallelAwareDataLoader(
562
+ rank=rank,
563
+ dataset=dataset,
564
+ batch_size=batch_size,
565
+ collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, context_len=context_len, varlen=varlen),
566
+ num_workers=num_workers,
567
+ pin_memory=pin_memory,
568
+ persistent_workers=persistent_workers,
569
+ snapshot_every_n_steps=snapshot_every_n_steps,
570
+ )
logs/none_75lcom2m/attempt_0/3/stderr.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Traceback (most recent call last):
2
+ File "<frozen runpy>", line 198, in _run_module_as_main
3
+ File "<frozen runpy>", line 88, in _run_code
4
+ File "/workspace/flame/flame/train.py", line 19, in <module>
5
+ import fla # noqa
6
+ ^^^^^^^^^^
7
+ File "/workspace/flame/fla/__init__.py", line 23, in <module>
8
+ from fla.models import (
9
+ File "/workspace/flame/fla/models/__init__.py", line 4, in <module>
10
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
11
+ File "/workspace/flame/fla/models/bitnet/__init__.py", line 8, in <module>
12
+ AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
13
+ File "/workspace/flame/.venv/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py", line 1211, in register
14
+ CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
15
+ File "/workspace/flame/.venv/lib/python3.11/site-packages/transformers/models/auto/configuration_auto.py", line 905, in register
16
+ raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
17
+ ValueError: 'bitnet' is already used by a Transformers config, pick another name.
logs/none_75lcom2m/attempt_0/3/stdout.log ADDED
File without changes
logs/none_vngrbiu1/attempt_0/0/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_vngrbiu1/attempt_0/0/stdout.log ADDED
File without changes
logs/none_vngrbiu1/attempt_0/1/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_vngrbiu1/attempt_0/1/stdout.log ADDED
File without changes
logs/none_vngrbiu1/attempt_0/2/stdout.log ADDED
File without changes
logs/none_vngrbiu1/attempt_0/3/stderr.log ADDED
The diff for this file is too large to render. See raw diff
 
logs/none_vngrbiu1/attempt_0/3/stdout.log ADDED
File without changes
profile_trace/iteration_1024/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1024/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff