Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import torch | |
| import torch.cuda.amp as amp | |
| from ..modules.model import sinusoidal_embedding_1d | |
| from .ulysses import distributed_attention | |
| from .util import gather_forward, get_rank, get_world_size | |
| def pad_freqs(original_tensor, target_len): | |
| seq_len, s1, s2 = original_tensor.shape | |
| pad_size = target_len - seq_len | |
| padding_tensor = torch.ones( | |
| pad_size, | |
| s1, | |
| s2, | |
| dtype=original_tensor.dtype, | |
| device=original_tensor.device) | |
| padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) | |
| return padded_tensor | |
| def rope_apply(x, grid_sizes, freqs): | |
| """ | |
| x: [B, L, N, C]. | |
| grid_sizes: [B, 3]. | |
| freqs: [M, C // 2]. | |
| """ | |
| s, n, c = x.size(1), x.size(2), x.size(3) // 2 | |
| # split freqs | |
| freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) | |
| # loop over samples | |
| output = [] | |
| for i, (f, h, w) in enumerate(grid_sizes.tolist()): | |
| seq_len = f * h * w | |
| # precompute multipliers | |
| x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( | |
| s, n, -1, 2)) | |
| freqs_i = torch.cat([ | |
| freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), | |
| freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), | |
| freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) | |
| ], | |
| dim=-1).reshape(seq_len, 1, -1) | |
| # apply rotary embedding | |
| sp_size = get_world_size() | |
| sp_rank = get_rank() | |
| freqs_i = pad_freqs(freqs_i, s * sp_size) | |
| s_per_rank = s | |
| freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * | |
| s_per_rank), :, :] | |
| x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) | |
| x_i = torch.cat([x_i, x[i, s:]]) | |
| # append to collection | |
| output.append(x_i) | |
| return torch.stack(output).float() | |
| def sp_dit_forward( | |
| self, | |
| x, | |
| t, | |
| context, | |
| seq_len, | |
| y=None, | |
| ): | |
| """ | |
| x: A list of videos each with shape [C, T, H, W]. | |
| t: [B]. | |
| context: A list of text embeddings each with shape [L, C]. | |
| """ | |
| if self.model_type == 'i2v': | |
| assert y is not None | |
| # params | |
| device = self.patch_embedding.weight.device | |
| if self.freqs.device != device: | |
| self.freqs = self.freqs.to(device) | |
| if y is not None: | |
| x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] | |
| # embeddings | |
| x = [self.patch_embedding(u.unsqueeze(0)) for u in x] | |
| grid_sizes = torch.stack( | |
| [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) | |
| x = [u.flatten(2).transpose(1, 2) for u in x] | |
| seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) | |
| assert seq_lens.max() <= seq_len | |
| x = torch.cat([ | |
| torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) | |
| for u in x | |
| ]) | |
| # time embeddings | |
| if t.dim() == 1: | |
| t = t.expand(t.size(0), seq_len) | |
| with torch.amp.autocast('cuda', dtype=torch.float32): | |
| bt = t.size(0) | |
| t = t.flatten() | |
| e = self.time_embedding( | |
| sinusoidal_embedding_1d(self.freq_dim, | |
| t).unflatten(0, (bt, seq_len)).float()) | |
| e0 = self.time_projection(e).unflatten(2, (6, self.dim)) | |
| assert e.dtype == torch.float32 and e0.dtype == torch.float32 | |
| # context | |
| context_lens = None | |
| context = self.text_embedding( | |
| torch.stack([ | |
| torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) | |
| for u in context | |
| ])) | |
| # Context Parallel | |
| x = torch.chunk(x, get_world_size(), dim=1)[get_rank()] | |
| e = torch.chunk(e, get_world_size(), dim=1)[get_rank()] | |
| e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()] | |
| # arguments | |
| kwargs = dict( | |
| e=e0, | |
| seq_lens=seq_lens, | |
| grid_sizes=grid_sizes, | |
| freqs=self.freqs, | |
| context=context, | |
| context_lens=context_lens) | |
| for block in self.blocks: | |
| x = block(x, **kwargs) | |
| # head | |
| x = self.head(x, e) | |
| # Context Parallel | |
| x = gather_forward(x, dim=1) | |
| # unpatchify | |
| x = self.unpatchify(x, grid_sizes) | |
| return [u.float() for u in x] | |
| def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16): | |
| b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim | |
| half_dtypes = (torch.float16, torch.bfloat16) | |
| def half(x): | |
| return x if x.dtype in half_dtypes else x.to(dtype) | |
| # query, key, value function | |
| def qkv_fn(x): | |
| q = self.norm_q(self.q(x)).view(b, s, n, d) | |
| k = self.norm_k(self.k(x)).view(b, s, n, d) | |
| v = self.v(x).view(b, s, n, d) | |
| return q, k, v | |
| q, k, v = qkv_fn(x) | |
| q = rope_apply(q, grid_sizes, freqs) | |
| k = rope_apply(k, grid_sizes, freqs) | |
| x = distributed_attention( | |
| half(q), | |
| half(k), | |
| half(v), | |
| seq_lens, | |
| window_size=self.window_size, | |
| ) | |
| # output | |
| x = x.flatten(2) | |
| x = self.o(x) | |
| return x | |