ReRoPE-test-report
#47
by
tpoisonooo
- opened
- config.json +7 -4
- modeling_qwen.py +181 -201
- test_passkey_retrieval.py +97 -0
config.json
CHANGED
@@ -7,7 +7,7 @@
|
|
7 |
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
|
8 |
},
|
9 |
"attn_dropout_prob": 0.0,
|
10 |
-
"bf16":
|
11 |
"emb_dropout_prob": 0.0,
|
12 |
"fp16": false,
|
13 |
"fp32": false,
|
@@ -30,8 +30,11 @@
|
|
30 |
"tokenizer_class": "QWenTokenizer",
|
31 |
"transformers_version": "4.32.0",
|
32 |
"use_cache": true,
|
33 |
-
"use_dynamic_ntk":
|
34 |
-
"use_flash_attn":
|
|
|
|
|
|
|
35 |
"use_logn_attn": true,
|
36 |
"vocab_size": 151936
|
37 |
-
}
|
|
|
7 |
"AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
|
8 |
},
|
9 |
"attn_dropout_prob": 0.0,
|
10 |
+
"bf16": true,
|
11 |
"emb_dropout_prob": 0.0,
|
12 |
"fp16": false,
|
13 |
"fp32": false,
|
|
|
30 |
"tokenizer_class": "QWenTokenizer",
|
31 |
"transformers_version": "4.32.0",
|
32 |
"use_cache": true,
|
33 |
+
"use_dynamic_ntk": false,
|
34 |
+
"use_flash_attn": false,
|
35 |
+
"use_rerope": true,
|
36 |
+
"rerope_window": 512,
|
37 |
+
"forward_max_length": 32768,
|
38 |
"use_logn_attn": true,
|
39 |
"vocab_size": 151936
|
40 |
+
}
|
modeling_qwen.py
CHANGED
@@ -7,6 +7,7 @@ import copy
|
|
7 |
import importlib
|
8 |
import math
|
9 |
import pathlib
|
|
|
10 |
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
11 |
|
12 |
import torch
|
@@ -28,6 +29,7 @@ from transformers.modeling_outputs import (
|
|
28 |
)
|
29 |
from transformers.modeling_utils import PreTrainedModel
|
30 |
from transformers.utils import logging
|
|
|
31 |
|
32 |
try:
|
33 |
from einops import rearrange
|
@@ -241,6 +243,7 @@ class QWenAttention(nn.Module):
|
|
241 |
|
242 |
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
243 |
self.seq_length = config.seq_length
|
|
|
244 |
|
245 |
self.hidden_size = config.hidden_size
|
246 |
self.split_size = config.hidden_size
|
@@ -276,17 +279,19 @@ class QWenAttention(nn.Module):
|
|
276 |
|
277 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
278 |
self.use_logn_attn = config.use_logn_attn
|
|
|
|
|
|
|
279 |
|
280 |
logn_list = [
|
281 |
math.log(i, self.seq_length) if i > self.seq_length else 1
|
282 |
-
for i in range(1,
|
283 |
]
|
284 |
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
285 |
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
286 |
|
287 |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
288 |
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
|
289 |
-
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
|
290 |
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
|
291 |
cache_dtype = torch.float
|
292 |
if self.bf16:
|
@@ -296,102 +301,60 @@ class QWenAttention(nn.Module):
|
|
296 |
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
297 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
298 |
|
299 |
-
if config.use_cache_quantization and config.use_cache_kernel:
|
300 |
-
# pre check if the support files existing
|
301 |
-
module_root = pathlib.Path(__file__).parent
|
302 |
-
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
|
303 |
-
if any(not (module_root/src).is_file() for src in src_files):
|
304 |
-
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
305 |
-
self.cache_kernels = None
|
306 |
-
else:
|
307 |
-
try:
|
308 |
-
from .cpp_kernels import cache_autogptq_cuda_256
|
309 |
-
self.cache_kernels = cache_autogptq_cuda_256
|
310 |
-
except ImportError:
|
311 |
-
warnings.warn("Failed to import KV cache kernels.")
|
312 |
-
self.cache_kernels = None
|
313 |
-
|
314 |
-
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
|
315 |
-
device = query.device
|
316 |
-
if self.use_cache_quantization:
|
317 |
-
qk, qk_scale, qk_zero = key
|
318 |
-
if self.use_cache_kernel and self.cache_kernels is not None:
|
319 |
-
shape = query.shape[:-1] + (qk.shape[-2],)
|
320 |
-
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
321 |
-
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
322 |
-
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
323 |
-
qk.transpose(-1, -2).contiguous(),
|
324 |
-
attn_weights,
|
325 |
-
qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
|
326 |
-
qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
|
327 |
-
# attn_weights = attn_weights.to(query.dtype).contiguous()
|
328 |
-
else:
|
329 |
-
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
|
330 |
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
331 |
-
else:
|
332 |
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
if self.scale_attn_weights:
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
[],
|
341 |
-
size_temp ** 0.5,
|
342 |
-
dtype=attn_weights.dtype,
|
343 |
-
device=attn_weights.device,
|
344 |
)
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
|
|
349 |
causal_mask = registered_causal_mask[
|
350 |
:, :, key_length - query_length : key_length, :key_length
|
351 |
]
|
352 |
mask_value = torch.finfo(attn_weights.dtype).min
|
353 |
-
mask_value = torch.
|
354 |
attn_weights.device
|
355 |
)
|
356 |
-
attn_weights = torch.where(
|
357 |
-
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
358 |
-
)
|
359 |
|
360 |
if attention_mask is not None:
|
361 |
attn_weights = attn_weights + attention_mask
|
362 |
|
363 |
-
|
364 |
-
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
|
365 |
-
else:
|
366 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
367 |
|
368 |
-
attn_weights
|
|
|
|
|
|
|
|
|
369 |
attn_weights = self.attn_dropout(attn_weights)
|
370 |
|
371 |
if head_mask is not None:
|
372 |
attn_weights = attn_weights * head_mask
|
373 |
|
374 |
-
|
375 |
-
qv, qv_scale, qv_zero = value
|
376 |
-
if self.use_cache_kernel and self.cache_kernels is not None:
|
377 |
-
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
378 |
-
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
379 |
-
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
380 |
-
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
381 |
-
qv.contiguous(), # dtype: int32
|
382 |
-
attn_output,
|
383 |
-
qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
|
384 |
-
qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
|
385 |
-
if attn_output.dtype != query.dtype:
|
386 |
-
attn_output = attn_output.to(query.dtype)
|
387 |
-
attn_weights = attn_weights.to(query.dtype)
|
388 |
-
else:
|
389 |
-
value = dequantize_cache_torch(qv, qv_scale, qv_zero)
|
390 |
-
attn_output = torch.matmul(attn_weights, value)
|
391 |
-
else:
|
392 |
-
attn_output = torch.matmul(attn_weights, value)
|
393 |
-
|
394 |
-
attn_output = attn_output.transpose(1, 2)
|
395 |
|
396 |
return attn_output, attn_weights
|
397 |
|
@@ -404,11 +367,31 @@ class QWenAttention(nn.Module):
|
|
404 |
tensor = tensor.contiguous()
|
405 |
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
406 |
return tensor.view(new_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
|
408 |
def forward(
|
409 |
self,
|
410 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
411 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
|
412 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
413 |
attention_mask: Optional[torch.FloatTensor] = None,
|
414 |
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -425,116 +408,101 @@ class QWenAttention(nn.Module):
|
|
425 |
key = self._split_heads(key, self.num_heads, self.head_dim)
|
426 |
value = self._split_heads(value, self.num_heads, self.head_dim)
|
427 |
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
447 |
-
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
448 |
-
query = torch.cat(query_list, dim=0)
|
449 |
-
key = torch.cat(key_list, dim=0)
|
450 |
-
|
451 |
-
if self.use_cache_quantization:
|
452 |
-
key = quantize_cache_v(key.permute(0, 2, 1, 3),
|
453 |
-
bits=8,
|
454 |
-
qmin=self.cache_qmin,
|
455 |
-
qmax=self.cache_qmax)
|
456 |
-
value = quantize_cache_v(value.permute(0, 2, 1, 3),
|
457 |
-
bits=8,
|
458 |
-
qmin=self.cache_qmin,
|
459 |
-
qmax=self.cache_qmax)
|
460 |
-
|
461 |
-
|
462 |
-
if layer_past is not None:
|
463 |
-
past_key, past_value = layer_past[0], layer_past[1]
|
464 |
-
if self.use_cache_quantization:
|
465 |
-
# use_cache_quantization:
|
466 |
-
# present=((q_key,key_scale,key_zero_point),
|
467 |
-
# (q_value,value_scale,value_zero_point))
|
468 |
-
key = (torch.cat((past_key[0], key[0]), dim=2),
|
469 |
-
torch.cat((past_key[1], key[1]), dim=2),
|
470 |
-
torch.cat((past_key[2], key[2]), dim=2))
|
471 |
-
value = (torch.cat((past_value[0], value[0]), dim=2),
|
472 |
-
torch.cat((past_value[1], value[1]), dim=2),
|
473 |
-
torch.cat((past_value[2], value[2]), dim=2))
|
474 |
-
else:
|
475 |
-
# not use_cache_quantization:
|
476 |
-
# present=(key,value)
|
477 |
key = torch.cat((past_key, key), dim=1)
|
478 |
value = torch.cat((past_value, value), dim=1)
|
479 |
-
|
480 |
-
if use_cache:
|
481 |
present = (key, value)
|
482 |
-
else:
|
483 |
-
present = None
|
484 |
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
|
|
490 |
seq_start = key.size(1) - query.size(1)
|
491 |
seq_end = key.size(1)
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
if (
|
496 |
-
self.use_flash_attn
|
497 |
-
and flash_attn_unpadded_func is not None
|
498 |
-
and not self.is_fp32
|
499 |
-
and query.is_cuda
|
500 |
-
):
|
501 |
-
q, k, v = query, key, value
|
502 |
-
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
503 |
-
else:
|
504 |
-
registered_causal_mask = torch.tril(
|
505 |
-
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
|
506 |
-
).view(1, 1, key.size(1), key.size(1))
|
507 |
query = query.permute(0, 2, 1, 3)
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
519 |
-
|
520 |
-
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
521 |
-
causal_mask = registered_causal_mask[
|
522 |
-
:, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
|
523 |
-
]
|
524 |
-
if attention_mask is not None:
|
525 |
-
attention_mask = attention_mask.expand(
|
526 |
-
-1, -1, causal_mask.size(2), -1
|
527 |
-
).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
528 |
-
else:
|
529 |
-
attention_mask = causal_mask
|
530 |
-
attn_output = F.scaled_dot_product_attention(
|
531 |
-
query, key, value, attn_mask=attention_mask
|
532 |
-
).transpose(1, 2)
|
533 |
-
attn_weight = None
|
534 |
else:
|
535 |
-
|
536 |
-
|
537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
context_layer = self._merge_heads(
|
539 |
attn_output, self.num_heads, self.head_dim
|
540 |
)
|
@@ -542,15 +510,6 @@ class QWenAttention(nn.Module):
|
|
542 |
attn_output = self.c_proj(context_layer)
|
543 |
|
544 |
outputs = (attn_output, present)
|
545 |
-
if output_attentions:
|
546 |
-
if (
|
547 |
-
self.use_flash_attn
|
548 |
-
and flash_attn_unpadded_func is not None
|
549 |
-
and not self.is_fp32
|
550 |
-
):
|
551 |
-
raise ValueError("Cannot output attentions while using flash-attn")
|
552 |
-
else:
|
553 |
-
outputs += (attn_weight,)
|
554 |
|
555 |
return outputs
|
556 |
|
@@ -596,6 +555,7 @@ class QWenBlock(nn.Module):
|
|
596 |
self,
|
597 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
598 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
|
|
599 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
600 |
attention_mask: Optional[torch.FloatTensor] = None,
|
601 |
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -609,6 +569,7 @@ class QWenBlock(nn.Module):
|
|
609 |
attn_outputs = self.attn(
|
610 |
layernorm_output,
|
611 |
rotary_pos_emb_list,
|
|
|
612 |
layer_past=layer_past,
|
613 |
attention_mask=attention_mask,
|
614 |
head_mask=head_mask,
|
@@ -682,10 +643,13 @@ class QWenModel(QWenPreTrainedModel):
|
|
682 |
self.vocab_size = config.vocab_size
|
683 |
self.num_hidden_layers = config.num_hidden_layers
|
684 |
self.embed_dim = config.hidden_size
|
685 |
-
self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
|
686 |
|
687 |
self.gradient_checkpointing = False
|
688 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
|
|
|
|
|
|
|
|
689 |
self.seq_length = config.seq_length
|
690 |
|
691 |
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
|
@@ -708,6 +672,21 @@ class QWenModel(QWenPreTrainedModel):
|
|
708 |
|
709 |
self.use_flash_attn = config.use_flash_attn
|
710 |
self.is_fp32 = not (config.bf16 or config.fp16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
|
712 |
self.h = nn.ModuleList(
|
713 |
[
|
@@ -792,10 +771,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
792 |
past_length = 0
|
793 |
past_key_values = tuple([None] * len(self.h))
|
794 |
else:
|
795 |
-
|
796 |
-
past_length = past_key_values[0][0][0].size(2)
|
797 |
-
else:
|
798 |
-
past_length = past_key_values[0][0].size(-2)
|
799 |
if position_ids is None:
|
800 |
position_ids = torch.arange(
|
801 |
past_length,
|
@@ -823,10 +799,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
823 |
kv_seq_len = hidden_states.size()[1]
|
824 |
if past_key_values[0] is not None:
|
825 |
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
826 |
-
|
827 |
-
kv_seq_len += past_key_values[0][0][0].shape[2]
|
828 |
-
else:
|
829 |
-
kv_seq_len += past_key_values[0][0].shape[1]
|
830 |
|
831 |
if self.training or not self.use_dynamic_ntk:
|
832 |
ntk_alpha_list = [1.0]
|
@@ -844,10 +817,15 @@ class QWenModel(QWenPreTrainedModel):
|
|
844 |
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
845 |
ntk_alpha_list.append(ntk_alpha)
|
846 |
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
|
|
|
|
|
|
|
|
|
|
847 |
rotary_pos_emb_list = [
|
848 |
-
self.rotary_emb(
|
849 |
]
|
850 |
-
|
851 |
hidden_states = self.drop(hidden_states)
|
852 |
output_shape = input_shape + (hidden_states.size(-1),)
|
853 |
|
@@ -879,6 +857,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
879 |
create_custom_forward(block),
|
880 |
hidden_states,
|
881 |
rotary_pos_emb_list,
|
|
|
882 |
None,
|
883 |
attention_mask,
|
884 |
head_mask[i],
|
@@ -890,6 +869,7 @@ class QWenModel(QWenPreTrainedModel):
|
|
890 |
hidden_states,
|
891 |
layer_past=layer_past,
|
892 |
rotary_pos_emb_list=rotary_pos_emb_list,
|
|
|
893 |
attention_mask=attention_mask,
|
894 |
head_mask=head_mask[i],
|
895 |
encoder_hidden_states=encoder_hidden_states,
|
|
|
7 |
import importlib
|
8 |
import math
|
9 |
import pathlib
|
10 |
+
import pdb
|
11 |
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
12 |
|
13 |
import torch
|
|
|
29 |
)
|
30 |
from transformers.modeling_utils import PreTrainedModel
|
31 |
from transformers.utils import logging
|
32 |
+
import numpy as np
|
33 |
|
34 |
try:
|
35 |
from einops import rearrange
|
|
|
243 |
|
244 |
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
245 |
self.seq_length = config.seq_length
|
246 |
+
self.forward_max_length = config.forward_max_length
|
247 |
|
248 |
self.hidden_size = config.hidden_size
|
249 |
self.split_size = config.hidden_size
|
|
|
279 |
|
280 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
281 |
self.use_logn_attn = config.use_logn_attn
|
282 |
+
self.use_rerope = config.use_rerope
|
283 |
+
self.rerope_window = config.rerope_window
|
284 |
+
self.causal = True
|
285 |
|
286 |
logn_list = [
|
287 |
math.log(i, self.seq_length) if i > self.seq_length else 1
|
288 |
+
for i in range(1, self.forward_max_length)
|
289 |
]
|
290 |
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
|
291 |
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
292 |
|
293 |
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
294 |
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
|
|
|
295 |
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
|
296 |
cache_dtype = torch.float
|
297 |
if self.bf16:
|
|
|
301 |
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
302 |
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
+
def _upcast_and_reordered_attn(
|
306 |
+
self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
|
307 |
+
):
|
308 |
+
bsz, num_heads, q_seq_len, dk = query.size()
|
309 |
+
_, _, k_seq_len, _ = key.size()
|
310 |
+
|
311 |
+
attn_weights = torch.empty(
|
312 |
+
bsz * num_heads,
|
313 |
+
q_seq_len,
|
314 |
+
k_seq_len,
|
315 |
+
dtype=torch.float32,
|
316 |
+
device=query.device,
|
317 |
+
)
|
318 |
+
|
319 |
+
scale_factor = 1.0
|
320 |
if self.scale_attn_weights:
|
321 |
+
scale_factor /= float(value.size(-1)) ** 0.5
|
322 |
+
|
323 |
+
with autocast(enabled=False):
|
324 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
|
325 |
+
-1, dk, k_seq_len
|
|
|
|
|
|
|
|
|
326 |
)
|
327 |
+
attn_weights = torch.baddbmm(
|
328 |
+
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
|
329 |
+
)
|
330 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
331 |
+
|
332 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
333 |
causal_mask = registered_causal_mask[
|
334 |
:, :, key_length - query_length : key_length, :key_length
|
335 |
]
|
336 |
mask_value = torch.finfo(attn_weights.dtype).min
|
337 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
|
338 |
attn_weights.device
|
339 |
)
|
340 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
|
|
|
|
341 |
|
342 |
if attention_mask is not None:
|
343 |
attn_weights = attn_weights + attention_mask
|
344 |
|
345 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
|
|
|
|
|
346 |
|
347 |
+
if attn_weights.dtype != torch.float32:
|
348 |
+
raise RuntimeError(
|
349 |
+
"Error with upcasting, attn_weights does not have dtype torch.float32"
|
350 |
+
)
|
351 |
+
attn_weights = attn_weights.type(value.dtype)
|
352 |
attn_weights = self.attn_dropout(attn_weights)
|
353 |
|
354 |
if head_mask is not None:
|
355 |
attn_weights = attn_weights * head_mask
|
356 |
|
357 |
+
attn_output = torch.matmul(attn_weights, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
return attn_output, attn_weights
|
360 |
|
|
|
367 |
tensor = tensor.contiguous()
|
368 |
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
369 |
return tensor.view(new_shape)
|
370 |
+
|
371 |
+
def rotate_half(self, x):
|
372 |
+
"""Rotates half the hidden dims of the input."""
|
373 |
+
x1 = x[..., :x.shape[-1] // 2]
|
374 |
+
x2 = x[..., x.shape[-1] // 2:]
|
375 |
+
return torch.cat((-x2, x1), dim=-1)
|
376 |
+
|
377 |
+
def apply_rotary_pos_emb_rerope(self, query, key, cos, sin, position_ids):
|
378 |
+
# take bsz into consideration
|
379 |
+
assert 1 == position_ids.shape[0]
|
380 |
+
|
381 |
+
cos = cos.squeeze(0).squeeze(1)
|
382 |
+
cos = cos[position_ids][:,:,None,:] # [bs, seq_len, 1, dim] to [1, pos_len, 1, dim]
|
383 |
+
sin = sin.squeeze(0).squeeze(1)
|
384 |
+
sin = sin[position_ids][:,:,None,:] # [bs, seq_len, 1, dim] to [1, pos_len, 1, dim]
|
385 |
+
|
386 |
+
q_embed = ((query * cos[:,-query.shape[1]:]) + (self.rotate_half(query) * sin[:,-query.shape[1]:])).to(query.dtype) if query is not None else None
|
387 |
+
k_embed = ((key * cos) + (self.rotate_half(key) * sin)).to(key.dtype) if key is not None else None
|
388 |
+
return q_embed, k_embed
|
389 |
|
390 |
def forward(
|
391 |
self,
|
392 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
393 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
394 |
+
registered_causal_mask: Optional[torch.Tensor] = None,
|
395 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
396 |
attention_mask: Optional[torch.FloatTensor] = None,
|
397 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
408 |
key = self._split_heads(key, self.num_heads, self.head_dim)
|
409 |
value = self._split_heads(value, self.num_heads, self.head_dim)
|
410 |
|
411 |
+
q_len = hidden_states.shape[1]
|
412 |
+
assert rotary_pos_emb_list is not None
|
413 |
+
assert output_attentions is False
|
414 |
+
|
415 |
+
# TODO
|
416 |
+
# 1. 移除动态量化
|
417 |
+
# 2. 用了 logn
|
418 |
+
# 3. 准备增加 context rotary_emb_apply
|
419 |
+
|
420 |
+
cos, sin = rotary_pos_emb_list[0]
|
421 |
+
assert len(rotary_pos_emb_list) == 1
|
422 |
+
|
423 |
+
if q_len == 1:
|
424 |
+
# position_ids = torch.tensor([[layer_past[0].shape[1]]], dtype=torch.int64, device=query.device)
|
425 |
+
# query *= ((position_ids.flatten() + 1)[None, :, None, None].log() / np.log(self.train_length)).clip(1).to(query.dtype)
|
426 |
+
|
427 |
+
if layer_past is not None:
|
428 |
+
past_key, past_value = layer_past[0], layer_past[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
key = torch.cat((past_key, key), dim=1)
|
430 |
value = torch.cat((past_value, value), dim=1)
|
|
|
|
|
431 |
present = (key, value)
|
|
|
|
|
432 |
|
433 |
+
# position embedding
|
434 |
+
position_ids = torch.arange(layer_past[0].shape[1] + 1, device=query.device).unsqueeze(0)
|
435 |
+
position_ids = (position_ids[:, -1] - position_ids).clip(max=self.rerope_window)
|
436 |
+
_, key = self.apply_rotary_pos_emb_rerope(None, key, cos, -sin, position_ids)
|
437 |
+
|
438 |
+
if self.use_logn_attn:
|
439 |
seq_start = key.size(1) - query.size(1)
|
440 |
seq_end = key.size(1)
|
441 |
+
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
442 |
+
query = query * logn_tensor.expand_as(query)
|
443 |
+
# attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
query = query.permute(0, 2, 1, 3)
|
445 |
+
key = key.permute(0, 2, 1, 3)
|
446 |
+
value = value.permute(0, 2, 1, 3)
|
447 |
+
|
448 |
+
causal_mask = registered_causal_mask[
|
449 |
+
:, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
|
450 |
+
]
|
451 |
+
if attention_mask is not None:
|
452 |
+
attention_mask = attention_mask.expand(
|
453 |
+
-1, -1, causal_mask.size(2), -1
|
454 |
+
).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
else:
|
456 |
+
attention_mask = causal_mask
|
457 |
+
|
458 |
+
|
459 |
+
attn_output = F.scaled_dot_product_attention(
|
460 |
+
query, key, value, attn_mask=attention_mask
|
461 |
+
).transpose(1, 2)
|
462 |
+
|
463 |
+
else:
|
464 |
+
# prefill
|
465 |
+
position_ids = torch.arange(query.shape[1], device=query.device).unsqueeze(0)
|
466 |
+
# query *= ((position_ids.flatten() + 1)[None, :, None, None].log() / np.log(self.train_length)).clip(1).to(query.dtype)
|
467 |
+
present = (key, value)
|
468 |
+
|
469 |
+
if self.use_logn_attn:
|
470 |
+
seq_start = key.size(1) - query.size(1)
|
471 |
+
seq_end = key.size(1)
|
472 |
+
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
473 |
+
query = query * logn_tensor.expand_as(query)
|
474 |
+
|
475 |
+
query_states1, key_states1 = self.apply_rotary_pos_emb_rerope(query, key, cos, sin, position_ids)
|
476 |
+
query_states2, _ = self.apply_rotary_pos_emb_rerope(query, None, cos, sin, position_ids * 0 + self.rerope_window)
|
477 |
+
|
478 |
+
query_states1 = query_states1.permute(0, 2, 1, 3)
|
479 |
+
query_states2 = query_states2.permute(0, 2, 1, 3)
|
480 |
+
key_states1 = key_states1.permute(0, 2, 1, 3)
|
481 |
+
key_states2 = key.to(key_states1.dtype).permute(0, 2, 1, 3)
|
482 |
+
value = value.permute(0, 2, 1, 3)
|
483 |
+
|
484 |
+
sm_scale = 1.0 / math.sqrt(self.head_dim)
|
485 |
+
attn_weights1 = torch.matmul(query_states1, key_states1.transpose(2, 3)) * sm_scale
|
486 |
+
attn_weights2 = torch.matmul(query_states2, key_states2.transpose(2, 3)) * sm_scale
|
487 |
+
rectified_mask = (position_ids[:, -q_len:, None] - position_ids[:, None]).abs() < self.rerope_window
|
488 |
+
attn_weights = torch.where(rectified_mask, attn_weights1, attn_weights2)
|
489 |
+
|
490 |
+
if self.causal:
|
491 |
+
tgt_len = attn_weights.shape[-1]
|
492 |
+
dtype = attn_weights.dtype
|
493 |
+
device = attn_weights.device
|
494 |
+
mask = torch.full((tgt_len, tgt_len),
|
495 |
+
torch.finfo(dtype).min,
|
496 |
+
device=device)
|
497 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
498 |
+
mask.masked_fill_(
|
499 |
+
mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
500 |
+
mask = mask.to(dtype)
|
501 |
+
attn_weights = attn_weights + mask
|
502 |
+
|
503 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
504 |
+
attn_output = torch.matmul(attn_weights, value).transpose(1, 2)
|
505 |
+
|
506 |
context_layer = self._merge_heads(
|
507 |
attn_output, self.num_heads, self.head_dim
|
508 |
)
|
|
|
510 |
attn_output = self.c_proj(context_layer)
|
511 |
|
512 |
outputs = (attn_output, present)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
|
514 |
return outputs
|
515 |
|
|
|
555 |
self,
|
556 |
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
557 |
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
|
558 |
+
registered_causal_mask: Optional[torch.Tensor] = None,
|
559 |
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
560 |
attention_mask: Optional[torch.FloatTensor] = None,
|
561 |
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
569 |
attn_outputs = self.attn(
|
570 |
layernorm_output,
|
571 |
rotary_pos_emb_list,
|
572 |
+
registered_causal_mask=registered_causal_mask,
|
573 |
layer_past=layer_past,
|
574 |
attention_mask=attention_mask,
|
575 |
head_mask=head_mask,
|
|
|
643 |
self.vocab_size = config.vocab_size
|
644 |
self.num_hidden_layers = config.num_hidden_layers
|
645 |
self.embed_dim = config.hidden_size
|
|
|
646 |
|
647 |
self.gradient_checkpointing = False
|
648 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
649 |
+
assert self.use_dynamic_ntk is False
|
650 |
+
self.use_rerope = config.use_rerope
|
651 |
+
self.rerope_window = config.rerope_window
|
652 |
+
assert self.use_rerope is True
|
653 |
self.seq_length = config.seq_length
|
654 |
|
655 |
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
|
|
|
672 |
|
673 |
self.use_flash_attn = config.use_flash_attn
|
674 |
self.is_fp32 = not (config.bf16 or config.fp16)
|
675 |
+
if (
|
676 |
+
self.use_flash_attn
|
677 |
+
and flash_attn_unpadded_func is not None
|
678 |
+
and not self.is_fp32
|
679 |
+
):
|
680 |
+
self.registered_causal_mask = None
|
681 |
+
else:
|
682 |
+
max_positions = config.max_position_embeddings
|
683 |
+
self.register_buffer(
|
684 |
+
"registered_causal_mask",
|
685 |
+
torch.tril(
|
686 |
+
torch.ones((max_positions, max_positions), dtype=torch.bool)
|
687 |
+
).view(1, 1, max_positions, max_positions),
|
688 |
+
persistent=False,
|
689 |
+
)
|
690 |
|
691 |
self.h = nn.ModuleList(
|
692 |
[
|
|
|
771 |
past_length = 0
|
772 |
past_key_values = tuple([None] * len(self.h))
|
773 |
else:
|
774 |
+
past_length = past_key_values[0][0].size(-2)
|
|
|
|
|
|
|
775 |
if position_ids is None:
|
776 |
position_ids = torch.arange(
|
777 |
past_length,
|
|
|
799 |
kv_seq_len = hidden_states.size()[1]
|
800 |
if past_key_values[0] is not None:
|
801 |
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
802 |
+
kv_seq_len += past_key_values[0][0].shape[1]
|
|
|
|
|
|
|
803 |
|
804 |
if self.training or not self.use_dynamic_ntk:
|
805 |
ntk_alpha_list = [1.0]
|
|
|
817 |
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
|
818 |
ntk_alpha_list.append(ntk_alpha)
|
819 |
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
820 |
+
if kv_seq_len > 1:
|
821 |
+
# prefill
|
822 |
+
rotary_emb_seq_len = max(kv_seq_len, self.rerope_window + 1)
|
823 |
+
else:
|
824 |
+
rotary_emb_seq_len = kv_seq_len
|
825 |
rotary_pos_emb_list = [
|
826 |
+
self.rotary_emb(rotary_emb_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
827 |
]
|
828 |
+
|
829 |
hidden_states = self.drop(hidden_states)
|
830 |
output_shape = input_shape + (hidden_states.size(-1),)
|
831 |
|
|
|
857 |
create_custom_forward(block),
|
858 |
hidden_states,
|
859 |
rotary_pos_emb_list,
|
860 |
+
self.registered_causal_mask,
|
861 |
None,
|
862 |
attention_mask,
|
863 |
head_mask[i],
|
|
|
869 |
hidden_states,
|
870 |
layer_past=layer_past,
|
871 |
rotary_pos_emb_list=rotary_pos_emb_list,
|
872 |
+
registered_causal_mask=self.registered_causal_mask,
|
873 |
attention_mask=attention_mask,
|
874 |
head_mask=head_mask[i],
|
875 |
encoder_hidden_states=encoder_hidden_states,
|
test_passkey_retrieval.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
from numpy import random
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
def parse_config():
|
8 |
+
parser = argparse.ArgumentParser(description='arg parser')
|
9 |
+
parser.add_argument('--max_tokens', type=int, default=20000, help='maximum token length for evaluation')
|
10 |
+
parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
|
11 |
+
parser.add_argument('--num_tests', type=int, default=30, help='number of repeat testing for each length')
|
12 |
+
|
13 |
+
args = parser.parse_args()
|
14 |
+
return args
|
15 |
+
|
16 |
+
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
|
17 |
+
def generate_prompt_landmark(n_garbage=60000, seed=666):
|
18 |
+
"""Generates a text file and inserts an passkey at a random position."""
|
19 |
+
rnd_state = random.get_state()
|
20 |
+
random.seed(seed)
|
21 |
+
n_garbage_prefix = random.randint(0, n_garbage)
|
22 |
+
n_garbage_suffix = n_garbage - n_garbage_prefix
|
23 |
+
|
24 |
+
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
|
25 |
+
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
|
26 |
+
garbage_inf = " ".join([garbage] * 5000)
|
27 |
+
assert len(garbage_inf) >= n_garbage
|
28 |
+
garbage_prefix = garbage_inf[:n_garbage_prefix]
|
29 |
+
garbage_suffix = garbage_inf[:n_garbage_suffix]
|
30 |
+
pass_key = random.randint(1, 50000)
|
31 |
+
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
|
32 |
+
final_question = "What is the pass key? The pass key is"
|
33 |
+
print('idx : {}'.format(len(task_description) + len(garbage_prefix)))
|
34 |
+
lines = [
|
35 |
+
task_description,
|
36 |
+
garbage_prefix,
|
37 |
+
information_line,
|
38 |
+
garbage_suffix,
|
39 |
+
final_question,
|
40 |
+
]
|
41 |
+
random.set_state(rnd_state)
|
42 |
+
return "\n".join(lines), str(pass_key)
|
43 |
+
|
44 |
+
# NTK+log on Qwen-7B tokens {'5801': 0.95, '7986': 0.9, '8805': 0.85, '9897': 0.8, '11809': 0.95, '12900': 0.78, '13993':0.06, '14812': 0.0}
|
45 |
+
# ReRoPE on Qwen-7B
|
46 |
+
def main(args):
|
47 |
+
# Load model and tokenizer
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained('/models/Qwen-7B-Chat-ReRoPE', trust_remote_code=True)
|
49 |
+
model = AutoModelForCausalLM.from_pretrained('/models/Qwen-7B-Chat-ReRoPE', trust_remote_code=True).eval().cuda('cuda:3')
|
50 |
+
# tokenizer = AutoTokenizer.from_pretrained('/models/Qwen-14B-Chat', trust_remote_code=True)
|
51 |
+
# model = AutoModelForCausalLM.from_pretrained('/models/Qwen-14B-Chat', trust_remote_code=True).eval().cuda('cuda:3')
|
52 |
+
|
53 |
+
all_accuries = {}
|
54 |
+
# This is a rough ratio to control the number of texts and tokens
|
55 |
+
# for val in [8000, 9000, 10000, 11000, 13000, 14000, 15000, 16000, 17000]:
|
56 |
+
for val in range(2000, 12000, args.interval):
|
57 |
+
n_garbage = int(3.75 * val // 1024 * 1024)
|
58 |
+
passed_tests = 0
|
59 |
+
total_tokens = 0
|
60 |
+
|
61 |
+
for j in range(args.num_tests):
|
62 |
+
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j)
|
63 |
+
response, _ = model.chat(tokenizer, prompt, history=[], top_k=1)
|
64 |
+
print((response, pass_key))
|
65 |
+
if pass_key in response:
|
66 |
+
passed_tests += 1
|
67 |
+
total_tokens += len(tokenizer(prompt).input_ids)
|
68 |
+
avg_tokens = total_tokens//args.num_tests
|
69 |
+
accuracy = passed_tests/args.num_tests
|
70 |
+
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
|
71 |
+
all_accuries[str(avg_tokens)] = accuracy
|
72 |
+
|
73 |
+
all_accuries = {}
|
74 |
+
# This is a rough ratio to control the number of texts and tokens
|
75 |
+
# for val in [8000, 9000, 10000, 11000, 13000, 14000, 15000, 16000, 17000]:
|
76 |
+
for val in range(2000, 12000, args.interval):
|
77 |
+
n_garbage = int(3.75 * val // 1024 * 1024)
|
78 |
+
passed_tests = 0
|
79 |
+
total_tokens = 0
|
80 |
+
|
81 |
+
for j in range(args.num_tests):
|
82 |
+
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j+val)
|
83 |
+
response, _ = model.chat(tokenizer, prompt, history=[])
|
84 |
+
print((response, pass_key))
|
85 |
+
if pass_key in response:
|
86 |
+
passed_tests += 1
|
87 |
+
total_tokens += len(tokenizer(prompt).input_ids)
|
88 |
+
avg_tokens = total_tokens//args.num_tests
|
89 |
+
accuracy = passed_tests/args.num_tests
|
90 |
+
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
|
91 |
+
all_accuries[str(avg_tokens)] = accuracy
|
92 |
+
print("accuries over tokens", all_accuries)
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
args = parse_config()
|
97 |
+
main(args)
|