add flash-attn support
Browse files
configuration_chartmoe.py
CHANGED
@@ -53,6 +53,7 @@ class ChartMoEConfig(PretrainedConfig):
|
|
53 |
rope_scaling=None,
|
54 |
num_experts=4,
|
55 |
num_selected=2,
|
|
|
56 |
**kwargs,
|
57 |
):
|
58 |
self.num_experts = num_experts
|
@@ -77,6 +78,10 @@ class ChartMoEConfig(PretrainedConfig):
|
|
77 |
self.rope_theta = rope_theta
|
78 |
self.rope_scaling = rope_scaling
|
79 |
self._rope_scaling_validation()
|
|
|
|
|
|
|
|
|
80 |
super().__init__(
|
81 |
pad_token_id=pad_token_id,
|
82 |
bos_token_id=bos_token_id,
|
|
|
53 |
rope_scaling=None,
|
54 |
num_experts=4,
|
55 |
num_selected=2,
|
56 |
+
attn_implementation=None,
|
57 |
**kwargs,
|
58 |
):
|
59 |
self.num_experts = num_experts
|
|
|
78 |
self.rope_theta = rope_theta
|
79 |
self.rope_scaling = rope_scaling
|
80 |
self._rope_scaling_validation()
|
81 |
+
|
82 |
+
self.attn_implementation = attn_implementation
|
83 |
+
if self.attn_implementation is None:
|
84 |
+
self.attn_implementation = "eager"
|
85 |
super().__init__(
|
86 |
pad_token_id=pad_token_id,
|
87 |
bos_token_id=bos_token_id,
|