minho commited on
Commit
eb7c31f
·
verified ·
1 Parent(s): de40af9

Update modeling_orion.py

Browse files
Files changed (1) hide show
  1. modeling_orion.py +6 -3
modeling_orion.py CHANGED
@@ -30,9 +30,12 @@ from transformers.utils import (
30
  )
31
  import logging
32
  if is_flash_attn_2_available():
33
- from flash_attn import flash_attn_func, flash_attn_varlen_func
34
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
-
 
 
 
36
  logger = logging.getLogger(__name__)
37
 
38
  _CONFIG_FOR_DOC = "OrionConfig"
 
30
  )
31
  import logging
32
  if is_flash_attn_2_available():
33
+ try:
34
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
35
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
36
+ except ImportError:
37
+ pass
38
+
39
  logger = logging.getLogger(__name__)
40
 
41
  _CONFIG_FOR_DOC = "OrionConfig"