cbensimon HF Staff commited on
Commit
da26724
·
1 Parent(s): cfad87d
Files changed (1) hide show
  1. fa3.py +3 -3
fa3.py CHANGED
@@ -36,9 +36,9 @@ class FlashFusedFluxAttnProcessor3_0:
36
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
 
38
  # `sample` projections.
39
- query = attn.to_q(hidden_states)
40
- key = attn.to_k(hidden_states)
41
- value = attn.to_v(hidden_states)
42
 
43
  inner_dim = key.shape[-1]
44
  head_dim = inner_dim // attn.heads
 
36
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
 
38
  # `sample` projections.
39
+ qkv = attn.to_qkv(hidden_states)
40
+ split_size = qkv.shape[-1] // 3
41
+ query, key, value = torch.split(qkv, split_size, dim=-1)
42
 
43
  inner_dim = key.shape[-1]
44
  head_dim = inner_dim // attn.heads