Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.nn.functional as F
import math
import os
import time
import numpy as np
import random
# from flash_attn.flash_attention import FlashAttention
class FlashAttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(FlashAttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
if self.head_dim <= 128 and (self.head_dim % 8) == 0:
new_scale = math.pow(head_dim, -0.5)
self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
# self.apply(self._init_weight)
def _init_weight(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.15)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=0.15)
if module.bias is not None:
module.bias.data.zero_()
def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device)
q = torch.cat([q, cq], dim=-1)
qkv = torch.cat([q,k,v], dim=1)
origin_dtype = qkv.dtype
qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous()
out, _ = self.flash_attn(qkv)
out.to(origin_dtype)
if context is not None:
out = out[:, :-4, :, :]
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w)
# output
x = self.proj(out)
return x + identity
if __name__ == '__main__':
batch_size = 8
flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda()
x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda()
context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda()
# context = None
flash_net.eval()
with amp.autocast(enabled=True):
# warm up
for i in range(5):
y = flash_net(x, context)
torch.cuda.synchronize()
s1 = time.time()
for i in range(10):
y = flash_net(x, context)
torch.cuda.synchronize()
s2 = time.time()
print(f'Average cost time {(s2-s1)*1000/10} ms')