Qwen
/

Files changed (3) hide show
  1. config.json +7 -4
  2. modeling_qwen.py +181 -201
  3. test_passkey_retrieval.py +97 -0
config.json CHANGED
@@ -7,7 +7,7 @@
7
  "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
8
  },
9
  "attn_dropout_prob": 0.0,
10
- "bf16": false,
11
  "emb_dropout_prob": 0.0,
12
  "fp16": false,
13
  "fp32": false,
@@ -30,8 +30,11 @@
30
  "tokenizer_class": "QWenTokenizer",
31
  "transformers_version": "4.32.0",
32
  "use_cache": true,
33
- "use_dynamic_ntk": true,
34
- "use_flash_attn": "auto",
 
 
 
35
  "use_logn_attn": true,
36
  "vocab_size": 151936
37
- }
 
7
  "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
8
  },
9
  "attn_dropout_prob": 0.0,
10
+ "bf16": true,
11
  "emb_dropout_prob": 0.0,
12
  "fp16": false,
13
  "fp32": false,
 
30
  "tokenizer_class": "QWenTokenizer",
31
  "transformers_version": "4.32.0",
32
  "use_cache": true,
33
+ "use_dynamic_ntk": false,
34
+ "use_flash_attn": false,
35
+ "use_rerope": true,
36
+ "rerope_window": 512,
37
+ "forward_max_length": 32768,
38
  "use_logn_attn": true,
39
  "vocab_size": 151936
40
+ }
modeling_qwen.py CHANGED
@@ -7,6 +7,7 @@ import copy
7
  import importlib
8
  import math
9
  import pathlib
 
10
  from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
11
 
12
  import torch
@@ -28,6 +29,7 @@ from transformers.modeling_outputs import (
28
  )
29
  from transformers.modeling_utils import PreTrainedModel
30
  from transformers.utils import logging
 
31
 
32
  try:
33
  from einops import rearrange
@@ -241,6 +243,7 @@ class QWenAttention(nn.Module):
241
 
242
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
243
  self.seq_length = config.seq_length
 
244
 
245
  self.hidden_size = config.hidden_size
246
  self.split_size = config.hidden_size
@@ -276,17 +279,19 @@ class QWenAttention(nn.Module):
276
 
277
  self.use_dynamic_ntk = config.use_dynamic_ntk
278
  self.use_logn_attn = config.use_logn_attn
 
 
 
279
 
280
  logn_list = [
281
  math.log(i, self.seq_length) if i > self.seq_length else 1
282
- for i in range(1, 32768)
283
  ]
284
  logn_tensor = torch.tensor(logn_list)[None, :, None, None]
285
  self.register_buffer("logn_tensor", logn_tensor, persistent=False)
286
 
287
  self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
288
  self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
289
- self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
290
  self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
291
  cache_dtype = torch.float
292
  if self.bf16:
@@ -296,102 +301,60 @@ class QWenAttention(nn.Module):
296
  self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
297
  self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
298
 
299
- if config.use_cache_quantization and config.use_cache_kernel:
300
- # pre check if the support files existing
301
- module_root = pathlib.Path(__file__).parent
302
- src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
303
- if any(not (module_root/src).is_file() for src in src_files):
304
- warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
305
- self.cache_kernels = None
306
- else:
307
- try:
308
- from .cpp_kernels import cache_autogptq_cuda_256
309
- self.cache_kernels = cache_autogptq_cuda_256
310
- except ImportError:
311
- warnings.warn("Failed to import KV cache kernels.")
312
- self.cache_kernels = None
313
-
314
- def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
315
- device = query.device
316
- if self.use_cache_quantization:
317
- qk, qk_scale, qk_zero = key
318
- if self.use_cache_kernel and self.cache_kernels is not None:
319
- shape = query.shape[:-1] + (qk.shape[-2],)
320
- attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
321
- self.cache_kernels.vecquant8matmul_batched_faster_old(
322
- query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
323
- qk.transpose(-1, -2).contiguous(),
324
- attn_weights,
325
- qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
326
- qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
327
- # attn_weights = attn_weights.to(query.dtype).contiguous()
328
- else:
329
- key = dequantize_cache_torch(qk, qk_scale, qk_zero)
330
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
331
- else:
332
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  if self.scale_attn_weights:
335
- if self.use_cache_quantization:
336
- size_temp = value[0].size(-1)
337
- else:
338
- size_temp = value.size(-1)
339
- attn_weights = attn_weights / torch.full(
340
- [],
341
- size_temp ** 0.5,
342
- dtype=attn_weights.dtype,
343
- device=attn_weights.device,
344
  )
345
- if self.use_cache_quantization:
346
- query_length, key_length = query.size(-2), key[0].size(-2)
347
- else:
348
- query_length, key_length = query.size(-2), key.size(-2)
 
 
349
  causal_mask = registered_causal_mask[
350
  :, :, key_length - query_length : key_length, :key_length
351
  ]
352
  mask_value = torch.finfo(attn_weights.dtype).min
353
- mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
354
  attn_weights.device
355
  )
