Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import torch | |
import torch.nn as nn | |
import torch.cuda.amp as amp | |
import torch.nn.functional as F | |
import math | |
import os | |
import time | |
import numpy as np | |
import random | |
# from flash_attn.flash_attention import FlashAttention | |
class FlashAttentionBlock(nn.Module): | |
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): | |
# consider head_dim first, then num_heads | |
num_heads = dim // head_dim if head_dim else num_heads | |
head_dim = dim // num_heads | |
assert num_heads * head_dim == dim | |
super(FlashAttentionBlock, self).__init__() | |
self.dim = dim | |
self.context_dim = context_dim | |
self.num_heads = num_heads | |
self.head_dim = head_dim | |
self.scale = math.pow(head_dim, -0.25) | |
# layers | |
self.norm = nn.GroupNorm(32, dim) | |
self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |
if context_dim is not None: | |
self.context_kv = nn.Linear(context_dim, dim * 2) | |
self.proj = nn.Conv2d(dim, dim, 1) | |
if self.head_dim <= 128 and (self.head_dim % 8) == 0: | |
new_scale = math.pow(head_dim, -0.5) | |
self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) | |
# zero out the last layer params | |
nn.init.zeros_(self.proj.weight) | |
# self.apply(self._init_weight) | |
def _init_weight(self, module): | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.15) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Conv2d): | |
module.weight.data.normal_(mean=0.0, std=0.15) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
def forward(self, x, context=None): | |
r"""x: [B, C, H, W]. | |
context: [B, L, C] or None. | |
""" | |
identity = x | |
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim | |
# compute query, key, value | |
x = self.norm(x) | |
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) | |
if context is not None: | |
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) | |
k = torch.cat([ck, k], dim=-1) | |
v = torch.cat([cv, v], dim=-1) | |
cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) | |
q = torch.cat([q, cq], dim=-1) | |
qkv = torch.cat([q,k,v], dim=1) | |
origin_dtype = qkv.dtype | |
qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() | |
out, _ = self.flash_attn(qkv) | |
out.to(origin_dtype) | |
if context is not None: | |
out = out[:, :-4, :, :] | |
out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) | |
# output | |
x = self.proj(out) | |
return x + identity | |
if __name__ == '__main__': | |
batch_size = 8 | |
flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() | |
x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() | |
context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() | |
# context = None | |
flash_net.eval() | |
with amp.autocast(enabled=True): | |
# warm up | |
for i in range(5): | |
y = flash_net(x, context) | |
torch.cuda.synchronize() | |
s1 = time.time() | |
for i in range(10): | |
y = flash_net(x, context) | |
torch.cuda.synchronize() | |
s2 = time.time() | |
print(f'Average cost time {(s2-s1)*1000/10} ms') |