File size: 4,642 Bytes
f62c8b9
 
 
 
 
 
c2a6cd2
 
 
f62c8b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import math
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.embeddings import (PixArtAlphaTextProjection,
                                         TimestepEmbedding, Timesteps,
                                         get_timestep_embedding)
from einops import rearrange
from torch import nn


class HunyuanDiTAttentionPool(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1)
        x = x + self.positional_embedding[None, :, :].to(x.dtype)

        query = self.q_proj(x[:, :1])
        key = self.k_proj(x)
        value = self.v_proj(x)
        batch_size, _, _ = query.size()

        query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2)  # (1, H, N, E/H)
        key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2)  # (L+1, H, N, E/H)
        value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2)  # (L+1, H, N, E/H)

        x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False)
        x = x.transpose(1, 2).reshape(batch_size, 1, -1)
        x = x.to(query.dtype)
        x = self.c_proj(x) 

        return x.squeeze(1) 


class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
    def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
        super().__init__()

        self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

        self.pooler = HunyuanDiTAttentionPool(
            seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
        )
        # Here we use a default learned embedder layer for future extension.
        self.style_embedder = nn.Embedding(1, embedding_dim)
        extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
        self.extra_embedder = PixArtAlphaTextProjection(
            in_features=extra_in_dim,
            hidden_size=embedding_dim * 4,
            out_features=embedding_dim,
            act_fn="silu_fp32",
        )

    def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
        timesteps_proj = self.time_proj(timestep)
        timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, 256)

        # extra condition1: text
        pooled_projections = self.pooler(encoder_hidden_states)  # (N, 1024)

        # extra condition2: image meta size embdding
        image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
        image_meta_size = image_meta_size.to(dtype=hidden_dtype)
        image_meta_size = image_meta_size.view(-1, 6 * 256)  # (N, 1536)

        # extra condition3: style embedding
        style_embedding = self.style_embedder(style)  # (N, embedding_dim)

        # Concatenate all extra vectors
        extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
        conditioning = timesteps_emb + self.extra_embedder(extra_cond)  # [B, D]

        return conditioning


class TimePositionalEncoding(nn.Module):
    def __init__(
        self, 
        d_model, 
        dropout = 0., 
        max_len = 24
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        b, c, f, h, w = x.size()
        x = rearrange(x, "b c f h w -> (b h w) f c")
        x = x + self.pe[:, :x.size(1)]
        x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
        return self.dropout(x)