356
- attn_weights = torch.where(
357
- causal_mask, attn_weights.to(attn_weights.dtype), mask_value
358
- )
359
 
360
  if attention_mask is not None:
361
  attn_weights = attn_weights + attention_mask
362
 
363
- if self.softmax_in_fp32:
364
- attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
365
- else:
366
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
367
 
368
- attn_weights = attn_weights.type(query.dtype)
 
 
 
 
369
  attn_weights = self.attn_dropout(attn_weights)
370
 
371
  if head_mask is not None:
372
  attn_weights = attn_weights * head_mask
373
 
374
- if self.use_cache_quantization:
375
- qv, qv_scale, qv_zero = value
376
- if self.use_cache_kernel and self.cache_kernels is not None:
377
- shape = attn_weights.shape[:-1] + (query.shape[-1],)
378
- attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
379
- self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
380
- attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
381
- qv.contiguous(), # dtype: int32
382
- attn_output,
383
- qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
384
- qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
385
- if attn_output.dtype != query.dtype:
386
- attn_output = attn_output.to(query.dtype)
387
- attn_weights = attn_weights.to(query.dtype)
388
- else:
389
- value = dequantize_cache_torch(qv, qv_scale, qv_zero)
390
- attn_output = torch.matmul(attn_weights, value)
391
- else:
392
- attn_output = torch.matmul(attn_weights, value)
393
-
394
- attn_output = attn_output.transpose(1, 2)
395
 
396
  return attn_output, attn_weights
397
 
@@ -404,11 +367,31 @@ class QWenAttention(nn.Module):
404
  tensor = tensor.contiguous()
405
  new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
