Update llama_xformers_attention.py
Browse files- llama_xformers_attention.py +4 -27
llama_xformers_attention.py
CHANGED
@@ -23,27 +23,9 @@ class LlamaXFormersAttention(LlamaAttention):
|
|
23 |
|
24 |
bsz, q_len, _ = hidden_states.size()
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
30 |
-
)
|
31 |
-
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
32 |
-
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
33 |
-
|
34 |
-
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
35 |
-
query_states = torch.cat(query_states, dim=-1)
|
36 |
-
|
37 |
-
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
38 |
-
key_states = torch.cat(key_states, dim=-1)
|
39 |
-
|
40 |
-
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
41 |
-
value_states = torch.cat(value_states, dim=-1)
|
42 |
-
|
43 |
-
else:
|
44 |
-
query_states = self.q_proj(hidden_states)
|
45 |
-
key_states = self.k_proj(hidden_states)
|
46 |
-
value_states = self.v_proj(hidden_states)
|
47 |
|
48 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
49 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
@@ -95,12 +77,7 @@ class LlamaXFormersAttention(LlamaAttention):
|
|
95 |
|
96 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
97 |
|
98 |
-
|
99 |
-
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
100 |
-
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
101 |
-
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
102 |
-
else:
|
103 |
-
attn_output = self.o_proj(attn_output)
|
104 |
|
105 |
if not output_attentions:
|
106 |
attn_weights = None
|
|
|
23 |
|
24 |
bsz, q_len, _ = hidden_states.size()
|
25 |
|
26 |
+
query_states = self.q_proj(hidden_states)
|
27 |
+
key_states = self.k_proj(hidden_states)
|
28 |
+
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
31 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
77 |
|
78 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
79 |
|
80 |
+
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
if not output_attentions:
|
83 |
attn_weights = None
|