|
from typing import Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from mld.models.operator.embeddings import TimestepEmbedding, Timesteps |
|
from mld.models.operator.attention import (SkipTransformerEncoder, |
|
SkipTransformerDecoder, |
|
TransformerDecoder, |
|
TransformerDecoderLayer, |
|
TransformerEncoder, |
|
TransformerEncoderLayer) |
|
from mld.models.operator.moe import MoeTransformerEncoderLayer, MoeTransformerDecoderLayer |
|
from mld.models.operator.utils import get_clones, get_activation_fn, zero_module |
|
from mld.models.operator.position_encoding import build_position_encoding |
|
|
|
|
|
def load_balancing_loss_func(router_logits: tuple, num_experts: int = 4, topk: int = 2): |
|
router_logits = torch.cat(router_logits, dim=0) |
|
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1) |
|
_, selected_experts = torch.topk(routing_weights, topk, dim=-1) |
|
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) |
|
tokens_per_expert = torch.mean(expert_mask.float(), dim=0) |
|
router_prob_per_expert = torch.mean(routing_weights, dim=0) |
|
overall_loss = num_experts * torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) |
|
return overall_loss |
|
|
|
|
|
class MldDenoiser(nn.Module): |
|
|
|
def __init__(self, |
|
latent_dim: list = [1, 256], |
|
hidden_dim: Optional[int] = None, |
|
text_dim: int = 768, |
|
time_dim: int = 768, |
|
ff_size: int = 1024, |
|
num_layers: int = 9, |
|
num_heads: int = 4, |
|
dropout: float = 0.1, |
|
normalize_before: bool = False, |
|
norm_eps: float = 1e-5, |
|
activation: str = "gelu", |
|
norm_post: bool = True, |
|
activation_post: Optional[str] = None, |
|
flip_sin_to_cos: bool = True, |
|
freq_shift: float = 0, |
|
time_act_fn: str = 'silu', |
|
time_post_act_fn: Optional[str] = None, |
|
position_embedding: str = "learned", |
|
arch: str = "trans_enc", |
|
add_mem_pos: bool = True, |
|
force_pre_post_proj: bool = False, |
|
text_act_fn: str = 'relu', |
|
time_cond_proj_dim: Optional[int] = None, |
|
zero_init_cond: bool = True, |
|
is_controlnet: bool = False, |
|
controlnet_embed_dim: Optional[int] = None, |
|
controlnet_act_fn: str = 'silu', |
|
moe: bool = False, |
|
moe_num_experts: int = 4, |
|
moe_topk: int = 2, |
|
moe_loss_weight: float = 1e-2, |
|
moe_jitter_noise: Optional[float] = None |
|
) -> None: |
|
super(MldDenoiser, self).__init__() |
|
|
|
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim |
|
add_pre_post_proj = force_pre_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1]) |
|
self.latent_pre = nn.Linear(latent_dim[-1], self.latent_dim) if add_pre_post_proj else nn.Identity() |
|
self.latent_post = nn.Linear(self.latent_dim, latent_dim[-1]) if add_pre_post_proj else nn.Identity() |
|
|
|
self.arch = arch |
|
self.time_cond_proj_dim = time_cond_proj_dim |
|
|
|
self.moe_num_experts = moe_num_experts |
|
self.moe_topk = moe_topk |
|
self.moe_loss_weight = moe_loss_weight |
|
|
|
self.time_proj = Timesteps(time_dim, flip_sin_to_cos, freq_shift) |
|
self.time_embedding = TimestepEmbedding(time_dim, self.latent_dim, time_act_fn, post_act_fn=time_post_act_fn, |
|
cond_proj_dim=time_cond_proj_dim, zero_init_cond=zero_init_cond) |
|
self.emb_proj = nn.Sequential(get_activation_fn(text_act_fn), nn.Linear(text_dim, self.latent_dim)) |
|
|
|
self.query_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding) |
|
if self.arch == "trans_enc": |
|
if moe: |
|
encoder_layer = MoeTransformerEncoderLayer( |
|
self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size, |
|
dropout, activation, normalize_before, norm_eps, moe_jitter_noise) |
|
else: |
|
encoder_layer = TransformerEncoderLayer( |
|
self.latent_dim, num_heads, ff_size, dropout, |
|
activation, normalize_before, norm_eps) |
|
|
|
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None |
|
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post, |
|
is_controlnet=is_controlnet, is_moe=moe) |
|
|
|
elif self.arch == 'trans_dec': |
|
if add_mem_pos: |
|
self.mem_pos = build_position_encoding(self.latent_dim, position_embedding=position_embedding) |
|
else: |
|
self.mem_pos = None |
|
if moe: |
|
decoder_layer = MoeTransformerDecoderLayer( |
|
self.latent_dim, num_heads, moe_num_experts, moe_topk, ff_size, |
|
dropout, activation, normalize_before, norm_eps, moe_jitter_noise) |
|
else: |
|
decoder_layer = TransformerDecoderLayer( |
|
self.latent_dim, num_heads, ff_size, dropout, |
|
activation, normalize_before, norm_eps) |
|
|
|
decoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post and not is_controlnet else None |
|
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers, decoder_norm, activation_post, |
|
is_controlnet=is_controlnet, is_moe=moe) |
|
else: |
|
raise ValueError(f"Not supported architecture: {self.arch}!") |
|
|
|
self.is_controlnet = is_controlnet |
|
if self.is_controlnet: |
|
embed_dim = controlnet_embed_dim if controlnet_embed_dim is not None else self.latent_dim |
|
modules = [ |
|
nn.Linear(latent_dim[-1], embed_dim), |
|
get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None, |
|
nn.Linear(embed_dim, embed_dim), |
|
get_activation_fn(controlnet_act_fn) if controlnet_act_fn else None, |
|
zero_module(nn.Linear(embed_dim, latent_dim[-1])) |
|
] |
|
self.controlnet_cond_embedding = nn.Sequential(*[m for m in modules if m is not None]) |
|
|
|
self.controlnet_down_mid_blocks = nn.ModuleList([ |
|
zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)]) |
|
|
|
def forward(self, |
|
sample: torch.Tensor, |
|
timestep: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
controlnet_cond: Optional[torch.Tensor] = None, |
|
controlnet_residuals: Optional[list[torch.Tensor]] = None |
|
) -> tuple: |
|
|
|
|
|
if self.is_controlnet: |
|
sample = sample + self.controlnet_cond_embedding(controlnet_cond) |
|
|
|
|
|
sample = sample.permute(1, 0, 2) |
|
sample = self.latent_pre(sample) |
|
|
|
|
|
timesteps = timestep.expand(sample.shape[1]).clone() |
|
time_emb = self.time_proj(timesteps) |
|
time_emb = time_emb.to(dtype=sample.dtype) |
|
|
|
time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0) |
|
|
|
|
|
|
|
encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2) |
|
|
|
text_emb_latent = self.emb_proj(encoder_hidden_states) |
|
emb_latent = torch.cat((time_emb, text_emb_latent), 0) |
|
|
|
|
|
if self.arch == "trans_enc": |
|
xseq = torch.cat((sample, emb_latent), axis=0) |
|
xseq = self.query_pos(xseq) |
|
tokens, intermediates, router_logits = self.encoder(xseq, controlnet_residuals=controlnet_residuals) |
|
elif self.arch == 'trans_dec': |
|
sample = self.query_pos(sample) |
|
if self.mem_pos: |
|
emb_latent = self.mem_pos(emb_latent) |
|
tokens, intermediates, router_logits = self.decoder(sample, emb_latent, |
|
controlnet_residuals=controlnet_residuals) |
|
else: |
|
raise TypeError(f"{self.arch} is not supported") |
|
|
|
router_loss = None |
|
if router_logits is not None: |
|
router_loss = load_balancing_loss_func(router_logits, self.moe_num_experts, self.moe_topk) |
|
router_loss = self.moe_loss_weight * router_loss |
|
|
|
if self.is_controlnet: |
|
control_res_samples = [] |
|
for res, block in zip(intermediates, self.controlnet_down_mid_blocks): |
|
r = block(res) |
|
control_res_samples.append(r) |
|
return control_res_samples, router_loss |
|
elif self.arch == "trans_enc": |
|
sample = tokens[:sample.shape[0]] |
|
elif self.arch == 'trans_dec': |
|
sample = tokens |
|
else: |
|
raise TypeError(f"{self.arch} is not supported") |
|
|
|
|
|
sample = self.latent_post(sample) |
|
sample = sample.permute(1, 0, 2) |
|
return sample, router_loss |
|
|