406
  return tensor.view(new_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  def forward(
409
  self,
410
  hidden_states: Optional[Tuple[torch.FloatTensor]],
411
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 
412
  layer_past: Optional[Tuple[torch.Tensor]] = None,
413
  attention_mask: Optional[torch.FloatTensor] = None,
414
  head_mask: Optional[torch.FloatTensor] = None,
@@ -425,116 +408,101 @@ class QWenAttention(nn.Module):
425
  key = self._split_heads(key, self.num_heads, self.head_dim)
426
  value = self._split_heads(value, self.num_heads, self.head_dim)
427
 
428
- if rotary_pos_emb_list is not None:
429
- cur_len = query.shape[1]
430
- if len(rotary_pos_emb_list) == 1:
431
- rotary_pos_emb = rotary_pos_emb_list[0]
432
- rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
433
- rotary_pos_emb = (rotary_pos_emb,) * 2
434
- q_pos_emb, k_pos_emb = rotary_pos_emb
435
- # Slice the pos emb for current inference
436
- query = apply_rotary_pos_emb(query, q_pos_emb)
437
- key = apply_rotary_pos_emb(key, k_pos_emb)
438
- else:
439
- query_list = []
440
- key_list = []
441
- for i, rotary_pos_emb in enumerate(rotary_pos_emb_list):
442
- rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
443
- rotary_pos_emb = (rotary_pos_emb,) * 2
444
- q_pos_emb, k_pos_emb = rotary_pos_emb
445
- # Slice the pos emb for current inference
446
- query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
447
- key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
448
- query = torch.cat(query_list, dim=0)
449
- key = torch.cat(key_list, dim=0)
450
-
451
- if self.use_cache_quantization:
452
- key = quantize_cache_v(key.permute(0, 2, 1, 3),
453
- bits=8,
454
- qmin=self.cache_qmin,
455
- qmax=self.cache_qmax)
456
- value = quantize_cache_v(value.permute(0, 2, 1, 3),
457
- bits=8,
458
- qmin=self.cache_qmin,
459
- qmax=self.cache_qmax)
460
-
461
-
462
- if layer_past is not None:
463
- past_key, past_value = layer_past[0], layer_past[1]
464
- if self.use_cache_quantization:
465
- # use_cache_quantization:
466
- # present=((q_key,key_scale,key_zero_point),
467
- # (q_value,value_scale,value_zero_point))
468
- key = (torch.cat((past_key[0], key[0]), dim=2),
469
- torch.cat((past_key[1], key[1]), dim=2),
470
- torch.cat((past_key[2], key[2]), dim=2))
471
- value = (torch.cat((past_value[0], value[0]), dim=2),
472
- torch.cat((past_value[1], value[1]), dim=2),
473
- torch.cat((past_value[2], value[2]), dim=2))
474
- else:
475
- # not use_cache_quantization:
476
- # present=(key,value)
477
  key = torch.cat((past_key, key), dim=1)
478
  value = torch.cat((past_value, value), dim=1)
479
-
480
- if use_cache:
481
  present = (key, value)
482
- else:
483
- present = None
484
 
485
- if self.use_logn_attn and not self.training:
486
- if self.use_cache_quantization:
487
- seq_start = key[0].size(2) - query.size(1)
488
- seq_end = key[0].size(2)
489
- else:
 
490
  seq_start = key.size(1) - query.size(1)
491
  seq_end = key.size(1)
492
- logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
493
- query = query * logn_tensor.expand_as(query)
494
-
495
- if (
496
- self.use_flash_attn
497
- and flash_attn_unpadded_func is not None
498
- and not self.is_fp32
499
- and query.is_cuda
500
- ):
501
- q, k, v = query, key, value
502
- attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
503
- else:
504
- registered_causal_mask = torch.tril(
505
- torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
506
- ).view(1, 1, key.size(1), key.size(1))
507
  query = query.permute(0, 2, 1, 3)
508
- if not self.use_cache_quantization:
509
- key = key.permute(0, 2, 1, 3)
510
- value = value.permute(0, 2, 1, 3)
511
- if (
512
- registered_causal_mask is None
513
- and self.use_flash_attn
514
- and flash_attn_unpadded_func is not None
515
- and not self.is_fp32
516
- and not query.is_cuda
517
- ):
518
- raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
519
-
520
- if not self.use_cache_quantization and SUPPORT_TORCH2:
521
- causal_mask = registered_causal_mask[
522
- :, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
523
- ]
524
- if attention_mask is not None:
525
- attention_mask = attention_mask.expand(
526
- -1, -1, causal_mask.size(2), -1
527
- ).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
528
- else:
529
- attention_mask = causal_mask
530
- attn_output = F.scaled_dot_product_attention(
531
- query, key, value, attn_mask=attention_mask
532
- ).transpose(1, 2)
533
- attn_weight = None
534
  else:
535
- attn_output, attn_weight = self._attn(
536
- query, key, value, registered_causal_mask, attention_mask, head_mask
537
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
  context_layer = self._merge_heads(
539
  attn_output, self.num_heads, self.head_dim
540
  )
@@ -542,15 +510,6 @@ class QWenAttention(nn.Module):
542
  attn_output = self.c_proj(context_layer)
543
 
544
  outputs = (attn_output, present)
545
- if output_attentions:
546
- if (
547
- self.use_flash_attn
548
- and flash_attn_unpadded_func is not None
549
- and not self.is_fp32
550
- ):
551
- raise ValueError("Cannot output attentions while using flash-attn")
552
- else:
553
- outputs += (attn_weight,)
554
 
555
  return outputs
556
 
@@ -596,6 +555,7 @@ class QWenBlock(nn.Module):
596
  self,
597
  hidden_states: Optional[Tuple[torch.FloatTensor]],
598
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
 
599
  layer_past: Optional[Tuple[torch.Tensor]] = None,
600
  attention_mask: Optional[torch.FloatTensor] = None,
601
  head_mask: Optional[torch.FloatTensor] = None,
@@ -609,6 +569,7 @@ class QWenBlock(nn.Module):
609
  attn_outputs = self.attn(
610
  layernorm_output,
611
  rotary_pos_emb_list,
 
612
  layer_past=layer_past,
613
  attention_mask=attention_mask,
614
  head_mask=head_mask,
@@ -682,10 +643,13 @@ class QWenModel(QWenPreTrainedModel):
682
  self.vocab_size = config.vocab_size
683
  self.num_hidden_layers = config.num_hidden_layers
684
  self.embed_dim = config.hidden_size
685
- self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
686
 
687
  self.gradient_checkpointing = False
688
  self.use_dynamic_ntk = config.use_dynamic_ntk
 
 
 
 
689
  self.seq_length = config.seq_length
690
 
691
  self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
@@ -708,6 +672,21 @@ class QWenModel(QWenPreTrainedModel):
708
 
709
  self.use_flash_attn = config.use_flash_attn
710
  self.is_fp32 = not (config.bf16 or config.fp16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
  self.h = nn.ModuleList(
713
  [
@@ -792,10 +771,7 @@ class QWenModel(QWenPreTrainedModel):
792
  past_length = 0
793
  past_key_values = tuple([None] * len(self.h))
794
  else:
795
- if self.use_cache_quantization:
796
- past_length = past_key_values[0][0][0].size(2)
797
- else:
798
- past_length = past_key_values[0][0].size(-2)
799
  if position_ids is None:
800
  position_ids = torch.arange(
801
  past_length,
@@ -823,10 +799,7 @@ class QWenModel(QWenPreTrainedModel):
823
  kv_seq_len = hidden_states.size()[1]
824
  if past_key_values[0] is not None:
825
  # past key values[0][0] shape: bs * seq_len * head_num * dim
826
- if self.use_cache_quantization:
827
- kv_seq_len += past_key_values[0][0][0].shape[2]
828
- else:
829
- kv_seq_len += past_key_values[0][0].shape[1]
830
 
831
  if self.training or not self.use_dynamic_ntk:
832
  ntk_alpha_list = [1.0]
@@ -844,10 +817,15 @@ class QWenModel(QWenPreTrainedModel):
844
  ntk_alpha = self.get_ntk_alpha(kv_seq_len)
845
  ntk_alpha_list.append(ntk_alpha)
846
  self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
 
 
 
 
 
847
  rotary_pos_emb_list = [
848
- self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
849
  ]
850
-
851
  hidden_states = self.drop(hidden_states)
852
  output_shape = input_shape + (hidden_states.size(-1),)
853
 
@@ -879,6 +857,7 @@ class QWenModel(QWenPreTrainedModel):
879
  create_custom_forward(block),
880
  hidden_states,
881
  rotary_pos_emb_list,
 
882
  None,
883
  attention_mask,
884
  head_mask[i],
@@ -890,6 +869,7 @@ class QWenModel(QWenPreTrainedModel):
890
  hidden_states,
891
  layer_past=layer_past,
892
  rotary_pos_emb_list=rotary_pos_emb_list,
 
893
  attention_mask=attention_mask,
894
  head_mask=head_mask[i],
895
  encoder_hidden_states=encoder_hidden_states,
 
7
  import importlib
8
  import math
9
  import pathlib
10
+ import pdb
11
  from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
12
 
13
  import torch
 
29
  )
30
  from transformers.modeling_utils import PreTrainedModel
31
  from transformers.utils import logging
32
+ import numpy as np
33
 
34
  try:
35
  from einops import rearrange
 
243
 
244
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
245
  self.seq_length = config.seq_length
246
+ self.forward_max_length = config.forward_max_length
247
 
248
  self.hidden_size = config.hidden_size
249
  self.split_size = config.hidden_size
 
279
 
280
  self.use_dynamic_ntk = config.use_dynamic_ntk
281
  self.use_logn_attn = config.use_logn_attn
282
+ self.use_rerope = config.use_rerope
283
+ self.rerope_window = config.rerope_window
284
+ self.causal = True
285
 
286
  logn_list = [
287
  math.log(i, self.seq_length) if i > self.seq_length else 1
288
+ for i in range(1, self.forward_max_length)
289
  ]
290
  logn_tensor = torch.tensor(logn_list)[None, :, None, None]
291
  self.register_buffer("logn_tensor", logn_tensor, persistent=False)
292
 
293
  self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
294
  self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
 
295
  self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
296
  cache_dtype = torch.float
297
  if self.bf16:
 
301
  self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
302
  self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ def _upcast_and_reordered_attn(
306
+ self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
307
+ ):
308
+ bsz, num_heads, q_seq_len, dk = query.size()
309
+ _, _, k_seq_len, _ = key.size()
310
+
311
+ attn_weights = torch.empty(
312
+ bsz * num_heads,
313
+ q_seq_len,
314
+ k_seq_len,
315
+ dtype=torch.float32,
316
+ device=query.device,
317
+ )
318
+
319
+ scale_factor = 1.0
320
  if self.scale_attn_weights:
321
+ scale_factor /= float(value.size(-1)) ** 0.5
322
+
323
+ with autocast(enabled=False):
324
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
325
+ -1, dk, k_seq_len
 
 
 
 
326
  )
327
+ attn_weights = torch.baddbmm(
328
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
329
+ )
330
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
331
+
332
+ query_length, key_length = query.size(-2), key.size(-2)
333
  causal_mask = registered_causal_mask[
334
  :, :, key_length - query_length : key_length, :key_length
335
  ]
336
  mask_value = torch.finfo(attn_weights.dtype).min
337
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
338
  attn_weights.device
339
  )
340
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
 
 
341
 
342
  if attention_mask is not None:
343
  attn_weights = attn_weights + attention_mask
344
 
345
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
 
 
 
346
 
347
+ if attn_weights.dtype != torch.float32:
348
+ raise RuntimeError(
349
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
350
+ )
351
+ attn_weights = attn_weights.type(value.dtype)
352
  attn_weights = self.attn_dropout(attn_weights)
353
 
354
  if head_mask is not None:
355
  attn_weights = attn_weights * head_mask
356
 
357
+ attn_output = torch.matmul(attn_weights, value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
  return attn_output, attn_weights
360
 
 
367
  tensor = tensor.contiguous()
368
  new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
369
  return tensor.view(new_shape)
370
+
371
+ def rotate_half(self, x):
372
+ """Rotates half the hidden dims of the input."""
373
+ x1 = x[..., :x.shape[-1] // 2]
374
+ x2 = x[..., x.shape[-1] // 2:]
375
+ return torch.cat((-x2, x1), dim=-1)
376
+
377
+ def apply_rotary_pos_emb_rerope(self, query, key, cos, sin, position_ids):
378
+ # take bsz into consideration
379
+ assert 1 == position_ids.shape[0]
380
+
381
+ cos = cos.squeeze(0).squeeze(1)
382
+ cos = cos[position_ids][:,:,None,:] # [bs, seq_len, 1, dim] to [1, pos_len, 1, dim]
383
+ sin = sin.squeeze(0).squeeze(1)
384
+ sin = sin[position_ids][:,:,None,:] # [bs, seq_len, 1, dim] to [1, pos_len, 1, dim]
385
+
386
+ q_embed = ((query * cos[:,-query.shape[1]:]) + (self.rotate_half(query) * sin[:,-query.shape[1]:])).to(query.dtype) if query is not None else None
387
+ k_embed = ((key * cos) + (self.rotate_half(key) * sin)).to(key.dtype) if key is not None else None
388
+ return q_embed, k_embed
389
 
390
  def forward(
391
  self,
392
  hidden_states: Optional[Tuple[torch.FloatTensor]],
393
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
394
+ registered_causal_mask: Optional[torch.Tensor] = None,
395
  layer_past: Optional[Tuple[torch.Tensor]] = None,
396
  attention_mask: Optional[torch.FloatTensor] = None,
397
  head_mask: Optional[torch.FloatTensor] = None,
 
408
  key = self._split_heads(key, self.num_heads, self.head_dim)
409
  value = self._split_heads(value, self.num_heads, self.head_dim)
410
 
411
+ q_len = hidden_states.shape[1]
412
+ assert rotary_pos_emb_list is not None
413
+ assert output_attentions is False
414
+
415
+ # TODO
416
+ # 1. 移除动态量化
417
+ # 2. 用了 logn
418
+ # 3. 准备增加 context rotary_emb_apply
419
+
420
+ cos, sin = rotary_pos_emb_list[0]
421
+ assert len(rotary_pos_emb_list) == 1
422
+
423
+ if q_len == 1:
424
+ # position_ids = torch.tensor([[layer_past[0].shape[1]]], dtype=torch.int64, device=query.device)
425
+ # query *= ((position_ids.flatten() + 1)[None, :, None, None].log() / np.log(self.train_length)).clip(1).to(query.dtype)
426
+
427
+ if layer_past is not None:
428
+ past_key, past_value = layer_past[0], layer_past[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  key = torch.cat((past_key, key), dim=1)
430
  value = torch.cat((past_value, value), dim=1)
 
 
431
  present = (key, value)
 
 
432
 
433
+ # position embedding
434
+ position_ids = torch.arange(layer_past[0].shape[1] + 1, device=query.device).unsqueeze(0)
435
+ position_ids = (position_ids[:, -1] - position_ids).clip(max=self.rerope_window)
436
+ _, key = self.apply_rotary_pos_emb_rerope(None, key, cos, -sin, position_ids)
437
+
438
+ if self.use_logn_attn:
439
  seq_start = key.size(1) - query.size(1)
440
  seq_end = key.size(1)
441
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
442
+ query = query * logn_tensor.expand_as(query)
443
+ # attn
 
 
 
 
 
 
 
 
 
 
 
 
444
  query = query.permute(0, 2, 1, 3)
445
+ key = key.permute(0, 2, 1, 3)
446
+ value = value.permute(0, 2, 1, 3)
447
+
448
+ causal_mask = registered_causal_mask[
449
+ :, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
450
+ ]
451
+ if attention_mask is not None:
452
+ attention_mask = attention_mask.expand(
453
+ -1, -1, causal_mask.size(2), -1
454
+ ).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  else:
456
+ attention_mask = causal_mask
457
+
458
+
459
+ attn_output = F.scaled_dot_product_attention(
460
+ query, key, value, attn_mask=attention_mask
461
+ ).transpose(1, 2)
462
+
463
+ else:
464
+ # prefill
465
+ position_ids = torch.arange(query.shape[1], device=query.device).unsqueeze(0)
466
+ # query *= ((position_ids.flatten() + 1)[None, :, None, None].log() / np.log(self.train_length)).clip(1).to(query.dtype)
467
+ present = (key, value)
468
+
469
+ if self.use_logn_attn:
470
+ seq_start = key.size(1) - query.size(1)
471
+ seq_end = key.size(1)
472
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
473
+ query = query * logn_tensor.expand_as(query)
474
+
475
+ query_states1, key_states1 = self.apply_rotary_pos_emb_rerope(query, key, cos, sin, position_ids)
476
+ query_states2, _ = self.apply_rotary_pos_emb_rerope(query, None, cos, sin, position_ids * 0 + self.rerope_window)
477
+
478
+ query_states1 = query_states1.permute(0, 2, 1, 3)
479
+ query_states2 = query_states2.permute(0, 2, 1, 3)
480
+ key_states1 = key_states1.permute(0, 2, 1, 3)
481
+ key_states2 = key.to(key_states1.dtype).permute(0, 2, 1, 3)
482
+ value = value.permute(0, 2, 1, 3)
483
+
484
+ sm_scale = 1.0 / math.sqrt(self.head_dim)
485
+ attn_weights1 = torch.matmul(query_states1, key_states1.transpose(2, 3)) * sm_scale
486
+ attn_weights2 = torch.matmul(query_states2, key_states2.transpose(2, 3)) * sm_scale
487
+ rectified_mask = (position_ids[:, -q_len:, None] - position_ids[:, None]).abs() < self.rerope_window
488
+ attn_weights = torch.where(rectified_mask, attn_weights1, attn_weights2)
489
+
490
+ if self.causal:
491
+ tgt_len = attn_weights.shape[-1]
492
+ dtype = attn_weights.dtype
493
+ device = attn_weights.device
494
+ mask = torch.full((tgt_len, tgt_len),
495
+ torch.finfo(dtype).min,
496
+ device=device)
497
+ mask_cond = torch.arange(mask.size(-1), device=device)
498
+ mask.masked_fill_(
499
+ mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
500
+ mask = mask.to(dtype)
501
+ attn_weights = attn_weights + mask
502
+
503
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
504
+ attn_output = torch.matmul(attn_weights, value).transpose(1, 2)
505
+
506
  context_layer = self._merge_heads(
507
  attn_output, self.num_heads, self.head_dim
508
  )
 
510
  attn_output = self.c_proj(context_layer)
511
 
512
  outputs = (attn_output, present)
 
 
 
 
 
 
 
 
 
513
 
514
  return outputs
515
 
 
555
  self,
556
  hidden_states: Optional[Tuple[torch.FloatTensor]],
557
  rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
558
+ registered_causal_mask: Optional[torch.Tensor] = None,
559
  layer_past: Optional[Tuple[torch.Tensor]] = None,
560
  attention_mask: Optional[torch.FloatTensor] = None,
561
  head_mask: Optional[torch.FloatTensor] = None,
 
569
  attn_outputs = self.attn(
570
  layernorm_output,
571
  rotary_pos_emb_list,
572
+ registered_causal_mask=registered_causal_mask,
573
  layer_past=layer_past,
574
  attention_mask=attention_mask,
575
  head_mask=head_mask,
 
643
  self.vocab_size = config.vocab_size
644
  self.num_hidden_layers = config.num_hidden_layers
645
  self.embed_dim = config.hidden_size
 
646
 
647
  self.gradient_checkpointing = False
648
  self.use_dynamic_ntk = config.use_dynamic_ntk
649
+ assert self.use_dynamic_ntk is False
650
+ self.use_rerope = config.use_rerope
651
+ self.rerope_window = config.rerope_window
652
+ assert self.use_rerope is True
653
  self.seq_length = config.seq_length
654
 
655
  self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
 
672
 
673
  self.use_flash_attn = config.use_flash_attn
674
  self.is_fp32 = not (config.bf16 or config.fp16)
675
+ if (
676
+ self.use_flash_attn
677
+ and flash_attn_unpadded_func is not None
678
+ and not self.is_fp32
679
+ ):
680
+ self.registered_causal_mask = None
681
+ else:
682
+ max_positions = config.max_position_embeddings
683
+ self.register_buffer(
684
+ "registered_causal_mask",
685
+ torch.tril(
686
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
687
+ ).view(1, 1, max_positions, max_positions),
688
+ persistent=False,
689
+ )
690
 
691
  self.h = nn.ModuleList(
692
  [
 
771
  past_length = 0
772
  past_key_values = tuple([None] * len(self.h))
773
  else:
774
+ past_length = past_key_values[0][0].size(-2)
 
 
 
775
  if position_ids is None:
776
  position_ids = torch.arange(
777
  past_length,
 
799
  kv_seq_len = hidden_states.size()[1]
800
  if past_key_values[0] is not None:
801
  # past key values[0][0] shape: bs * seq_len * head_num * dim
802
+ kv_seq_len += past_key_values[0][0].shape[1]
 
 
 
803
 
804
  if self.training or not self.use_dynamic_ntk:
805
  ntk_alpha_list = [1.0]
 
817
  ntk_alpha = self.get_ntk_alpha(kv_seq_len)
818
  ntk_alpha_list.append(ntk_alpha)
819
  self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
820
+ if kv_seq_len > 1:
821
+ # prefill
822
+ rotary_emb_seq_len = max(kv_seq_len, self.rerope_window + 1)
823
+ else:
824
+ rotary_emb_seq_len = kv_seq_len
825
  rotary_pos_emb_list = [
826
+ self.rotary_emb(rotary_emb_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
827
  ]
828
+
829
  hidden_states = self.drop(hidden_states)
830
  output_shape = input_shape + (hidden_states.size(-1),)
831
 
 
857
  create_custom_forward(block),
858
  hidden_states,
859
  rotary_pos_emb_list,
860
+ self.registered_causal_mask,
861
  None,
862
  attention_mask,
863
  head_mask[i],
 
869
  hidden_states,
870
  layer_past=layer_past,
871
  rotary_pos_emb_list=rotary_pos_emb_list,
872
+ registered_causal_mask=self.registered_causal_mask,
873
  attention_mask=attention_mask,
874
  head_mask=head_mask[i],
875
  encoder_hidden_states=encoder_hidden_states,
test_passkey_retrieval.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from numpy import random
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import pdb
6
+
7
+ def parse_config():
8
+ parser = argparse.ArgumentParser(description='arg parser')
9
+ parser.add_argument('--max_tokens', type=int, default=20000, help='maximum token length for evaluation')
10
+ parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
11
+ parser.add_argument('--num_tests', type=int, default=30, help='number of repeat testing for each length')
12
+
13
+ args = parser.parse_args()
14
+ return args
15
+
16
+ # copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
17
+ def generate_prompt_landmark(n_garbage=60000, seed=666):
18
+ """Generates a text file and inserts an passkey at a random position."""
19
+ rnd_state = random.get_state()
20
+ random.seed(seed)
21
+ n_garbage_prefix = random.randint(0, n_garbage)
22
+ n_garbage_suffix = n_garbage - n_garbage_prefix
23
+
24
+ task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
25
+ garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
26
+ garbage_inf = " ".join([garbage] * 5000)
27
+ assert len(garbage_inf) >= n_garbage
28
+ garbage_prefix = garbage_inf[:n_garbage_prefix]
29
+ garbage_suffix = garbage_inf[:n_garbage_suffix]
30
+ pass_key = random.randint(1, 50000)
31
+ information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
32
+ final_question = "What is the pass key? The pass key is"
33
+ print('idx : {}'.format(len(task_description) + len(garbage_prefix)))
34
+ lines = [
35
+ task_description,
36
+ garbage_prefix,
37
+ information_line,
38
+ garbage_suffix,
39
+ final_question,
40
+ ]
41
+ random.set_state(rnd_state)
42
+ return "\n".join(lines), str(pass_key)
43
+
44
+ # NTK+log on Qwen-7B tokens {'5801': 0.95, '7986': 0.9, '8805': 0.85, '9897': 0.8, '11809': 0.95, '12900': 0.78, '13993':0.06, '14812': 0.0}
45
+ # ReRoPE on Qwen-7B
46
+ def main(args):
47
+ # Load model and tokenizer
48
+ tokenizer = AutoTokenizer.from_pretrained('/models/Qwen-7B-Chat-ReRoPE', trust_remote_code=True)
49
+ model = AutoModelForCausalLM.from_pretrained('/models/Qwen-7B-Chat-ReRoPE', trust_remote_code=True).eval().cuda('cuda:3')
50
+ # tokenizer = AutoTokenizer.from_pretrained('/models/Qwen-14B-Chat', trust_remote_code=True)
51
+ # model = AutoModelForCausalLM.from_pretrained('/models/Qwen-14B-Chat', trust_remote_code=True).eval().cuda('cuda:3')
52
+
53
+ all_accuries = {}
54
+ # This is a rough ratio to control the number of texts and tokens
55
+ # for val in [8000, 9000, 10000, 11000, 13000, 14000, 15000, 16000, 17000]:
56
+ for val in range(2000, 12000, args.interval):
57
+ n_garbage = int(3.75 * val // 1024 * 1024)
58
+ passed_tests = 0
59
+ total_tokens = 0
60
+
61
+ for j in range(args.num_tests):
62
+ prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j)
63
+ response, _ = model.chat(tokenizer, prompt, history=[], top_k=1)
64
+ print((response, pass_key))
65
+ if pass_key in response:
66
+ passed_tests += 1
67
+ total_tokens += len(tokenizer(prompt).input_ids)
68
+ avg_tokens = total_tokens//args.num_tests
69
+ accuracy = passed_tests/args.num_tests
70
+ print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
71
+ all_accuries[str(avg_tokens)] = accuracy
72
+
73
+ all_accuries = {}
74
+ # This is a rough ratio to control the number of texts and tokens
75
+ # for val in [8000, 9000, 10000, 11000, 13000, 14000, 15000, 16000, 17000]:
76
+ for val in range(2000, 12000, args.interval):
77
+ n_garbage = int(3.75 * val // 1024 * 1024)
78
+ passed_tests = 0
79
+ total_tokens = 0
80
+
81
+ for j in range(args.num_tests):
82
+ prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j+val)
83
+ response, _ = model.chat(tokenizer, prompt, history=[])
84
+ print((response, pass_key))
85
+ if pass_key in response:
86
+ passed_tests += 1
87
+ total_tokens += len(tokenizer(prompt).input_ids)
88
+ avg_tokens = total_tokens//args.num_tests
89
+ accuracy = passed_tests/args.num_tests
90
+ print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
91
+ all_accuries[str(avg_tokens)] = accuracy
92
+ print("accuries over tokens", all_accuries)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ args = parse_config()
97
+ main(args)