Wendy-Fly commited on
Commit
aa0b1e2
·
verified ·
1 Parent(s): 9226937

Upload crossattention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. crossattention.py +541 -0
crossattention.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils.import_utils import is_xformers_available
21
+
22
+
23
+ if is_xformers_available():
24
+ import xformers
25
+ import xformers.ops
26
+ else:
27
+ xformers = None
28
+
29
+
30
+ class CrossAttention(nn.Module):
31
+ r"""
32
+ A cross attention layer.
33
+
34
+ Parameters:
35
+ query_dim (`int`): The number of channels in the query.
36
+ cross_attention_dim (`int`, *optional*):
37
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
38
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
39
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
40
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
41
+ bias (`bool`, *optional*, defaults to False):
42
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ query_dim: int,
48
+ cross_attention_dim: Optional[int] = None,
49
+ heads: int = 8,
50
+ dim_head: int = 64,
51
+ dropout: float = 0.0,
52
+ bias=False,
53
+ upcast_attention: bool = False,
54
+ upcast_softmax: bool = False,
55
+ added_kv_proj_dim: Optional[int] = None,
56
+ norm_num_groups: Optional[int] = None,
57
+ processor: Optional["AttnProcessor"] = None,
58
+ ):
59
+ super().__init__()
60
+ inner_dim = dim_head * heads
61
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
62
+ self.upcast_attention = upcast_attention
63
+ self.upcast_softmax = upcast_softmax
64
+
65
+ self.scale = dim_head**-0.5
66
+
67
+ self.heads = heads
68
+ # for slice_size > 0 the attention score computation
69
+ # is split across the batch axis to save memory
70
+ # You can set slice_size with `set_attention_slice`
71
+ self.sliceable_head_dim = heads
72
+
73
+ self.added_kv_proj_dim = added_kv_proj_dim
74
+
75
+ if norm_num_groups is not None:
76
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
77
+ else:
78
+ self.group_norm = None
79
+
80
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
81
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
82
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
83
+
84
+ if self.added_kv_proj_dim is not None:
85
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
86
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
87
+
88
+ self.to_out = nn.ModuleList([])
89
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
90
+ self.to_out.append(nn.Dropout(dropout))
91
+
92
+ # set attention processor
93
+ processor = processor if processor is not None else CrossAttnProcessor()
94
+ self.set_processor(processor)
95
+
96
+ def set_use_memory_efficient_attention_xformers(
97
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
98
+ ):
99
+ if use_memory_efficient_attention_xformers:
100
+ if self.added_kv_proj_dim is not None:
101
+ # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
102
+ # which uses this type of cross attention ONLY because the attention mask of format
103
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
104
+ raise NotImplementedError(
105
+ "Memory efficient attention with `xformers` is currently not supported when"
106
+ " `self.added_kv_proj_dim` is defined."
107
+ )
108
+ elif not is_xformers_available():
109
+ raise ModuleNotFoundError(
110
+ (
111
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
112
+ " xformers"
113
+ ),
114
+ name="xformers",
115
+ )
116
+ elif not torch.cuda.is_available():
117
+ raise ValueError(
118
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
119
+ " only available for GPU "
120
+ )
121
+ else:
122
+ try:
123
+ # Make sure we can run the memory efficient attention
124
+ _ = xformers.ops.memory_efficient_attention(
125
+ torch.randn((1, 2, 40), device="cuda"),
126
+ torch.randn((1, 2, 40), device="cuda"),
127
+ torch.randn((1, 2, 40), device="cuda"),
128
+ )
129
+ except Exception as e:
130
+ raise e
131
+
132
+ processor = XFormersCrossAttnProcessor(attention_op=attention_op)
133
+ else:
134
+ processor = CrossAttnProcessor()
135
+
136
+ self.set_processor(processor)
137
+
138
+ def set_attention_slice(self, slice_size):
139
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
140
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
141
+
142
+ if slice_size is not None and self.added_kv_proj_dim is not None:
143
+ processor = SlicedAttnAddedKVProcessor(slice_size)
144
+ elif slice_size is not None:
145
+ processor = SlicedAttnProcessor(slice_size)
146
+ elif self.added_kv_proj_dim is not None:
147
+ processor = CrossAttnAddedKVProcessor()
148
+ else:
149
+ processor = CrossAttnProcessor()
150
+
151
+ self.set_processor(processor)
152
+
153
+ def set_processor(self, processor: "AttnProcessor"):
154
+ self.processor = processor
155
+
156
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
157
+ # The `CrossAttention` class can call different attention processors / attention functions
158
+ # here we simply pass along all tensors to the selected processor class
159
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
160
+ return self.processor(
161
+ self,
162
+ hidden_states,
163
+ encoder_hidden_states=encoder_hidden_states,
164
+ attention_mask=attention_mask,
165
+ **cross_attention_kwargs,
166
+ )
167
+
168
+ def batch_to_head_dim(self, tensor):
169
+ head_size = self.heads
170
+ batch_size, seq_len, dim = tensor.shape
171
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
172
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
173
+ return tensor
174
+
175
+ def head_to_batch_dim(self, tensor):
176
+ head_size = self.heads
177
+ batch_size, seq_len, dim = tensor.shape
178
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
179
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
180
+ return tensor
181
+
182
+ def get_attention_scores(self, query, key, attention_mask=None):
183
+ dtype = query.dtype
184
+ if self.upcast_attention:
185
+ query = query.float()
186
+ key = key.float()
187
+
188
+ attention_scores = torch.baddbmm(
189
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
190
+ query,
191
+ key.transpose(-1, -2),
192
+ beta=0,
193
+ alpha=self.scale,
194
+ )
195
+
196
+ if attention_mask is not None:
197
+ attention_scores = attention_scores + attention_mask
198
+
199
+ if self.upcast_softmax:
200
+ attention_scores = attention_scores.float()
201
+
202
+ attention_probs = attention_scores.softmax(dim=-1)
203
+ attention_probs = attention_probs.to(dtype)
204
+
205
+ return attention_probs
206
+
207
+ def prepare_attention_mask(self, attention_mask, target_length):
208
+ head_size = self.heads
209
+ if attention_mask is None:
210
+ return attention_mask
211
+
212
+ if attention_mask.shape[-1] != target_length:
213
+ if attention_mask.device.type == "mps":
214
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
215
+ # Instead, we can manually construct the padding tensor.
216
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
217
+ padding = torch.zeros(padding_shape, device=attention_mask.device)
218
+ attention_mask = torch.concat([attention_mask, padding], dim=2)
219
+ else:
220
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
221
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
222
+ return attention_mask
223
+
224
+
225
+ class CrossAttnProcessor:
226
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
227
+ batch_size, sequence_length, _ = hidden_states.shape
228
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
229
+
230
+ query = attn.to_q(hidden_states)
231
+ query = attn.head_to_batch_dim(query)
232
+
233
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
234
+ key = attn.to_k(encoder_hidden_states)
235
+ value = attn.to_v(encoder_hidden_states)
236
+ key = attn.head_to_batch_dim(key)
237
+ value = attn.head_to_batch_dim(value)
238
+
239
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
240
+ hidden_states = torch.bmm(attention_probs, value)
241
+ hidden_states = attn.batch_to_head_dim(hidden_states)
242
+
243
+ # linear proj
244
+ hidden_states = attn.to_out[0](hidden_states)
245
+ # dropout
246
+ hidden_states = attn.to_out[1](hidden_states)
247
+
248
+ return hidden_states
249
+
250
+
251
+ class LoRALinearLayer(nn.Module):
252
+ def __init__(self, in_features, out_features, rank=4):
253
+ super().__init__()
254
+
255
+ if rank > min(in_features, out_features):
256
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
257
+
258
+ self.down = nn.Linear(in_features, rank, bias=False)
259
+ self.up = nn.Linear(rank, out_features, bias=False)
260
+ self.scale = 1.0
261
+
262
+ nn.init.normal_(self.down.weight, std=1 / rank)
263
+ nn.init.zeros_(self.up.weight)
264
+
265
+ def forward(self, hidden_states):
266
+ orig_dtype = hidden_states.dtype
267
+ dtype = self.down.weight.dtype
268
+
269
+ down_hidden_states = self.down(hidden_states.to(dtype))
270
+ up_hidden_states = self.up(down_hidden_states)
271
+
272
+ return up_hidden_states.to(orig_dtype)
273
+
274
+
275
+ class LoRACrossAttnProcessor(nn.Module):
276
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
277
+ super().__init__()
278
+
279
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
280
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
281
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
282
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
283
+
284
+ def __call__(
285
+ self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
286
+ ):
287
+ batch_size, sequence_length, _ = hidden_states.shape
288
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
289
+
290
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
291
+ query = attn.head_to_batch_dim(query)
292
+
293
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
294
+
295
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
296
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
297
+
298
+ key = attn.head_to_batch_dim(key)
299
+ value = attn.head_to_batch_dim(value)
300
+
301
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
302
+ hidden_states = torch.bmm(attention_probs, value)
303
+ hidden_states = attn.batch_to_head_dim(hidden_states)
304
+
305
+ # linear proj
306
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
307
+ # dropout
308
+ hidden_states = attn.to_out[1](hidden_states)
309
+
310
+ return hidden_states
311
+
312
+
313
+ class CrossAttnAddedKVProcessor:
314
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
315
+ residual = hidden_states
316
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
317
+ batch_size, sequence_length, _ = hidden_states.shape
318
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
319
+
320
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
321
+
322
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
323
+
324
+ query = attn.to_q(hidden_states)
325
+ query = attn.head_to_batch_dim(query)
326
+
327
+ key = attn.to_k(hidden_states)
328
+ value = attn.to_v(hidden_states)
329
+ key = attn.head_to_batch_dim(key)
330
+ value = attn.head_to_batch_dim(value)
331
+
332
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
333
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
334
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
335
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
336
+
337
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
338
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
339
+
340
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
341
+ hidden_states = torch.bmm(attention_probs, value)
342
+ hidden_states = attn.batch_to_head_dim(hidden_states)
343
+
344
+ # linear proj
345
+ hidden_states = attn.to_out[0](hidden_states)
346
+ # dropout
347
+ hidden_states = attn.to_out[1](hidden_states)
348
+
349
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
350
+ hidden_states = hidden_states + residual
351
+
352
+ return hidden_states
353
+
354
+
355
+ class XFormersCrossAttnProcessor:
356
+ def __init__(self, attention_op: Optional[Callable] = None):
357
+ self.attention_op = attention_op
358
+
359
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
360
+ batch_size, sequence_length, _ = hidden_states.shape
361
+
362
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
363
+
364
+ query = attn.to_q(hidden_states)
365
+
366
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
367
+ key = attn.to_k(encoder_hidden_states)
368
+ value = attn.to_v(encoder_hidden_states)
369
+
370
+ query = attn.head_to_batch_dim(query).contiguous()
371
+ key = attn.head_to_batch_dim(key).contiguous()
372
+ value = attn.head_to_batch_dim(value).contiguous()
373
+
374
+ hidden_states = xformers.ops.memory_efficient_attention(
375
+ query, key, value, attn_bias=attention_mask, op=self.attention_op
376
+ )
377
+ hidden_states = hidden_states.to(query.dtype)
378
+ hidden_states = attn.batch_to_head_dim(hidden_states)
379
+
380
+ # linear proj
381
+ hidden_states = attn.to_out[0](hidden_states)
382
+ # dropout
383
+ hidden_states = attn.to_out[1](hidden_states)
384
+ return hidden_states
385
+
386
+
387
+ class LoRAXFormersCrossAttnProcessor(nn.Module):
388
+ def __init__(self, hidden_size, cross_attention_dim, rank=4):
389
+ super().__init__()
390
+
391
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
392
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
393
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
394
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
395
+
396
+ def __call__(
397
+ self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
398
+ ):
399
+ batch_size, sequence_length, _ = hidden_states.shape
400
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
401
+
402
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
403
+ query = attn.head_to_batch_dim(query).contiguous()
404
+
405
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
406
+
407
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
408
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
409
+
410
+ key = attn.head_to_batch_dim(key).contiguous()
411
+ value = attn.head_to_batch_dim(value).contiguous()
412
+
413
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
414
+
415
+ # linear proj
416
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
417
+ # dropout
418
+ hidden_states = attn.to_out[1](hidden_states)
419
+
420
+ return hidden_states
421
+
422
+
423
+ class SlicedAttnProcessor:
424
+ def __init__(self, slice_size):
425
+ self.slice_size = slice_size
426
+
427
+ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
428
+ batch_size, sequence_length, _ = hidden_states.shape
429
+
430
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
431
+
432
+ query = attn.to_q(hidden_states)
433
+ dim = query.shape[-1]
434
+ query = attn.head_to_batch_dim(query)
435
+
436
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
437
+ key = attn.to_k(encoder_hidden_states)
438
+ value = attn.to_v(encoder_hidden_states)
439
+ key = attn.head_to_batch_dim(key)
440
+ value = attn.head_to_batch_dim(value)
441
+
442
+ batch_size_attention = query.shape[0]
443
+ hidden_states = torch.zeros(
444
+ (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
445
+ )
446
+
447
+ for i in range(hidden_states.shape[0] // self.slice_size):
448
+ start_idx = i * self.slice_size
449
+ end_idx = (i + 1) * self.slice_size
450
+
451
+ query_slice = query[start_idx:end_idx]
452
+ key_slice = key[start_idx:end_idx]
453
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
454
+
455
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
456
+
457
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
458
+
459
+ hidden_states[start_idx:end_idx] = attn_slice
460
+
461
+ hidden_states = attn.batch_to_head_dim(hidden_states)
462
+
463
+ # linear proj
464
+ hidden_states = attn.to_out[0](hidden_states)
465
+ # dropout
466
+ hidden_states = attn.to_out[1](hidden_states)
467
+
468
+ return hidden_states
469
+
470
+
471
+ class SlicedAttnAddedKVProcessor:
472
+ def __init__(self, slice_size):
473
+ self.slice_size = slice_size
474
+
475
+ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
476
+ residual = hidden_states
477
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
478
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
479
+
480
+ batch_size, sequence_length, _ = hidden_states.shape
481
+
482
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
483
+
484
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
485
+
486
+ query = attn.to_q(hidden_states)
487
+ dim = query.shape[-1]
488
+ query = attn.head_to_batch_dim(query)
489
+
490
+ key = attn.to_k(hidden_states)
491
+ value = attn.to_v(hidden_states)
492
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
493
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
494
+
495
+ key = attn.head_to_batch_dim(key)
496
+ value = attn.head_to_batch_dim(value)
497
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
498
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
499
+
500
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
501
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
502
+
503
+ batch_size_attention = query.shape[0]
504
+ hidden_states = torch.zeros(
505
+ (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
506
+ )
507
+
508
+ for i in range(hidden_states.shape[0] // self.slice_size):
509
+ start_idx = i * self.slice_size
510
+ end_idx = (i + 1) * self.slice_size
511
+
512
+ query_slice = query[start_idx:end_idx]
513
+ key_slice = key[start_idx:end_idx]
514
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
515
+
516
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
517
+
518
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
519
+
520
+ hidden_states[start_idx:end_idx] = attn_slice
521
+
522
+ hidden_states = attn.batch_to_head_dim(hidden_states)
523
+
524
+ # linear proj
525
+ hidden_states = attn.to_out[0](hidden_states)
526
+ # dropout
527
+ hidden_states = attn.to_out[1](hidden_states)
528
+
529
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
530
+ hidden_states = hidden_states + residual
531
+
532
+ return hidden_states
533
+
534
+
535
+ AttnProcessor = Union[
536
+ CrossAttnProcessor,
537
+ XFormersCrossAttnProcessor,
538
+ SlicedAttnProcessor,
539
+ CrossAttnAddedKVProcessor,
540
+ SlicedAttnAddedKVProcessor,
541
+ ]