+
FluxMusic Generator
+
Generate music based on text prompts using FluxMusic model.
+
+ """)
+
+ with gr.Row():
+ model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model if default_model in model_choices else model_choices[0])
+
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt")
+ seed = gr.Number(label="Seed", value=0)
+
+ with gr.Row():
+ cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
+ steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
+ duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
+
+ generate_button = gr.Button("Generate Music")
+ output_status = gr.Textbox(label="Generation Status")
+ output_audio = gr.Audio(type="filepath")
+
+ def on_model_change(model_name):
+ load_model(model_name)
+
+ model_dropdown.change(on_model_change, inputs=[model_dropdown])
+ generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
+
+ # Load default model on startup
+ default_model_path = os.path.join(MODELS_DIR, default_model)
+ if os.path.exists(default_model_path):
+ iface.load(lambda: load_model(default_model), inputs=None, outputs=None)
+
+if __name__ == "__main__":
+ iface.launch()
\ No newline at end of file
diff --git a/generations/generationsShowUpHere.txt b/generations/generationsShowUpHere.txt
new file mode 100644
index 0000000000000000000000000000000000000000..eb364e3140ba22981f9450bbdf035c9e9ebe1f35
--- /dev/null
+++ b/generations/generationsShowUpHere.txt
@@ -0,0 +1 @@
+generations show up here
\ No newline at end of file
diff --git a/model.py b/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7334fb49e34d8ed83b672b167d206dbbec8af37
--- /dev/null
+++ b/model.py
@@ -0,0 +1,112 @@
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+
+from modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
+ MLPEmbedder, SingleStreamBlock,
+ timestep_embedding)
+
+
+@dataclass
+class FluxParams:
+ in_channels: int
+ vec_in_dim: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list[int]
+ theta: int
+ qkv_bias: bool
+ guidance_embed: bool
+
+
+class Flux(nn.Module):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ def __init__(self, params: FluxParams):
+ super().__init__()
+
+ self.params = params
+ self.in_channels = params.in_channels
+ self.out_channels = self.in_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
+ # self.guidance_in = (
+ # MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
+ # )
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+
+ def forward(
+ self,
+ x: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ t: Tensor,
+ y: Tensor,
+ guidance: Tensor | None = None,
+ ) -> Tensor:
+ if x.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(x)
+ vec = self.time_in(timestep_embedding(t, 256))
+ # if self.params.guidance_embed:
+ # if guidance is None:
+ # raise ValueError("Didn't get guidance strength for guidance distilled model.")
+ # vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
+ vec = vec + self.vector_in(y)
+ txt = self.txt_in(txt)
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+
+ for block in self.double_blocks:
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
+
+ img = torch.cat((txt, img), 1)
+ for block in self.single_blocks:
+ img = block(img, vec=vec, pe=pe)
+ img = img[:, txt.shape[1] :, ...]
+
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+ return img
diff --git a/models/.DS_Store b/models/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/models/.DS_Store differ
diff --git a/models/modelsgohere.txt b/models/modelsgohere.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e0bd5452ee55c06d2bfb4de595ea2a84a584b686
--- /dev/null
+++ b/models/modelsgohere.txt
@@ -0,0 +1 @@
+models go here
\ No newline at end of file
diff --git a/modelsgohere.txt b/modelsgohere.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e0bd5452ee55c06d2bfb4de595ea2a84a584b686
--- /dev/null
+++ b/modelsgohere.txt
@@ -0,0 +1 @@
+models go here
\ No newline at end of file
diff --git a/modules/__pycache__/autoencoder.cpython-310.pyc b/modules/__pycache__/autoencoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8755bee79d551ba7f59a98321d4a7405e7ae087c
Binary files /dev/null and b/modules/__pycache__/autoencoder.cpython-310.pyc differ
diff --git a/modules/__pycache__/conditioner.cpython-310.pyc b/modules/__pycache__/conditioner.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bcae9f1fd1b5d18a4776e7ae917a6172a7f9018c
Binary files /dev/null and b/modules/__pycache__/conditioner.cpython-310.pyc differ
diff --git a/modules/__pycache__/layers.cpython-310.pyc b/modules/__pycache__/layers.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4aa8a8df688183d46ae1a6391f97e2c70f3ce831
Binary files /dev/null and b/modules/__pycache__/layers.cpython-310.pyc differ
diff --git a/modules/autoencoder.py b/modules/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..41d8080dfa098a82484fe48876cc25694c0e72e4
--- /dev/null
+++ b/modules/autoencoder.py
@@ -0,0 +1,312 @@
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+
+
+@dataclass
+class AutoEncoderParams:
+ resolution: int
+ in_channels: int
+ ch: int
+ out_ch: int
+ ch_mult: list[int]
+ num_res_blocks: int
+ z_channels: int
+ scale_factor: float
+ shift_factor: float
+
+
+def swish(x: Tensor) -> Tensor:
+ return x * torch.sigmoid(x)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
+
+ def attention(self, h_: Tensor) -> Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x + self.proj_out(self.attention(x))
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x):
+ h = x
+ h = self.norm1(h)
+ h = swish(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = swish(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x: Tensor):
+ pad = (0, 1, 0, 1)
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: Tensor):
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ resolution: int,
+ in_channels: int,
+ ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ # downsampling
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ block_in = self.ch
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x: Tensor) -> Tensor:
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1])
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ ch: int,
+ out_ch: int,
+ ch_mult: list[int],
+ num_res_blocks: int,
+ in_channels: int,
+ resolution: int,
+ z_channels: int,
+ ):
+ super().__init__()
+ self.ch = ch
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.ffactor = 2 ** (self.num_resolutions - 1)
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+
+ # z to block_in
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for _ in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
+ block_in = block_out
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, z: Tensor) -> Tensor:
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = swish(h)
+ h = self.conv_out(h)
+ return h
+
+
+class DiagonalGaussian(nn.Module):
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
+ super().__init__()
+ self.sample = sample
+ self.chunk_dim = chunk_dim
+
+ def forward(self, z: Tensor) -> Tensor:
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
+ if self.sample:
+ std = torch.exp(0.5 * logvar)
+ return mean + std * torch.randn_like(mean)
+ else:
+ return mean
+
+
+class AutoEncoder(nn.Module):
+ def __init__(self, params: AutoEncoderParams):
+ super().__init__()
+ self.encoder = Encoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.decoder = Decoder(
+ resolution=params.resolution,
+ in_channels=params.in_channels,
+ ch=params.ch,
+ out_ch=params.out_ch,
+ ch_mult=params.ch_mult,
+ num_res_blocks=params.num_res_blocks,
+ z_channels=params.z_channels,
+ )
+ self.reg = DiagonalGaussian()
+
+ self.scale_factor = params.scale_factor
+ self.shift_factor = params.shift_factor
+
+ def encode(self, x: Tensor) -> Tensor:
+ z = self.reg(self.encoder(x))
+ z = self.scale_factor * (z - self.shift_factor)
+ return z
+
+ def decode(self, z: Tensor) -> Tensor:
+ z = z / self.scale_factor + self.shift_factor
+ return self.decoder(z)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.decode(self.encode(x))
diff --git a/modules/conditioner.py b/modules/conditioner.py
new file mode 100644
index 0000000000000000000000000000000000000000..786dbed55628f5ef73cfc2fc142174ed073236e2
--- /dev/null
+++ b/modules/conditioner.py
@@ -0,0 +1,46 @@
+from torch import Tensor, nn
+from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
+ T5Tokenizer, AutoTokenizer, ClapTextModel)
+
+
+class HFEmbedder(nn.Module):
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
+ super().__init__()
+ self.is_t5 = version.startswith("google")
+ self.max_length = max_length
+ self.output_key = "last_hidden_state" if self.is_t5 else "pooler_output"
+
+ if version.startswith("openai"):
+ local_path = 'ckpt/stable-diffusion-3-medium-diffusers'
+ local_path_tokenizer = 'ckpt/stable-diffusion-3-medium-diffusers'
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(local_path_tokenizer, subfolder="tokenizer", max_length=max_length)
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(local_path, subfolder="text_encoder", **hf_kwargs).half()
+ elif version.startswith("laion"):
+ local_path = "laion/clap-htsat-fused"
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path, max_length=max_length)
+ self.hf_module: ClapTextModel = ClapTextModel.from_pretrained(local_path, **hf_kwargs).half()
+ else:
+ local_path = 'ckpt/stable-diffusion-3-medium-diffusers'
+ local_path_tokenizer = 'ckpt/stable-diffusion-3-medium-diffusers'
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(local_path_tokenizer, subfolder="tokenizer_3", max_length=max_length)
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(local_path, subfolder="text_encoder_3", **hf_kwargs).half()
+
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
+
+ def forward(self, text: list[str]) -> Tensor:
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ outputs = self.hf_module(
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
+ attention_mask=None,
+ output_hidden_states=False,
+ )
+ return outputs[self.output_key]
diff --git a/modules/layers.py b/modules/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..38470d14a19b0f3494c58c6e25e22bfd4161d134
--- /dev/null
+++ b/modules/layers.py
@@ -0,0 +1,348 @@
+import math
+from dataclasses import dataclass
+
+import torch
+from einops import rearrange
+from torch import Tensor, nn
+import torch
+from einops import rearrange
+from torch import Tensor
+
+try:
+ import flash_attn
+ if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2:
+ from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
+ from flash_attn.modules.mha import FlashSelfAttention
+ else:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
+ from flash_attn.modules.mha import FlashSelfAttention
+except Exception as e:
+ print(f'flash_attn import failed: {e}')
+
+
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
+ q, k = apply_rope(q, k, pe)
+
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "B H L D -> B L (H D)")
+
+ return x
+
+
+def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
+ assert dim % 2 == 0
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
+ return out.float()
+
+
+def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+
+class EmbedND(nn.Module):
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: Tensor) -> Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+
+ return emb.unsqueeze(1)
+
+
+def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ t = time_factor * t
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+ t.device
+ )
+
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ if torch.is_floating_point(t):
+ embedding = embedding.to(t)
+ return embedding
+
+
+class MLPEmbedder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class RMSNorm(torch.nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.scale = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x: Tensor):
+ x_dtype = x.dtype
+ x = x.float()
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
+ return (x * rrms).to(dtype=x_dtype) * self.scale
+
+
+class QKNorm(torch.nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = RMSNorm(dim)
+ self.key_norm = RMSNorm(dim)
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
+ q = self.query_norm(q)
+ k = self.key_norm(k)
+ return q.to(v), k.to(v)
+
+
+class SelfAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.norm = QKNorm(head_dim)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
+ qkv = self.qkv(x)
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k = self.norm(q, k, v)
+ x = attention(q, k, v, pe=pe)
+ x = self.proj(x)
+ return x
+
+
+
+class FlashSelfMHAModified(nn.Module):
+ """
+ self-attention with flashattention
+ """
+ def __init__(self,
+ dim,
+ num_heads,
+ qkv_bias=False,
+ qk_norm=True,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ device=None,
+ dtype=None,
+ norm_layer=RMSNorm,
+ ):
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads"
+ self.head_dim = self.dim // num_heads
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+
+ self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs)
+ # TODO: eps should be 1 / 65530 if using fp16
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop)
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, pe):
+ """
+ Parameters
+ ----------
+ x: torch.Tensor
+ (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
+ """
+ b, s, d = x.shape
+
+ qkv = self.Wqkv(x)
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d]
+ q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
+ q = self.q_norm(q).half() # [b, s, h, d]
+ k = self.k_norm(k).half()
+ q, k = apply_rope(q, k, pe)
+
+ qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d]
+ context = self.inner_attn(qkv)
+ out = self.out_proj(context.view(b, s, d))
+ out = self.proj_drop(out)
+
+ return out
+
+
+@dataclass
+class ModulationOut:
+ shift: Tensor
+ scale: Tensor
+ gate: Tensor
+
+
+class Modulation(nn.Module):
+ def __init__(self, dim: int, double: bool):
+ super().__init__()
+ self.is_double = double
+ self.multiplier = 6 if double else 3
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
+
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
+
+ return (
+ ModulationOut(*out[:3]),
+ ModulationOut(*out[3:]) if self.is_double else None,
+ )
+
+
+class DoubleStreamBlock(nn.Module):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
+ super().__init__()
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.img_mod = Modulation(hidden_size, double=True)
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
+
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ self.txt_mod = Modulation(hidden_size, double=True)
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
+
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.txt_mlp = nn.Sequential(
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
+ nn.GELU(approximate="tanh"),
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
+ )
+
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
+ img_mod1, img_mod2 = self.img_mod(vec)
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
+
+ # prepare image for attention
+ img_modulated = self.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = self.img_attn.qkv(img_modulated)
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ # run actual attention
+ q = torch.cat((txt_q, img_q), dim=2)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+
+ attn = attention(q, k, v, pe=pe)
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
+
+ # calculate the txt bloks
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
+ return img, txt
+
+
+class SingleStreamBlock(nn.Module):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: float | None = None,
+ ):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ head_dim = hidden_size // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
+ # proj and mlp_out
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
+
+ self.norm = QKNorm(head_dim)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ self.mlp_act = nn.GELU(approximate="tanh")
+ self.modulation = Modulation(hidden_size, double=False)
+
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
+ mod, _ = self.modulation(vec)
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k = self.norm(q, k, v)
+
+ # compute attention
+ attn = attention(q, k, v, pe=pe)
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
+ return x + mod.gate * output
+
+
+class LastLayer(nn.Module):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7797d12a91019355e5517ccc5201042773318ec2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,22 @@
+torch
+gradio
+einops
+diffusers
+transformers
+scipy
+numpy
+regex
+tqdm
+accelerate
+soundfile
+unidecode
+phonemizer
+torchlibrosa
+ftfy
+pandas
+timm
+matplotlib
+thop
+flash-attn==2.6.3
+sentencepiece
+Pillow
\ No newline at end of file
diff --git a/sample.py b/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1c02f31358a065a50a125ae830874087c42531e
--- /dev/null
+++ b/sample.py
@@ -0,0 +1,122 @@
+import os
+import torch
+import argparse
+import math
+from einops import rearrange, repeat
+from PIL import Image
+from diffusers import AutoencoderKL
+from transformers import SpeechT5HifiGan
+
+from utils import load_t5, load_clap, load_ae
+from train import RF
+from constants import build_model
+
+
+def prepare(t5, clip, img, prompt):
+ bs, c, h, w = img.shape
+ if bs == 1 and not isinstance(prompt, str):
+ bs = len(prompt)
+
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
+ if img.shape[0] == 1 and bs > 1:
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
+
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ txt = t5(prompt)
+ if txt.shape[0] == 1 and bs > 1:
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
+
+ vec = clip(prompt)
+ if vec.shape[0] == 1 and bs > 1:
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
+
+ print(img_ids.size(), txt.size(), vec.size())
+ return img, {
+ "img_ids": img_ids.to(img.device),
+ "txt": txt.to(img.device),
+ "txt_ids": txt_ids.to(img.device),
+ "y": vec.to(img.device),
+ }
+
+def main(args):
+ print('generate with MusicFlux')
+ torch.manual_seed(args.seed)
+ torch.set_grad_enabled(False)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ latent_size = (256, 16)
+
+ model = build_model(args.version).to(device)
+ local_path = '/maindata/data/shared/multimodal/zhengcong.fei/code/music-flow/results/base/checkpoints/0050000.pt'
+ state_dict = torch.load(local_path, map_location=lambda storage, loc: storage)
+ model.load_state_dict(state_dict['ema'])
+ model.eval() # important!
+ diffusion = RF()
+
+ model_path = '/maindata/data/shared/multimodal/public/ckpts/FLUX.1-dev'
+
+ # Setup VAE
+ t5 = load_t5(device, max_length=256)
+ clap = load_clap(device, max_length=256)
+
+ model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2'
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device)
+ vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device)
+
+ with open(args.prompt_file, 'r') as f:
+ conds_txt = f.readlines()
+ L = len(conds_txt)
+ unconds_txt = ["low quality, gentle"] * L
+ print(L, conds_txt, unconds_txt)
+
+ init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda()
+
+ STEPSIZE = 50
+ img, conds = prepare(t5, clap, init_noise, conds_txt)
+ _, unconds = prepare(t5, clap, init_noise, unconds_txt)
+ with torch.autocast(device_type='cuda'):
+ images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0)
+
+ print(images[-1].size(), )
+
+ images = rearrange(
+ images[-1],
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
+ h=128,
+ w=8,
+ ph=2,
+ pw=2,)
+ # print(images.size())
+ latents = 1 / vae.config.scaling_factor * images
+ mel_spectrogram = vae.decode(latents).sample
+ print(mel_spectrogram.size())
+
+ for i in range(L):
+ x_i = mel_spectrogram[i]
+ if x_i.dim() == 4:
+ x_i = x_i.squeeze(1)
+ waveform = vocoder(x_i)
+ waveform = waveform[0].cpu().float().detach().numpy()
+ print(waveform.shape)
+ # import soundfile as sf
+ # sf.write('reconstruct.wav', waveform, samplerate=16000)
+ from scipy.io import wavfile
+ wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--version", type=str, default="base")
+ parser.add_argument("--prompt_file", type=str, default='config/example.txt')
+ parser.add_argument("--seed", type=int, default=2024)
+ args = parser.parse_args()
+ main(args)
+
+
diff --git a/scripts/train_b.sh b/scripts/train_b.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f21013df32dc0bff9ec0a8bad713d7e26e5476dd
--- /dev/null
+++ b/scripts/train_b.sh
@@ -0,0 +1,6 @@
+torchrun --nnodes=2 --nproc_per_node=8 train.py \
+--version base \
+--data-path combine_dataset.json \
+--global_batch_size 128 \
+--resume xxx \
+--global-seed 2023
diff --git a/scripts/train_g.sh b/scripts/train_g.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bf0769f2f64cef66b5a9e0f02e57b0ddbbf73a50
--- /dev/null
+++ b/scripts/train_g.sh
@@ -0,0 +1,7 @@
+torchrun --nnodes=4 --nproc_per_node=8 train.py \
+--version giant \
+--data-path combine_dataset.json \
+--resume results/giant/checkpoints/0050000.pt \
+--global_batch_size 128 \
+--global-seed 2023 \
+--accum_iter 8
diff --git a/scripts/train_l.sh b/scripts/train_l.sh
new file mode 100644
index 0000000000000000000000000000000000000000..93dcd890c12dc6e38e1559dce6e341abf5ab1b95
--- /dev/null
+++ b/scripts/train_l.sh
@@ -0,0 +1,5 @@
+torchrun --nnodes=4 --nproc_per_node=8 train.py \
+--version large \
+--data-path combine_dataset.json \
+--global_batch_size 128 \
+--global-seed 2023
\ No newline at end of file
diff --git a/scripts/train_s.sh b/scripts/train_s.sh
new file mode 100644
index 0000000000000000000000000000000000000000..653a4c06bc1ab4ea9cc717765cf6fbe5a3147a18
--- /dev/null
+++ b/scripts/train_s.sh
@@ -0,0 +1,5 @@
+torchrun --nnodes=1 --nproc_per_node=8 train.py \
+--version small \
+--data-path combine_dataset.json \
+--global_batch_size 128 \
+--global-seed 2023
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9fb263045b8a4999d34bf2c88889e21a9421213
--- /dev/null
+++ b/test.py
@@ -0,0 +1,172 @@
+import os
+import json
+
+def test_reconstuct():
+ import yaml
+ from diffusers import AutoencoderKL
+ from transformers import SpeechT5HifiGan
+ from audioldm2.utilities.data.dataset import AudioDataset
+ from utils import load_clip, load_clap, load_t5
+
+ model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2'
+ config = yaml.load(
+ open(
+ 'config/16k_64.yaml',
+ 'r'
+ ),
+ Loader=yaml.FullLoader,
+ )
+ print(config)
+ t5 = load_t5('cuda', max_length=256)
+ clap = load_clap('cuda', max_length=256)
+
+ dataset = AudioDataset(
+ config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json',
+ tokenizer=clap.tokenizer,
+ uncond_pro=0.1,
+ text_ctx_len=77,
+ tokenizer_t5=t5.tokenizer,
+ text_ctx_len_t5=256,
+ uncond_pro_t5=0.1,
+ )
+ print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size())
+
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae'))
+ vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder'))
+ latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor)
+ print('laten size:', latents.size())
+
+ latents = 1 / vae.config.scaling_factor * latents
+ mel_spectrogram = vae.decode(latents).sample
+ print(mel_spectrogram.size())
+ if mel_spectrogram.dim() == 4:
+ mel_spectrogram = mel_spectrogram.squeeze(1)
+ waveform = vocoder(mel_spectrogram)
+ waveform = waveform[0].cpu().float().detach().numpy()
+ print(waveform.shape)
+ # import soundfile as sf
+ # sf.write('reconstruct.wav', waveform, samplerate=16000)
+ from scipy.io import wavfile
+ # wavfile.write('reconstruct.wav', 16000, waveform)
+
+
+
+def mini_dataset(num=32):
+ data = []
+ for i in range(num):
+ data.append(
+ {
+ 'wav': 'case.mp3',
+ 'label': 'a beautiful music',
+ }
+ )
+
+ with open('mini_dataset.json', 'w') as f:
+ json.dump(data, f, indent=4)
+
+
+def fma_dataset():
+ import pandas as pd
+
+ annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation"
+ annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet']
+ dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large'
+
+ data = []
+ for annotation_file in annotation_list:
+ annotation_file = os.path.join(annotation_prex, annotation_file)
+ df=pd.read_parquet(annotation_file)
+ print(df.shape)
+ for id, row in df.iterrows():
+ #print(id, row['pseudo_caption'], row['path'])
+ tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3')
+ # print(tmp_path)
+ if os.path.exists(tmp_path):
+ data.append(
+ {
+ 'wav': tmp_path,
+ 'label': row['pseudo_caption'],
+ }
+ )
+ # break
+ print(len(data))
+ with open('fma_dataset.json', 'w') as f:
+ json.dump(data, f, indent=4)
+
+
+
+
+
+def audioset_dataset():
+ import pandas as pd
+ dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset'
+ annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet'
+ df=pd.read_parquet(annotation_path)
+ print(df.shape)
+
+ data = []
+ for id, row in df.iterrows():
+ #print(id, row['pseudo_caption'], row['path'])
+ try:
+ tmp_path = os.path.join(dataset_prex, row['path'] + '.flac')
+ except:
+ print(row['path'])
+
+ if os.path.exists(tmp_path):
+ # print(tmp_path)
+ data.append(
+ {
+ 'wav': tmp_path,
+ 'label': row['pseudo_caption'],
+ }
+ )
+ print(len(data))
+ with open('audioset_dataset.json', 'w') as f:
+ json.dump(data, f, indent=4)
+
+
+
+def combine_dataset():
+ data_list = ['fma_dataset.json', 'audioset_dataset.json']
+
+ data = []
+ for data_file in data_list:
+ with open(data_file, 'r') as f:
+ data += json.load(f)
+ print(len(data))
+ with open('combine_dataset.json', 'w') as f:
+ json.dump(data, f, indent=4)
+
+
+
+def test_music_format():
+ import torchaudio
+ filename = '2.flac'
+ waveform, sr = torchaudio.load(filename,)
+ print(waveform, sr )
+
+
+def test_flops():
+ version = 'giant'
+ import torch
+ from constants import build_model
+ from thop import profile
+
+ model = build_model(version).cuda()
+ img_ids = torch.randn((1, 1024, 3)).cuda()
+ txt = torch.randn((1, 256, 4096)).cuda()
+ txt_ids = torch.randn((1, 256, 3)).cuda()
+ y = torch.randn((1, 768)).cuda()
+ x = torch.randn((1, 1024, 32)).cuda()
+ t = torch.tensor([1] * 1).cuda()
+ flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,))
+ print('FLOPs = ' + str(flops * 2/1000**3) + 'G')
+
+
+# test_music_format()
+# test_reconstuct()
+# mini_dataset()
+# fma_dataset()
+# audioset_dataset()
+# combine_dataset()
+test_flops()
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b5eb3c98c29d2c821fdac74a895cb8158dd8ffb
--- /dev/null
+++ b/train.py
@@ -0,0 +1,374 @@
+import torch
+import os
+import argparse
+import logging
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from copy import deepcopy
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data import DataLoader
+from glob import glob
+import yaml
+from collections import OrderedDict
+from time import time
+from einops import rearrange, repeat
+
+from diffusers import AutoencoderKL
+from transformers import SpeechT5HifiGan
+from audioldm2.utilities.data.dataset import AudioDataset
+
+from constants import build_model
+from utils import load_clip, load_clap, load_t5
+from thop import profile
+
+
+@torch.no_grad()
+def update_ema(ema_model, model, decay=0.9999):
+ """
+ Step the EMA model towards the current model.
+ """
+ ema_params = OrderedDict(ema_model.named_parameters())
+ model_params = OrderedDict(model.named_parameters())
+
+ for name, param in model_params.items():
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
+
+
+def requires_grad(model, flag=True):
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
+
+def cleanup():
+ """
+ End DDP training.
+ """
+ dist.destroy_process_group()
+
+
+def create_logger(logging_dir):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if dist.get_rank() == 0: # real logger
+ logging.basicConfig(
+ level=logging.INFO,
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
+ )
+ logger = logging.getLogger(__name__)
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+
+class RF(torch.nn.Module):
+ def __init__(self, ln=True):
+ super().__init__()
+ self.ln = ln
+ self.stratified = False
+
+ def forward(self, model, x, **kwargs):
+
+ b = x.size(0)
+ if self.ln:
+ if self.stratified:
+ # stratified sampling of normals
+ # first stratified sample from uniform
+ quantiles = torch.linspace(0, 1, b + 1).to(x.device)
+ z = quantiles[:-1] + torch.rand((b,)).to(x.device) / b
+ # now transform to normal
+ z = torch.erfinv(2 * z - 1) * math.sqrt(2)
+ t = torch.sigmoid(z)
+ else:
+ nt = torch.randn((b,)).to(x.device)
+ t = torch.sigmoid(nt)
+ else:
+ t = torch.rand((b,)).to(x.device)
+ texp = t.view([b, *([1] * len(x.shape[1:]))])
+ z1 = torch.randn_like(x)
+ zt = (1 - texp) * x + texp * z1
+
+ # make t, zt into same dtype as x
+ zt, t = zt.to(x.dtype), t.to(x.dtype)
+ vtheta = model(x=zt, t=t, **kwargs)
+ # print(z1.size(), x.size(), vtheta.size())
+ batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape))))
+ tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
+ ttloss = [(tv, tloss) for tv, tloss in zip(t, tlist)]
+ return batchwise_mse.mean(), {"batchwise_loss": ttloss}
+
+ @torch.no_grad()
+ def sample(self, model, z, conds, null_cond=None, sample_steps=50, cfg=2.0, **kwargs):
+ b = z.size(0)
+ dt = 1.0 / sample_steps
+ dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))])
+ images = [z]
+ for i in range(sample_steps, 0, -1):
+ t = i / sample_steps
+ t = torch.tensor([t] * b).to(z.device)
+
+ vc = model(x=z, t=t, **conds)
+ if null_cond is not None:
+ vu = model(x=z, t=t, **null_cond)
+ vc = vu + cfg * (vc - vu)
+
+ z = z - dt * vc
+ images.append(z)
+ return images
+
+ @torch.no_grad()
+ def sample_with_xps(self, model, z, conds, null_cond=None, sample_steps=50, cfg=2.0, **kwargs):
+ b = z.size(0)
+ dt = 1.0 / sample_steps
+ dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))])
+ images = [z]
+ for i in range(sample_steps, 0, -1):
+ t = i / sample_steps
+ t = torch.tensor([t] * b).to(z.device)
+
+ # print(z.size(), t.size())
+ vc = model(x=z, t=t, **conds)
+ if null_cond is not None:
+ vu = model(x=z, t=t, **null_cond)
+ vc = vu + cfg * (vc - vu)
+ x = z - i * dt * vc
+ z = z - dt * vc
+ images.append(x)
+ return images
+
+
+def prepare_model_inputs(args, batch, device, vae, clip, t5,):
+ text_embedding, text_embedding_mask = batch['text_embedding'], batch['text_embedding_mask']
+ text_embedding_t5, text_embedding_mask_t5 = batch['text_embedding_t5'], batch['text_embedding_mask_t5']
+ # print(image.size(), text_embedding.size(), text_embedding_t5.size())
+
+ # clip & mT5 text embedding
+ text_embedding = text_embedding.to(device)
+ text_embedding_mask = text_embedding_mask.to(device)
+ with torch.no_grad():
+ encoder_hidden_states = clip.hf_module(
+ text_embedding.to(device),
+ attention_mask=text_embedding_mask,
+ output_hidden_states=False,
+ )["pooler_output"] # ()
+
+ # print(encoder_hidden_states.size())
+
+ text_embedding_t5 = text_embedding_t5.to(device).squeeze(1)
+ text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1)
+ with torch.no_grad():
+ output_t5 = t5.hf_module(
+ input_ids=text_embedding_t5,
+ attention_mask=text_embedding_mask_t5,
+ output_hidden_states=False,
+ )
+ encoder_hidden_states_t5 = output_t5["last_hidden_state"].detach()
+
+ with torch.no_grad():
+ image = vae.encode(batch['log_mel_spec'].unsqueeze(1).to(device)).latent_dist.sample().mul_(vae.config.scaling_factor)
+
+ # positional embedding
+ bs, c, h, w = image.shape
+ image = rearrange(image, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2).float()
+ img_ids = torch.zeros(h // 2, w // 2, 3)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ txt_ids = torch.zeros(bs, encoder_hidden_states_t5.shape[1], 3)
+ # Model conditions
+ model_kwargs = dict(
+ img_ids=img_ids.to(image.device),
+ txt = encoder_hidden_states_t5.to(image.device).float(),
+ txt_ids = txt_ids.to(image.device),
+ y = encoder_hidden_states.to(image.device).float(),
+ )
+
+ return image, model_kwargs
+
+
+
+def main(args):
+ assert torch.cuda.is_available(), "Training currently requires at least one GPU."
+
+ dist.init_process_group("nccl")
+ assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
+ rank = dist.get_rank()
+ device = rank % torch.cuda.device_count()
+ seed = args.global_seed * dist.get_world_size() + rank
+ torch.manual_seed(seed)
+ torch.cuda.set_device(device)
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
+
+ # Setup an experiment folder:
+ if rank == 0:
+ os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
+ experiment_index = len(glob(f"{args.results_dir}/*"))
+ model_string_name = args.version.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
+ experiment_dir = f"{args.results_dir}/{model_string_name}" # Create an experiment folder
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
+ os.makedirs(checkpoint_dir, exist_ok=True)
+ logger = create_logger(experiment_dir)
+ logger.info(f"Experiment directory created at {experiment_dir}")
+ else:
+ logger = create_logger(None)
+
+
+ model = build_model(args.version).to(device)
+ parameters_sum = sum(x.numel() for x in model.parameters())
+ logger.info(f"{parameters_sum / 1000000.0} M")
+
+ if args.resume is not None:
+ print('load from: ', args.resume)
+ resume_ckpt = torch.load(args.resume, map_location=lambda storage, loc: storage)['ema']
+ model.load_state_dict(resume_ckpt)
+
+ # Note that parameter initialization is done within the DiT constructor
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
+ requires_grad(ema, False)
+ model = DDP(model.to(device), device_ids=[rank])
+
+ diffusion = RF()
+ model_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioldm2'
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device)
+ # vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device)
+ t5 = load_t5(device, max_length=256)
+ clap = load_clap(device, max_length=256)
+ # clip = load_clip(device)
+
+ opt = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0)
+
+
+ config = yaml.load(
+ open(
+ 'config/16k_64.yaml',
+ 'r'
+ ),
+ Loader=yaml.FullLoader,
+ )
+ dataset = AudioDataset(
+ config=config, split="train",
+ waveform_only=False,
+ dataset_json_path=args.data_path,
+ tokenizer=clap.tokenizer,
+ uncond_pro=0.1,
+ text_ctx_len=77,
+ tokenizer_t5=t5.tokenizer,
+ text_ctx_len_t5=256,
+ uncond_pro_t5=0.1,
+ )
+ sampler = DistributedSampler(
+ dataset,
+ num_replicas=dist.get_world_size(),
+ rank=rank,
+ shuffle=True,
+ seed=args.global_seed
+ )
+ loader = DataLoader(
+ dataset,
+ batch_size=int(args.global_batch_size // dist.get_world_size()),
+ shuffle=False,
+ sampler=sampler,
+ num_workers=args.num_workers,
+ pin_memory=True,
+ drop_last=True
+ )
+ logger.info(f"Dataset contains {len(dataset):,}")
+
+ # Prepare models for training:
+ update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
+ model.train() # important! This enables embedding dropout for classifier-free guidance
+ ema.eval() # EMA model should always be in eval mode
+
+ # Variables for monitoring/logging purposes:
+ train_steps = 0
+ log_steps = 0
+ running_loss = 0
+ start_time = time()
+ logger.info(f"Training for {args.epochs} epochs...")
+ for epoch in range(args.epochs):
+ sampler.set_epoch(epoch)
+ logger.info(f"Beginning epoch {epoch}...")
+ data_iter_step = 0
+ for batch in loader:
+ latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, clap, t5,)
+ loss, _ = diffusion.forward(model=model, x=latents, **model_kwargs)
+ # print(loss)
+ if (data_iter_step + 1) % args.accum_iter == 0:
+ opt.zero_grad()
+ loss.backward()
+ opt.step()
+ update_ema(ema, model.module)
+
+ data_iter_step += 1
+ # Log loss values:
+ running_loss += loss.item()
+ log_steps += 1
+ train_steps += 1
+ if train_steps % args.log_every == 0:
+ # Measure training speed:
+ torch.cuda.synchronize()
+ end_time = time()
+ steps_per_sec = log_steps / (end_time - start_time)
+ # Reduce loss history over all processes:
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
+ avg_loss = avg_loss.item() / dist.get_world_size()
+ logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
+ # Reset monitoring variables:
+ running_loss = 0
+ log_steps = 0
+ start_time = time()
+
+ # Save DiT checkpoint:
+ if train_steps % args.ckpt_every == 0 and train_steps > 0:
+ if rank == 0:
+ checkpoint = {
+ # "model": model.module.state_dict(),
+ "ema": ema.state_dict(),
+ "opt": opt.state_dict(),
+ "args": args
+ }
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
+ try:
+ torch.save(checkpoint, checkpoint_path)
+ except Exception as e:
+ print(e)
+
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
+ dist.barrier()
+
+ # model.eval() # important! This disables randomized embedding dropout
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
+
+ logger.info("Done!")
+ cleanup()
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data-path", type=str, default='fma_dataset.json')
+ parser.add_argument("--results-dir", type=str, default="results")
+ parser.add_argument("--resume", type=str, default=None)
+ parser.add_argument("--version", type=str, default="large")
+ parser.add_argument("--vae-path", type=str, default='audioldm2/vae')
+ parser.add_argument("--epochs", type=int, default=1400)
+ parser.add_argument("--global_batch_size", type=int, default=32)
+ parser.add_argument("--global-seed", type=int, default=1234)
+ parser.add_argument("--num-workers", type=int, default=4)
+ parser.add_argument("--log-every", type=int, default=100)
+ parser.add_argument('--accum_iter', default=16, type=int,)
+ parser.add_argument("--ckpt-every", type=int, default=100_000)
+ parser.add_argument('--local-rank', type=int, default=-1, help='local rank passed from distributed launcher')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e4f580e1ddb49f80b22a51cd93e2ea897da3b8f
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,36 @@
+import torch
+from modules.autoencoder import AutoEncoder, AutoEncoderParams
+from modules.conditioner import HFEmbedder
+from safetensors.torch import load_file as load_sft
+
+
+def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
+
+
+def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
+
+
+def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
+ return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device)
+
+def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder:
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ )
+ # Loading the autoencoder
+ ae = AutoEncoder(ae_params)
+ sd = load_sft(ckpt_path,)
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
+ ae.to(device)
+ return ae
diff --git a/visuals/framework.png b/visuals/framework.png
new file mode 100644
index 0000000000000000000000000000000000000000..e46b1f72741854ad8256f99d160ea0398e286902
Binary files /dev/null and b/visuals/framework.png differ