{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "VjYy0F2gZIPR"
            },
            "outputs": [],
            "source": [
                "%cd /content\n",
                "!git clone -b dev https://github.com/camenduru/FluxMusic\n",
                "%cd C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic\n",
                "\n",
                "!apt -y install -qq aria2\n",
                "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/FluxMusic/resolve/main/musicflow_b.pt -d C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic -o musicflow_b.pt\n",
                "\n",
                "!pip install transformers diffusers accelerate einops soundfile progressbar unidecode phonemizer torchlibrosa ftfy pandas timm matplotlib numpy==1.26.4 thop flash-attn==2.6.3 sentencepiece"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "NoTEt9Wto70D"
            },
            "outputs": [],
            "source": [
                "%cd C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter\n",
                "\n",
                "import os\n",
                "import torch\n",
                "import argparse\n",
                "import math\n",
                "from einops import rearrange, repeat\n",
                "from PIL import Image\n",
                "from diffusers import AutoencoderKL\n",
                "from transformers import SpeechT5HifiGan\n",
                "\n",
                "from utils import load_t5, load_clap, load_ae\n",
                "from train import RF\n",
                "from constants import build_model\n",
                "\n",
                "def prepare(t5, clip, img, prompt):\n",
                "    bs, c, h, w = img.shape\n",
                "    if bs == 1 and not isinstance(prompt, str):\n",
                "        bs = len(prompt)\n",
                "\n",
                "    img = rearrange(img, \"b c (h ph) (w pw) -> b (h w) (c ph pw)\", ph=2, pw=2)\n",
                "    if img.shape[0] == 1 and bs > 1:\n",
                "        img = repeat(img, \"1 ... -> bs ...\", bs=bs)\n",
                "\n",
                "    img_ids = torch.zeros(h // 2, w // 2, 3)\n",
                "    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]\n",
                "    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]\n",
                "    img_ids = repeat(img_ids, \"h w c -> b (h w) c\", b=bs)\n",
                "\n",
                "    if isinstance(prompt, str):\n",
                "        prompt = [prompt]\n",
                "    txt = t5(prompt)\n",
                "    if txt.shape[0] == 1 and bs > 1:\n",
                "        txt = repeat(txt, \"1 ... -> bs ...\", bs=bs)\n",
                "    txt_ids = torch.zeros(bs, txt.shape[1], 3)\n",
                "\n",
                "    vec = clip(prompt)\n",
                "    if vec.shape[0] == 1 and bs > 1:\n",
                "        vec = repeat(vec, \"1 ... -> bs ...\", bs=bs)\n",
                "\n",
                "    print(img_ids.size(), txt.size(), vec.size())\n",
                "    return img, {\n",
                "        \"img_ids\": img_ids.to(img.device),\n",
                "        \"txt\": txt.to(img.device),\n",
                "        \"txt_ids\": txt_ids.to(img.device),\n",
                "        \"y\": vec.to(img.device),\n",
                "    }\n",
                "\n",
                "version=\"base\"\n",
                "seed=2024\n",
                "prompt_file=\"C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/config/example.txt\"\n",
                "\n",
                "print('generate with MusicFlux')\n",
                "torch.manual_seed(seed)\n",
                "torch.set_grad_enabled(False)\n",
                "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
                "\n",
                "latent_size = (256, 16)\n",
                "\n",
                "model = build_model(version).to(device)\n",
                "local_path = 'C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/musicflow_b.pt'\n",
                "state_dict = torch.load(local_path, map_location=lambda storage, loc: storage, weights_only=True)\n",
                "model.load_state_dict(state_dict['ema'])\n",
                "model.eval()  # important!\n",
                "diffusion = RF()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "B5ebyTmto70D"
            },
            "outputs": [],
            "source": [
                "t5 = load_t5(device, max_length=256)\n",
                "clap = load_clap(device, max_length=256)\n",
                "\n",
                "vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder=\"vae\").to(device)\n",
                "vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder=\"vocoder\").to(device)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "xqG8Px6xo70D"
            },
            "outputs": [],
            "source": [
                "prompt_file=\"C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/config/example.txt\"\n",
                "\n",
                "with open(prompt_file, 'r') as f:\n",
                "    conds_txt = f.readlines()\n",
                "L = len(conds_txt)\n",
                "unconds_txt = [\"low quality, gentle\"] * L\n",
                "print(L, conds_txt, unconds_txt)\n",
                "\n",
                "init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda()\n",
                "\n",
                "STEPSIZE = 50\n",
                "img, conds = prepare(t5, clap, init_noise, conds_txt)\n",
                "_, unconds = prepare(t5, clap, init_noise, unconds_txt)\n",
                "with torch.autocast(device_type='cuda'):\n",
                "    images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0)\n",
                "\n",
                "print(images[-1].size(), )\n",
                "\n",
                "images = rearrange(\n",
                "    images[-1],\n",
                "    \"b (h w) (c ph pw) -> b c (h ph) (w pw)\",\n",
                "    h=128,\n",
                "    w=8,\n",
                "    ph=2,\n",
                "    pw=2,)\n",
                "# print(images.size())\n",
                "latents = 1 / vae.config.scaling_factor * images\n",
                "mel_spectrogram = vae.decode(latents).sample\n",
                "print(mel_spectrogram.size())"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {
                "id": "ytAXlAEdo70D"
            },
            "outputs": [],
            "source": [
                "!mkdir C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic/b_output\n",
                "\n",
                "for i in range(L):\n",
                "    x_i = mel_spectrogram[i]\n",
                "    if x_i.dim() == 4:\n",
                "        x_i = x_i.squeeze(1)\n",
                "    waveform = vocoder(x_i)\n",
                "    waveform = waveform[0].cpu().float().detach().numpy()\n",
                "    print(waveform.shape)\n",
                "    # import soundfile as sf\n",
                "    # sf.write('reconstruct.wav', waveform, samplerate=16000)\n",
                "    from  scipy.io import wavfile\n",
                "    wavfile.write('C:/Users/Curt/Developer/AItools/AIaudio/AudioCreation/FluxMusicJupyter/FluxMusic/b_output/sample_' + str(i) + '.wav', 16000, waveform)"
            ]
        }
    ],
    "metadata": {
        "accelerator": "GPU",
        "colab": {
            "gpuType": "T4",
            "provenance": []
        },
        "kernelspec": {
            "display_name": "Python 3",
            "name": "python3"
        },
        "language_info": {
            "name": "python"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 0
}