cbensimon HF Staff commited on
Commit
cfad87d
·
1 Parent(s): 69667cb
Files changed (1) hide show
  1. fa3.py +3 -3
fa3.py CHANGED
@@ -36,9 +36,9 @@ class FlashFusedFluxAttnProcessor3_0:
36
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
 
38
  # `sample` projections.
39
- qkv = attn.to_qkv(hidden_states)
40
- split_size = qkv.shape[-1] // 3
41
- query, key, value = torch.split(qkv, split_size, dim=-1)
42
 
43
  inner_dim = key.shape[-1]
44
  head_dim = inner_dim // attn.heads
 
36
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
 
38
  # `sample` projections.
39
+ query = attn.to_q(hidden_states)
40
+ key = attn.to_k(hidden_states)
41
+ value = attn.to_v(hidden_states)
42
 
43
  inner_dim = key.shape[-1]
44
  head_dim = inner_dim // attn.heads