Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Upload 288 files
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .env +8 -0
- .gitattributes +1 -0
- .gitignore +49 -0
- Applio_(Mangio_RVC_Fork).ipynb +169 -0
- Dockerfile +29 -0
- Fixes/local_fixes.py +136 -0
- Fixes/tensor-launch.py +15 -0
- LICENSE +59 -0
- LazyImport.py +13 -0
- MDX-Net_Colab.ipynb +524 -0
- MDXNet.py +272 -0
- Makefile +63 -0
- README.md +222 -12
- assets/hubert/.gitignore +2 -0
- assets/pretrained/.gitignore +2 -0
- assets/pretrained_v2/.gitignore +2 -0
- assets/rmvpe/.gitignore +2 -0
- assets/uvr5_weights/.gitignore +2 -0
- assets/weights/.gitignore +2 -0
- audioEffects.py +37 -0
- audios/.gitignore +0 -0
- colab_for_mdx.py +71 -0
- configs/32k.json +50 -0
- configs/32k_v2.json +50 -0
- configs/40k.json +50 -0
- configs/48k.json +50 -0
- configs/48k_v2.json +50 -0
- configs/config.json +15 -0
- configs/config.py +265 -0
- configs/v1/32k.json +46 -0
- configs/v1/40k.json +46 -0
- configs/v1/48k.json +46 -0
- configs/v2/32k.json +46 -0
- configs/v2/48k.json +46 -0
- csvdb/formanting.csv +0 -0
- csvdb/stop.csv +0 -0
- demucs/__init__.py +7 -0
- demucs/__main__.py +317 -0
- demucs/audio.py +172 -0
- demucs/augment.py +106 -0
- demucs/compressed.py +115 -0
- demucs/model.py +202 -0
- demucs/parser.py +244 -0
- demucs/pretrained.py +107 -0
- demucs/raw.py +173 -0
- demucs/repitch.py +96 -0
- demucs/separate.py +185 -0
- demucs/tasnet.py +452 -0
- demucs/test.py +109 -0
- demucs/train.py +127 -0
    	
        .env
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            OPENBLAS_NUM_THREADS = 1
         | 
| 2 | 
            +
            no_proxy = localhost, 127.0.0.1, ::1
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # You can change the location of the model, etc. by changing here
         | 
| 5 | 
            +
            weight_root = weights
         | 
| 6 | 
            +
            weight_uvr5_root = uvr5_weights
         | 
| 7 | 
            +
            index_root = logs
         | 
| 8 | 
            +
            rmvpe_root = assets/rmvpe
         | 
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            stftpitchshift filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .DS_Store
         | 
| 2 | 
            +
            __pycache__
         | 
| 3 | 
            +
            /TEMP
         | 
| 4 | 
            +
            /DATASETS
         | 
| 5 | 
            +
            /RUNTIME
         | 
| 6 | 
            +
            *.pyd
         | 
| 7 | 
            +
            hubert_base.pt
         | 
| 8 | 
            +
            .venv
         | 
| 9 | 
            +
            alexforkINSTALL.bat
         | 
| 10 | 
            +
            Changelog_CN.md
         | 
| 11 | 
            +
            Changelog_EN.md
         | 
| 12 | 
            +
            Changelog_KO.md
         | 
| 13 | 
            +
            difdep.py
         | 
| 14 | 
            +
            EasierGUI.py
         | 
| 15 | 
            +
            envfilescheck.bat
         | 
| 16 | 
            +
            export_onnx.py
         | 
| 17 | 
            +
            .vscode/
         | 
| 18 | 
            +
            export_onnx_old.py
         | 
| 19 | 
            +
            ffmpeg.exe
         | 
| 20 | 
            +
            ffprobe.exe
         | 
| 21 | 
            +
            Fixes/Launch_Tensorboard.bat
         | 
| 22 | 
            +
            Fixes/LOCAL_CREPE_FIX.bat
         | 
| 23 | 
            +
            Fixes/local_fixes.py
         | 
| 24 | 
            +
            Fixes/tensor-launch.py
         | 
| 25 | 
            +
            gui.py
         | 
| 26 | 
            +
            infer-web — backup.py
         | 
| 27 | 
            +
            infer-webbackup.py
         | 
| 28 | 
            +
            install_easy_dependencies.py
         | 
| 29 | 
            +
            install_easyGUI.bat
         | 
| 30 | 
            +
            installstft.bat
         | 
| 31 | 
            +
            Launch_Tensorboard.bat
         | 
| 32 | 
            +
            listdepend.bat
         | 
| 33 | 
            +
            LOCAL_CREPE_FIX.bat
         | 
| 34 | 
            +
            local_fixes.py
         | 
| 35 | 
            +
            oldinfer.py
         | 
| 36 | 
            +
            onnx_inference_demo.py
         | 
| 37 | 
            +
            Praat.exe
         | 
| 38 | 
            +
            requirementsNEW.txt
         | 
| 39 | 
            +
            rmvpe.pt
         | 
| 40 | 
            +
            rmvpe.onnx
         | 
| 41 | 
            +
            run_easiergui.bat
         | 
| 42 | 
            +
            tensor-launch.py
         | 
| 43 | 
            +
            values1.json
         | 
| 44 | 
            +
            使用需遵守的协议-LICENSE.txt
         | 
| 45 | 
            +
            !logs/
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            logs/*
         | 
| 48 | 
            +
            logs/mute/0_gt_wavs/mute40k.spec.pt
         | 
| 49 | 
            +
            !logs/mute/
         | 
    	
        Applio_(Mangio_RVC_Fork).ipynb
    ADDED
    
    | @@ -0,0 +1,169 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "cells": [
         | 
| 3 | 
            +
                {
         | 
| 4 | 
            +
                  "cell_type": "code",
         | 
| 5 | 
            +
                  "execution_count": null,
         | 
| 6 | 
            +
                  "metadata": {
         | 
| 7 | 
            +
                    "cellView": "form",
         | 
| 8 | 
            +
                    "id": "izLwNF_8T1TK"
         | 
| 9 | 
            +
                  },
         | 
| 10 | 
            +
                  "outputs": [],
         | 
| 11 | 
            +
                  "source": [
         | 
| 12 | 
            +
                    "#@title <font color='#06ae56'>**🍏 Applio (Mangio-RVC-Fork)**</font>\n",
         | 
| 13 | 
            +
                    "import time\n",
         | 
| 14 | 
            +
                    "import os\n",
         | 
| 15 | 
            +
                    "import subprocess\n",
         | 
| 16 | 
            +
                    "import shutil\n",
         | 
| 17 | 
            +
                    "import threading\n",
         | 
| 18 | 
            +
                    "import base64\n",
         | 
| 19 | 
            +
                    "import threading\n",
         | 
| 20 | 
            +
                    "import time\n",
         | 
| 21 | 
            +
                    "from IPython.display import HTML, clear_output\n",
         | 
| 22 | 
            +
                    "\n",
         | 
| 23 | 
            +
                    "nosv_name1 = base64.b64decode(('ZXh0ZXJuYWxj').encode('ascii')).decode('ascii')\n",
         | 
| 24 | 
            +
                    "nosv_name2 = base64.b64decode(('b2xhYmNvZGU=').encode('ascii')).decode('ascii')\n",
         | 
| 25 | 
            +
                    "guebui = base64.b64decode(('V2U=').encode('ascii')).decode('ascii')\n",
         | 
| 26 | 
            +
                    "guebui2 = base64.b64decode(('YlVJ').encode('ascii')).decode('ascii')\n",
         | 
| 27 | 
            +
                    "pbestm = base64.b64decode(('cm12cGU=').encode('ascii')).decode('ascii')\n",
         | 
| 28 | 
            +
                    "tryre = base64.b64decode(('UmV0cmlldmFs').encode('ascii')).decode('ascii')\n",
         | 
| 29 | 
            +
                    "\n",
         | 
| 30 | 
            +
                    "xdsame = '/content/'+ tryre +'-based-Voice-Conversion-' + guebui + guebui2 +'/'\n",
         | 
| 31 | 
            +
                    "\n",
         | 
| 32 | 
            +
                    "collapsible_section = \"\"\"\n",
         | 
| 33 | 
            +
                    "<br>\n",
         | 
| 34 | 
            +
                    "<br>\n",
         | 
| 35 | 
            +
                    "<details style=\"border: 1px solid #ddd; border-radius: 5px; padding: 10px; margin-bottom: 10px;\">\n",
         | 
| 36 | 
            +
                    "    <summary open style=\"font-weight: bold; cursor: pointer;\">🚀 Click to learn more about Applio</summary>\n",
         | 
| 37 | 
            +
                    "    <div style=\"margin-left: 20px;\">\n",
         | 
| 38 | 
            +
                    "        <ul>\n",
         | 
| 39 | 
            +
                    "            <li><a href=\"https://github.com/Mangio621/Mangio-RVC-Fork\" style=\"color: #06ae56;\">Mangio-RVC-Fork</a> - Source of inspiration and base for this improved code, special thanks to the developers.</li>\n",
         | 
| 40 | 
            +
                    "            <li><a href=\"https://github.com/Anjok07/ultimatevocalremovergui\" style=\"color: #06ae56;\">UltimateVocalRemover</a> - Used for voice and instrument separation.</li>\n",
         | 
| 41 | 
            +
                    "            <li>Vidal, Blaise & Aitron - Contributors to the Applio version.</li>\n",
         | 
| 42 | 
            +
                    "            <li>kalomaze - Creator of external scripts that help the functioning of Applio.</li>\n",
         | 
| 43 | 
            +
                    "        </ul>\n",
         | 
| 44 | 
            +
                    "        <p style=\"color: #fff;\">Join and contribute to the project on <a href=\"https://github.com/IAHispano/Applio-RVC-Fork\" style=\"color: #06ae56;\">our GitHub repository</a>.</p>\n",
         | 
| 45 | 
            +
                    "    </div>\n",
         | 
| 46 | 
            +
                    "</details>\n",
         | 
| 47 | 
            +
                    "<br>\n",
         | 
| 48 | 
            +
                    "<button style=\"font-weight: bold; cursor: pointer; background-color: #06ae56; color: white; border: 1px solid #fff; border-radius: 4px; padding: 10px 20px; text-decoration: none;\" onclick=\"window.open('https://discord.gg/IAHispano', '_blank')\">🍏 Join our support Discord server (IA Hispano)</button>\n",
         | 
| 49 | 
            +
                    "<br>\n",
         | 
| 50 | 
            +
                    "<br>\n",
         | 
| 51 | 
            +
                    "\"\"\"\n",
         | 
| 52 | 
            +
                    "#@markdown **Settings:**\n",
         | 
| 53 | 
            +
                    "ForceUpdateDependencies = True\n",
         | 
| 54 | 
            +
                    "ForceNoMountDrive = False\n",
         | 
| 55 | 
            +
                    "#@markdown Restore your backup from Google Drive.\n",
         | 
| 56 | 
            +
                    "LoadBackupDrive = False #@param{type:\"boolean\"}\n",
         | 
| 57 | 
            +
                    "#@markdown Make regular backups of your model's training.\n",
         | 
| 58 | 
            +
                    "AutoBackups = True #@param{type:\"boolean\"}\n",
         | 
| 59 | 
            +
                    "if not os.path.exists(xdsame):\n",
         | 
| 60 | 
            +
                    " current_path = os.getcwd()\n",
         | 
| 61 | 
            +
                    " shutil.rmtree('/content/')\n",
         | 
| 62 | 
            +
                    " os.makedirs('/content/', exist_ok=True)\n",
         | 
| 63 | 
            +
                    "\n",
         | 
| 64 | 
            +
                    " os.chdir(current_path)\n",
         | 
| 65 | 
            +
                    " !git clone https://github.com/IAHispano/$nosv_name1$nosv_name2 /content/$tryre-based-Voice-Conversion-$guebui$guebui2/utils\n",
         | 
| 66 | 
            +
                    " clear_output()\n",
         | 
| 67 | 
            +
                    "\n",
         | 
| 68 | 
            +
                    " os.chdir(xdsame)\n",
         | 
| 69 | 
            +
                    " from utils.dependency import *\n",
         | 
| 70 | 
            +
                    " from utils.clonerepo_experimental import *\n",
         | 
| 71 | 
            +
                    " os.chdir(\"..\")\n",
         | 
| 72 | 
            +
                    "\n",
         | 
| 73 | 
            +
                    "\n",
         | 
| 74 | 
            +
                    "\n",
         | 
| 75 | 
            +
                    " setup_environment(ForceUpdateDependencies, ForceNoMountDrive)\n",
         | 
| 76 | 
            +
                    " clone_repository(True)\n",
         | 
| 77 | 
            +
                    "\n",
         | 
| 78 | 
            +
                    " !wget https://huggingface.co/lj1995/VoiceConversion$guebui$guebui2/resolve/main/rmvpe.pt -P /content/Retrieval-based-Voice-Conversion-$guebui$guebui2/\n",
         | 
| 79 | 
            +
                    " clear_output()\n",
         | 
| 80 | 
            +
                    "\n",
         | 
| 81 | 
            +
                    "base_path = \"/content/Retrieval-based-Voice-Conversion-$guebui$guebui2/\"\n",
         | 
| 82 | 
            +
                    "clear_output()\n",
         | 
| 83 | 
            +
                    "\n",
         | 
| 84 | 
            +
                    "\n",
         | 
| 85 | 
            +
                    "\n",
         | 
| 86 | 
            +
                    "from utils import backups\n",
         | 
| 87 | 
            +
                    "\n",
         | 
| 88 | 
            +
                    "LOGS_FOLDER = xdsame + '/logs'\n",
         | 
| 89 | 
            +
                    "if not os.path.exists(LOGS_FOLDER):\n",
         | 
| 90 | 
            +
                    "    os.makedirs(LOGS_FOLDER)\n",
         | 
| 91 | 
            +
                    "    clear_output()\n",
         | 
| 92 | 
            +
                    "\n",
         | 
| 93 | 
            +
                    "WEIGHTS_FOLDER = xdsame + '/logs' + '/weights'\n",
         | 
| 94 | 
            +
                    "if not os.path.exists(WEIGHTS_FOLDER):\n",
         | 
| 95 | 
            +
                    "    os.makedirs(WEIGHTS_FOLDER)\n",
         | 
| 96 | 
            +
                    "    clear_output()\n",
         | 
| 97 | 
            +
                    "\n",
         | 
| 98 | 
            +
                    "others_FOLDER = xdsame + '/audio-others'\n",
         | 
| 99 | 
            +
                    "if not os.path.exists(others_FOLDER):\n",
         | 
| 100 | 
            +
                    "    os.makedirs(others_FOLDER)\n",
         | 
| 101 | 
            +
                    "    clear_output()\n",
         | 
| 102 | 
            +
                    "\n",
         | 
| 103 | 
            +
                    "audio_outputs_FOLDER = xdsame + '/audio-outputs'\n",
         | 
| 104 | 
            +
                    "if not os.path.exists(audio_outputs_FOLDER):\n",
         | 
| 105 | 
            +
                    "    os.makedirs(audio_outputs_FOLDER)\n",
         | 
| 106 | 
            +
                    "    clear_output()\n",
         | 
| 107 | 
            +
                    "\n",
         | 
| 108 | 
            +
                    "if LoadBackupDrive:\n",
         | 
| 109 | 
            +
                    "    backups.import_google_drive_backup()\n",
         | 
| 110 | 
            +
                    "    clear_output()\n",
         | 
| 111 | 
            +
                    "\n",
         | 
| 112 | 
            +
                    "#@markdown Choose the language in which you want the interface to be available.\n",
         | 
| 113 | 
            +
                    "i18n_path = xdsame + 'i18n.py'\n",
         | 
| 114 | 
            +
                    "i18n_new_path = xdsame + 'utils/i18n.py'\n",
         | 
| 115 | 
            +
                    "try:\n",
         | 
| 116 | 
            +
                    "    if os.path.exists(i18n_path) and os.path.exists(i18n_new_path):\n",
         | 
| 117 | 
            +
                    "        shutil.move(i18n_new_path, i18n_path)\n",
         | 
| 118 | 
            +
                    "\n",
         | 
| 119 | 
            +
                    "    SelectedLanguage = \"en_US\" #@param [\"es_ES\", \"en_US\", \"zh_CN\", \"ar_AR\", \"id_ID\", \"pt_PT\", \"ru_RU\", \"ur_UR\", \"tr_TR\", \"it_IT\", \"de_DE\"]\n",
         | 
| 120 | 
            +
                    "    new_language_line = '            language = \"' + SelectedLanguage + '\"\\n'\n",
         | 
| 121 | 
            +
                    "#@markdown <a href=\"https://discord.gg/iahispano\"><font>If you need more help, feel free to join our official Discord server!</font></a>\n",
         | 
| 122 | 
            +
                    "    with open(i18n_path, 'r') as file:\n",
         | 
| 123 | 
            +
                    "        lines = file.readlines()\n",
         | 
| 124 | 
            +
                    "\n",
         | 
| 125 | 
            +
                    "    with open(i18n_path, 'w') as file:\n",
         | 
| 126 | 
            +
                    "        for index, line in enumerate(lines):\n",
         | 
| 127 | 
            +
                    "            if index == 14:\n",
         | 
| 128 | 
            +
                    "                file.write(new_language_line)\n",
         | 
| 129 | 
            +
                    "            else:\n",
         | 
| 130 | 
            +
                    "                file.write(line)\n",
         | 
| 131 | 
            +
                    "\n",
         | 
| 132 | 
            +
                    "except FileNotFoundError:\n",
         | 
| 133 | 
            +
                    "    print(\"Translation couldn't be applied successfully. Please restart the environment and run the cell again.\")\n",
         | 
| 134 | 
            +
                    "\n",
         | 
| 135 | 
            +
                    "def start_web_server():\n",
         | 
| 136 | 
            +
                    "    %cd /content/$tryre-based-Voice-Conversion-$guebui$guebui2\n",
         | 
| 137 | 
            +
                    "    %load_ext tensorboard\n",
         | 
| 138 | 
            +
                    "    clear_output()\n",
         | 
| 139 | 
            +
                    "    %tensorboard --logdir /content/$tryre-based-Voice-Conversion-$guebui$guebui2/logs\n",
         | 
| 140 | 
            +
                    "    !mkdir -p /content/$tryre-based-Voice-Conversion-$guebui$guebui2/audios\n",
         | 
| 141 | 
            +
                    "    display(HTML(collapsible_section))\n",
         | 
| 142 | 
            +
                    "    !python3 infer-web.py --colab --pycmd python3\n",
         | 
| 143 | 
            +
                    "\n",
         | 
| 144 | 
            +
                    "if AutoBackups:\n",
         | 
| 145 | 
            +
                    "  web_server_thread = threading.Thread(target=start_web_server)\n",
         | 
| 146 | 
            +
                    "  web_server_thread.start()\n",
         | 
| 147 | 
            +
                    "  backups.backup_files()\n",
         | 
| 148 | 
            +
                    "\n",
         | 
| 149 | 
            +
                    "else:\n",
         | 
| 150 | 
            +
                    "  start_web_server()"
         | 
| 151 | 
            +
                  ]
         | 
| 152 | 
            +
                }
         | 
| 153 | 
            +
              ],
         | 
| 154 | 
            +
              "metadata": {
         | 
| 155 | 
            +
                "accelerator": "GPU",
         | 
| 156 | 
            +
                "colab": {
         | 
| 157 | 
            +
                  "provenance": []
         | 
| 158 | 
            +
                },
         | 
| 159 | 
            +
                "kernelspec": {
         | 
| 160 | 
            +
                  "display_name": "Python 3",
         | 
| 161 | 
            +
                  "name": "python3"
         | 
| 162 | 
            +
                },
         | 
| 163 | 
            +
                "language_info": {
         | 
| 164 | 
            +
                  "name": "python"
         | 
| 165 | 
            +
                }
         | 
| 166 | 
            +
              },
         | 
| 167 | 
            +
              "nbformat": 4,
         | 
| 168 | 
            +
              "nbformat_minor": 0
         | 
| 169 | 
            +
            }
         | 
    	
        Dockerfile
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # syntax=docker/dockerfile:1
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            FROM python:3.10-bullseye
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            EXPOSE 7865
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            WORKDIR /app
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            COPY . .
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            RUN apt update && apt install -y -qq ffmpeg aria2 && apt clean
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            RUN pip3 install --no-cache-dir -r requirements.txt
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D40k.pth -d assets/pretrained_v2/ -o D40k.pth
         | 
| 16 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G40k.pth -d assets/pretrained_v2/ -o G40k.pth
         | 
| 17 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D40k.pth -d assets/pretrained_v2/ -o f0D40k.pth
         | 
| 18 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G40k.pth -d assets/pretrained_v2/ -o f0G40k.pth
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2-人声vocals+非人声instrumentals.pth -d assets/uvr5_weights/ -o HP2-人声vocals+非人声instrumentals.pth
         | 
| 21 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5-主旋律人声vocals+其他instrumentals.pth -d assets/uvr5_weights/ -o HP5-主旋律人声vocals+其他instrumentals.pth
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d assets/hubert -o hubert_base.pt
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            RUN aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/rmvpe.pt -d assets/hubert -o rmvpe.pt
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            VOLUME [ "/app/weights", "/app/opt" ]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            CMD ["python3", "infer-web.py"]
         | 
    	
        Fixes/local_fixes.py
    ADDED
    
    | @@ -0,0 +1,136 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import time
         | 
| 4 | 
            +
            import shutil
         | 
| 5 | 
            +
            import requests
         | 
| 6 | 
            +
            import zipfile
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def insert_new_line(file_name, line_to_find, text_to_insert):
         | 
| 9 | 
            +
                lines = []
         | 
| 10 | 
            +
                with open(file_name, 'r', encoding='utf-8') as read_obj:
         | 
| 11 | 
            +
                    lines = read_obj.readlines()
         | 
| 12 | 
            +
                already_exists = False
         | 
| 13 | 
            +
                with open(file_name + '.tmp', 'w', encoding='utf-8') as write_obj:
         | 
| 14 | 
            +
                    for i in range(len(lines)):
         | 
| 15 | 
            +
                        write_obj.write(lines[i])
         | 
| 16 | 
            +
                        if lines[i].strip() == line_to_find:
         | 
| 17 | 
            +
                            # If next line exists and starts with sys.path.append, skip
         | 
| 18 | 
            +
                            if i+1 < len(lines) and lines[i+1].strip().startswith("sys.path.append"):
         | 
| 19 | 
            +
                                print('It was already fixed! Skip adding a line...')
         | 
| 20 | 
            +
                                already_exists = True
         | 
| 21 | 
            +
                                break
         | 
| 22 | 
            +
                            else:
         | 
| 23 | 
            +
                                write_obj.write(text_to_insert + '\n')
         | 
| 24 | 
            +
                # If no existing sys.path.append line was found, replace the original file
         | 
| 25 | 
            +
                if not already_exists:
         | 
| 26 | 
            +
                    os.replace(file_name + '.tmp', file_name)
         | 
| 27 | 
            +
                    return True
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                    # If existing line was found, delete temporary file
         | 
| 30 | 
            +
                    os.remove(file_name + '.tmp')
         | 
| 31 | 
            +
                    return False
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def replace_in_file(file_name, old_text, new_text):
         | 
| 34 | 
            +
                with open(file_name, 'r', encoding='utf-8') as file:
         | 
| 35 | 
            +
                    file_contents = file.read()
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                if old_text in file_contents:
         | 
| 38 | 
            +
                    file_contents = file_contents.replace(old_text, new_text)
         | 
| 39 | 
            +
                    with open(file_name, 'w', encoding='utf-8') as file:
         | 
| 40 | 
            +
                        file.write(file_contents)
         | 
| 41 | 
            +
                        return True
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                return False
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            if __name__ == "__main__":
         | 
| 46 | 
            +
                current_path = os.getcwd()
         | 
| 47 | 
            +
                file_name = os.path.join(current_path, "infer", "modules", "train", "extract", "extract_f0_print.py")
         | 
| 48 | 
            +
                line_to_find = 'import numpy as np, logging'
         | 
| 49 | 
            +
                text_to_insert = "sys.path.append(r'" + current_path + "')"
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                success_1 = insert_new_line(file_name, line_to_find, text_to_insert)
         | 
| 53 | 
            +
                if success_1:
         | 
| 54 | 
            +
                    print('The first operation was successful!')
         | 
| 55 | 
            +
                else:
         | 
| 56 | 
            +
                    print('He skipped the first operation because it was already fixed!')
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                file_name = 'infer-web.py'
         | 
| 59 | 
            +
                old_text = 'with gr.Blocks(theme=gr.themes.Soft()) as app:'
         | 
| 60 | 
            +
                new_text = 'with gr.Blocks() as app:'
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                success_2 = replace_in_file(file_name, old_text, new_text)
         | 
| 63 | 
            +
                if success_2:
         | 
| 64 | 
            +
                    print('The second operation was successful!')
         | 
| 65 | 
            +
                else:
         | 
| 66 | 
            +
                    print('The second operation was omitted because it was already fixed!')
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                print('Local corrections successful! You should now be able to infer and train locally in Applio RVC Fork.')
         | 
| 69 | 
            +
                
         | 
| 70 | 
            +
                time.sleep(5)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def find_torchcrepe_directory(directory):
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                Recursively searches for the topmost folder named 'torchcrepe' within a directory.
         | 
| 75 | 
            +
                Returns the path of the directory found or None if none is found.
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
                for root, dirs, files in os.walk(directory):
         | 
| 78 | 
            +
                    if 'torchcrepe' in dirs:
         | 
| 79 | 
            +
                        return os.path.join(root, 'torchcrepe')
         | 
| 80 | 
            +
                return None
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            def download_and_extract_torchcrepe():
         | 
| 83 | 
            +
                url = 'https://github.com/maxrmorrison/torchcrepe/archive/refs/heads/master.zip'
         | 
| 84 | 
            +
                temp_dir = 'temp_torchcrepe'
         | 
| 85 | 
            +
                destination_dir = os.getcwd()
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                try:
         | 
| 88 | 
            +
                    torchcrepe_dir_path = os.path.join(destination_dir, 'torchcrepe')
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    if os.path.exists(torchcrepe_dir_path):
         | 
| 91 | 
            +
                        print("Skipping the torchcrepe download. The folder already exists.")
         | 
| 92 | 
            +
                        return
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # Download the file
         | 
| 95 | 
            +
                    print("Starting torchcrepe download...")
         | 
| 96 | 
            +
                    response = requests.get(url)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # Raise an error if the GET request was unsuccessful
         | 
| 99 | 
            +
                    response.raise_for_status()
         | 
| 100 | 
            +
                    print("Download completed.")
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    # Save the downloaded file
         | 
| 103 | 
            +
                    zip_file_path = os.path.join(temp_dir, 'master.zip')
         | 
| 104 | 
            +
                    os.makedirs(temp_dir, exist_ok=True)
         | 
| 105 | 
            +
                    with open(zip_file_path, 'wb') as file:
         | 
| 106 | 
            +
                        file.write(response.content)
         | 
| 107 | 
            +
                    print(f"Zip file saved to {zip_file_path}")
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # Extract the zip file
         | 
| 110 | 
            +
                    print("Extracting content...")
         | 
| 111 | 
            +
                    with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
         | 
| 112 | 
            +
                        zip_file.extractall(temp_dir)
         | 
| 113 | 
            +
                    print("Extraction completed.")
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # Locate the torchcrepe folder and move it to the destination directory
         | 
| 116 | 
            +
                    torchcrepe_dir = find_torchcrepe_directory(temp_dir)
         | 
| 117 | 
            +
                    if torchcrepe_dir:
         | 
| 118 | 
            +
                        shutil.move(torchcrepe_dir, destination_dir)
         | 
| 119 | 
            +
                        print(f"Moved the torchcrepe directory to {destination_dir}!")
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        print("The torchcrepe directory could not be located.")
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                except Exception as e:
         | 
| 124 | 
            +
                    print("Torchcrepe not successfully downloaded", e)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                # Clean up temporary directory
         | 
| 127 | 
            +
                if os.path.exists(temp_dir):
         | 
| 128 | 
            +
                    shutil.rmtree(temp_dir)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            # Run the function
         | 
| 131 | 
            +
            download_and_extract_torchcrepe()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            temp_dir = 'temp_torchcrepe'
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            if os.path.exists(temp_dir):
         | 
| 136 | 
            +
                shutil.rmtree(temp_dir)
         | 
    	
        Fixes/tensor-launch.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import threading
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            from tensorboard import program
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            log_path = "logs"
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            if __name__ == "__main__":
         | 
| 9 | 
            +
                tb = program.TensorBoard()
         | 
| 10 | 
            +
                tb.configure(argv=[None, '--logdir', log_path])
         | 
| 11 | 
            +
                url = tb.launch()
         | 
| 12 | 
            +
                print(f'Tensorboard can be accessed at: {url}')
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                while True:
         | 
| 15 | 
            +
                    time.sleep(600)  # Keep the main thread running
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2023 liujing04
         | 
| 4 | 
            +
            Copyright (c) 2023 源文雨
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 7 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 8 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 9 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 10 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 11 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 14 | 
            +
            copies or substantial portions of the Software.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 17 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 18 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 19 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 20 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 21 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 22 | 
            +
            SOFTWARE.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            The licenses for related libraries are as follows:
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ContentVec
         | 
| 27 | 
            +
            https://github.com/auspicious3000/contentvec/blob/main/LICENSE
         | 
| 28 | 
            +
            MIT License
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            VITS
         | 
| 31 | 
            +
            https://github.com/jaywalnut310/vits/blob/main/LICENSE
         | 
| 32 | 
            +
            MIT License
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            HIFIGAN
         | 
| 35 | 
            +
            https://github.com/jik876/hifi-gan/blob/master/LICENSE
         | 
| 36 | 
            +
            MIT License
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            gradio
         | 
| 39 | 
            +
            https://github.com/gradio-app/gradio/blob/main/LICENSE
         | 
| 40 | 
            +
            Apache License 2.0
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            ffmpeg
         | 
| 43 | 
            +
            https://github.com/FFmpeg/FFmpeg/blob/master/COPYING.LGPLv3
         | 
| 44 | 
            +
            https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2021-02-28-12-32/ffmpeg-n4.3.2-160-gfbb9368226-win64-lgpl-4.3.zip
         | 
| 45 | 
            +
            LPGLv3 License
         | 
| 46 | 
            +
            MIT License
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            ultimatevocalremovergui
         | 
| 49 | 
            +
            https://github.com/Anjok07/ultimatevocalremovergui/blob/master/LICENSE
         | 
| 50 | 
            +
            https://github.com/yang123qwe/vocal_separation_by_uvr5
         | 
| 51 | 
            +
            MIT License
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            audio-slicer
         | 
| 54 | 
            +
            https://github.com/openvpi/audio-slicer/blob/main/LICENSE
         | 
| 55 | 
            +
            MIT License
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            PySimpleGUI
         | 
| 58 | 
            +
            https://github.com/PySimpleGUI/PySimpleGUI/blob/master/license.txt
         | 
| 59 | 
            +
            LPGLv3 License
         | 
    	
        LazyImport.py
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from importlib.util import find_spec, LazyLoader, module_from_spec
         | 
| 2 | 
            +
            from sys import modules
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def lazyload(name):
         | 
| 5 | 
            +
                if name in modules:
         | 
| 6 | 
            +
                    return modules[name]
         | 
| 7 | 
            +
                else:
         | 
| 8 | 
            +
                    spec = find_spec(name)
         | 
| 9 | 
            +
                    loader = LazyLoader(spec.loader)
         | 
| 10 | 
            +
                    module = module_from_spec(spec)
         | 
| 11 | 
            +
                    modules[name] = module
         | 
| 12 | 
            +
                    loader.exec_module(module)
         | 
| 13 | 
            +
                    return module
         | 
    	
        MDX-Net_Colab.ipynb
    ADDED
    
    | @@ -0,0 +1,524 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "cells": [
         | 
| 3 | 
            +
                {
         | 
| 4 | 
            +
                  "cell_type": "markdown",
         | 
| 5 | 
            +
                  "metadata": {
         | 
| 6 | 
            +
                    "id": "wX9xzLur4tus"
         | 
| 7 | 
            +
                  },
         | 
| 8 | 
            +
                  "source": [
         | 
| 9 | 
            +
                    "# MDX-Net Colab\n",
         | 
| 10 | 
            +
                    "<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
         | 
| 11 | 
            +
                    "  <img src=\"https://github.githubassets.com/pinned-octocat.svg\" alt=\"icon1\" style=\"margin-right:10px; height: 20px;\" width=\"1.5%\">\n",
         | 
| 12 | 
            +
                    "  <span>Trained models provided in this notebook are from <a href=\"https://github.com/Anjok07\">UVR-GUI</a>.</span>\n",
         | 
| 13 | 
            +
                    "</div>\n",
         | 
| 14 | 
            +
                    "<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
         | 
| 15 | 
            +
                    "  <img src=\"https://github.com/Anjok07/ultimatevocalremovergui/raw/master/gui_data/img/GUI-Icon.ico\" alt=\"icon2\" style=\"margin-right:10px; height: 20px;margin-top:10px\" width=\"1.5%\">\n",
         | 
| 16 | 
            +
                    "  <span>OFFICIAL UVR GITHUB PAGE: <a href=\"https://github.com/Anjok07/ultimatevocalremovergui\">here</a>.</span>\n",
         | 
| 17 | 
            +
                    "</div>\n",
         | 
| 18 | 
            +
                    "<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
         | 
| 19 | 
            +
                    "  <img src=\"https://avatars.githubusercontent.com/u/24620594\" alt=\"icon3\" style=\"margin-right:10px; height: 20px;\" width=\"1.5%\">\n",
         | 
| 20 | 
            +
                    "  <span>OFFICIAL CLI Version: <a href=\"https://github.com/tsurumeso/vocal-remover\">here</a>.</span>\n",
         | 
| 21 | 
            +
                    "</div>\n",
         | 
| 22 | 
            +
                    "<div style=\"display:flex; align-items:center; font-size: 16px;\">\n",
         | 
| 23 | 
            +
                    "  <img src=\"https://icons.getbootstrap.com/assets/icons/discord.svg\" alt=\"icon4\" style=\"margin-right:10px; height: 20px;\" width=\"1.5%\">\n",
         | 
| 24 | 
            +
                    "  <span>Join our <a href=\"https://cutt.ly/0TcDjmo\">Discord server</a>!</span>\n",
         | 
| 25 | 
            +
                    "</div>\n",
         | 
| 26 | 
            +
                    "<sup><br>Ultimate Vocal Remover (unofficial)</sup>\n",
         | 
| 27 | 
            +
                    "<sup><br>MDX-Net by <a href=\"https://github.com/kuielab\">kuielab</a> and adapted for Colaboratory by <a href=\"https://www.youtube.com/channel/UC0NiSV1jLMH-9E09wiDVFYw\">AudioHacker</a>.</sup>\n",
         | 
| 28 | 
            +
                    "\n",
         | 
| 29 | 
            +
                    "<sup><br>Your support means a lot to me. If you enjoy my work, please consider buying me a ko-fi:<br></sup>\n",
         | 
| 30 | 
            +
                    "[](https://ko-fi.com/X8X6M8FR0)"
         | 
| 31 | 
            +
                  ]
         | 
| 32 | 
            +
                },
         | 
| 33 | 
            +
                {
         | 
| 34 | 
            +
                  "cell_type": "code",
         | 
| 35 | 
            +
                  "execution_count": null,
         | 
| 36 | 
            +
                  "metadata": {
         | 
| 37 | 
            +
                    "id": "3J69RV7G8ocb",
         | 
| 38 | 
            +
                    "cellView": "form"
         | 
| 39 | 
            +
                  },
         | 
| 40 | 
            +
                  "outputs": [],
         | 
| 41 | 
            +
                  "source": [
         | 
| 42 | 
            +
                    "import json\n",
         | 
| 43 | 
            +
                    "import os\n",
         | 
| 44 | 
            +
                    "import os.path\n",
         | 
| 45 | 
            +
                    "import gc\n",
         | 
| 46 | 
            +
                    "import psutil\n",
         | 
| 47 | 
            +
                    "import requests\n",
         | 
| 48 | 
            +
                    "import subprocess\n",
         | 
| 49 | 
            +
                    "import glob\n",
         | 
| 50 | 
            +
                    "import time\n",
         | 
| 51 | 
            +
                    "import logging\n",
         | 
| 52 | 
            +
                    "import sys\n",
         | 
| 53 | 
            +
                    "from bs4 import BeautifulSoup\n",
         | 
| 54 | 
            +
                    "from google.colab import drive, files, output\n",
         | 
| 55 | 
            +
                    "from IPython.display import Audio, display\n",
         | 
| 56 | 
            +
                    "\n",
         | 
| 57 | 
            +
                    "if \"first_cell_ran\" in locals():\n",
         | 
| 58 | 
            +
                    "    print(\"You've ran this cell for this session. No need to run it again.\\nif you think something went wrong or you want to change mounting path, restart the runtime.\")\n",
         | 
| 59 | 
            +
                    "else:\n",
         | 
| 60 | 
            +
                    "    print('Setting up... please wait around 1-2 minute(s).')\n",
         | 
| 61 | 
            +
                    "\n",
         | 
| 62 | 
            +
                    "    branch = \"https://github.com/NaJeongMo/Colab-for-MDX_B\"\n",
         | 
| 63 | 
            +
                    "\n",
         | 
| 64 | 
            +
                    "    model_params = \"https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data.json\"\n",
         | 
| 65 | 
            +
                    "    _Models = \"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/\"\n",
         | 
| 66 | 
            +
                    "    # _models = \"https://pastebin.com/raw/jBzYB8vz\"\n",
         | 
| 67 | 
            +
                    "    _models = \"https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json\"\n",
         | 
| 68 | 
            +
                    "    stem_naming = \"https://pastebin.com/raw/mpH4hRcF\"\n",
         | 
| 69 | 
            +
                    "    arl_check_endpoint = 'https://dz.doubledouble.top/check' # param: arl?=<>\n",
         | 
| 70 | 
            +
                    "\n",
         | 
| 71 | 
            +
                    "    file_folder = \"Colab-for-MDX_B\"\n",
         | 
| 72 | 
            +
                    "\n",
         | 
| 73 | 
            +
                    "    model_ids = requests.get(_models).json()\n",
         | 
| 74 | 
            +
                    "    model_ids = model_ids[\"mdx_download_list\"].values()\n",
         | 
| 75 | 
            +
                    "\n",
         | 
| 76 | 
            +
                    "    model_params = requests.get(model_params).json()\n",
         | 
| 77 | 
            +
                    "    stem_naming = requests.get(stem_naming).json()\n",
         | 
| 78 | 
            +
                    "\n",
         | 
| 79 | 
            +
                    "    os.makedirs(\"tmp_models\", exist_ok=True)\n",
         | 
| 80 | 
            +
                    "\n",
         | 
| 81 | 
            +
                    "    # @markdown If you don't wish to mount google drive, uncheck this box.\n",
         | 
| 82 | 
            +
                    "    MountDrive = True  # @param{type:\"boolean\"}\n",
         | 
| 83 | 
            +
                    "    # @markdown The path for the drive to be mounted: Please be cautious when modifying this as it can cause issues if not done properly.\n",
         | 
| 84 | 
            +
                    "    mounting_path = \"/content/drive/MyDrive\"  # @param [\"snippets:\",\"/content/drive/MyDrive\",\"/content/drive/Shareddrives/<your shared drive name>\", \"/content/drive/Shareddrives/Shared Drive\"]{allow-input: true}\n",
         | 
| 85 | 
            +
                    "    # @markdown Force update and disregard local changes: discards all local modifications in your repository, effectively replacing all files with the versions from the original commit.\n",
         | 
| 86 | 
            +
                    "    force_update = False  # @param{type:\"boolean\"}\n",
         | 
| 87 | 
            +
                    "    # @markdown Auto Update (does not discard your changes)\n",
         | 
| 88 | 
            +
                    "    auto_update = True  # @param{type:\"boolean\"}\n",
         | 
| 89 | 
            +
                    "\n",
         | 
| 90 | 
            +
                    "\n",
         | 
| 91 | 
            +
                    "    reqs_apt = []  # !sudo apt-get install\n",
         | 
| 92 | 
            +
                    "    reqs_pip = [\"librosa>=0.6.3,<0.9\", \"onnxruntime_gpu\", \"deemix\", \"yt_dlp\"]  # pip3 install\n",
         | 
| 93 | 
            +
                    "\n",
         | 
| 94 | 
            +
                    "    class hide_opt:  # hide outputs\n",
         | 
| 95 | 
            +
                    "        def __enter__(self):\n",
         | 
| 96 | 
            +
                    "            self._original_stdout = sys.stdout\n",
         | 
| 97 | 
            +
                    "            sys.stdout = open(os.devnull, \"w\")\n",
         | 
| 98 | 
            +
                    "\n",
         | 
| 99 | 
            +
                    "        def __exit__(self, exc_type, exc_val, exc_tb):\n",
         | 
| 100 | 
            +
                    "            sys.stdout.close()\n",
         | 
| 101 | 
            +
                    "            sys.stdout = self._original_stdout\n",
         | 
| 102 | 
            +
                    "\n",
         | 
| 103 | 
            +
                    "    def get_size(bytes, suffix=\"B\"):  # read ram\n",
         | 
| 104 | 
            +
                    "        global svmem\n",
         | 
| 105 | 
            +
                    "        factor = 1024\n",
         | 
| 106 | 
            +
                    "        for unit in [\"\", \"K\", \"M\", \"G\", \"T\", \"P\"]:\n",
         | 
| 107 | 
            +
                    "            if bytes < factor:\n",
         | 
| 108 | 
            +
                    "                return f\"{bytes:.2f}{unit}{suffix}\"\n",
         | 
| 109 | 
            +
                    "            bytes /= factor\n",
         | 
| 110 | 
            +
                    "        svmem = psutil.virtual_memory()\n",
         | 
| 111 | 
            +
                    "\n",
         | 
| 112 | 
            +
                    "\n",
         | 
| 113 | 
            +
                    "    print('installing requirements...',end=' ')\n",
         | 
| 114 | 
            +
                    "    with hide_opt():\n",
         | 
| 115 | 
            +
                    "        for x in reqs_apt:\n",
         | 
| 116 | 
            +
                    "            subprocess.run([\"sudo\", \"apt-get\", \"install\", x])\n",
         | 
| 117 | 
            +
                    "        for x in reqs_pip:\n",
         | 
| 118 | 
            +
                    "            subprocess.run([\"python3\", \"-m\", \"pip\", \"install\", x])\n",
         | 
| 119 | 
            +
                    "    print('done')\n",
         | 
| 120 | 
            +
                    "\n",
         | 
| 121 | 
            +
                    "    def install_or_mount_drive():\n",
         | 
| 122 | 
            +
                    "        print(\n",
         | 
| 123 | 
            +
                    "            \"Please log in to your account by following the prompts in the pop-up tab.\\nThis step is necessary to install the files to your Google Drive.\\nIf you have any concerns about the safety of this notebook, you can choose not to mount your drive by unchecking the \\\"MountDrive\\\" checkbox.\"\n",
         | 
| 124 | 
            +
                    "        )\n",
         | 
| 125 | 
            +
                    "        drive.mount(\"/content/drive\", force_remount=True)\n",
         | 
| 126 | 
            +
                    "        os.chdir(mounting_path)\n",
         | 
| 127 | 
            +
                    "        # check if previous installation is done\n",
         | 
| 128 | 
            +
                    "        if os.path.exists(os.path.join(mounting_path, file_folder)):\n",
         | 
| 129 | 
            +
                    "            # update checking\n",
         | 
| 130 | 
            +
                    "            os.chdir(file_folder)\n",
         | 
| 131 | 
            +
                    "\n",
         | 
| 132 | 
            +
                    "            if force_update:\n",
         | 
| 133 | 
            +
                    "                print('Force updating...')\n",
         | 
| 134 | 
            +
                    "\n",
         | 
| 135 | 
            +
                    "                commands = [\n",
         | 
| 136 | 
            +
                    "                    [\"git\", \"pull\"],\n",
         | 
| 137 | 
            +
                    "                    [\"git\", \"checkout\", \"--\", \".\"],\n",
         | 
| 138 | 
            +
                    "                ]\n",
         | 
| 139 | 
            +
                    "\n",
         | 
| 140 | 
            +
                    "                for cmd in commands:\n",
         | 
| 141 | 
            +
                    "                    subprocess.run(cmd)\n",
         | 
| 142 | 
            +
                    "\n",
         | 
| 143 | 
            +
                    "            elif auto_update:\n",
         | 
| 144 | 
            +
                    "                print('Checking for updates...')\n",
         | 
| 145 | 
            +
                    "                commands = [\n",
         | 
| 146 | 
            +
                    "                    [\"git\", \"pull\"],\n",
         | 
| 147 | 
            +
                    "                ]\n",
         | 
| 148 | 
            +
                    "\n",
         | 
| 149 | 
            +
                    "                for cmd in commands:\n",
         | 
| 150 | 
            +
                    "                    subprocess.run(cmd)\n",
         | 
| 151 | 
            +
                    "        else:\n",
         | 
| 152 | 
            +
                    "            subprocess.run([\"git\", \"clone\", \"https://github.com/NaJeongMo/Colab-for-MDX_B.git\"])\n",
         | 
| 153 | 
            +
                    "            os.chdir(file_folder)\n",
         | 
| 154 | 
            +
                    "\n",
         | 
| 155 | 
            +
                    "    def use_uvr_without_saving():\n",
         | 
| 156 | 
            +
                    "        global mounting_path\n",
         | 
| 157 | 
            +
                    "        print(\"Notice: files won't be saved to personal drive.\")\n",
         | 
| 158 | 
            +
                    "        print(f\"Downloading {file_folder}...\", end=\" \")\n",
         | 
| 159 | 
            +
                    "        mounting_path = \"/content\"\n",
         | 
| 160 | 
            +
                    "        with hide_opt():\n",
         | 
| 161 | 
            +
                    "            os.chdir(mounting_path)\n",
         | 
| 162 | 
            +
                    "            subprocess.run([\"git\", \"clone\", \"https://github.com/NaJeongMo/Colab-for-MDX_B.git\"])\n",
         | 
| 163 | 
            +
                    "            os.chdir(file_folder)\n",
         | 
| 164 | 
            +
                    "\n",
         | 
| 165 | 
            +
                    "    if MountDrive:\n",
         | 
| 166 | 
            +
                    "        install_or_mount_drive()\n",
         | 
| 167 | 
            +
                    "    else:\n",
         | 
| 168 | 
            +
                    "        use_uvr_without_saving()\n",
         | 
| 169 | 
            +
                    "    print(\"done!\")\n",
         | 
| 170 | 
            +
                    "    if not os.path.exists(\"tracks\"):\n",
         | 
| 171 | 
            +
                    "        os.mkdir(\"tracks\")\n",
         | 
| 172 | 
            +
                    "\n",
         | 
| 173 | 
            +
                    "    print('Importing required libraries...',end=' ')\n",
         | 
| 174 | 
            +
                    "\n",
         | 
| 175 | 
            +
                    "    import os\n",
         | 
| 176 | 
            +
                    "    import mdx\n",
         | 
| 177 | 
            +
                    "    import librosa\n",
         | 
| 178 | 
            +
                    "    import torch\n",
         | 
| 179 | 
            +
                    "    import soundfile as sf\n",
         | 
| 180 | 
            +
                    "    import numpy as np\n",
         | 
| 181 | 
            +
                    "    import yt_dlp\n",
         | 
| 182 | 
            +
                    "\n",
         | 
| 183 | 
            +
                    "    from deezer import Deezer\n",
         | 
| 184 | 
            +
                    "    from deezer import TrackFormats\n",
         | 
| 185 | 
            +
                    "    import deemix\n",
         | 
| 186 | 
            +
                    "    from deemix.settings import load as loadSettings\n",
         | 
| 187 | 
            +
                    "    from deemix.downloader import Downloader\n",
         | 
| 188 | 
            +
                    "    from deemix import generateDownloadObject\n",
         | 
| 189 | 
            +
                    "\n",
         | 
| 190 | 
            +
                    "    logger = logging.getLogger(\"yt_dlp\")\n",
         | 
| 191 | 
            +
                    "    logger.setLevel(logging.ERROR)\n",
         | 
| 192 | 
            +
                    "\n",
         | 
| 193 | 
            +
                    "    def id_to_ptm(mkey):\n",
         | 
| 194 | 
            +
                    "        if mkey in model_ids:\n",
         | 
| 195 | 
            +
                    "            mpath = f\"/content/tmp_models/{mkey}\"\n",
         | 
| 196 | 
            +
                    "            if not os.path.exists(f'/content/tmp_models/{mkey}'):\n",
         | 
| 197 | 
            +
                    "                print('Downloading model...',end=' ')\n",
         | 
| 198 | 
            +
                    "                subprocess.run(\n",
         | 
| 199 | 
            +
                    "                    [\"wget\", _Models+mkey, \"-O\", mpath]\n",
         | 
| 200 | 
            +
                    "                )\n",
         | 
| 201 | 
            +
                    "                print(f'saved to {mpath}')\n",
         | 
| 202 | 
            +
                    "                # get_ipython().system(f'gdown {model_id} -O /content/tmp_models/{mkey}')\n",
         | 
| 203 | 
            +
                    "                return mpath\n",
         | 
| 204 | 
            +
                    "            else:\n",
         | 
| 205 | 
            +
                    "                return mpath\n",
         | 
| 206 | 
            +
                    "        else:\n",
         | 
| 207 | 
            +
                    "            mpath = f'models/{mkey}'\n",
         | 
| 208 | 
            +
                    "            return mpath\n",
         | 
| 209 | 
            +
                    "\n",
         | 
| 210 | 
            +
                    "    def prepare_mdx(custom_param=False, dim_f=None, dim_t=None, n_fft=None, stem_name=None, compensation=None):\n",
         | 
| 211 | 
            +
                    "        device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n",
         | 
| 212 | 
            +
                    "        if custom_param:\n",
         | 
| 213 | 
            +
                    "            assert not (dim_f is None or dim_t is None or n_fft is None or compensation is None), 'Custom parameter selected, but incomplete parameters are provided.'\n",
         | 
| 214 | 
            +
                    "            mdx_model = mdx.MDX_Model(\n",
         | 
| 215 | 
            +
                    "                device,\n",
         | 
| 216 | 
            +
                    "                dim_f = dim_f,\n",
         | 
| 217 | 
            +
                    "                dim_t = dim_t,\n",
         | 
| 218 | 
            +
                    "                n_fft = n_fft,\n",
         | 
| 219 | 
            +
                    "                stem_name=stem_name,\n",
         | 
| 220 | 
            +
                    "                compensation=compensation\n",
         | 
| 221 | 
            +
                    "            )\n",
         | 
| 222 | 
            +
                    "        else:\n",
         | 
| 223 | 
            +
                    "            model_hash = mdx.MDX.get_hash(onnx)\n",
         | 
| 224 | 
            +
                    "            if model_hash in model_params:\n",
         | 
| 225 | 
            +
                    "                mp = model_params.get(model_hash)\n",
         | 
| 226 | 
            +
                    "                mdx_model = mdx.MDX_Model(\n",
         | 
| 227 | 
            +
                    "                    device,\n",
         | 
| 228 | 
            +
                    "                    dim_f = mp[\"mdx_dim_f_set\"],\n",
         | 
| 229 | 
            +
                    "                    dim_t = 2**mp[\"mdx_dim_t_set\"],\n",
         | 
| 230 | 
            +
                    "                    n_fft = mp[\"mdx_n_fft_scale_set\"],\n",
         | 
| 231 | 
            +
                    "                    stem_name=mp[\"primary_stem\"],\n",
         | 
| 232 | 
            +
                    "                    compensation=compensation if not custom_param and compensation is not None else mp[\"compensate\"]\n",
         | 
| 233 | 
            +
                    "                )\n",
         | 
| 234 | 
            +
                    "        return mdx_model\n",
         | 
| 235 | 
            +
                    "\n",
         | 
| 236 | 
            +
                    "    def run_mdx(onnx, mdx_model,filename,diff=False,suffix=None,diff_suffix=None, denoise=False, m_threads=1):\n",
         | 
| 237 | 
            +
                    "        mdx_sess = mdx.MDX(onnx,mdx_model)\n",
         | 
| 238 | 
            +
                    "        print(f\"Processing: {filename}\")\n",
         | 
| 239 | 
            +
                    "        wave, sr = librosa.load(filename,mono=False, sr=44100)\n",
         | 
| 240 | 
            +
                    "        # normalizing input wave gives better output\n",
         | 
| 241 | 
            +
                    "        peak = max(np.max(wave), abs(np.min(wave)))\n",
         | 
| 242 | 
            +
                    "        wave /= peak\n",
         | 
| 243 | 
            +
                    "        if denoise:\n",
         | 
| 244 | 
            +
                    "            wave_processed = -(mdx_sess.process_wave(-wave, m_threads)) + (mdx_sess.process_wave(wave, m_threads))\n",
         | 
| 245 | 
            +
                    "            wave_processed *= 0.5\n",
         | 
| 246 | 
            +
                    "        else:\n",
         | 
| 247 | 
            +
                    "            wave_processed = mdx_sess.process_wave(wave, m_threads)\n",
         | 
| 248 | 
            +
                    "        # return to previous peak\n",
         | 
| 249 | 
            +
                    "        wave_processed *= peak\n",
         | 
| 250 | 
            +
                    "\n",
         | 
| 251 | 
            +
                    "        stem_name = mdx_model.stem_name if suffix is None else suffix # use suffix if provided\n",
         | 
| 252 | 
            +
                    "        save_path = f\"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav\"\n",
         | 
| 253 | 
            +
                    "        save_path = os.path.join(\n",
         | 
| 254 | 
            +
                    "                'separated',\n",
         | 
| 255 | 
            +
                    "                save_path\n",
         | 
| 256 | 
            +
                    "            )\n",
         | 
| 257 | 
            +
                    "        sf.write(\n",
         | 
| 258 | 
            +
                    "            save_path,\n",
         | 
| 259 | 
            +
                    "            wave_processed.T,\n",
         | 
| 260 | 
            +
                    "            sr\n",
         | 
| 261 | 
            +
                    "        )\n",
         | 
| 262 | 
            +
                    "\n",
         | 
| 263 | 
            +
                    "        print(f'done, saved to: {save_path}')\n",
         | 
| 264 | 
            +
                    "\n",
         | 
| 265 | 
            +
                    "        if diff:\n",
         | 
| 266 | 
            +
                    "            diff_stem_name = stem_naming.get(stem_name) if diff_suffix is None else diff_suffix # use suffix if provided\n",
         | 
| 267 | 
            +
                    "            stem_name = f\"{stem_name}_diff\" if diff_stem_name is None else diff_stem_name\n",
         | 
| 268 | 
            +
                    "            save_path = f\"{os.path.basename(os.path.splitext(filename)[0])}_{stem_name}.wav\"\n",
         | 
| 269 | 
            +
                    "            save_path = os.path.join(\n",
         | 
| 270 | 
            +
                    "                    'separated',\n",
         | 
| 271 | 
            +
                    "                    save_path\n",
         | 
| 272 | 
            +
                    "                )\n",
         | 
| 273 | 
            +
                    "            sf.write(\n",
         | 
| 274 | 
            +
                    "                save_path,\n",
         | 
| 275 | 
            +
                    "                (-wave_processed.T*mdx_model.compensation)+wave.T,\n",
         | 
| 276 | 
            +
                    "                sr\n",
         | 
| 277 | 
            +
                    "            )\n",
         | 
| 278 | 
            +
                    "            print(f'invert done, saved to: {save_path}')\n",
         | 
| 279 | 
            +
                    "        del mdx_sess, wave_processed, wave\n",
         | 
| 280 | 
            +
                    "        gc.collect()\n",
         | 
| 281 | 
            +
                    "\n",
         | 
| 282 | 
            +
                    "    def is_valid_url(url):\n",
         | 
| 283 | 
            +
                    "        import re\n",
         | 
| 284 | 
            +
                    "        regex = re.compile(\n",
         | 
| 285 | 
            +
                    "            r'^https?://'\n",
         | 
| 286 | 
            +
                    "            r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\\.)+[A-Z]{2,6}\\.?|'\n",
         | 
| 287 | 
            +
                    "            r'localhost|'\n",
         | 
| 288 | 
            +
                    "            r'\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3})'\n",
         | 
| 289 | 
            +
                    "            r'(?::\\d+)?'\n",
         | 
| 290 | 
            +
                    "            r'(?:/?|[/?]\\S+)$', re.IGNORECASE)\n",
         | 
| 291 | 
            +
                    "        return url is not None and regex.search(url)\n",
         | 
| 292 | 
            +
                    "\n",
         | 
| 293 | 
            +
                    "    def download_deezer(link, arl, fmt='FLAC'):\n",
         | 
| 294 | 
            +
                    "        match fmt:\n",
         | 
| 295 | 
            +
                    "            case 'FLAC':\n",
         | 
| 296 | 
            +
                    "                bitrate = TrackFormats.FLAC\n",
         | 
| 297 | 
            +
                    "            case 'MP3_320':\n",
         | 
| 298 | 
            +
                    "                bitrate = TrackFormats.MP3_320\n",
         | 
| 299 | 
            +
                    "            case 'MP3_128':\n",
         | 
| 300 | 
            +
                    "                bitrate = TrackFormats.MP3_128\n",
         | 
| 301 | 
            +
                    "            case _:\n",
         | 
| 302 | 
            +
                    "                bitrate = TrackFormats.MP3_128\n",
         | 
| 303 | 
            +
                    "\n",
         | 
| 304 | 
            +
                    "        dz = Deezer()\n",
         | 
| 305 | 
            +
                    "        settings = loadSettings('dz_config')\n",
         | 
| 306 | 
            +
                    "        settings['downloadLocation'] = './tracks'\n",
         | 
| 307 | 
            +
                    "        if not dz.login_via_arl(arl.strip()):\n",
         | 
| 308 | 
            +
                    "            raise Exception('Error while logging in with provided ARL.')\n",
         | 
| 309 | 
            +
                    "        downloadObject = generateDownloadObject(dz, link, bitrate)\n",
         | 
| 310 | 
            +
                    "        print(f'Downloading {downloadObject.type}: \"{downloadObject.title}\" by {downloadObject.artist}...',end=' ',flush=True)\n",
         | 
| 311 | 
            +
                    "        Downloader(dz, downloadObject, settings).start()\n",
         | 
| 312 | 
            +
                    "        print(f'done.')\n",
         | 
| 313 | 
            +
                    "\n",
         | 
| 314 | 
            +
                    "        path_to_audio = []\n",
         | 
| 315 | 
            +
                    "        for file in downloadObject.files:\n",
         | 
| 316 | 
            +
                    "            path_to_audio.append(file[\"path\"])\n",
         | 
| 317 | 
            +
                    "\n",
         | 
| 318 | 
            +
                    "        return path_to_audio\n",
         | 
| 319 | 
            +
                    "\n",
         | 
| 320 | 
            +
                    "    def download_link(url):\n",
         | 
| 321 | 
            +
                    "        ydl_opts = {\n",
         | 
| 322 | 
            +
                    "            'format': 'bestvideo+bestaudio/best',\n",
         | 
| 323 | 
            +
                    "            'outtmpl': '%(title)s.%(ext)s',\n",
         | 
| 324 | 
            +
                    "            'nocheckcertificate': True,\n",
         | 
| 325 | 
            +
                    "            'ignoreerrors': True,\n",
         | 
| 326 | 
            +
                    "            'no_warnings': True,\n",
         | 
| 327 | 
            +
                    "            'extractaudio': True,\n",
         | 
| 328 | 
            +
                    "        }\n",
         | 
| 329 | 
            +
                    "        with yt_dlp.YoutubeDL(ydl_opts) as ydl:\n",
         | 
| 330 | 
            +
                    "            result = ydl.extract_info(url, download=True)\n",
         | 
| 331 | 
            +
                    "            download_path = ydl.prepare_filename(result)\n",
         | 
| 332 | 
            +
                    "        return download_path\n",
         | 
| 333 | 
            +
                    "\n",
         | 
| 334 | 
            +
                    "    print('finished setting up!')\n",
         | 
| 335 | 
            +
                    "    first_cell_ran = True"
         | 
| 336 | 
            +
                  ]
         | 
| 337 | 
            +
                },
         | 
| 338 | 
            +
                {
         | 
| 339 | 
            +
                  "cell_type": "code",
         | 
| 340 | 
            +
                  "execution_count": null,
         | 
| 341 | 
            +
                  "metadata": {
         | 
| 342 | 
            +
                    "id": "4hd1TzEGCiRo",
         | 
| 343 | 
            +
                    "cellView": "form"
         | 
| 344 | 
            +
                  },
         | 
| 345 | 
            +
                  "outputs": [],
         | 
| 346 | 
            +
                  "source": [
         | 
| 347 | 
            +
                    "if 'first_cell_ran' in locals():\n",
         | 
| 348 | 
            +
                    "    os.chdir(mounting_path + '/' + file_folder + '/')\n",
         | 
| 349 | 
            +
                    "    #parameter markdowns-----------------\n",
         | 
| 350 | 
            +
                    "    #@markdown ### Input files\n",
         | 
| 351 | 
            +
                    "    #@markdown track filename: Upload your songs to the \"tracks\" folder. You may provide multiple links/files by spliting them with ;\n",
         | 
| 352 | 
            +
                    "    filename = \"https://deezer.com/album/281108671\" #@param {type:\"string\"}\n",
         | 
| 353 | 
            +
                    "    #@markdown onnx model (if you have your own model, upload it in models folder)\n",
         | 
| 354 | 
            +
                    "    onnx = \"UVR-MDX-NET-Inst_HQ_3.onnx\" #@param [\"Kim_Inst.onnx\", \"Kim_Vocal_1.onnx\", \"Kim_Vocal_2.onnx\", \"kuielab_a_bass.onnx\", \"kuielab_a_drums.onnx\", \"kuielab_a_other.onnx\", \"kuielab_a_vocals.onnx\", \"kuielab_b_bass.onnx\", \"kuielab_b_drums.onnx\", \"kuielab_b_other.onnx\", \"kuielab_b_vocals.onnx\", \"Reverb_HQ_By_FoxJoy.onnx\", \"UVR-MDX-NET-Inst_1.onnx\", \"UVR-MDX-NET-Inst_2.onnx\", \"UVR-MDX-NET-Inst_3.onnx\", \"UVR-MDX-NET-Inst_HQ_1.onnx\", \"UVR-MDX-NET-Inst_HQ_2.onnx\", \"UVR-MDX-NET-Inst_Main.onnx\", \"UVR_MDXNET_1_9703.onnx\", \"UVR_MDXNET_2_9682.onnx\", \"UVR_MDXNET_3_9662.onnx\", \"UVR_MDXNET_9482.onnx\", \"UVR_MDXNET_KARA.onnx\", \"UVR_MDXNET_KARA_2.onnx\", \"UVR_MDXNET_Main.onnx\", \"UVR-MDX-NET-Inst_HQ_3.onnx\", \"UVR-MDX-NET-Voc_FT.onnx\"]{allow-input: true}\n",
         | 
| 355 | 
            +
                    "    #@markdown process all: processes all tracks inside tracks/ folder instead. (filename will be ignored!)\n",
         | 
| 356 | 
            +
                    "    process_all = False  # @param{type:\"boolean\"}\n",
         | 
| 357 | 
            +
                    "\n",
         | 
| 358 | 
            +
                    "\n",
         | 
| 359 | 
            +
                    "    #@markdown ### Settings\n",
         | 
| 360 | 
            +
                    "    #@markdown invert: get difference between input and output (e.g get Instrumental out of Vocals)\n",
         | 
| 361 | 
            +
                    "    invert = True  # @param{type:\"boolean\"}\n",
         | 
| 362 | 
            +
                    "    #@markdown denoise: get rid of MDX noise. (This processes input track twice)\n",
         | 
| 363 | 
            +
                    "    denoise = True  # @param{type:\"boolean\"}\n",
         | 
| 364 | 
            +
                    "    #@markdown m_threads: like batch size, processes input wave in n threads. (beneficial for CPU)\n",
         | 
| 365 | 
            +
                    "    m_threads = 2 #@param {type:\"slider\", min:1, max:8, step:1}\n",
         | 
| 366 | 
            +
                    "\n",
         | 
| 367 | 
            +
                    "    #@markdown ### Custom model parameters (Only use this if you're using new/unofficial/custom models)\n",
         | 
| 368 | 
            +
                    "    #@markdown Use custom model parameters. (Default: unchecked, or auto)\n",
         | 
| 369 | 
            +
                    "    use_custom_parameter = False  # @param{type:\"boolean\"}\n",
         | 
| 370 | 
            +
                    "    #@markdown Output file suffix (usually the stem name e.g Vocals)\n",
         | 
| 371 | 
            +
                    "    suffix = \"Vocals_custom\" #@param [\"Vocals\", \"Drums\", \"Bass\", \"Other\"]{allow-input: true}\n",
         | 
| 372 | 
            +
                    "    suffix_invert = \"Instrumental_custom\" #@param [\"Instrumental\", \"Drumless\", \"Bassless\", \"Instruments\"]{allow-input: true}\n",
         | 
| 373 | 
            +
                    "    #@markdown Model parameters\n",
         | 
| 374 | 
            +
                    "    dim_f = 3072 #@param {type: \"integer\"}\n",
         | 
| 375 | 
            +
                    "    dim_t = 256 #@param {type: \"integer\"}\n",
         | 
| 376 | 
            +
                    "    n_fft = 6144 #@param {type: \"integer\"}\n",
         | 
| 377 | 
            +
                    "    #@markdown use custom compensation: only if you have your own compensation value for your model. this still apply even if you don't have use_custom_parameter checked (Default: unchecked, or auto)\n",
         | 
| 378 | 
            +
                    "    use_custom_compensation = False  # @param{type:\"boolean\"}\n",
         | 
| 379 | 
            +
                    "    compensation = 1.000 #@param {type: \"number\"}\n",
         | 
| 380 | 
            +
                    "\n",
         | 
| 381 | 
            +
                    "    #@markdown ### Extras\n",
         | 
| 382 | 
            +
                    "    #@markdown Deezer arl: paste your ARL here for deezer tracks directly!\n",
         | 
| 383 | 
            +
                    "    arl = \"\" #@param {type:\"string\"}\n",
         | 
| 384 | 
            +
                    "    #@markdown Track format: select track quality/format\n",
         | 
| 385 | 
            +
                    "    track_format = \"FLAC\" #@param [\"FLAC\",\"MP3_320\",\"MP3_128\"]\n",
         | 
| 386 | 
            +
                    "    #@markdown Print settings being used in the run\n",
         | 
| 387 | 
            +
                    "    print_settings = True  # @param{type:\"boolean\"}\n",
         | 
| 388 | 
            +
                    "\n",
         | 
| 389 | 
            +
                    "\n",
         | 
| 390 | 
            +
                    "\n",
         | 
| 391 | 
            +
                    "    onnx = id_to_ptm(onnx)\n",
         | 
| 392 | 
            +
                    "    compensation = compensation if use_custom_compensation or use_custom_parameter else None\n",
         | 
| 393 | 
            +
                    "    mdx_model = prepare_mdx(use_custom_parameter, dim_f, dim_t, n_fft, compensation=compensation)\n",
         | 
| 394 | 
            +
                    "\n",
         | 
| 395 | 
            +
                    "    filename_split = filename.split(';')\n",
         | 
| 396 | 
            +
                    "\n",
         | 
| 397 | 
            +
                    "    usable_files = []\n",
         | 
| 398 | 
            +
                    "\n",
         | 
| 399 | 
            +
                    "    if not process_all:\n",
         | 
| 400 | 
            +
                    "        for fn in filename_split:\n",
         | 
| 401 | 
            +
                    "            fn = fn.strip()\n",
         | 
| 402 | 
            +
                    "            if is_valid_url(fn):\n",
         | 
| 403 | 
            +
                    "                dm, ltype, lid = deemix.parseLink(fn)\n",
         | 
| 404 | 
            +
                    "                if ltype and lid:\n",
         | 
| 405 | 
            +
                    "                    usable_files += download_deezer(fn, arl, track_format)\n",
         | 
| 406 | 
            +
                    "                else:\n",
         | 
| 407 | 
            +
                    "                    print('downloading link...',end=' ')\n",
         | 
| 408 | 
            +
                    "                    usable_files+=[download_link(fn)]\n",
         | 
| 409 | 
            +
                    "                    print('done')\n",
         | 
| 410 | 
            +
                    "            else:\n",
         | 
| 411 | 
            +
                    "                usable_files.append(os.path.join('tracks',fn))\n",
         | 
| 412 | 
            +
                    "    else:\n",
         | 
| 413 | 
            +
                    "        for fn in glob.glob('tracks/*'):\n",
         | 
| 414 | 
            +
                    "            usable_files.append(fn)\n",
         | 
| 415 | 
            +
                    "    for filename in usable_files:\n",
         | 
| 416 | 
            +
                    "        suffix_naming = suffix if use_custom_parameter else None\n",
         | 
| 417 | 
            +
                    "        diff_suffix_naming = suffix_invert if use_custom_parameter else None\n",
         | 
| 418 | 
            +
                    "        run_mdx(onnx, mdx_model, filename, diff=invert,suffix=suffix_naming,diff_suffix=diff_suffix_naming,denoise=denoise)\n",
         | 
| 419 | 
            +
                    "\n",
         | 
| 420 | 
            +
                    "    if print_settings:\n",
         | 
| 421 | 
            +
                    "        print()\n",
         | 
| 422 | 
            +
                    "        print('[MDX-Net_Colab settings used]')\n",
         | 
| 423 | 
            +
                    "        print(f'Model used: {onnx}')\n",
         | 
| 424 | 
            +
                    "        print(f'Model MD5: {mdx.MDX.get_hash(onnx)}')\n",
         | 
| 425 | 
            +
                    "        print(f'Using de-noise: {denoise}')\n",
         | 
| 426 | 
            +
                    "        print(f'Model parameters:')\n",
         | 
| 427 | 
            +
                    "        print(f'    -dim_f: {mdx_model.dim_f}')\n",
         | 
| 428 | 
            +
                    "        print(f'    -dim_t: {mdx_model.dim_t}')\n",
         | 
| 429 | 
            +
                    "        print(f'    -n_fft: {mdx_model.n_fft}')\n",
         | 
| 430 | 
            +
                    "        print(f'    -compensation: {mdx_model.compensation}')\n",
         | 
| 431 | 
            +
                    "        print()\n",
         | 
| 432 | 
            +
                    "        print('[Input file]')\n",
         | 
| 433 | 
            +
                    "        print('filename(s): ')\n",
         | 
| 434 | 
            +
                    "        for filename in usable_files:\n",
         | 
| 435 | 
            +
                    "            print(f'    -{filename}')\n",
         | 
| 436 | 
            +
                    "\n",
         | 
| 437 | 
            +
                    "    del mdx_model"
         | 
| 438 | 
            +
                  ]
         | 
| 439 | 
            +
                },
         | 
| 440 | 
            +
                {
         | 
| 441 | 
            +
                  "cell_type": "markdown",
         | 
| 442 | 
            +
                  "source": [
         | 
| 443 | 
            +
                    "# Guide\n",
         | 
| 444 | 
            +
                    "\n",
         | 
| 445 | 
            +
                    "This tutorial guide will walk you through the steps to use the features of this Colab notebook.\n",
         | 
| 446 | 
            +
                    "\n",
         | 
| 447 | 
            +
                    "## Mount Drive\n",
         | 
| 448 | 
            +
                    "\n",
         | 
| 449 | 
            +
                    "To mount your Google Drive, follow these steps:\n",
         | 
| 450 | 
            +
                    "\n",
         | 
| 451 | 
            +
                    "1. Check the box next to \"MountDrive\" if you want to mount Google Drive.\n",
         | 
| 452 | 
            +
                    "2. Modify the \"mounting_path\" if you want to specify a different path for the drive to be mounted. **Note:** Be cautious when modifying this path as it can cause issues if not done properly.\n",
         | 
| 453 | 
            +
                    "3. Check the box next to \"Force update and disregard local changes\" if you want to discard all local modifications in your repository and replace the files with the versions from the original commit.\n",
         | 
| 454 | 
            +
                    "4. Check the box next to \"Auto Update\" if you want to automatically update without discarding your changes. Leave it unchecked if you want to manually update.\n",
         | 
| 455 | 
            +
                    "\n",
         | 
| 456 | 
            +
                    "## Input Files\n",
         | 
| 457 | 
            +
                    "\n",
         | 
| 458 | 
            +
                    "To upload your songs, follow these steps:\n",
         | 
| 459 | 
            +
                    "\n",
         | 
| 460 | 
            +
                    "1. Specify the \"track filename\" for your songs. You can provide multiple links or files by separating them with a semicolon (;).\n",
         | 
| 461 | 
            +
                    "2. Upload your songs to the \"tracks\" folder.\n",
         | 
| 462 | 
            +
                    "\n",
         | 
| 463 | 
            +
                    "## ONNX Model\n",
         | 
| 464 | 
            +
                    "\n",
         | 
| 465 | 
            +
                    "If you have your own ONNX model, follow these steps:\n",
         | 
| 466 | 
            +
                    "\n",
         | 
| 467 | 
            +
                    "1. Upload your model to the \"models\" folder.\n",
         | 
| 468 | 
            +
                    "2. Specify the \"onnx\" filename for your model.\n",
         | 
| 469 | 
            +
                    "\n",
         | 
| 470 | 
            +
                    "## Processing\n",
         | 
| 471 | 
            +
                    "\n",
         | 
| 472 | 
            +
                    "To process your tracks, follow these steps:\n",
         | 
| 473 | 
            +
                    "\n",
         | 
| 474 | 
            +
                    "1. If you want to process all tracks inside the \"tracks\" folder, check the box next to \"process_all\" and ignore the \"filename\" field.\n",
         | 
| 475 | 
            +
                    "2. Specify any additional settings you want:\n",
         | 
| 476 | 
            +
                    "   - Check the box next to \"invert\" to get the difference between input and output (e.g., get Instrumental out of Vocals).\n",
         | 
| 477 | 
            +
                    "   - Check the box next to \"denoise\" to get rid of MDX noise. This processes the input track twice.\n",
         | 
| 478 | 
            +
                    "   - Specify custom model parameters only if you're using new/unofficial/custom models. Use the \"use_custom_parameter\" checkbox to enable this feature.\n",
         | 
| 479 | 
            +
                    "   - Specify the output file suffix, which is usually the stem name (e.g., Vocals). Use the \"suffix\" field to specify the suffix for normal processing and the \"suffix_invert\" field for inverted processing.\n",
         | 
| 480 | 
            +
                    "\n",
         | 
| 481 | 
            +
                    "## Model Parameters\n",
         | 
| 482 | 
            +
                    "\n",
         | 
| 483 | 
            +
                    "Specify the following custom model parameters if applicable:\n",
         | 
| 484 | 
            +
                    "\n",
         | 
| 485 | 
            +
                    "- \"dim_f\": The value for the `dim_f` parameter.\n",
         | 
| 486 | 
            +
                    "- \"dim_t\": The value for the `dim_t` parameter.\n",
         | 
| 487 | 
            +
                    "- \"n_fft\": The value for the `n_fft` parameter.\n",
         | 
| 488 | 
            +
                    "- Check the box next to \"use_custom_compensation\" if you have your own compensation value for your model. Specify the compensation value in the \"compensation\" field.\n",
         | 
| 489 | 
            +
                    "\n",
         | 
| 490 | 
            +
                    "## Extras\n",
         | 
| 491 | 
            +
                    "\n",
         | 
| 492 | 
            +
                    "If you're working with Deezer tracks, paste your ARL (Authentication Request Library) in the \"arl\" field to directly access the tracks.\n",
         | 
| 493 | 
            +
                    "\n",
         | 
| 494 | 
            +
                    "Specify the \"Track format\" by selecting the desired quality/format for the track.\n",
         | 
| 495 | 
            +
                    "\n",
         | 
| 496 | 
            +
                    "To print the settings being used in the run, check the box next to \"print_settings\".\n",
         | 
| 497 | 
            +
                    "\n",
         | 
| 498 | 
            +
                    "That's it! You're now ready to use this Colab notebook. Enjoy!\n",
         | 
| 499 | 
            +
                    "\n",
         | 
| 500 | 
            +
                    "## For more detailed guide, proceed to this <a href=\"https://docs.google.com/document/d/17fjNvJzj8ZGSer7c7OFe_CNfUKbAxEh_OBv94ZdRG5c\">link</a>.\n",
         | 
| 501 | 
            +
                    "credits: (discord) deton24"
         | 
| 502 | 
            +
                  ],
         | 
| 503 | 
            +
                  "metadata": {
         | 
| 504 | 
            +
                    "id": "tMVwX5RhZSRP"
         | 
| 505 | 
            +
                  }
         | 
| 506 | 
            +
                }
         | 
| 507 | 
            +
              ],
         | 
| 508 | 
            +
              "metadata": {
         | 
| 509 | 
            +
                "accelerator": "GPU",
         | 
| 510 | 
            +
                "colab": {
         | 
| 511 | 
            +
                  "gpuType": "T4",
         | 
| 512 | 
            +
                  "provenance": []
         | 
| 513 | 
            +
                },
         | 
| 514 | 
            +
                "kernelspec": {
         | 
| 515 | 
            +
                  "display_name": "Python 3",
         | 
| 516 | 
            +
                  "name": "python3"
         | 
| 517 | 
            +
                },
         | 
| 518 | 
            +
                "language_info": {
         | 
| 519 | 
            +
                  "name": "python"
         | 
| 520 | 
            +
                }
         | 
| 521 | 
            +
              },
         | 
| 522 | 
            +
              "nbformat": 4,
         | 
| 523 | 
            +
              "nbformat_minor": 0
         | 
| 524 | 
            +
            }
         | 
    	
        MDXNet.py
    ADDED
    
    | @@ -0,0 +1,272 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import soundfile as sf
         | 
| 2 | 
            +
            import torch, pdb, os, warnings, librosa
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import onnxruntime as ort
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            dim_c = 4
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class Conv_TDF_net_trim:
         | 
| 12 | 
            +
                def __init__(
         | 
| 13 | 
            +
                    self, device, model_name, target_name, L, dim_f, dim_t, n_fft, hop=1024
         | 
| 14 | 
            +
                ):
         | 
| 15 | 
            +
                    super(Conv_TDF_net_trim, self).__init__()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    self.dim_f = dim_f
         | 
| 18 | 
            +
                    self.dim_t = 2**dim_t
         | 
| 19 | 
            +
                    self.n_fft = n_fft
         | 
| 20 | 
            +
                    self.hop = hop
         | 
| 21 | 
            +
                    self.n_bins = self.n_fft // 2 + 1
         | 
| 22 | 
            +
                    self.chunk_size = hop * (self.dim_t - 1)
         | 
| 23 | 
            +
                    self.window = torch.hann_window(window_length=self.n_fft, periodic=True).to(
         | 
| 24 | 
            +
                        device
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
                    self.target_name = target_name
         | 
| 27 | 
            +
                    self.blender = "blender" in model_name
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    out_c = dim_c * 4 if target_name == "*" else dim_c
         | 
| 30 | 
            +
                    self.freq_pad = torch.zeros(
         | 
| 31 | 
            +
                        [1, out_c, self.n_bins - self.dim_f, self.dim_t]
         | 
| 32 | 
            +
                    ).to(device)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    self.n = L // 2
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def stft(self, x):
         | 
| 37 | 
            +
                    x = x.reshape([-1, self.chunk_size])
         | 
| 38 | 
            +
                    x = torch.stft(
         | 
| 39 | 
            +
                        x,
         | 
| 40 | 
            +
                        n_fft=self.n_fft,
         | 
| 41 | 
            +
                        hop_length=self.hop,
         | 
| 42 | 
            +
                        window=self.window,
         | 
| 43 | 
            +
                        center=True,
         | 
| 44 | 
            +
                        return_complex=True,
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
                    x = torch.view_as_real(x)
         | 
| 47 | 
            +
                    x = x.permute([0, 3, 1, 2])
         | 
| 48 | 
            +
                    x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape(
         | 
| 49 | 
            +
                        [-1, dim_c, self.n_bins, self.dim_t]
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
                    return x[:, :, : self.dim_f]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def istft(self, x, freq_pad=None):
         | 
| 54 | 
            +
                    freq_pad = (
         | 
| 55 | 
            +
                        self.freq_pad.repeat([x.shape[0], 1, 1, 1])
         | 
| 56 | 
            +
                        if freq_pad is None
         | 
| 57 | 
            +
                        else freq_pad
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                    x = torch.cat([x, freq_pad], -2)
         | 
| 60 | 
            +
                    c = 4 * 2 if self.target_name == "*" else 2
         | 
| 61 | 
            +
                    x = x.reshape([-1, c, 2, self.n_bins, self.dim_t]).reshape(
         | 
| 62 | 
            +
                        [-1, 2, self.n_bins, self.dim_t]
         | 
| 63 | 
            +
                    )
         | 
| 64 | 
            +
                    x = x.permute([0, 2, 3, 1])
         | 
| 65 | 
            +
                    x = x.contiguous()
         | 
| 66 | 
            +
                    x = torch.view_as_complex(x)
         | 
| 67 | 
            +
                    x = torch.istft(
         | 
| 68 | 
            +
                        x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
                    return x.reshape([-1, c, self.chunk_size])
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def get_models(device, dim_f, dim_t, n_fft):
         | 
| 74 | 
            +
                return Conv_TDF_net_trim(
         | 
| 75 | 
            +
                    device=device,
         | 
| 76 | 
            +
                    model_name="Conv-TDF",
         | 
| 77 | 
            +
                    target_name="vocals",
         | 
| 78 | 
            +
                    L=11,
         | 
| 79 | 
            +
                    dim_f=dim_f,
         | 
| 80 | 
            +
                    dim_t=dim_t,
         | 
| 81 | 
            +
                    n_fft=n_fft,
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            warnings.filterwarnings("ignore")
         | 
| 86 | 
            +
            cpu = torch.device("cpu")
         | 
| 87 | 
            +
            if torch.cuda.is_available():
         | 
| 88 | 
            +
                device = torch.device("cuda:0")
         | 
| 89 | 
            +
            elif torch.backends.mps.is_available():
         | 
| 90 | 
            +
                device = torch.device("mps")
         | 
| 91 | 
            +
            else:
         | 
| 92 | 
            +
                device = torch.device("cpu")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            class Predictor:
         | 
| 96 | 
            +
                def __init__(self, args):
         | 
| 97 | 
            +
                    self.args = args
         | 
| 98 | 
            +
                    self.model_ = get_models(
         | 
| 99 | 
            +
                        device=cpu, dim_f=args.dim_f, dim_t=args.dim_t, n_fft=args.n_fft
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    self.model = ort.InferenceSession(
         | 
| 102 | 
            +
                        os.path.join(args.onnx, self.model_.target_name + ".onnx"),
         | 
| 103 | 
            +
                        providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
                    print("onnx load done")
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def demix(self, mix):
         | 
| 108 | 
            +
                    samples = mix.shape[-1]
         | 
| 109 | 
            +
                    margin = self.args.margin
         | 
| 110 | 
            +
                    chunk_size = self.args.chunks * 44100
         | 
| 111 | 
            +
                    assert not margin == 0, "margin cannot be zero!"
         | 
| 112 | 
            +
                    if margin > chunk_size:
         | 
| 113 | 
            +
                        margin = chunk_size
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    segmented_mix = {}
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    if self.args.chunks == 0 or samples < chunk_size:
         | 
| 118 | 
            +
                        chunk_size = samples
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    counter = -1
         | 
| 121 | 
            +
                    for skip in range(0, samples, chunk_size):
         | 
| 122 | 
            +
                        counter += 1
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        s_margin = 0 if counter == 0 else margin
         | 
| 125 | 
            +
                        end = min(skip + chunk_size + margin, samples)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        start = skip - s_margin
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        segmented_mix[skip] = mix[:, start:end].copy()
         | 
| 130 | 
            +
                        if end == samples:
         | 
| 131 | 
            +
                            break
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    sources = self.demix_base(segmented_mix, margin_size=margin)
         | 
| 134 | 
            +
                    """
         | 
| 135 | 
            +
                    mix:(2,big_sample)
         | 
| 136 | 
            +
                    segmented_mix:offset->(2,small_sample)
         | 
| 137 | 
            +
                    sources:(1,2,big_sample)
         | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    return sources
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def demix_base(self, mixes, margin_size):
         | 
| 142 | 
            +
                    chunked_sources = []
         | 
| 143 | 
            +
                    progress_bar = tqdm(total=len(mixes))
         | 
| 144 | 
            +
                    progress_bar.set_description("Processing")
         | 
| 145 | 
            +
                    for mix in mixes:
         | 
| 146 | 
            +
                        cmix = mixes[mix]
         | 
| 147 | 
            +
                        sources = []
         | 
| 148 | 
            +
                        n_sample = cmix.shape[1]
         | 
| 149 | 
            +
                        model = self.model_
         | 
| 150 | 
            +
                        trim = model.n_fft // 2
         | 
| 151 | 
            +
                        gen_size = model.chunk_size - 2 * trim
         | 
| 152 | 
            +
                        pad = gen_size - n_sample % gen_size
         | 
| 153 | 
            +
                        mix_p = np.concatenate(
         | 
| 154 | 
            +
                            (np.zeros((2, trim)), cmix, np.zeros((2, pad)), np.zeros((2, trim))), 1
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
                        mix_waves = []
         | 
| 157 | 
            +
                        i = 0
         | 
| 158 | 
            +
                        while i < n_sample + pad:
         | 
| 159 | 
            +
                            waves = np.array(mix_p[:, i : i + model.chunk_size])
         | 
| 160 | 
            +
                            mix_waves.append(waves)
         | 
| 161 | 
            +
                            i += gen_size
         | 
| 162 | 
            +
                        mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(cpu)
         | 
| 163 | 
            +
                        with torch.no_grad():
         | 
| 164 | 
            +
                            _ort = self.model
         | 
| 165 | 
            +
                            spek = model.stft(mix_waves)
         | 
| 166 | 
            +
                            if self.args.denoise:
         | 
| 167 | 
            +
                                spec_pred = (
         | 
| 168 | 
            +
                                    -_ort.run(None, {"input": -spek.cpu().numpy()})[0] * 0.5
         | 
| 169 | 
            +
                                    + _ort.run(None, {"input": spek.cpu().numpy()})[0] * 0.5
         | 
| 170 | 
            +
                                )
         | 
| 171 | 
            +
                                tar_waves = model.istft(torch.tensor(spec_pred))
         | 
| 172 | 
            +
                            else:
         | 
| 173 | 
            +
                                tar_waves = model.istft(
         | 
| 174 | 
            +
                                    torch.tensor(_ort.run(None, {"input": spek.cpu().numpy()})[0])
         | 
| 175 | 
            +
                                )
         | 
| 176 | 
            +
                            tar_signal = (
         | 
| 177 | 
            +
                                tar_waves[:, :, trim:-trim]
         | 
| 178 | 
            +
                                .transpose(0, 1)
         | 
| 179 | 
            +
                                .reshape(2, -1)
         | 
| 180 | 
            +
                                .numpy()[:, :-pad]
         | 
| 181 | 
            +
                            )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                            start = 0 if mix == 0 else margin_size
         | 
| 184 | 
            +
                            end = None if mix == list(mixes.keys())[::-1][0] else -margin_size
         | 
| 185 | 
            +
                            if margin_size == 0:
         | 
| 186 | 
            +
                                end = None
         | 
| 187 | 
            +
                            sources.append(tar_signal[:, start:end])
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                            progress_bar.update(1)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        chunked_sources.append(sources)
         | 
| 192 | 
            +
                    _sources = np.concatenate(chunked_sources, axis=-1)
         | 
| 193 | 
            +
                    # del self.model
         | 
| 194 | 
            +
                    progress_bar.close()
         | 
| 195 | 
            +
                    return _sources
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def prediction(self, m, vocal_root, others_root, format):
         | 
| 198 | 
            +
                    os.makedirs(vocal_root, exist_ok=True)
         | 
| 199 | 
            +
                    os.makedirs(others_root, exist_ok=True)
         | 
| 200 | 
            +
                    basename = os.path.basename(m)
         | 
| 201 | 
            +
                    mix, rate = librosa.load(m, mono=False, sr=44100)
         | 
| 202 | 
            +
                    if mix.ndim == 1:
         | 
| 203 | 
            +
                        mix = np.asfortranarray([mix, mix])
         | 
| 204 | 
            +
                    mix = mix.T
         | 
| 205 | 
            +
                    sources = self.demix(mix.T)
         | 
| 206 | 
            +
                    opt = sources[0].T
         | 
| 207 | 
            +
                    if format in ["wav", "flac"]:
         | 
| 208 | 
            +
                        sf.write(
         | 
| 209 | 
            +
                            "%s/%s_main_vocal.%s" % (vocal_root, basename, format), mix - opt, rate
         | 
| 210 | 
            +
                        )
         | 
| 211 | 
            +
                        sf.write("%s/%s_others.%s" % (others_root, basename, format), opt, rate)
         | 
| 212 | 
            +
                    else:
         | 
| 213 | 
            +
                        path_vocal = "%s/%s_main_vocal.wav" % (vocal_root, basename)
         | 
| 214 | 
            +
                        path_other = "%s/%s_others.wav" % (others_root, basename)
         | 
| 215 | 
            +
                        sf.write(path_vocal, mix - opt, rate)
         | 
| 216 | 
            +
                        sf.write(path_other, opt, rate)
         | 
| 217 | 
            +
                        if os.path.exists(path_vocal):
         | 
| 218 | 
            +
                            os.system(
         | 
| 219 | 
            +
                                "ffmpeg -i %s -vn %s -q:a 2 -y"
         | 
| 220 | 
            +
                                % (path_vocal, path_vocal[:-4] + ".%s" % format)
         | 
| 221 | 
            +
                            )
         | 
| 222 | 
            +
                        if os.path.exists(path_other):
         | 
| 223 | 
            +
                            os.system(
         | 
| 224 | 
            +
                                "ffmpeg -i %s -vn %s -q:a 2 -y"
         | 
| 225 | 
            +
                                % (path_other, path_other[:-4] + ".%s" % format)
         | 
| 226 | 
            +
                            )
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            class MDXNetDereverb:
         | 
| 230 | 
            +
                def __init__(self, chunks):
         | 
| 231 | 
            +
                    self.onnx = "uvr5_weights/onnx_dereverb_By_FoxJoy"
         | 
| 232 | 
            +
                    self.shifts = 10  #'Predict with randomised equivariant stabilisation'
         | 
| 233 | 
            +
                    self.mixing = "min_mag"  # ['default','min_mag','max_mag']
         | 
| 234 | 
            +
                    self.chunks = chunks
         | 
| 235 | 
            +
                    self.margin = 44100
         | 
| 236 | 
            +
                    self.dim_t = 9
         | 
| 237 | 
            +
                    self.dim_f = 3072
         | 
| 238 | 
            +
                    self.n_fft = 6144
         | 
| 239 | 
            +
                    self.denoise = True
         | 
| 240 | 
            +
                    self.pred = Predictor(self)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def _path_audio_(self, input, vocal_root, others_root, format):
         | 
| 243 | 
            +
                    self.pred.prediction(input, vocal_root, others_root, format)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            if __name__ == "__main__":
         | 
| 247 | 
            +
                dereverb = MDXNetDereverb(15)
         | 
| 248 | 
            +
                from time import time as ttime
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                t0 = ttime()
         | 
| 251 | 
            +
                dereverb._path_audio_(
         | 
| 252 | 
            +
                    "雪雪伴奏对消HP5.wav",
         | 
| 253 | 
            +
                    "vocal",
         | 
| 254 | 
            +
                    "others",
         | 
| 255 | 
            +
                )
         | 
| 256 | 
            +
                t1 = ttime()
         | 
| 257 | 
            +
                print(t1 - t0)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
             | 
| 260 | 
            +
            """
         | 
| 261 | 
            +
             | 
| 262 | 
            +
            runtime\python.exe MDXNet.py 
         | 
| 263 | 
            +
             | 
| 264 | 
            +
            6G:
         | 
| 265 | 
            +
            15/9:0.8G->6.8G
         | 
| 266 | 
            +
            14:0.8G->6.5G
         | 
| 267 | 
            +
            25:炸
         | 
| 268 | 
            +
             | 
| 269 | 
            +
            half15:0.7G->6.6G,22.69s
         | 
| 270 | 
            +
            fp32-15:0.7G->6.6G,20.85s
         | 
| 271 | 
            +
             | 
| 272 | 
            +
            """
         | 
    	
        Makefile
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .PHONY:
         | 
| 2 | 
            +
            .ONESHELL:
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            help: ## Show this help and exit
         | 
| 5 | 
            +
            	@grep -hE '^[A-Za-z0-9_ \-]*?:.*##.*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            install: ## Install dependencies (Do everytime you start up a paperspace machine)
         | 
| 8 | 
            +
            	apt-get -y install build-essential python3-dev ffmpeg
         | 
| 9 | 
            +
            	pip install --upgrade setuptools wheel
         | 
| 10 | 
            +
            	pip install --upgrade pip
         | 
| 11 | 
            +
            	pip install faiss-gpu fairseq gradio ffmpeg ffmpeg-python praat-parselmouth pyworld numpy==1.23.5 numba==0.56.4 librosa==0.9.1
         | 
| 12 | 
            +
            	pip install -r requirements.txt
         | 
| 13 | 
            +
            	pip install --upgrade lxml
         | 
| 14 | 
            +
            	apt-get update
         | 
| 15 | 
            +
            	apt -y install -qq aria2
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            basev1: ## Download version 1 pre-trained models (Do only once after cloning the fork)
         | 
| 18 | 
            +
            	mkdir -p pretrained uvr5_weights
         | 
| 19 | 
            +
            	git pull
         | 
| 20 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/D32k.pth -d pretrained -o D32k.pth
         | 
| 21 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/D40k.pth -d pretrained -o D40k.pth
         | 
| 22 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/D48k.pth -d pretrained -o D48k.pth
         | 
| 23 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/G32k.pth -d pretrained -o G32k.pth
         | 
| 24 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/G40k.pth -d pretrained -o G40k.pth
         | 
| 25 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/G48k.pth -d pretrained -o G48k.pth
         | 
| 26 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0D32k.pth -d pretrained -o f0D32k.pth
         | 
| 27 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0D40k.pth -d pretrained -o f0D40k.pth
         | 
| 28 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0D48k.pth -d pretrained -o f0D48k.pth
         | 
| 29 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0G32k.pth -d pretrained -o f0G32k.pth
         | 
| 30 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0G40k.pth -d pretrained -o f0G40k.pth
         | 
| 31 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained/f0G48k.pth -d pretrained -o f0G48k.pth
         | 
| 32 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2-人声vocals+非人声instrumentals.pth -d uvr5_weights -o HP2-人声vocals+非人声instrumentals.pth
         | 
| 33 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5-主旋律人声vocals+其他instrumentals.pth -d uvr5_weights -o HP5-主旋律人声vocals+其他instrumentals.pth
         | 
| 34 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d ./ -o hubert_base.pt
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            basev2: ## Download version 2 pre-trained models (Do only once after cloning the fork)
         | 
| 37 | 
            +
            	mkdir -p pretrained_v2 uvr5_weights
         | 
| 38 | 
            +
            	git pull
         | 
| 39 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D32k.pth -d pretrained_v2 -o D32k.pth
         | 
| 40 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D40k.pth -d pretrained_v2 -o D40k.pth
         | 
| 41 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/D48k.pth -d pretrained_v2 -o D48k.pth
         | 
| 42 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G32k.pth -d pretrained_v2 -o G32k.pth
         | 
| 43 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G40k.pth -d pretrained_v2 -o G40k.pth
         | 
| 44 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/G48k.pth -d pretrained_v2 -o G48k.pth
         | 
| 45 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D32k.pth -d pretrained_v2 -o f0D32k.pth
         | 
| 46 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D40k.pth -d pretrained_v2 -o f0D40k.pth
         | 
| 47 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0D48k.pth -d pretrained_v2 -o f0D48k.pth
         | 
| 48 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G32k.pth -d pretrained_v2 -o f0G32k.pth
         | 
| 49 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G40k.pth -d pretrained_v2 -o f0G40k.pth
         | 
| 50 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/pretrained_v2/f0G48k.pth -d pretrained_v2 -o f0G48k.pth
         | 
| 51 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2-人声vocals+非人声instrumentals.pth -d uvr5_weights -o HP2-人声vocals+非人声instrumentals.pth
         | 
| 52 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5-主旋律人声vocals+其他instrumentals.pth -d uvr5_weights -o HP5-主旋律人声vocals+其他instrumentals.pth
         | 
| 53 | 
            +
            	aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/hubert_base.pt -d ./ -o hubert_base.pt
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            run-ui: ## Run the python GUI
         | 
| 56 | 
            +
            	python infer-web.py --paperspace --pycmd python
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            run-cli: ## Run the python CLI
         | 
| 59 | 
            +
            	python infer-web.py --pycmd python --is_cli
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            tensorboard: ## Start the tensorboard (Run on separate terminal)
         | 
| 62 | 
            +
            	echo https://tensorboard-$$(hostname).clg07azjl.paperspacegradient.com
         | 
| 63 | 
            +
            	tensorboard --logdir logs --bind_all
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,222 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # 🍏 Applio-RVC-Fork
         | 
| 2 | 
            +
            Applio is a user-friendly fork of Mangio-RVC-Fork/RVC, designed to provide an intuitive interface, especially for newcomers.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            ## 📎 Links
         | 
| 5 | 
            +
            [](https://discord.gg/IAHispano)
         | 
| 6 | 
            +
            [](https://colab.research.google.com/drive/157pUQep6txJOYModYFqvz_5OJajeh7Ii)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ## 📚 Table of Contents
         | 
| 9 | 
            +
              1. [Improvements of Applio Over RVC](#-improvements-of-applio-over-rvc)
         | 
| 10 | 
            +
              2. [Additional Features of This Repository](#️-additional-features-of-this-repository)
         | 
| 11 | 
            +
              3. [Planned Features for Future Development](#️-planned-features-for-future-development)
         | 
| 12 | 
            +
              4. [Installation](#-installation)
         | 
| 13 | 
            +
              5. [Running the Web GUI (Inference & Train)](#-running-the-web-gui-inference--train)
         | 
| 14 | 
            +
              6. [Running the CLI (Inference & Train)](#-running-the-cli-inference--train)
         | 
| 15 | 
            +
              7. [Credits](#credits)
         | 
| 16 | 
            +
              8. [Thanks to all RVC and Mangio contributors](#thanks-to-all-rvc-and-mangio-contributors)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            ## 🎯 Improvements of Applio Over RVC
         | 
| 20 | 
            +
            ### f0 Inference Algorithm Overhaul
         | 
| 21 | 
            +
            - Applio features a comprehensive overhaul of the f0 inference algorithm, including:
         | 
| 22 | 
            +
              - Addition of the pyworld dio f0 method.
         | 
| 23 | 
            +
              - Alternative method for calculating crepe f0.
         | 
| 24 | 
            +
              - Introduction of the torchcrepe crepe-tiny model.
         | 
| 25 | 
            +
              - Customizable crepe_hop_length for the crepe algorithm via both the web GUI and CLI.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            ### f0 Crepe Pitch Extraction for Training
         | 
| 28 | 
            +
            - Works on paperspace machines but not local MacOS/Windows machines (Potential memory leak).
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ### Paperspace Integration
         | 
| 31 | 
            +
            - Applio seamlessly integrates with Paperspace, providing the following features:
         | 
| 32 | 
            +
              - Paperspace argument on infer-web.py (--paperspace) for sharing a Gradio link.
         | 
| 33 | 
            +
              - A dedicated make file tailored for Paperspace users.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            ### Access to Tensorboard
         | 
| 36 | 
            +
            - Applio grants easy access to Tensorboard via a Makefile and a Python script.
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            ### CLI Functionality
         | 
| 39 | 
            +
            - Applio introduces command-line interface (CLI) functionality, with the addition of the --is_cli flag in infer-web.py for CLI system usage.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            ### f0 Hybrid Estimation Method
         | 
| 42 | 
            +
            - Applio offers a novel f0 hybrid estimation method by calculating nanmedian for a specified array of f0 methods, ensuring the best results from multiple methods (CLI exclusive).
         | 
| 43 | 
            +
            - This hybrid estimation method is also available for f0 feature extraction during training.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            ### UI Changes
         | 
| 46 | 
            +
            #### Inference:
         | 
| 47 | 
            +
            - A complete interface redesign enhances user experience, with notable features such as:
         | 
| 48 | 
            +
              - Audio recording directly from the interface.
         | 
| 49 | 
            +
              - Convenient drop-down menus for audio and .index file selection.
         | 
| 50 | 
            +
              - An advanced settings section with new features like autotune and formant shifting.
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            #### Training:
         | 
| 53 | 
            +
            - Improved training features include:
         | 
| 54 | 
            +
              - A total epoch slider now limited to 10,000.
         | 
| 55 | 
            +
              - Increased save frequency limit to 100.
         | 
| 56 | 
            +
              - Default recommended options for smoother setup.
         | 
| 57 | 
            +
              - Better adaptation to high-resolution screens.
         | 
| 58 | 
            +
              - A drop-down menu for dataset selection.
         | 
| 59 | 
            +
              - Enhanced saving system options, including Save all files, Save G and D files, and Save model for inference.
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            #### UVR:
         | 
| 62 | 
            +
            - Applio ensures compatibility with all VR/MDX models for an extended range of possibilities.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            #### TTS (Text-to-Speech, New):
         | 
| 65 | 
            +
            - Introducing a new Text-to-Speech (TTS) feature using RVC models.
         | 
| 66 | 
            +
            - Support for multiple languages and Edge-tts/Bark-tts.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            #### Resources (New):
         | 
| 69 | 
            +
            - Users can now upload models, backups, datasets, and audios from various storage services like Drive, Huggingface, Discord, and more.
         | 
| 70 | 
            +
            - Download audios from YouTube with the ability to automatically separate instrumental and vocals, offering advanced options and UVR support.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            #### Extra (New):
         | 
| 73 | 
            +
            - Combine instrumental and vocals with ease, including independent volume control for each track and the option to add effects like reverb, compressor, and noise gate.
         | 
| 74 | 
            +
            - Significant improvements in the processing interface, allowing tasks such as merging models, modifying information, obtaining information, or extracting models effortlessly.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            ## ⚙️ Additional Features of This Repository
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            In addition to the aforementioned improvements, this repository offers the following features:
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            ### Enhanced Tone Leakage Reduction
         | 
| 81 | 
            +
            - Implements tone leakage reduction by replacing source features with training-set features using top1 retrieval. This helps in achieving cleaner audio results.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            ### Efficient Training
         | 
| 84 | 
            +
            - Provides a seamless and speedy training experience, even on relatively modest graphics cards. The system is optimized for efficient resource utilization.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            ### Data Efficiency
         | 
| 87 | 
            +
            - Supports training with a small dataset, yielding commendable results, especially with audio clips of at least 10 minutes of low-noise speech.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            ## 🛠️ Planned Features for Future Development
         | 
| 90 | 
            +
            As part of the ongoing development of this fork, the following features are planned to be added:
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            - Incorporating an inference batcher script based on user feedback. This enhancement will allow for processing 30-second audio samples at a time, improving output quality and preventing memory errors during inference.
         | 
| 93 | 
            +
            - Implementing an automatic removal mechanism for old generations to optimize storage space usage. This feature ensures that the repository remains efficient and organized over time.
         | 
| 94 | 
            +
            - Streamlining the training process for Paperspace machines to further improve efficiency and resource utilization during training tasks.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            ## Compatibility
         | 
| 97 | 
            +
            - AMD/Intel graphics cards acceleration supported.
         | 
| 98 | 
            +
            - Intel ARC graphics cards acceleration with IPEX supported.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            ## ✨ Installation
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            ### Automatic installation (Windows):
         | 
| 103 | 
            +
            To quickly and effortlessly install Applio along with all the necessary models and configurations on Windows, you can use the [install_Applio.bat](https://github.com/IAHispano/Applio-RVC-Fork/releases) script available in the releases section.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            ### Manual installation (Windows/MacOS):
         | 
| 106 | 
            +
            **Note for MacOS Users**: When using `faiss 1.7.2` under MacOS, you may encounter a Segmentation Fault: 11 error. To resolve this issue, install `faiss-cpu 1.7.0` using the following command if you're installing it manually with pip: 
         | 
| 107 | 
            +
             ```bash
         | 
| 108 | 
            +
            pip install faiss-cpu==1.7.0
         | 
| 109 | 
            +
            ```
         | 
| 110 | 
            +
            Additionally, you can install Swig on MacOS using brew:
         | 
| 111 | 
            +
            ```bash
         | 
| 112 | 
            +
            brew install swig
         | 
| 113 | 
            +
            ```
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            Install requirements:
         | 
| 116 | 
            +
            *Using pip (Python 3.9.8 is stable with this fork)*
         | 
| 117 | 
            +
            ```bash
         | 
| 118 | 
            +
            pip install -r requirements.txt
         | 
| 119 | 
            +
            ```
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            ### Manual installation (Paperspace):
         | 
| 122 | 
            +
            ```bash
         | 
| 123 | 
            +
            cd Applio-RVC-Fork
         | 
| 124 | 
            +
            make install # Do this everytime you start your paperspace machine
         | 
| 125 | 
            +
            ```
         | 
| 126 | 
            +
            ### You can also use pip to install them:
         | 
| 127 | 
            +
            ```bash
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            for Nvidia graphics cards
         | 
| 130 | 
            +
              pip install -r requirements.txt
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            for AMD/Intel graphics cards:
         | 
| 133 | 
            +
              pip install -r requirements-dml.txt
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            for Intel ARC graphics cards on Linux / WSL using Python 3.10: 
         | 
| 136 | 
            +
              pip install -r requirements-ipex.txt
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            ```
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            ## 🪄 Running the Web GUI (Inference & Train) 
         | 
| 141 | 
            +
            *Use --paperspace or --colab if on cloud system.*
         | 
| 142 | 
            +
            ```bash
         | 
| 143 | 
            +
            python infer-web.py --pycmd python --port 3000
         | 
| 144 | 
            +
            ```
         | 
| 145 | 
            +
             | 
| 146 | 
            +
            ## 💻 Running the CLI (Inference & Train) 
         | 
| 147 | 
            +
            ```bash
         | 
| 148 | 
            +
            python infer-web.py --pycmd python --is_cli
         | 
| 149 | 
            +
            ```
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            ```bash
         | 
| 152 | 
            +
            Mangio-RVC-Fork v2 CLI App!
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            Welcome to the CLI version of RVC. Please read the documentation on https://github.com/Mangio621/Mangio-RVC-Fork (README.MD) to understand how to use this app.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            You are currently in 'HOME':
         | 
| 157 | 
            +
                go home            : Takes you back to home with a navigation list.
         | 
| 158 | 
            +
                go infer           : Takes you to inference command execution.
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                go pre-process     : Takes you to training step.1) pre-process command execution.
         | 
| 161 | 
            +
                go extract-feature : Takes you to training step.2) extract-feature command execution.
         | 
| 162 | 
            +
                go train           : Takes you to training step.3) being or continue training command execution.
         | 
| 163 | 
            +
                go train-feature   : Takes you to the train feature index command execution.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                go extract-model   : Takes you to the extract small model command execution.
         | 
| 166 | 
            +
             | 
| 167 | 
            +
            HOME:
         | 
| 168 | 
            +
            ```
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            Typing 'go infer' for example will take you to the infer page where you can then enter in your arguments that you wish to use for that specific page. For example typing 'go infer' will take you here:
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            ```bash
         | 
| 173 | 
            +
            HOME: go infer
         | 
| 174 | 
            +
            You are currently in 'INFER':
         | 
| 175 | 
            +
                arg 1) model name with .pth in ./weights: mi-test.pth
         | 
| 176 | 
            +
                arg 2) source audio path: myFolder\MySource.wav
         | 
| 177 | 
            +
                arg 3) output file name to be placed in './audio-outputs': MyTest.wav
         | 
| 178 | 
            +
                arg 4) feature index file path: logs/mi-test/added_IVF3042_Flat_nprobe_1.index
         | 
| 179 | 
            +
                arg 5) speaker id: 0
         | 
| 180 | 
            +
                arg 6) transposition: 0
         | 
| 181 | 
            +
                arg 7) f0 method: harvest (pm, harvest, crepe, crepe-tiny)
         | 
| 182 | 
            +
                arg 8) crepe hop length: 160
         | 
| 183 | 
            +
                arg 9) harvest median filter radius: 3 (0-7)
         | 
| 184 | 
            +
                arg 10) post resample rate: 0
         | 
| 185 | 
            +
                arg 11) mix volume envelope: 1
         | 
| 186 | 
            +
                arg 12) feature index ratio: 0.78 (0-1)
         | 
| 187 | 
            +
                arg 13) Voiceless Consonant Protection (Less Artifact): 0.33 (Smaller number = more protection. 0.50 means Dont Use.)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            Example: mi-test.pth saudio/Sidney.wav myTest.wav logs/mi-test/added_index.index 0 -2 harvest 160 3 0 1 0.95 0.33
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            INFER: <INSERT ARGUMENTS HERE OR COPY AND PASTE THE EXAMPLE>
         | 
| 192 | 
            +
            ```
         | 
| 193 | 
            +
            ## 🏆 Credits
         | 
| 194 | 
            +
            Applio owes its existence to the collaborative efforts of various repositories, including Mangio-RVC-Fork, and all the other credited contributors. Without their contributions, Applio would not have been possible. Therefore, we kindly request that if you appreciate the work we've accomplished, you consider exploring the projects mentioned in our credits.
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            Our goal is not to supplant RVC or Mangio; rather, we aim to provide a contemporary and up-to-date alternative for the entire community.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
            + [Retrieval-based-Voice-Conversion-WebUI](Retrieval-based-Voice-Conversion-WebUI)
         | 
| 199 | 
            +
            + [Mangio-RVC-Fork](https://github.com/Mangio621/Mangio-RVC-Fork)
         | 
| 200 | 
            +
            + [RVG_tts](https://github.com/Foxify52/RVG_tts)
         | 
| 201 | 
            +
            + [ContentVec](https://github.com/auspicious3000/contentvec/)
         | 
| 202 | 
            +
            + [VITS](https://github.com/jaywalnut310/vits)
         | 
| 203 | 
            +
            + [HIFIGAN](https://github.com/jik876/hifi-gan)
         | 
| 204 | 
            +
            + [Gradio](https://github.com/gradio-app/gradio)
         | 
| 205 | 
            +
            + [FFmpeg](https://github.com/FFmpeg/FFmpeg)
         | 
| 206 | 
            +
            + [Ultimate Vocal Remover](https://github.com/Anjok07/ultimatevocalremovergui)
         | 
| 207 | 
            +
            + [audio-slicer](https://github.com/openvpi/audio-slicer)
         | 
| 208 | 
            +
            + [Vocal pitch extraction:RMVPE](https://github.com/Dream-High/RMVPE)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            ## 🙏 Thanks to all RVC, Mangio and Applio contributors
         | 
| 212 | 
            +
            <a href="https://github.com/liujing04/Retrieval-based-Voice-Conversion-WebUI/graphs/contributors" target="_blank">
         | 
| 213 | 
            +
              <img src="https://contrib.rocks/image?repo=liujing04/Retrieval-based-Voice-Conversion-WebUI" />
         | 
| 214 | 
            +
            </a>
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            <a href="https://github.com/Mangio621/Mangio-RVC-Fork/graphs/contributors" target="_blank">
         | 
| 217 | 
            +
              <img src="https://contrib.rocks/image?repo=Mangio621/Mangio-RVC-Fork" />
         | 
| 218 | 
            +
            </a>
         | 
| 219 | 
            +
             | 
| 220 | 
            +
            <a href="https://github.com/IAHispano/Applio-RVC-Fork/graphs/contributors" target="_blank">
         | 
| 221 | 
            +
              <img src="https://contrib.rocks/image?repo=IAHispano/Applio-RVC-Fork" />
         | 
| 222 | 
            +
            </a>
         | 
    	
        assets/hubert/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        assets/pretrained/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        assets/pretrained_v2/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        assets/rmvpe/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        assets/uvr5_weights/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        assets/weights/.gitignore
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *
         | 
| 2 | 
            +
            !.gitignore
         | 
    	
        audioEffects.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from pedalboard import Pedalboard, Compressor, Reverb, NoiseGate
         | 
| 2 | 
            +
            from pedalboard.io import AudioFile
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            now_dir = os.getcwd()
         | 
| 6 | 
            +
            sys.path.append(now_dir)
         | 
| 7 | 
            +
            from i18n import I18nAuto
         | 
| 8 | 
            +
            i18n = I18nAuto()
         | 
| 9 | 
            +
            from pydub import AudioSegment
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import soundfile as sf
         | 
| 12 | 
            +
            from pydub.playback import play
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def process_audio(input_path, output_path, reverb_enabled, compressor_enabled, noise_gate_enabled, ):
         | 
| 15 | 
            +
                print(reverb_enabled)
         | 
| 16 | 
            +
                print(compressor_enabled)
         | 
| 17 | 
            +
                print(noise_gate_enabled)
         | 
| 18 | 
            +
                effects = []
         | 
| 19 | 
            +
                if reverb_enabled:
         | 
| 20 | 
            +
                    effects.append(Reverb(room_size=0.01))
         | 
| 21 | 
            +
                if compressor_enabled:
         | 
| 22 | 
            +
                    effects.append(Compressor(threshold_db=-10, ratio=25))
         | 
| 23 | 
            +
                if noise_gate_enabled:
         | 
| 24 | 
            +
                    effects.append(NoiseGate(threshold_db=-16, ratio=1.5, release_ms=250))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                board = Pedalboard(effects)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                with AudioFile(input_path) as f:
         | 
| 29 | 
            +
                    with AudioFile(output_path, 'w', f.samplerate, f.num_channels) as o:
         | 
| 30 | 
            +
                        while f.tell() < f.frames:
         | 
| 31 | 
            +
                            chunk = f.read(f.samplerate)
         | 
| 32 | 
            +
                            effected = board(chunk, f.samplerate, reset=False)
         | 
| 33 | 
            +
                            o.write(effected)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                result = i18n("Processed audio saved at: ") +  output_path
         | 
| 36 | 
            +
                print(result)
         | 
| 37 | 
            +
                return output_path
         | 
    	
        audios/.gitignore
    ADDED
    
    | 
            File without changes
         | 
    	
        colab_for_mdx.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import gc
         | 
| 4 | 
            +
            import psutil
         | 
| 5 | 
            +
            import requests
         | 
| 6 | 
            +
            import subprocess
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import logging
         | 
| 9 | 
            +
            import sys
         | 
| 10 | 
            +
            import shutil
         | 
| 11 | 
            +
            now_dir = os.getcwd()
         | 
| 12 | 
            +
            sys.path.append(now_dir)
         | 
| 13 | 
            +
            first_cell_executed = False
         | 
| 14 | 
            +
            file_folder = "Colab-for-MDX_B"
         | 
| 15 | 
            +
            def first_cell_ran():
         | 
| 16 | 
            +
                global first_cell_executed
         | 
| 17 | 
            +
                if first_cell_executed:
         | 
| 18 | 
            +
                    #print("The 'first_cell_ran' function has already been executed.")
         | 
| 19 | 
            +
                    return
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                first_cell_executed = True
         | 
| 24 | 
            +
                os.makedirs("tmp_models", exist_ok=True)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
                class hide_opt:  # hide outputs
         | 
| 29 | 
            +
                    def __enter__(self):
         | 
| 30 | 
            +
                        self._original_stdout = sys.stdout
         | 
| 31 | 
            +
                        sys.stdout = open(os.devnull, "w")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    def __exit__(self, exc_type, exc_val, exc_tb):
         | 
| 34 | 
            +
                        sys.stdout.close()
         | 
| 35 | 
            +
                        sys.stdout = self._original_stdout
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def get_size(bytes, suffix="B"):  # read ram
         | 
| 38 | 
            +
                    global svmem
         | 
| 39 | 
            +
                    factor = 1024
         | 
| 40 | 
            +
                    for unit in ["", "K", "M", "G", "T", "P"]:
         | 
| 41 | 
            +
                        if bytes < factor:
         | 
| 42 | 
            +
                            return f"{bytes:.2f}{unit}{suffix}"
         | 
| 43 | 
            +
                        bytes /= factor
         | 
| 44 | 
            +
                    svmem = psutil.virtual_memory()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
                def use_uvr_without_saving():
         | 
| 48 | 
            +
                    print("Notice: files won't be saved to personal drive.")
         | 
| 49 | 
            +
                    print(f"Downloading {file_folder}...", end=" ")
         | 
| 50 | 
            +
                    with hide_opt():
         | 
| 51 | 
            +
                        #os.chdir(mounting_path)
         | 
| 52 | 
            +
                        items_to_move = ["demucs", "diffq","julius","model","separated","tracks","mdx.py","MDX-Net_Colab.ipynb"]
         | 
| 53 | 
            +
                        subprocess.run(["git", "clone", "https://github.com/NaJeongMo/Colab-for-MDX_B.git"])
         | 
| 54 | 
            +
                        for item_name in items_to_move:
         | 
| 55 | 
            +
                            item_path = os.path.join(file_folder, item_name)
         | 
| 56 | 
            +
                            if os.path.exists(item_path):
         | 
| 57 | 
            +
                                if os.path.isfile(item_path):
         | 
| 58 | 
            +
                                    shutil.move(item_path, now_dir)
         | 
| 59 | 
            +
                                elif os.path.isdir(item_path):
         | 
| 60 | 
            +
                                    shutil.move(item_path, now_dir)
         | 
| 61 | 
            +
                        try:
         | 
| 62 | 
            +
                            shutil.rmtree(file_folder)
         | 
| 63 | 
            +
                        except PermissionError:
         | 
| 64 | 
            +
                            print(f"No se pudo eliminar la carpeta {file_folder}. Puede estar relacionada con Git.")
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                use_uvr_without_saving()
         | 
| 68 | 
            +
                print("done!")
         | 
| 69 | 
            +
                if not os.path.exists("tracks"):
         | 
| 70 | 
            +
                    os.mkdir("tracks")
         | 
| 71 | 
            +
            first_cell_ran()
         | 
    	
        configs/32k.json
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": false,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 12800,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 32000,
         | 
| 21 | 
            +
                "filter_length": 1024,
         | 
| 22 | 
            +
                "hop_length": 320,
         | 
| 23 | 
            +
                "win_length": 1024,
         | 
| 24 | 
            +
                "n_mel_channels": 80,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [
         | 
| 39 | 
            +
                  [1, 3, 5],
         | 
| 40 | 
            +
                  [1, 3, 5],
         | 
| 41 | 
            +
                  [1, 3, 5]
         | 
| 42 | 
            +
                ],
         | 
| 43 | 
            +
                "upsample_rates": [10, 4, 2, 2, 2],
         | 
| 44 | 
            +
                "upsample_initial_channel": 512,
         | 
| 45 | 
            +
                "upsample_kernel_sizes": [16, 16, 4, 4, 4],
         | 
| 46 | 
            +
                "use_spectral_norm": false,
         | 
| 47 | 
            +
                "gin_channels": 256,
         | 
| 48 | 
            +
                "spk_embed_dim": 109
         | 
| 49 | 
            +
              }
         | 
| 50 | 
            +
            }
         | 
    	
        configs/32k_v2.json
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 12800,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 32000,
         | 
| 21 | 
            +
                "filter_length": 1024,
         | 
| 22 | 
            +
                "hop_length": 320,
         | 
| 23 | 
            +
                "win_length": 1024,
         | 
| 24 | 
            +
                "n_mel_channels": 80,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [
         | 
| 39 | 
            +
                  [1, 3, 5],
         | 
| 40 | 
            +
                  [1, 3, 5],
         | 
| 41 | 
            +
                  [1, 3, 5]
         | 
| 42 | 
            +
                ],
         | 
| 43 | 
            +
                "upsample_rates": [10, 8, 2, 2],
         | 
| 44 | 
            +
                "upsample_initial_channel": 512,
         | 
| 45 | 
            +
                "upsample_kernel_sizes": [20, 16, 4, 4],
         | 
| 46 | 
            +
                "use_spectral_norm": false,
         | 
| 47 | 
            +
                "gin_channels": 256,
         | 
| 48 | 
            +
                "spk_embed_dim": 109
         | 
| 49 | 
            +
              }
         | 
| 50 | 
            +
            }
         | 
    	
        configs/40k.json
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": false,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 12800,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 40000,
         | 
| 21 | 
            +
                "filter_length": 2048,
         | 
| 22 | 
            +
                "hop_length": 400,
         | 
| 23 | 
            +
                "win_length": 2048,
         | 
| 24 | 
            +
                "n_mel_channels": 125,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [
         | 
| 39 | 
            +
                  [1, 3, 5],
         | 
| 40 | 
            +
                  [1, 3, 5],
         | 
| 41 | 
            +
                  [1, 3, 5]
         | 
| 42 | 
            +
                ],
         | 
| 43 | 
            +
                "upsample_rates": [10, 10, 2, 2],
         | 
| 44 | 
            +
                "upsample_initial_channel": 512,
         | 
| 45 | 
            +
                "upsample_kernel_sizes": [16, 16, 4, 4],
         | 
| 46 | 
            +
                "use_spectral_norm": false,
         | 
| 47 | 
            +
                "gin_channels": 256,
         | 
| 48 | 
            +
                "spk_embed_dim": 109
         | 
| 49 | 
            +
              }
         | 
| 50 | 
            +
            }
         | 
    	
        configs/48k.json
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": false,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 11520,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 48000,
         | 
| 21 | 
            +
                "filter_length": 2048,
         | 
| 22 | 
            +
                "hop_length": 480,
         | 
| 23 | 
            +
                "win_length": 2048,
         | 
| 24 | 
            +
                "n_mel_channels": 128,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [
         | 
| 39 | 
            +
                  [1, 3, 5],
         | 
| 40 | 
            +
                  [1, 3, 5],
         | 
| 41 | 
            +
                  [1, 3, 5]
         | 
| 42 | 
            +
                ],
         | 
| 43 | 
            +
                "upsample_rates": [10, 6, 2, 2, 2],
         | 
| 44 | 
            +
                "upsample_initial_channel": 512,
         | 
| 45 | 
            +
                "upsample_kernel_sizes": [16, 16, 4, 4, 4],
         | 
| 46 | 
            +
                "use_spectral_norm": false,
         | 
| 47 | 
            +
                "gin_channels": 256,
         | 
| 48 | 
            +
                "spk_embed_dim": 109
         | 
| 49 | 
            +
              }
         | 
| 50 | 
            +
            }
         | 
    	
        configs/48k_v2.json
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 17280,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 48000,
         | 
| 21 | 
            +
                "filter_length": 2048,
         | 
| 22 | 
            +
                "hop_length": 480,
         | 
| 23 | 
            +
                "win_length": 2048,
         | 
| 24 | 
            +
                "n_mel_channels": 128,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [
         | 
| 39 | 
            +
                  [1, 3, 5],
         | 
| 40 | 
            +
                  [1, 3, 5],
         | 
| 41 | 
            +
                  [1, 3, 5]
         | 
| 42 | 
            +
                ],
         | 
| 43 | 
            +
                "upsample_rates": [12, 10, 2, 2],
         | 
| 44 | 
            +
                "upsample_initial_channel": 512,
         | 
| 45 | 
            +
                "upsample_kernel_sizes": [24, 20, 4, 4],
         | 
| 46 | 
            +
                "use_spectral_norm": false,
         | 
| 47 | 
            +
                "gin_channels": 256,
         | 
| 48 | 
            +
                "spk_embed_dim": 109
         | 
| 49 | 
            +
              }
         | 
| 50 | 
            +
            }
         | 
    	
        configs/config.json
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "pth_path": "assets/weights/kikiV1.pth",
         | 
| 3 | 
            +
              "index_path": "logs/kikiV1.index",
         | 
| 4 | 
            +
              "sg_input_device": "VoiceMeeter Output (VB-Audio Vo (MME)",
         | 
| 5 | 
            +
              "sg_output_device": "VoiceMeeter Aux Input (VB-Audio (MME)",
         | 
| 6 | 
            +
              "threhold": -45.0,
         | 
| 7 | 
            +
              "pitch": 12.0,
         | 
| 8 | 
            +
              "index_rate": 0.0,
         | 
| 9 | 
            +
              "rms_mix_rate": 0.0,
         | 
| 10 | 
            +
              "block_time": 0.25,
         | 
| 11 | 
            +
              "crossfade_length": 0.04,
         | 
| 12 | 
            +
              "extra_time": 2.0,
         | 
| 13 | 
            +
              "n_cpu": 6.0,
         | 
| 14 | 
            +
              "f0method": "rmvpe"
         | 
| 15 | 
            +
            }
         | 
    	
        configs/config.py
    ADDED
    
    | @@ -0,0 +1,265 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            from multiprocessing import cpu_count
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            try:
         | 
| 10 | 
            +
                import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
         | 
| 11 | 
            +
                if torch.xpu.is_available():
         | 
| 12 | 
            +
                    from infer.modules.ipex import ipex_init
         | 
| 13 | 
            +
                    ipex_init()
         | 
| 14 | 
            +
            except Exception:
         | 
| 15 | 
            +
                pass
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import logging
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            version_config_list = [
         | 
| 23 | 
            +
                "v1/32k.json",
         | 
| 24 | 
            +
                "v1/40k.json",
         | 
| 25 | 
            +
                "v1/48k.json",
         | 
| 26 | 
            +
                "v2/48k.json",
         | 
| 27 | 
            +
                "v2/32k.json",
         | 
| 28 | 
            +
            ]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def singleton_variable(func):
         | 
| 32 | 
            +
                def wrapper(*args, **kwargs):
         | 
| 33 | 
            +
                    if not wrapper.instance:
         | 
| 34 | 
            +
                        wrapper.instance = func(*args, **kwargs)
         | 
| 35 | 
            +
                    return wrapper.instance
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                wrapper.instance = None
         | 
| 38 | 
            +
                return wrapper
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            @singleton_variable
         | 
| 42 | 
            +
            class Config:
         | 
| 43 | 
            +
                def __init__(self):
         | 
| 44 | 
            +
                    self.device = "cuda:0"
         | 
| 45 | 
            +
                    self.is_half = True
         | 
| 46 | 
            +
                    self.n_cpu = 0
         | 
| 47 | 
            +
                    self.gpu_name = None
         | 
| 48 | 
            +
                    self.json_config = self.load_config_json()
         | 
| 49 | 
            +
                    self.gpu_mem = None
         | 
| 50 | 
            +
                    (
         | 
| 51 | 
            +
                        self.python_cmd,
         | 
| 52 | 
            +
                        self.listen_port,
         | 
| 53 | 
            +
                        self.iscolab,
         | 
| 54 | 
            +
                        self.noparallel,
         | 
| 55 | 
            +
                        self.noautoopen,
         | 
| 56 | 
            +
                        self.paperspace,
         | 
| 57 | 
            +
                        self.is_cli,
         | 
| 58 | 
            +
                        self.grtheme,
         | 
| 59 | 
            +
                        self.dml,
         | 
| 60 | 
            +
                    ) = self.arg_parse()
         | 
| 61 | 
            +
                    self.instead = ""
         | 
| 62 | 
            +
                    self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @staticmethod
         | 
| 65 | 
            +
                def load_config_json() -> dict:
         | 
| 66 | 
            +
                    d = {}
         | 
| 67 | 
            +
                    for config_file in version_config_list:
         | 
| 68 | 
            +
                        with open(f"configs/{config_file}", "r") as f:
         | 
| 69 | 
            +
                            d[config_file] = json.load(f)
         | 
| 70 | 
            +
                    return d
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                @staticmethod
         | 
| 73 | 
            +
                def arg_parse() -> tuple:
         | 
| 74 | 
            +
                    exe = sys.executable or "python"
         | 
| 75 | 
            +
                    parser = argparse.ArgumentParser()
         | 
| 76 | 
            +
                    parser.add_argument("--port", type=int, default=7865, help="Listen port")
         | 
| 77 | 
            +
                    parser.add_argument("--pycmd", type=str, default=exe, help="Python command")
         | 
| 78 | 
            +
                    parser.add_argument("--colab", action="store_true", help="Launch in colab")
         | 
| 79 | 
            +
                    parser.add_argument(
         | 
| 80 | 
            +
                        "--noparallel", action="store_true", help="Disable parallel processing"
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
                    parser.add_argument(
         | 
| 83 | 
            +
                        "--noautoopen",
         | 
| 84 | 
            +
                        action="store_true",
         | 
| 85 | 
            +
                        help="Do not open in browser automatically",
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
                    parser.add_argument(  
         | 
| 88 | 
            +
                        "--paperspace",
         | 
| 89 | 
            +
                        action="store_true",
         | 
| 90 | 
            +
                        help="Note that this argument just shares a gradio link for the web UI. Thus can be used on other non-local CLI systems.",
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    parser.add_argument(  
         | 
| 93 | 
            +
                        "--is_cli",
         | 
| 94 | 
            +
                        action="store_true",
         | 
| 95 | 
            +
                        help="Use the CLI instead of setting up a gradio UI. This flag will launch an RVC text interface where you can execute functions from infer-web.py!",
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    parser.add_argument(
         | 
| 99 | 
            +
                                "-t",
         | 
| 100 | 
            +
                                "--theme",
         | 
| 101 | 
            +
                        help    = "Theme for Gradio. Format - `JohnSmith9982/small_and_pretty` (no backticks)",
         | 
| 102 | 
            +
                        default = "JohnSmith9982/small_and_pretty",
         | 
| 103 | 
            +
                        type    = str
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    parser.add_argument(
         | 
| 107 | 
            +
                        "--dml",
         | 
| 108 | 
            +
                        action="store_true",
         | 
| 109 | 
            +
                        help="Use DirectML backend instead of CUDA."
         | 
| 110 | 
            +
                    )
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    cmd_opts = parser.parse_args()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    return (
         | 
| 117 | 
            +
                        cmd_opts.pycmd,
         | 
| 118 | 
            +
                        cmd_opts.port,
         | 
| 119 | 
            +
                        cmd_opts.colab,
         | 
| 120 | 
            +
                        cmd_opts.noparallel,
         | 
| 121 | 
            +
                        cmd_opts.noautoopen,
         | 
| 122 | 
            +
                        cmd_opts.paperspace,
         | 
| 123 | 
            +
                        cmd_opts.is_cli,
         | 
| 124 | 
            +
                        cmd_opts.theme,
         | 
| 125 | 
            +
                        cmd_opts.dml,
         | 
| 126 | 
            +
                    )
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
         | 
| 129 | 
            +
                # check `getattr` and try it for compatibility
         | 
| 130 | 
            +
                @staticmethod
         | 
| 131 | 
            +
                def has_mps() -> bool:
         | 
| 132 | 
            +
                    if not torch.backends.mps.is_available():
         | 
| 133 | 
            +
                        return False
         | 
| 134 | 
            +
                    try:
         | 
| 135 | 
            +
                        torch.zeros(1).to(torch.device("mps"))
         | 
| 136 | 
            +
                        return True
         | 
| 137 | 
            +
                    except Exception:
         | 
| 138 | 
            +
                        return False
         | 
| 139 | 
            +
                    
         | 
| 140 | 
            +
                @staticmethod
         | 
| 141 | 
            +
                def has_xpu() -> bool:
         | 
| 142 | 
            +
                    if hasattr(torch, "xpu") and torch.xpu.is_available():
         | 
| 143 | 
            +
                        return True
         | 
| 144 | 
            +
                    else:
         | 
| 145 | 
            +
                        return False
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def use_fp32_config(self):
         | 
| 148 | 
            +
                    for config_file in version_config_list:
         | 
| 149 | 
            +
                        self.json_config[config_file]["train"]["fp16_run"] = False
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def device_config(self) -> tuple:
         | 
| 152 | 
            +
                    if torch.cuda.is_available():
         | 
| 153 | 
            +
                        if self.has_xpu():
         | 
| 154 | 
            +
                            self.device = self.instead = "xpu:0"
         | 
| 155 | 
            +
                            self.is_half = True
         | 
| 156 | 
            +
                        i_device = int(self.device.split(":")[-1])
         | 
| 157 | 
            +
                        self.gpu_name = torch.cuda.get_device_name(i_device)
         | 
| 158 | 
            +
                        if (
         | 
| 159 | 
            +
                            ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
         | 
| 160 | 
            +
                            or "P40" in self.gpu_name.upper()
         | 
| 161 | 
            +
                            or "P10" in self.gpu_name.upper()
         | 
| 162 | 
            +
                            or "1060" in self.gpu_name
         | 
| 163 | 
            +
                            or "1070" in self.gpu_name
         | 
| 164 | 
            +
                            or "1080" in self.gpu_name
         | 
| 165 | 
            +
                        ):
         | 
| 166 | 
            +
                            logger.info("Found GPU %s, force to fp32", self.gpu_name)
         | 
| 167 | 
            +
                            self.is_half = False
         | 
| 168 | 
            +
                            self.use_fp32_config()
         | 
| 169 | 
            +
                        else:
         | 
| 170 | 
            +
                            logger.info("Found GPU %s", self.gpu_name)
         | 
| 171 | 
            +
                        self.gpu_mem = int(
         | 
| 172 | 
            +
                            torch.cuda.get_device_properties(i_device).total_memory
         | 
| 173 | 
            +
                            / 1024
         | 
| 174 | 
            +
                            / 1024
         | 
| 175 | 
            +
                            / 1024
         | 
| 176 | 
            +
                            + 0.4
         | 
| 177 | 
            +
                        )
         | 
| 178 | 
            +
                        if self.gpu_mem <= 4:
         | 
| 179 | 
            +
                            with open("infer/modules/train/preprocess.py", "r") as f:
         | 
| 180 | 
            +
                                strr = f.read().replace("3.7", "3.0")
         | 
| 181 | 
            +
                            with open("infer/modules/train/preprocess.py", "w") as f:
         | 
| 182 | 
            +
                                f.write(strr)
         | 
| 183 | 
            +
                    elif self.has_mps():
         | 
| 184 | 
            +
                        logger.info("No supported Nvidia GPU found")
         | 
| 185 | 
            +
                        self.device = self.instead = "mps"
         | 
| 186 | 
            +
                        self.is_half = False
         | 
| 187 | 
            +
                        self.use_fp32_config()
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        logger.info("No supported Nvidia GPU found")
         | 
| 190 | 
            +
                        self.device = self.instead = "cpu"
         | 
| 191 | 
            +
                        self.is_half = False
         | 
| 192 | 
            +
                        self.use_fp32_config()
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    if self.n_cpu == 0:
         | 
| 195 | 
            +
                        self.n_cpu = cpu_count()
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    if self.is_half:
         | 
| 198 | 
            +
                        # 6G显存配置
         | 
| 199 | 
            +
                        x_pad = 3
         | 
| 200 | 
            +
                        x_query = 10
         | 
| 201 | 
            +
                        x_center = 60
         | 
| 202 | 
            +
                        x_max = 65
         | 
| 203 | 
            +
                    else:
         | 
| 204 | 
            +
                        # 5G显存配置
         | 
| 205 | 
            +
                        x_pad = 1
         | 
| 206 | 
            +
                        x_query = 6
         | 
| 207 | 
            +
                        x_center = 38
         | 
| 208 | 
            +
                        x_max = 41
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    if self.gpu_mem is not None and self.gpu_mem <= 4:
         | 
| 211 | 
            +
                        x_pad = 1
         | 
| 212 | 
            +
                        x_query = 5
         | 
| 213 | 
            +
                        x_center = 30
         | 
| 214 | 
            +
                        x_max = 32
         | 
| 215 | 
            +
                    if self.dml:
         | 
| 216 | 
            +
                        logger.info("Use DirectML instead")
         | 
| 217 | 
            +
                        if (
         | 
| 218 | 
            +
                            os.path.exists(
         | 
| 219 | 
            +
                                "runtime\Lib\site-packages\onnxruntime\capi\DirectML.dll"
         | 
| 220 | 
            +
                            )
         | 
| 221 | 
            +
                            == False
         | 
| 222 | 
            +
                        ):
         | 
| 223 | 
            +
                            try:
         | 
| 224 | 
            +
                                os.rename(
         | 
| 225 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime",
         | 
| 226 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime-cuda",
         | 
| 227 | 
            +
                                )
         | 
| 228 | 
            +
                            except:
         | 
| 229 | 
            +
                                pass
         | 
| 230 | 
            +
                            try:
         | 
| 231 | 
            +
                                os.rename(
         | 
| 232 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime-dml",
         | 
| 233 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime",
         | 
| 234 | 
            +
                                )
         | 
| 235 | 
            +
                            except:
         | 
| 236 | 
            +
                                pass
         | 
| 237 | 
            +
                        # if self.device != "cpu":
         | 
| 238 | 
            +
                        import torch_directml
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                        self.device = torch_directml.device(torch_directml.default_device())
         | 
| 241 | 
            +
                        self.is_half = False
         | 
| 242 | 
            +
                    else:
         | 
| 243 | 
            +
                        if self.instead:
         | 
| 244 | 
            +
                            logger.info(f"Use {self.instead} instead")
         | 
| 245 | 
            +
                        if (
         | 
| 246 | 
            +
                            os.path.exists(
         | 
| 247 | 
            +
                                "runtime\Lib\site-packages\onnxruntime\capi\onnxruntime_providers_cuda.dll"
         | 
| 248 | 
            +
                            )
         | 
| 249 | 
            +
                            == False
         | 
| 250 | 
            +
                        ):
         | 
| 251 | 
            +
                            try:
         | 
| 252 | 
            +
                                os.rename(
         | 
| 253 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime",
         | 
| 254 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime-dml",
         | 
| 255 | 
            +
                                )
         | 
| 256 | 
            +
                            except:
         | 
| 257 | 
            +
                                pass
         | 
| 258 | 
            +
                            try:
         | 
| 259 | 
            +
                                os.rename(
         | 
| 260 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime-cuda",
         | 
| 261 | 
            +
                                    "runtime\Lib\site-packages\onnxruntime",
         | 
| 262 | 
            +
                                )
         | 
| 263 | 
            +
                            except:
         | 
| 264 | 
            +
                                pass
         | 
| 265 | 
            +
                    return x_pad, x_query, x_center, x_max
         | 
    	
        configs/v1/32k.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 12800,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 32000,
         | 
| 21 | 
            +
                "filter_length": 1024,
         | 
| 22 | 
            +
                "hop_length": 320,
         | 
| 23 | 
            +
                "win_length": 1024,
         | 
| 24 | 
            +
                "n_mel_channels": 80,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3,7,11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
         | 
| 39 | 
            +
                "upsample_rates": [10,4,2,2,2],
         | 
| 40 | 
            +
                "upsample_initial_channel": 512,
         | 
| 41 | 
            +
                "upsample_kernel_sizes": [16,16,4,4,4],
         | 
| 42 | 
            +
                "use_spectral_norm": false,
         | 
| 43 | 
            +
                "gin_channels": 256,
         | 
| 44 | 
            +
                "spk_embed_dim": 109
         | 
| 45 | 
            +
              }
         | 
| 46 | 
            +
            }
         | 
    	
        configs/v1/40k.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 12800,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 40000,
         | 
| 21 | 
            +
                "filter_length": 2048,
         | 
| 22 | 
            +
                "hop_length": 400,
         | 
| 23 | 
            +
                "win_length": 2048,
         | 
| 24 | 
            +
                "n_mel_channels": 125,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3,7,11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
         | 
| 39 | 
            +
                "upsample_rates": [10,10,2,2],
         | 
| 40 | 
            +
                "upsample_initial_channel": 512,
         | 
| 41 | 
            +
                "upsample_kernel_sizes": [16,16,4,4],
         | 
| 42 | 
            +
                "use_spectral_norm": false,
         | 
| 43 | 
            +
                "gin_channels": 256,
         | 
| 44 | 
            +
                "spk_embed_dim": 109
         | 
| 45 | 
            +
              }
         | 
| 46 | 
            +
            }
         | 
    	
        configs/v1/48k.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 11520,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 48000,
         | 
| 21 | 
            +
                "filter_length": 2048,
         | 
| 22 | 
            +
                "hop_length": 480,
         | 
| 23 | 
            +
                "win_length": 2048,
         | 
| 24 | 
            +
                "n_mel_channels": 128,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3,7,11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
         | 
| 39 | 
            +
                "upsample_rates": [10,6,2,2,2],
         | 
| 40 | 
            +
                "upsample_initial_channel": 512,
         | 
| 41 | 
            +
                "upsample_kernel_sizes": [16,16,4,4,4],
         | 
| 42 | 
            +
                "use_spectral_norm": false,
         | 
| 43 | 
            +
                "gin_channels": 256,
         | 
| 44 | 
            +
                "spk_embed_dim": 109
         | 
| 45 | 
            +
              }
         | 
| 46 | 
            +
            }
         | 
    	
        configs/v2/32k.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 12800,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 32000,
         | 
| 21 | 
            +
                "filter_length": 1024,
         | 
| 22 | 
            +
                "hop_length": 320,
         | 
| 23 | 
            +
                "win_length": 1024,
         | 
| 24 | 
            +
                "n_mel_channels": 80,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3,7,11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
         | 
| 39 | 
            +
                "upsample_rates": [10,8,2,2],
         | 
| 40 | 
            +
                "upsample_initial_channel": 512,
         | 
| 41 | 
            +
                "upsample_kernel_sizes": [20,16,4,4],
         | 
| 42 | 
            +
                "use_spectral_norm": false,
         | 
| 43 | 
            +
                "gin_channels": 256,
         | 
| 44 | 
            +
                "spk_embed_dim": 109
         | 
| 45 | 
            +
              }
         | 
| 46 | 
            +
            }
         | 
    	
        configs/v2/48k.json
    ADDED
    
    | @@ -0,0 +1,46 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "train": {
         | 
| 3 | 
            +
                "log_interval": 200,
         | 
| 4 | 
            +
                "seed": 1234,
         | 
| 5 | 
            +
                "epochs": 20000,
         | 
| 6 | 
            +
                "learning_rate": 1e-4,
         | 
| 7 | 
            +
                "betas": [0.8, 0.99],
         | 
| 8 | 
            +
                "eps": 1e-9,
         | 
| 9 | 
            +
                "batch_size": 4,
         | 
| 10 | 
            +
                "fp16_run": true,
         | 
| 11 | 
            +
                "lr_decay": 0.999875,
         | 
| 12 | 
            +
                "segment_size": 17280,
         | 
| 13 | 
            +
                "init_lr_ratio": 1,
         | 
| 14 | 
            +
                "warmup_epochs": 0,
         | 
| 15 | 
            +
                "c_mel": 45,
         | 
| 16 | 
            +
                "c_kl": 1.0
         | 
| 17 | 
            +
              },
         | 
| 18 | 
            +
              "data": {
         | 
| 19 | 
            +
                "max_wav_value": 32768.0,
         | 
| 20 | 
            +
                "sampling_rate": 48000,
         | 
| 21 | 
            +
                "filter_length": 2048,
         | 
| 22 | 
            +
                "hop_length": 480,
         | 
| 23 | 
            +
                "win_length": 2048,
         | 
| 24 | 
            +
                "n_mel_channels": 128,
         | 
| 25 | 
            +
                "mel_fmin": 0.0,
         | 
| 26 | 
            +
                "mel_fmax": null
         | 
| 27 | 
            +
              },
         | 
| 28 | 
            +
              "model": {
         | 
| 29 | 
            +
                "inter_channels": 192,
         | 
| 30 | 
            +
                "hidden_channels": 192,
         | 
| 31 | 
            +
                "filter_channels": 768,
         | 
| 32 | 
            +
                "n_heads": 2,
         | 
| 33 | 
            +
                "n_layers": 6,
         | 
| 34 | 
            +
                "kernel_size": 3,
         | 
| 35 | 
            +
                "p_dropout": 0,
         | 
| 36 | 
            +
                "resblock": "1",
         | 
| 37 | 
            +
                "resblock_kernel_sizes": [3,7,11],
         | 
| 38 | 
            +
                "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
         | 
| 39 | 
            +
                "upsample_rates": [12,10,2,2],
         | 
| 40 | 
            +
                "upsample_initial_channel": 512,
         | 
| 41 | 
            +
                "upsample_kernel_sizes": [24,20,4,4],
         | 
| 42 | 
            +
                "use_spectral_norm": false,
         | 
| 43 | 
            +
                "gin_channels": 256,
         | 
| 44 | 
            +
                "spk_embed_dim": 109
         | 
| 45 | 
            +
              }
         | 
| 46 | 
            +
            }
         | 
    	
        csvdb/formanting.csv
    ADDED
    
    | 
            File without changes
         | 
    	
        csvdb/stop.csv
    ADDED
    
    | 
            File without changes
         | 
    	
        demucs/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __version__ = "2.0.3"
         | 
    	
        demucs/__main__.py
    ADDED
    
    | @@ -0,0 +1,317 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            import math
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            import sys
         | 
| 11 | 
            +
            import time
         | 
| 12 | 
            +
            from dataclasses import dataclass, field
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch as th
         | 
| 15 | 
            +
            from torch import distributed, nn
         | 
| 16 | 
            +
            from torch.nn.parallel.distributed import DistributedDataParallel
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .augment import FlipChannels, FlipSign, Remix, Scale, Shift
         | 
| 19 | 
            +
            from .compressed import get_compressed_datasets
         | 
| 20 | 
            +
            from .model import Demucs
         | 
| 21 | 
            +
            from .parser import get_name, get_parser
         | 
| 22 | 
            +
            from .raw import Rawset
         | 
| 23 | 
            +
            from .repitch import RepitchedWrapper
         | 
| 24 | 
            +
            from .pretrained import load_pretrained, SOURCES
         | 
| 25 | 
            +
            from .tasnet import ConvTasNet
         | 
| 26 | 
            +
            from .test import evaluate
         | 
| 27 | 
            +
            from .train import train_model, validate_model
         | 
| 28 | 
            +
            from .utils import (human_seconds, load_model, save_model, get_state,
         | 
| 29 | 
            +
                                save_state, sizeof_fmt, get_quantizer)
         | 
| 30 | 
            +
            from .wav import get_wav_datasets, get_musdb_wav_datasets
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            @dataclass
         | 
| 34 | 
            +
            class SavedState:
         | 
| 35 | 
            +
                metrics: list = field(default_factory=list)
         | 
| 36 | 
            +
                last_state: dict = None
         | 
| 37 | 
            +
                best_state: dict = None
         | 
| 38 | 
            +
                optimizer: dict = None
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def main():
         | 
| 42 | 
            +
                parser = get_parser()
         | 
| 43 | 
            +
                args = parser.parse_args()
         | 
| 44 | 
            +
                name = get_name(parser, args)
         | 
| 45 | 
            +
                print(f"Experiment {name}")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                if args.musdb is None and args.rank == 0:
         | 
| 48 | 
            +
                    print(
         | 
| 49 | 
            +
                        "You must provide the path to the MusDB dataset with the --musdb flag. "
         | 
| 50 | 
            +
                        "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.",
         | 
| 51 | 
            +
                        file=sys.stderr)
         | 
| 52 | 
            +
                    sys.exit(1)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                eval_folder = args.evals / name
         | 
| 55 | 
            +
                eval_folder.mkdir(exist_ok=True, parents=True)
         | 
| 56 | 
            +
                args.logs.mkdir(exist_ok=True)
         | 
| 57 | 
            +
                metrics_path = args.logs / f"{name}.json"
         | 
| 58 | 
            +
                eval_folder.mkdir(exist_ok=True, parents=True)
         | 
| 59 | 
            +
                args.checkpoints.mkdir(exist_ok=True, parents=True)
         | 
| 60 | 
            +
                args.models.mkdir(exist_ok=True, parents=True)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                if args.device is None:
         | 
| 63 | 
            +
                    device = "cpu"
         | 
| 64 | 
            +
                    if th.cuda.is_available():
         | 
| 65 | 
            +
                        device = "cuda"
         | 
| 66 | 
            +
                else:
         | 
| 67 | 
            +
                    device = args.device
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                th.manual_seed(args.seed)
         | 
| 70 | 
            +
                # Prevents too many threads to be started when running `museval` as it can be quite
         | 
| 71 | 
            +
                # inefficient on NUMA architectures.
         | 
| 72 | 
            +
                os.environ["OMP_NUM_THREADS"] = "1"
         | 
| 73 | 
            +
                os.environ["MKL_NUM_THREADS"] = "1"
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                if args.world_size > 1:
         | 
| 76 | 
            +
                    if device != "cuda" and args.rank == 0:
         | 
| 77 | 
            +
                        print("Error: distributed training is only available with cuda device", file=sys.stderr)
         | 
| 78 | 
            +
                        sys.exit(1)
         | 
| 79 | 
            +
                    th.cuda.set_device(args.rank % th.cuda.device_count())
         | 
| 80 | 
            +
                    distributed.init_process_group(backend="nccl",
         | 
| 81 | 
            +
                                                   init_method="tcp://" + args.master,
         | 
| 82 | 
            +
                                                   rank=args.rank,
         | 
| 83 | 
            +
                                                   world_size=args.world_size)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                checkpoint = args.checkpoints / f"{name}.th"
         | 
| 86 | 
            +
                checkpoint_tmp = args.checkpoints / f"{name}.th.tmp"
         | 
| 87 | 
            +
                if args.restart and checkpoint.exists() and args.rank == 0:
         | 
| 88 | 
            +
                    checkpoint.unlink()
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                if args.test or args.test_pretrained:
         | 
| 91 | 
            +
                    args.epochs = 1
         | 
| 92 | 
            +
                    args.repeat = 0
         | 
| 93 | 
            +
                    if args.test:
         | 
| 94 | 
            +
                        model = load_model(args.models / args.test)
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        model = load_pretrained(args.test_pretrained)
         | 
| 97 | 
            +
                elif args.tasnet:
         | 
| 98 | 
            +
                    model = ConvTasNet(audio_channels=args.audio_channels,
         | 
| 99 | 
            +
                                       samplerate=args.samplerate, X=args.X,
         | 
| 100 | 
            +
                                       segment_length=4 * args.samples,
         | 
| 101 | 
            +
                                       sources=SOURCES)
         | 
| 102 | 
            +
                else:
         | 
| 103 | 
            +
                    model = Demucs(
         | 
| 104 | 
            +
                        audio_channels=args.audio_channels,
         | 
| 105 | 
            +
                        channels=args.channels,
         | 
| 106 | 
            +
                        context=args.context,
         | 
| 107 | 
            +
                        depth=args.depth,
         | 
| 108 | 
            +
                        glu=args.glu,
         | 
| 109 | 
            +
                        growth=args.growth,
         | 
| 110 | 
            +
                        kernel_size=args.kernel_size,
         | 
| 111 | 
            +
                        lstm_layers=args.lstm_layers,
         | 
| 112 | 
            +
                        rescale=args.rescale,
         | 
| 113 | 
            +
                        rewrite=args.rewrite,
         | 
| 114 | 
            +
                        stride=args.conv_stride,
         | 
| 115 | 
            +
                        resample=args.resample,
         | 
| 116 | 
            +
                        normalize=args.normalize,
         | 
| 117 | 
            +
                        samplerate=args.samplerate,
         | 
| 118 | 
            +
                        segment_length=4 * args.samples,
         | 
| 119 | 
            +
                        sources=SOURCES,
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                model.to(device)
         | 
| 122 | 
            +
                if args.init:
         | 
| 123 | 
            +
                    model.load_state_dict(load_pretrained(args.init).state_dict())
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                if args.show:
         | 
| 126 | 
            +
                    print(model)
         | 
| 127 | 
            +
                    size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters()))
         | 
| 128 | 
            +
                    print(f"Model size {size}")
         | 
| 129 | 
            +
                    return
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                try:
         | 
| 132 | 
            +
                    saved = th.load(checkpoint, map_location='cpu')
         | 
| 133 | 
            +
                except IOError:
         | 
| 134 | 
            +
                    saved = SavedState()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                quantizer = None
         | 
| 139 | 
            +
                quantizer = get_quantizer(model, args, optimizer)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                if saved.last_state is not None:
         | 
| 142 | 
            +
                    model.load_state_dict(saved.last_state, strict=False)
         | 
| 143 | 
            +
                if saved.optimizer is not None:
         | 
| 144 | 
            +
                    optimizer.load_state_dict(saved.optimizer)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                model_name = f"{name}.th"
         | 
| 147 | 
            +
                if args.save_model:
         | 
| 148 | 
            +
                    if args.rank == 0:
         | 
| 149 | 
            +
                        model.to("cpu")
         | 
| 150 | 
            +
                        model.load_state_dict(saved.best_state)
         | 
| 151 | 
            +
                        save_model(model, quantizer, args, args.models / model_name)
         | 
| 152 | 
            +
                    return
         | 
| 153 | 
            +
                elif args.save_state:
         | 
| 154 | 
            +
                    model_name = f"{args.save_state}.th"
         | 
| 155 | 
            +
                    if args.rank == 0:
         | 
| 156 | 
            +
                        model.to("cpu")
         | 
| 157 | 
            +
                        model.load_state_dict(saved.best_state)
         | 
| 158 | 
            +
                        state = get_state(model, quantizer)
         | 
| 159 | 
            +
                        save_state(state, args.models / model_name)
         | 
| 160 | 
            +
                    return
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                if args.rank == 0:
         | 
| 163 | 
            +
                    done = args.logs / f"{name}.done"
         | 
| 164 | 
            +
                    if done.exists():
         | 
| 165 | 
            +
                        done.unlink()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                augment = [Shift(args.data_stride)]
         | 
| 168 | 
            +
                if args.augment:
         | 
| 169 | 
            +
                    augment += [FlipSign(), FlipChannels(), Scale(),
         | 
| 170 | 
            +
                                Remix(group_size=args.remix_group_size)]
         | 
| 171 | 
            +
                augment = nn.Sequential(*augment).to(device)
         | 
| 172 | 
            +
                print("Agumentation pipeline:", augment)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                if args.mse:
         | 
| 175 | 
            +
                    criterion = nn.MSELoss()
         | 
| 176 | 
            +
                else:
         | 
| 177 | 
            +
                    criterion = nn.L1Loss()
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                # Setting number of samples so that all convolution windows are full.
         | 
| 180 | 
            +
                # Prevents hard to debug mistake with the prediction being shifted compared
         | 
| 181 | 
            +
                # to the input mixture.
         | 
| 182 | 
            +
                samples = model.valid_length(args.samples)
         | 
| 183 | 
            +
                print(f"Number of training samples adjusted to {samples}")
         | 
| 184 | 
            +
                samples = samples + args.data_stride
         | 
| 185 | 
            +
                if args.repitch:
         | 
| 186 | 
            +
                    # We need a bit more audio samples, to account for potential
         | 
| 187 | 
            +
                    # tempo change.
         | 
| 188 | 
            +
                    samples = math.ceil(samples / (1 - 0.01 * args.max_tempo))
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                args.metadata.mkdir(exist_ok=True, parents=True)
         | 
| 191 | 
            +
                if args.raw:
         | 
| 192 | 
            +
                    train_set = Rawset(args.raw / "train",
         | 
| 193 | 
            +
                                       samples=samples,
         | 
| 194 | 
            +
                                       channels=args.audio_channels,
         | 
| 195 | 
            +
                                       streams=range(1, len(model.sources) + 1),
         | 
| 196 | 
            +
                                       stride=args.data_stride)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
         | 
| 199 | 
            +
                elif args.wav:
         | 
| 200 | 
            +
                    train_set, valid_set = get_wav_datasets(args, samples, model.sources)
         | 
| 201 | 
            +
                elif args.is_wav:
         | 
| 202 | 
            +
                    train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources)
         | 
| 203 | 
            +
                else:
         | 
| 204 | 
            +
                    train_set, valid_set = get_compressed_datasets(args, samples)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                if args.repitch:
         | 
| 207 | 
            +
                    train_set = RepitchedWrapper(
         | 
| 208 | 
            +
                        train_set,
         | 
| 209 | 
            +
                        proba=args.repitch,
         | 
| 210 | 
            +
                        max_tempo=args.max_tempo)
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                best_loss = float("inf")
         | 
| 213 | 
            +
                for epoch, metrics in enumerate(saved.metrics):
         | 
| 214 | 
            +
                    print(f"Epoch {epoch:03d}: "
         | 
| 215 | 
            +
                          f"train={metrics['train']:.8f} "
         | 
| 216 | 
            +
                          f"valid={metrics['valid']:.8f} "
         | 
| 217 | 
            +
                          f"best={metrics['best']:.4f} "
         | 
| 218 | 
            +
                          f"ms={metrics.get('true_model_size', 0):.2f}MB "
         | 
| 219 | 
            +
                          f"cms={metrics.get('compressed_model_size', 0):.2f}MB "
         | 
| 220 | 
            +
                          f"duration={human_seconds(metrics['duration'])}")
         | 
| 221 | 
            +
                    best_loss = metrics['best']
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                if args.world_size > 1:
         | 
| 224 | 
            +
                    dmodel = DistributedDataParallel(model,
         | 
| 225 | 
            +
                                                     device_ids=[th.cuda.current_device()],
         | 
| 226 | 
            +
                                                     output_device=th.cuda.current_device())
         | 
| 227 | 
            +
                else:
         | 
| 228 | 
            +
                    dmodel = model
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                for epoch in range(len(saved.metrics), args.epochs):
         | 
| 231 | 
            +
                    begin = time.time()
         | 
| 232 | 
            +
                    model.train()
         | 
| 233 | 
            +
                    train_loss, model_size = train_model(
         | 
| 234 | 
            +
                        epoch, train_set, dmodel, criterion, optimizer, augment,
         | 
| 235 | 
            +
                        quantizer=quantizer,
         | 
| 236 | 
            +
                        batch_size=args.batch_size,
         | 
| 237 | 
            +
                        device=device,
         | 
| 238 | 
            +
                        repeat=args.repeat,
         | 
| 239 | 
            +
                        seed=args.seed,
         | 
| 240 | 
            +
                        diffq=args.diffq,
         | 
| 241 | 
            +
                        workers=args.workers,
         | 
| 242 | 
            +
                        world_size=args.world_size)
         | 
| 243 | 
            +
                    model.eval()
         | 
| 244 | 
            +
                    valid_loss = validate_model(
         | 
| 245 | 
            +
                        epoch, valid_set, model, criterion,
         | 
| 246 | 
            +
                        device=device,
         | 
| 247 | 
            +
                        rank=args.rank,
         | 
| 248 | 
            +
                        split=args.split_valid,
         | 
| 249 | 
            +
                        overlap=args.overlap,
         | 
| 250 | 
            +
                        world_size=args.world_size)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    ms = 0
         | 
| 253 | 
            +
                    cms = 0
         | 
| 254 | 
            +
                    if quantizer and args.rank == 0:
         | 
| 255 | 
            +
                        ms = quantizer.true_model_size()
         | 
| 256 | 
            +
                        cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10))
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    duration = time.time() - begin
         | 
| 259 | 
            +
                    if valid_loss < best_loss and ms <= args.ms_target:
         | 
| 260 | 
            +
                        best_loss = valid_loss
         | 
| 261 | 
            +
                        saved.best_state = {
         | 
| 262 | 
            +
                            key: value.to("cpu").clone()
         | 
| 263 | 
            +
                            for key, value in model.state_dict().items()
         | 
| 264 | 
            +
                        }
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    saved.metrics.append({
         | 
| 267 | 
            +
                        "train": train_loss,
         | 
| 268 | 
            +
                        "valid": valid_loss,
         | 
| 269 | 
            +
                        "best": best_loss,
         | 
| 270 | 
            +
                        "duration": duration,
         | 
| 271 | 
            +
                        "model_size": model_size,
         | 
| 272 | 
            +
                        "true_model_size": ms,
         | 
| 273 | 
            +
                        "compressed_model_size": cms,
         | 
| 274 | 
            +
                    })
         | 
| 275 | 
            +
                    if args.rank == 0:
         | 
| 276 | 
            +
                        json.dump(saved.metrics, open(metrics_path, "w"))
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    saved.last_state = model.state_dict()
         | 
| 279 | 
            +
                    saved.optimizer = optimizer.state_dict()
         | 
| 280 | 
            +
                    if args.rank == 0 and not args.test:
         | 
| 281 | 
            +
                        th.save(saved, checkpoint_tmp)
         | 
| 282 | 
            +
                        checkpoint_tmp.rename(checkpoint)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    print(f"Epoch {epoch:03d}: "
         | 
| 285 | 
            +
                          f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB "
         | 
| 286 | 
            +
                          f"cms={cms:.2f}MB "
         | 
| 287 | 
            +
                          f"duration={human_seconds(duration)}")
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                if args.world_size > 1:
         | 
| 290 | 
            +
                    distributed.barrier()
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                del dmodel
         | 
| 293 | 
            +
                model.load_state_dict(saved.best_state)
         | 
| 294 | 
            +
                if args.eval_cpu:
         | 
| 295 | 
            +
                    device = "cpu"
         | 
| 296 | 
            +
                    model.to(device)
         | 
| 297 | 
            +
                model.eval()
         | 
| 298 | 
            +
                evaluate(model, args.musdb, eval_folder,
         | 
| 299 | 
            +
                         is_wav=args.is_wav,
         | 
| 300 | 
            +
                         rank=args.rank,
         | 
| 301 | 
            +
                         world_size=args.world_size,
         | 
| 302 | 
            +
                         device=device,
         | 
| 303 | 
            +
                         save=args.save,
         | 
| 304 | 
            +
                         split=args.split_valid,
         | 
| 305 | 
            +
                         shifts=args.shifts,
         | 
| 306 | 
            +
                         overlap=args.overlap,
         | 
| 307 | 
            +
                         workers=args.eval_workers)
         | 
| 308 | 
            +
                model.to("cpu")
         | 
| 309 | 
            +
                if args.rank == 0:
         | 
| 310 | 
            +
                    if not (args.test or args.test_pretrained):
         | 
| 311 | 
            +
                        save_model(model, quantizer, args, args.models / model_name)
         | 
| 312 | 
            +
                    print("done")
         | 
| 313 | 
            +
                    done.write_text("done")
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            if __name__ == "__main__":
         | 
| 317 | 
            +
                main()
         | 
    	
        demucs/audio.py
    ADDED
    
    | @@ -0,0 +1,172 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            import subprocess as sp
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import julius
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .utils import temp_filenames
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def _read_info(path):
         | 
| 18 | 
            +
                stdout_data = sp.check_output([
         | 
| 19 | 
            +
                    'ffprobe', "-loglevel", "panic",
         | 
| 20 | 
            +
                    str(path), '-print_format', 'json', '-show_format', '-show_streams'
         | 
| 21 | 
            +
                ])
         | 
| 22 | 
            +
                return json.loads(stdout_data.decode('utf-8'))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class AudioFile:
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                Allows to read audio from any format supported by ffmpeg, as well as resampling or
         | 
| 28 | 
            +
                converting to mono on the fly. See :method:`read` for more details.
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                def __init__(self, path: Path):
         | 
| 31 | 
            +
                    self.path = Path(path)
         | 
| 32 | 
            +
                    self._info = None
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __repr__(self):
         | 
| 35 | 
            +
                    features = [("path", self.path)]
         | 
| 36 | 
            +
                    features.append(("samplerate", self.samplerate()))
         | 
| 37 | 
            +
                    features.append(("channels", self.channels()))
         | 
| 38 | 
            +
                    features.append(("streams", len(self)))
         | 
| 39 | 
            +
                    features_str = ", ".join(f"{name}={value}" for name, value in features)
         | 
| 40 | 
            +
                    return f"AudioFile({features_str})"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @property
         | 
| 43 | 
            +
                def info(self):
         | 
| 44 | 
            +
                    if self._info is None:
         | 
| 45 | 
            +
                        self._info = _read_info(self.path)
         | 
| 46 | 
            +
                    return self._info
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @property
         | 
| 49 | 
            +
                def duration(self):
         | 
| 50 | 
            +
                    return float(self.info['format']['duration'])
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                @property
         | 
| 53 | 
            +
                def _audio_streams(self):
         | 
| 54 | 
            +
                    return [
         | 
| 55 | 
            +
                        index for index, stream in enumerate(self.info["streams"])
         | 
| 56 | 
            +
                        if stream["codec_type"] == "audio"
         | 
| 57 | 
            +
                    ]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def __len__(self):
         | 
| 60 | 
            +
                    return len(self._audio_streams)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def channels(self, stream=0):
         | 
| 63 | 
            +
                    return int(self.info['streams'][self._audio_streams[stream]]['channels'])
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def samplerate(self, stream=0):
         | 
| 66 | 
            +
                    return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def read(self,
         | 
| 69 | 
            +
                         seek_time=None,
         | 
| 70 | 
            +
                         duration=None,
         | 
| 71 | 
            +
                         streams=slice(None),
         | 
| 72 | 
            +
                         samplerate=None,
         | 
| 73 | 
            +
                         channels=None,
         | 
| 74 | 
            +
                         temp_folder=None):
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    Slightly more efficient implementation than stempeg,
         | 
| 77 | 
            +
                    in particular, this will extract all stems at once
         | 
| 78 | 
            +
                    rather than having to loop over one file multiple times
         | 
| 79 | 
            +
                    for each stream.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    Args:
         | 
| 82 | 
            +
                        seek_time (float):  seek time in seconds or None if no seeking is needed.
         | 
| 83 | 
            +
                        duration (float): duration in seconds to extract or None to extract until the end.
         | 
| 84 | 
            +
                        streams (slice, int or list): streams to extract, can be a single int, a list or
         | 
| 85 | 
            +
                            a slice. If it is a slice or list, the output will be of size [S, C, T]
         | 
| 86 | 
            +
                            with S the number of streams, C the number of channels and T the number of samples.
         | 
| 87 | 
            +
                            If it is an int, the output will be [C, T].
         | 
| 88 | 
            +
                        samplerate (int): if provided, will resample on the fly. If None, no resampling will
         | 
| 89 | 
            +
                            be done. Original sampling rate can be obtained with :method:`samplerate`.
         | 
| 90 | 
            +
                        channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
         | 
| 91 | 
            +
                            as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
         | 
| 92 | 
            +
                            See https://sound.stackexchange.com/a/42710.
         | 
| 93 | 
            +
                            Our definition of mono is simply the average of the two channels. Any other
         | 
| 94 | 
            +
                            value will be ignored.
         | 
| 95 | 
            +
                        temp_folder (str or Path or None): temporary folder to use for decoding.
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
                    """
         | 
| 99 | 
            +
                    streams = np.array(range(len(self)))[streams]
         | 
| 100 | 
            +
                    single = not isinstance(streams, np.ndarray)
         | 
| 101 | 
            +
                    if single:
         | 
| 102 | 
            +
                        streams = [streams]
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if duration is None:
         | 
| 105 | 
            +
                        target_size = None
         | 
| 106 | 
            +
                        query_duration = None
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        target_size = int((samplerate or self.samplerate()) * duration)
         | 
| 109 | 
            +
                        query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    with temp_filenames(len(streams)) as filenames:
         | 
| 112 | 
            +
                        command = ['ffmpeg', '-y']
         | 
| 113 | 
            +
                        command += ['-loglevel', 'panic']
         | 
| 114 | 
            +
                        if seek_time:
         | 
| 115 | 
            +
                            command += ['-ss', str(seek_time)]
         | 
| 116 | 
            +
                        command += ['-i', str(self.path)]
         | 
| 117 | 
            +
                        for stream, filename in zip(streams, filenames):
         | 
| 118 | 
            +
                            command += ['-map', f'0:{self._audio_streams[stream]}']
         | 
| 119 | 
            +
                            if query_duration is not None:
         | 
| 120 | 
            +
                                command += ['-t', str(query_duration)]
         | 
| 121 | 
            +
                            command += ['-threads', '1']
         | 
| 122 | 
            +
                            command += ['-f', 'f32le']
         | 
| 123 | 
            +
                            if samplerate is not None:
         | 
| 124 | 
            +
                                command += ['-ar', str(samplerate)]
         | 
| 125 | 
            +
                            command += [filename]
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        sp.run(command, check=True)
         | 
| 128 | 
            +
                        wavs = []
         | 
| 129 | 
            +
                        for filename in filenames:
         | 
| 130 | 
            +
                            wav = np.fromfile(filename, dtype=np.float32)
         | 
| 131 | 
            +
                            wav = torch.from_numpy(wav)
         | 
| 132 | 
            +
                            wav = wav.view(-1, self.channels()).t()
         | 
| 133 | 
            +
                            if channels is not None:
         | 
| 134 | 
            +
                                wav = convert_audio_channels(wav, channels)
         | 
| 135 | 
            +
                            if target_size is not None:
         | 
| 136 | 
            +
                                wav = wav[..., :target_size]
         | 
| 137 | 
            +
                            wavs.append(wav)
         | 
| 138 | 
            +
                    wav = torch.stack(wavs, dim=0)
         | 
| 139 | 
            +
                    if single:
         | 
| 140 | 
            +
                        wav = wav[0]
         | 
| 141 | 
            +
                    return wav
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            def convert_audio_channels(wav, channels=2):
         | 
| 145 | 
            +
                """Convert audio to the given number of channels."""
         | 
| 146 | 
            +
                *shape, src_channels, length = wav.shape
         | 
| 147 | 
            +
                if src_channels == channels:
         | 
| 148 | 
            +
                    pass
         | 
| 149 | 
            +
                elif channels == 1:
         | 
| 150 | 
            +
                    # Case 1:
         | 
| 151 | 
            +
                    # The caller asked 1-channel audio, but the stream have multiple
         | 
| 152 | 
            +
                    # channels, downmix all channels.
         | 
| 153 | 
            +
                    wav = wav.mean(dim=-2, keepdim=True)
         | 
| 154 | 
            +
                elif src_channels == 1:
         | 
| 155 | 
            +
                    # Case 2:
         | 
| 156 | 
            +
                    # The caller asked for multiple channels, but the input file have
         | 
| 157 | 
            +
                    # one single channel, replicate the audio over all channels.
         | 
| 158 | 
            +
                    wav = wav.expand(*shape, channels, length)
         | 
| 159 | 
            +
                elif src_channels >= channels:
         | 
| 160 | 
            +
                    # Case 3:
         | 
| 161 | 
            +
                    # The caller asked for multiple channels, and the input file have
         | 
| 162 | 
            +
                    # more channels than requested. In that case return the first channels.
         | 
| 163 | 
            +
                    wav = wav[..., :channels, :]
         | 
| 164 | 
            +
                else:
         | 
| 165 | 
            +
                    # Case 4: What is a reasonable choice here?
         | 
| 166 | 
            +
                    raise ValueError('The audio file has less channels than requested but is not mono.')
         | 
| 167 | 
            +
                return wav
         | 
| 168 | 
            +
             | 
| 169 | 
            +
             | 
| 170 | 
            +
            def convert_audio(wav, from_samplerate, to_samplerate, channels):
         | 
| 171 | 
            +
                wav = convert_audio_channels(wav, channels)
         | 
| 172 | 
            +
                return julius.resample_frac(wav, from_samplerate, to_samplerate)
         | 
    	
        demucs/augment.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            import torch as th
         | 
| 9 | 
            +
            from torch import nn
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Shift(nn.Module):
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                Randomly shift audio in time by up to `shift` samples.
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                def __init__(self, shift=8192):
         | 
| 17 | 
            +
                    super().__init__()
         | 
| 18 | 
            +
                    self.shift = shift
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def forward(self, wav):
         | 
| 21 | 
            +
                    batch, sources, channels, time = wav.size()
         | 
| 22 | 
            +
                    length = time - self.shift
         | 
| 23 | 
            +
                    if self.shift > 0:
         | 
| 24 | 
            +
                        if not self.training:
         | 
| 25 | 
            +
                            wav = wav[..., :length]
         | 
| 26 | 
            +
                        else:
         | 
| 27 | 
            +
                            offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device)
         | 
| 28 | 
            +
                            offsets = offsets.expand(-1, -1, channels, -1)
         | 
| 29 | 
            +
                            indexes = th.arange(length, device=wav.device)
         | 
| 30 | 
            +
                            wav = wav.gather(3, indexes + offsets)
         | 
| 31 | 
            +
                    return wav
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class FlipChannels(nn.Module):
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Flip left-right channels.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                def forward(self, wav):
         | 
| 39 | 
            +
                    batch, sources, channels, time = wav.size()
         | 
| 40 | 
            +
                    if self.training and wav.size(2) == 2:
         | 
| 41 | 
            +
                        left = th.randint(2, (batch, sources, 1, 1), device=wav.device)
         | 
| 42 | 
            +
                        left = left.expand(-1, -1, -1, time)
         | 
| 43 | 
            +
                        right = 1 - left
         | 
| 44 | 
            +
                        wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2)
         | 
| 45 | 
            +
                    return wav
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            class FlipSign(nn.Module):
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                Random sign flip.
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                def forward(self, wav):
         | 
| 53 | 
            +
                    batch, sources, channels, time = wav.size()
         | 
| 54 | 
            +
                    if self.training:
         | 
| 55 | 
            +
                        signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32)
         | 
| 56 | 
            +
                        wav = wav * (2 * signs - 1)
         | 
| 57 | 
            +
                    return wav
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            class Remix(nn.Module):
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                Shuffle sources to make new mixes.
         | 
| 63 | 
            +
                """
         | 
| 64 | 
            +
                def __init__(self, group_size=4):
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    Shuffle sources within one batch.
         | 
| 67 | 
            +
                    Each batch is divided into groups of size `group_size` and shuffling is done within
         | 
| 68 | 
            +
                    each group separatly. This allow to keep the same probability distribution no matter
         | 
| 69 | 
            +
                    the number of GPUs. Without this grouping, using more GPUs would lead to a higher
         | 
| 70 | 
            +
                    probability of keeping two sources from the same track together which can impact
         | 
| 71 | 
            +
                    performance.
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    super().__init__()
         | 
| 74 | 
            +
                    self.group_size = group_size
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def forward(self, wav):
         | 
| 77 | 
            +
                    batch, streams, channels, time = wav.size()
         | 
| 78 | 
            +
                    device = wav.device
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if self.training:
         | 
| 81 | 
            +
                        group_size = self.group_size or batch
         | 
| 82 | 
            +
                        if batch % group_size != 0:
         | 
| 83 | 
            +
                            raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}")
         | 
| 84 | 
            +
                        groups = batch // group_size
         | 
| 85 | 
            +
                        wav = wav.view(groups, group_size, streams, channels, time)
         | 
| 86 | 
            +
                        permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device),
         | 
| 87 | 
            +
                                                  dim=1)
         | 
| 88 | 
            +
                        wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time))
         | 
| 89 | 
            +
                        wav = wav.view(batch, streams, channels, time)
         | 
| 90 | 
            +
                    return wav
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            class Scale(nn.Module):
         | 
| 94 | 
            +
                def __init__(self, proba=1., min=0.25, max=1.25):
         | 
| 95 | 
            +
                    super().__init__()
         | 
| 96 | 
            +
                    self.proba = proba
         | 
| 97 | 
            +
                    self.min = min
         | 
| 98 | 
            +
                    self.max = max
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def forward(self, wav):
         | 
| 101 | 
            +
                    batch, streams, channels, time = wav.size()
         | 
| 102 | 
            +
                    device = wav.device
         | 
| 103 | 
            +
                    if self.training and random.random() < self.proba:
         | 
| 104 | 
            +
                        scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max)
         | 
| 105 | 
            +
                        wav *= scales
         | 
| 106 | 
            +
                    return wav
         | 
    	
        demucs/compressed.py
    ADDED
    
    | @@ -0,0 +1,115 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
            from fractions import Fraction
         | 
| 9 | 
            +
            from concurrent import futures
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import musdb
         | 
| 12 | 
            +
            from torch import distributed
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .audio import AudioFile
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def get_musdb_tracks(root, *args, **kwargs):
         | 
| 18 | 
            +
                mus = musdb.DB(root, *args, **kwargs)
         | 
| 19 | 
            +
                return {track.name: track.path for track in mus}
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class StemsSet:
         | 
| 23 | 
            +
                def __init__(self, tracks, metadata, duration=None, stride=1,
         | 
| 24 | 
            +
                             samplerate=44100, channels=2, streams=slice(None)):
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    self.metadata = []
         | 
| 27 | 
            +
                    for name, path in tracks.items():
         | 
| 28 | 
            +
                        meta = dict(metadata[name])
         | 
| 29 | 
            +
                        meta["path"] = path
         | 
| 30 | 
            +
                        meta["name"] = name
         | 
| 31 | 
            +
                        self.metadata.append(meta)
         | 
| 32 | 
            +
                        if duration is not None and meta["duration"] < duration:
         | 
| 33 | 
            +
                            raise ValueError(f"Track {name} duration is too small {meta['duration']}")
         | 
| 34 | 
            +
                    self.metadata.sort(key=lambda x: x["name"])
         | 
| 35 | 
            +
                    self.duration = duration
         | 
| 36 | 
            +
                    self.stride = stride
         | 
| 37 | 
            +
                    self.channels = channels
         | 
| 38 | 
            +
                    self.samplerate = samplerate
         | 
| 39 | 
            +
                    self.streams = streams
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __len__(self):
         | 
| 42 | 
            +
                    return sum(self._examples_count(m) for m in self.metadata)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def _examples_count(self, meta):
         | 
| 45 | 
            +
                    if self.duration is None:
         | 
| 46 | 
            +
                        return 1
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        return int((meta["duration"] - self.duration) // self.stride + 1)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def track_metadata(self, index):
         | 
| 51 | 
            +
                    for meta in self.metadata:
         | 
| 52 | 
            +
                        examples = self._examples_count(meta)
         | 
| 53 | 
            +
                        if index >= examples:
         | 
| 54 | 
            +
                            index -= examples
         | 
| 55 | 
            +
                            continue
         | 
| 56 | 
            +
                        return meta
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __getitem__(self, index):
         | 
| 59 | 
            +
                    for meta in self.metadata:
         | 
| 60 | 
            +
                        examples = self._examples_count(meta)
         | 
| 61 | 
            +
                        if index >= examples:
         | 
| 62 | 
            +
                            index -= examples
         | 
| 63 | 
            +
                            continue
         | 
| 64 | 
            +
                        streams = AudioFile(meta["path"]).read(seek_time=index * self.stride,
         | 
| 65 | 
            +
                                                               duration=self.duration,
         | 
| 66 | 
            +
                                                               channels=self.channels,
         | 
| 67 | 
            +
                                                               samplerate=self.samplerate,
         | 
| 68 | 
            +
                                                               streams=self.streams)
         | 
| 69 | 
            +
                        return (streams - meta["mean"]) / meta["std"]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def _get_track_metadata(path):
         | 
| 73 | 
            +
                # use mono at 44kHz as reference. For any other settings data won't be perfectly
         | 
| 74 | 
            +
                # normalized but it should be good enough.
         | 
| 75 | 
            +
                audio = AudioFile(path)
         | 
| 76 | 
            +
                mix = audio.read(streams=0, channels=1, samplerate=44100)
         | 
| 77 | 
            +
                return {"duration": audio.duration, "std": mix.std().item(), "mean": mix.mean().item()}
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def _build_metadata(tracks, workers=10):
         | 
| 81 | 
            +
                pendings = []
         | 
| 82 | 
            +
                with futures.ProcessPoolExecutor(workers) as pool:
         | 
| 83 | 
            +
                    for name, path in tracks.items():
         | 
| 84 | 
            +
                        pendings.append((name, pool.submit(_get_track_metadata, path)))
         | 
| 85 | 
            +
                return {name: p.result() for name, p in pendings}
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def _build_musdb_metadata(path, musdb, workers):
         | 
| 89 | 
            +
                tracks = get_musdb_tracks(musdb)
         | 
| 90 | 
            +
                metadata = _build_metadata(tracks, workers)
         | 
| 91 | 
            +
                path.parent.mkdir(exist_ok=True, parents=True)
         | 
| 92 | 
            +
                json.dump(metadata, open(path, "w"))
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def get_compressed_datasets(args, samples):
         | 
| 96 | 
            +
                metadata_file = args.metadata / "musdb.json"
         | 
| 97 | 
            +
                if not metadata_file.is_file() and args.rank == 0:
         | 
| 98 | 
            +
                    _build_musdb_metadata(metadata_file, args.musdb, args.workers)
         | 
| 99 | 
            +
                if args.world_size > 1:
         | 
| 100 | 
            +
                    distributed.barrier()
         | 
| 101 | 
            +
                metadata = json.load(open(metadata_file))
         | 
| 102 | 
            +
                duration = Fraction(samples, args.samplerate)
         | 
| 103 | 
            +
                stride = Fraction(args.data_stride, args.samplerate)
         | 
| 104 | 
            +
                train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"),
         | 
| 105 | 
            +
                                     metadata,
         | 
| 106 | 
            +
                                     duration=duration,
         | 
| 107 | 
            +
                                     stride=stride,
         | 
| 108 | 
            +
                                     streams=slice(1, None),
         | 
| 109 | 
            +
                                     samplerate=args.samplerate,
         | 
| 110 | 
            +
                                     channels=args.audio_channels)
         | 
| 111 | 
            +
                valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"),
         | 
| 112 | 
            +
                                     metadata,
         | 
| 113 | 
            +
                                     samplerate=args.samplerate,
         | 
| 114 | 
            +
                                     channels=args.audio_channels)
         | 
| 115 | 
            +
                return train_set, valid_set
         | 
    	
        demucs/model.py
    ADDED
    
    | @@ -0,0 +1,202 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import math
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import julius
         | 
| 10 | 
            +
            from torch import nn
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .utils import capture_init, center_trim
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class BLSTM(nn.Module):
         | 
| 16 | 
            +
                def __init__(self, dim, layers=1):
         | 
| 17 | 
            +
                    super().__init__()
         | 
| 18 | 
            +
                    self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
         | 
| 19 | 
            +
                    self.linear = nn.Linear(2 * dim, dim)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def forward(self, x):
         | 
| 22 | 
            +
                    x = x.permute(2, 0, 1)
         | 
| 23 | 
            +
                    x = self.lstm(x)[0]
         | 
| 24 | 
            +
                    x = self.linear(x)
         | 
| 25 | 
            +
                    x = x.permute(1, 2, 0)
         | 
| 26 | 
            +
                    return x
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def rescale_conv(conv, reference):
         | 
| 30 | 
            +
                std = conv.weight.std().detach()
         | 
| 31 | 
            +
                scale = (std / reference)**0.5
         | 
| 32 | 
            +
                conv.weight.data /= scale
         | 
| 33 | 
            +
                if conv.bias is not None:
         | 
| 34 | 
            +
                    conv.bias.data /= scale
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def rescale_module(module, reference):
         | 
| 38 | 
            +
                for sub in module.modules():
         | 
| 39 | 
            +
                    if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
         | 
| 40 | 
            +
                        rescale_conv(sub, reference)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            class Demucs(nn.Module):
         | 
| 44 | 
            +
                @capture_init
         | 
| 45 | 
            +
                def __init__(self,
         | 
| 46 | 
            +
                             sources,
         | 
| 47 | 
            +
                             audio_channels=2,
         | 
| 48 | 
            +
                             channels=64,
         | 
| 49 | 
            +
                             depth=6,
         | 
| 50 | 
            +
                             rewrite=True,
         | 
| 51 | 
            +
                             glu=True,
         | 
| 52 | 
            +
                             rescale=0.1,
         | 
| 53 | 
            +
                             resample=True,
         | 
| 54 | 
            +
                             kernel_size=8,
         | 
| 55 | 
            +
                             stride=4,
         | 
| 56 | 
            +
                             growth=2.,
         | 
| 57 | 
            +
                             lstm_layers=2,
         | 
| 58 | 
            +
                             context=3,
         | 
| 59 | 
            +
                             normalize=False,
         | 
| 60 | 
            +
                             samplerate=44100,
         | 
| 61 | 
            +
                             segment_length=4 * 10 * 44100):
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    Args:
         | 
| 64 | 
            +
                        sources (list[str]): list of source names
         | 
| 65 | 
            +
                        audio_channels (int): stereo or mono
         | 
| 66 | 
            +
                        channels (int): first convolution channels
         | 
| 67 | 
            +
                        depth (int): number of encoder/decoder layers
         | 
| 68 | 
            +
                        rewrite (bool): add 1x1 convolution to each encoder layer
         | 
| 69 | 
            +
                            and a convolution to each decoder layer.
         | 
| 70 | 
            +
                            For the decoder layer, `context` gives the kernel size.
         | 
| 71 | 
            +
                        glu (bool): use glu instead of ReLU
         | 
| 72 | 
            +
                        resample_input (bool): upsample x2 the input and downsample /2 the output.
         | 
| 73 | 
            +
                        rescale (int): rescale initial weights of convolutions
         | 
| 74 | 
            +
                            to get their standard deviation closer to `rescale`
         | 
| 75 | 
            +
                        kernel_size (int): kernel size for convolutions
         | 
| 76 | 
            +
                        stride (int): stride for convolutions
         | 
| 77 | 
            +
                        growth (float): multiply (resp divide) number of channels by that
         | 
| 78 | 
            +
                            for each layer of the encoder (resp decoder)
         | 
| 79 | 
            +
                        lstm_layers (int): number of lstm layers, 0 = no lstm
         | 
| 80 | 
            +
                        context (int): kernel size of the convolution in the
         | 
| 81 | 
            +
                            decoder before the transposed convolution. If > 1,
         | 
| 82 | 
            +
                            will provide some context from neighboring time
         | 
| 83 | 
            +
                            steps.
         | 
| 84 | 
            +
                        samplerate (int): stored as meta information for easing
         | 
| 85 | 
            +
                            future evaluations of the model.
         | 
| 86 | 
            +
                        segment_length (int): stored as meta information for easing
         | 
| 87 | 
            +
                            future evaluations of the model. Length of the segments on which
         | 
| 88 | 
            +
                            the model was trained.
         | 
| 89 | 
            +
                    """
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    super().__init__()
         | 
| 92 | 
            +
                    self.audio_channels = audio_channels
         | 
| 93 | 
            +
                    self.sources = sources
         | 
| 94 | 
            +
                    self.kernel_size = kernel_size
         | 
| 95 | 
            +
                    self.context = context
         | 
| 96 | 
            +
                    self.stride = stride
         | 
| 97 | 
            +
                    self.depth = depth
         | 
| 98 | 
            +
                    self.resample = resample
         | 
| 99 | 
            +
                    self.channels = channels
         | 
| 100 | 
            +
                    self.normalize = normalize
         | 
| 101 | 
            +
                    self.samplerate = samplerate
         | 
| 102 | 
            +
                    self.segment_length = segment_length
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.encoder = nn.ModuleList()
         | 
| 105 | 
            +
                    self.decoder = nn.ModuleList()
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    if glu:
         | 
| 108 | 
            +
                        activation = nn.GLU(dim=1)
         | 
| 109 | 
            +
                        ch_scale = 2
         | 
| 110 | 
            +
                    else:
         | 
| 111 | 
            +
                        activation = nn.ReLU()
         | 
| 112 | 
            +
                        ch_scale = 1
         | 
| 113 | 
            +
                    in_channels = audio_channels
         | 
| 114 | 
            +
                    for index in range(depth):
         | 
| 115 | 
            +
                        encode = []
         | 
| 116 | 
            +
                        encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), nn.ReLU()]
         | 
| 117 | 
            +
                        if rewrite:
         | 
| 118 | 
            +
                            encode += [nn.Conv1d(channels, ch_scale * channels, 1), activation]
         | 
| 119 | 
            +
                        self.encoder.append(nn.Sequential(*encode))
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        decode = []
         | 
| 122 | 
            +
                        if index > 0:
         | 
| 123 | 
            +
                            out_channels = in_channels
         | 
| 124 | 
            +
                        else:
         | 
| 125 | 
            +
                            out_channels = len(self.sources) * audio_channels
         | 
| 126 | 
            +
                        if rewrite:
         | 
| 127 | 
            +
                            decode += [nn.Conv1d(channels, ch_scale * channels, context), activation]
         | 
| 128 | 
            +
                        decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride)]
         | 
| 129 | 
            +
                        if index > 0:
         | 
| 130 | 
            +
                            decode.append(nn.ReLU())
         | 
| 131 | 
            +
                        self.decoder.insert(0, nn.Sequential(*decode))
         | 
| 132 | 
            +
                        in_channels = channels
         | 
| 133 | 
            +
                        channels = int(growth * channels)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    channels = in_channels
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if lstm_layers:
         | 
| 138 | 
            +
                        self.lstm = BLSTM(channels, lstm_layers)
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        self.lstm = None
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if rescale:
         | 
| 143 | 
            +
                        rescale_module(self, reference=rescale)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def valid_length(self, length):
         | 
| 146 | 
            +
                    """
         | 
| 147 | 
            +
                    Return the nearest valid length to use with the model so that
         | 
| 148 | 
            +
                    there is no time steps left over in a convolutions, e.g. for all
         | 
| 149 | 
            +
                    layers, size of the input - kernel_size % stride = 0.
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    If the mixture has a valid length, the estimated sources
         | 
| 152 | 
            +
                    will have exactly the same length when context = 1. If context > 1,
         | 
| 153 | 
            +
                    the two signals can be center trimmed to match.
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    For training, extracts should have a valid length.For evaluation
         | 
| 156 | 
            +
                    on full tracks we recommend passing `pad = True` to :method:`forward`.
         | 
| 157 | 
            +
                    """
         | 
| 158 | 
            +
                    if self.resample:
         | 
| 159 | 
            +
                        length *= 2
         | 
| 160 | 
            +
                    for _ in range(self.depth):
         | 
| 161 | 
            +
                        length = math.ceil((length - self.kernel_size) / self.stride) + 1
         | 
| 162 | 
            +
                        length = max(1, length)
         | 
| 163 | 
            +
                        length += self.context - 1
         | 
| 164 | 
            +
                    for _ in range(self.depth):
         | 
| 165 | 
            +
                        length = (length - 1) * self.stride + self.kernel_size
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    if self.resample:
         | 
| 168 | 
            +
                        length = math.ceil(length / 2)
         | 
| 169 | 
            +
                    return int(length)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def forward(self, mix):
         | 
| 172 | 
            +
                    x = mix
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    if self.normalize:
         | 
| 175 | 
            +
                        mono = mix.mean(dim=1, keepdim=True)
         | 
| 176 | 
            +
                        mean = mono.mean(dim=-1, keepdim=True)
         | 
| 177 | 
            +
                        std = mono.std(dim=-1, keepdim=True)
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        mean = 0
         | 
| 180 | 
            +
                        std = 1
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    x = (x - mean) / (1e-5 + std)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if self.resample:
         | 
| 185 | 
            +
                        x = julius.resample_frac(x, 1, 2)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    saved = []
         | 
| 188 | 
            +
                    for encode in self.encoder:
         | 
| 189 | 
            +
                        x = encode(x)
         | 
| 190 | 
            +
                        saved.append(x)
         | 
| 191 | 
            +
                    if self.lstm:
         | 
| 192 | 
            +
                        x = self.lstm(x)
         | 
| 193 | 
            +
                    for decode in self.decoder:
         | 
| 194 | 
            +
                        skip = center_trim(saved.pop(-1), x)
         | 
| 195 | 
            +
                        x = x + skip
         | 
| 196 | 
            +
                        x = decode(x)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    if self.resample:
         | 
| 199 | 
            +
                        x = julius.resample_frac(x, 2, 1)
         | 
| 200 | 
            +
                    x = x * std + mean
         | 
| 201 | 
            +
                    x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
         | 
| 202 | 
            +
                    return x
         | 
    	
        demucs/parser.py
    ADDED
    
    | @@ -0,0 +1,244 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            from pathlib import Path
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_parser():
         | 
| 13 | 
            +
                parser = argparse.ArgumentParser("demucs", description="Train and evaluate Demucs.")
         | 
| 14 | 
            +
                default_raw = None
         | 
| 15 | 
            +
                default_musdb = None
         | 
| 16 | 
            +
                if 'DEMUCS_RAW' in os.environ:
         | 
| 17 | 
            +
                    default_raw = Path(os.environ['DEMUCS_RAW'])
         | 
| 18 | 
            +
                if 'DEMUCS_MUSDB' in os.environ:
         | 
| 19 | 
            +
                    default_musdb = Path(os.environ['DEMUCS_MUSDB'])
         | 
| 20 | 
            +
                parser.add_argument(
         | 
| 21 | 
            +
                    "--raw",
         | 
| 22 | 
            +
                    type=Path,
         | 
| 23 | 
            +
                    default=default_raw,
         | 
| 24 | 
            +
                    help="Path to raw audio, can be faster, see python3 -m demucs.raw to extract.")
         | 
| 25 | 
            +
                parser.add_argument("--no_raw", action="store_const", const=None, dest="raw")
         | 
| 26 | 
            +
                parser.add_argument("-m",
         | 
| 27 | 
            +
                                    "--musdb",
         | 
| 28 | 
            +
                                    type=Path,
         | 
| 29 | 
            +
                                    default=default_musdb,
         | 
| 30 | 
            +
                                    help="Path to musdb root")
         | 
| 31 | 
            +
                parser.add_argument("--is_wav", action="store_true",
         | 
| 32 | 
            +
                                    help="Indicate that the MusDB dataset is in wav format (i.e. MusDB-HQ).")
         | 
| 33 | 
            +
                parser.add_argument("--metadata", type=Path, default=Path("metadata/"),
         | 
| 34 | 
            +
                                    help="Folder where metadata information is stored.")
         | 
| 35 | 
            +
                parser.add_argument("--wav", type=Path,
         | 
| 36 | 
            +
                                    help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
         | 
| 37 | 
            +
                                         "subfolder.")
         | 
| 38 | 
            +
                parser.add_argument("--samplerate", type=int, default=44100)
         | 
| 39 | 
            +
                parser.add_argument("--audio_channels", type=int, default=2)
         | 
| 40 | 
            +
                parser.add_argument("--samples",
         | 
| 41 | 
            +
                                    default=44100 * 10,
         | 
| 42 | 
            +
                                    type=int,
         | 
| 43 | 
            +
                                    help="number of samples to feed in")
         | 
| 44 | 
            +
                parser.add_argument("--data_stride",
         | 
| 45 | 
            +
                                    default=44100,
         | 
| 46 | 
            +
                                    type=int,
         | 
| 47 | 
            +
                                    help="Stride for chunks, shorter = longer epochs")
         | 
| 48 | 
            +
                parser.add_argument("-w", "--workers", default=10, type=int, help="Loader workers")
         | 
| 49 | 
            +
                parser.add_argument("--eval_workers", default=2, type=int, help="Final evaluation workers")
         | 
| 50 | 
            +
                parser.add_argument("-d",
         | 
| 51 | 
            +
                                    "--device",
         | 
| 52 | 
            +
                                    help="Device to train on, default is cuda if available else cpu")
         | 
| 53 | 
            +
                parser.add_argument("--eval_cpu", action="store_true", help="Eval on test will be run on cpu.")
         | 
| 54 | 
            +
                parser.add_argument("--dummy", help="Dummy parameter, useful to create a new checkpoint file")
         | 
| 55 | 
            +
                parser.add_argument("--test", help="Just run the test pipeline + one validation. "
         | 
| 56 | 
            +
                                                   "This should be a filename relative to the models/ folder.")
         | 
| 57 | 
            +
                parser.add_argument("--test_pretrained", help="Just run the test pipeline + one validation, "
         | 
| 58 | 
            +
                                                              "on a pretrained model. ")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                parser.add_argument("--rank", default=0, type=int)
         | 
| 61 | 
            +
                parser.add_argument("--world_size", default=1, type=int)
         | 
| 62 | 
            +
                parser.add_argument("--master")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                parser.add_argument("--checkpoints",
         | 
| 65 | 
            +
                                    type=Path,
         | 
| 66 | 
            +
                                    default=Path("checkpoints"),
         | 
| 67 | 
            +
                                    help="Folder where to store checkpoints etc")
         | 
| 68 | 
            +
                parser.add_argument("--evals",
         | 
| 69 | 
            +
                                    type=Path,
         | 
| 70 | 
            +
                                    default=Path("evals"),
         | 
| 71 | 
            +
                                    help="Folder where to store evals and waveforms")
         | 
| 72 | 
            +
                parser.add_argument("--save",
         | 
| 73 | 
            +
                                    action="store_true",
         | 
| 74 | 
            +
                                    help="Save estimated for the test set waveforms")
         | 
| 75 | 
            +
                parser.add_argument("--logs",
         | 
| 76 | 
            +
                                    type=Path,
         | 
| 77 | 
            +
                                    default=Path("logs"),
         | 
| 78 | 
            +
                                    help="Folder where to store logs")
         | 
| 79 | 
            +
                parser.add_argument("--models",
         | 
| 80 | 
            +
                                    type=Path,
         | 
| 81 | 
            +
                                    default=Path("models"),
         | 
| 82 | 
            +
                                    help="Folder where to store trained models")
         | 
| 83 | 
            +
                parser.add_argument("-R",
         | 
| 84 | 
            +
                                    "--restart",
         | 
| 85 | 
            +
                                    action='store_true',
         | 
| 86 | 
            +
                                    help='Restart training, ignoring previous run')
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                parser.add_argument("--seed", type=int, default=42)
         | 
| 89 | 
            +
                parser.add_argument("-e", "--epochs", type=int, default=180, help="Number of epochs")
         | 
| 90 | 
            +
                parser.add_argument("-r",
         | 
| 91 | 
            +
                                    "--repeat",
         | 
| 92 | 
            +
                                    type=int,
         | 
| 93 | 
            +
                                    default=2,
         | 
| 94 | 
            +
                                    help="Repeat the train set, longer epochs")
         | 
| 95 | 
            +
                parser.add_argument("-b", "--batch_size", type=int, default=64)
         | 
| 96 | 
            +
                parser.add_argument("--lr", type=float, default=3e-4)
         | 
| 97 | 
            +
                parser.add_argument("--mse", action="store_true", help="Use MSE instead of L1")
         | 
| 98 | 
            +
                parser.add_argument("--init", help="Initialize from a pre-trained model.")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                # Augmentation options
         | 
| 101 | 
            +
                parser.add_argument("--no_augment",
         | 
| 102 | 
            +
                                    action="store_false",
         | 
| 103 | 
            +
                                    dest="augment",
         | 
| 104 | 
            +
                                    default=True,
         | 
| 105 | 
            +
                                    help="No basic data augmentation.")
         | 
| 106 | 
            +
                parser.add_argument("--repitch", type=float, default=0.2,
         | 
| 107 | 
            +
                                    help="Probability to do tempo/pitch change")
         | 
| 108 | 
            +
                parser.add_argument("--max_tempo", type=float, default=12,
         | 
| 109 | 
            +
                                    help="Maximum relative tempo change in %% when using repitch.")
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                parser.add_argument("--remix_group_size",
         | 
| 112 | 
            +
                                    type=int,
         | 
| 113 | 
            +
                                    default=4,
         | 
| 114 | 
            +
                                    help="Shuffle sources using group of this size. Useful to somewhat "
         | 
| 115 | 
            +
                                    "replicate multi-gpu training "
         | 
| 116 | 
            +
                                    "on less GPUs.")
         | 
| 117 | 
            +
                parser.add_argument("--shifts",
         | 
| 118 | 
            +
                                    type=int,
         | 
| 119 | 
            +
                                    default=10,
         | 
| 120 | 
            +
                                    help="Number of random shifts used for the shift trick.")
         | 
| 121 | 
            +
                parser.add_argument("--overlap",
         | 
| 122 | 
            +
                                    type=float,
         | 
| 123 | 
            +
                                    default=0.25,
         | 
| 124 | 
            +
                                    help="Overlap when --split_valid is passed.")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                # See model.py for doc
         | 
| 127 | 
            +
                parser.add_argument("--growth",
         | 
| 128 | 
            +
                                    type=float,
         | 
| 129 | 
            +
                                    default=2.,
         | 
| 130 | 
            +
                                    help="Number of channels between two layers will increase by this factor")
         | 
| 131 | 
            +
                parser.add_argument("--depth",
         | 
| 132 | 
            +
                                    type=int,
         | 
| 133 | 
            +
                                    default=6,
         | 
| 134 | 
            +
                                    help="Number of layers for the encoder and decoder")
         | 
| 135 | 
            +
                parser.add_argument("--lstm_layers", type=int, default=2, help="Number of layers for the LSTM")
         | 
| 136 | 
            +
                parser.add_argument("--channels",
         | 
| 137 | 
            +
                                    type=int,
         | 
| 138 | 
            +
                                    default=64,
         | 
| 139 | 
            +
                                    help="Number of channels for the first encoder layer")
         | 
| 140 | 
            +
                parser.add_argument("--kernel_size",
         | 
| 141 | 
            +
                                    type=int,
         | 
| 142 | 
            +
                                    default=8,
         | 
| 143 | 
            +
                                    help="Kernel size for the (transposed) convolutions")
         | 
| 144 | 
            +
                parser.add_argument("--conv_stride",
         | 
| 145 | 
            +
                                    type=int,
         | 
| 146 | 
            +
                                    default=4,
         | 
| 147 | 
            +
                                    help="Stride for the (transposed) convolutions")
         | 
| 148 | 
            +
                parser.add_argument("--context",
         | 
| 149 | 
            +
                                    type=int,
         | 
| 150 | 
            +
                                    default=3,
         | 
| 151 | 
            +
                                    help="Context size for the decoder convolutions "
         | 
| 152 | 
            +
                                    "before the transposed convolutions")
         | 
| 153 | 
            +
                parser.add_argument("--rescale",
         | 
| 154 | 
            +
                                    type=float,
         | 
| 155 | 
            +
                                    default=0.1,
         | 
| 156 | 
            +
                                    help="Initial weight rescale reference")
         | 
| 157 | 
            +
                parser.add_argument("--no_resample", action="store_false",
         | 
| 158 | 
            +
                                    default=True, dest="resample",
         | 
| 159 | 
            +
                                    help="No Resampling of the input/output x2")
         | 
| 160 | 
            +
                parser.add_argument("--no_glu",
         | 
| 161 | 
            +
                                    action="store_false",
         | 
| 162 | 
            +
                                    default=True,
         | 
| 163 | 
            +
                                    dest="glu",
         | 
| 164 | 
            +
                                    help="Replace all GLUs by ReLUs")
         | 
| 165 | 
            +
                parser.add_argument("--no_rewrite",
         | 
| 166 | 
            +
                                    action="store_false",
         | 
| 167 | 
            +
                                    default=True,
         | 
| 168 | 
            +
                                    dest="rewrite",
         | 
| 169 | 
            +
                                    help="No 1x1 rewrite convolutions")
         | 
| 170 | 
            +
                parser.add_argument("--normalize", action="store_true")
         | 
| 171 | 
            +
                parser.add_argument("--no_norm_wav", action="store_false", dest='norm_wav', default=True)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                # Tasnet options
         | 
| 174 | 
            +
                parser.add_argument("--tasnet", action="store_true")
         | 
| 175 | 
            +
                parser.add_argument("--split_valid",
         | 
| 176 | 
            +
                                    action="store_true",
         | 
| 177 | 
            +
                                    help="Predict chunks by chunks for valid and test. Required for tasnet")
         | 
| 178 | 
            +
                parser.add_argument("--X", type=int, default=8)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                # Other options
         | 
| 181 | 
            +
                parser.add_argument("--show",
         | 
| 182 | 
            +
                                    action="store_true",
         | 
| 183 | 
            +
                                    help="Show model architecture, size and exit")
         | 
| 184 | 
            +
                parser.add_argument("--save_model", action="store_true",
         | 
| 185 | 
            +
                                    help="Skip traning, just save final model "
         | 
| 186 | 
            +
                                         "for the current checkpoint value.")
         | 
| 187 | 
            +
                parser.add_argument("--save_state",
         | 
| 188 | 
            +
                                    help="Skip training, just save state "
         | 
| 189 | 
            +
                                         "for the current checkpoint value. You should "
         | 
| 190 | 
            +
                                         "provide a model name as argument.")
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                # Quantization options
         | 
| 193 | 
            +
                parser.add_argument("--q-min-size", type=float, default=1,
         | 
| 194 | 
            +
                                    help="Only quantize layers over this size (in MB)")
         | 
| 195 | 
            +
                parser.add_argument(
         | 
| 196 | 
            +
                    "--qat", type=int, help="If provided, use QAT training with that many bits.")
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                parser.add_argument("--diffq", type=float, default=0)
         | 
| 199 | 
            +
                parser.add_argument(
         | 
| 200 | 
            +
                    "--ms-target", type=float, default=162,
         | 
| 201 | 
            +
                    help="Model size target in MB, when using DiffQ. Best model will be kept "
         | 
| 202 | 
            +
                         "only if it is smaller than this target.")
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                return parser
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def get_name(parser, args):
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                Return the name of an experiment given the args. Some parameters are ignored,
         | 
| 210 | 
            +
                for instance --workers, as they do not impact the final result.
         | 
| 211 | 
            +
                """
         | 
| 212 | 
            +
                ignore_args = set([
         | 
| 213 | 
            +
                    "checkpoints",
         | 
| 214 | 
            +
                    "deterministic",
         | 
| 215 | 
            +
                    "eval",
         | 
| 216 | 
            +
                    "evals",
         | 
| 217 | 
            +
                    "eval_cpu",
         | 
| 218 | 
            +
                    "eval_workers",
         | 
| 219 | 
            +
                    "logs",
         | 
| 220 | 
            +
                    "master",
         | 
| 221 | 
            +
                    "rank",
         | 
| 222 | 
            +
                    "restart",
         | 
| 223 | 
            +
                    "save",
         | 
| 224 | 
            +
                    "save_model",
         | 
| 225 | 
            +
                    "save_state",
         | 
| 226 | 
            +
                    "show",
         | 
| 227 | 
            +
                    "workers",
         | 
| 228 | 
            +
                    "world_size",
         | 
| 229 | 
            +
                ])
         | 
| 230 | 
            +
                parts = []
         | 
| 231 | 
            +
                name_args = dict(args.__dict__)
         | 
| 232 | 
            +
                for name, value in name_args.items():
         | 
| 233 | 
            +
                    if name in ignore_args:
         | 
| 234 | 
            +
                        continue
         | 
| 235 | 
            +
                    if value != parser.get_default(name):
         | 
| 236 | 
            +
                        if isinstance(value, Path):
         | 
| 237 | 
            +
                            parts.append(f"{name}={value.name}")
         | 
| 238 | 
            +
                        else:
         | 
| 239 | 
            +
                            parts.append(f"{name}={value}")
         | 
| 240 | 
            +
                if parts:
         | 
| 241 | 
            +
                    name = " ".join(parts)
         | 
| 242 | 
            +
                else:
         | 
| 243 | 
            +
                    name = "default"
         | 
| 244 | 
            +
                return name
         | 
    	
        demucs/pretrained.py
    ADDED
    
    | @@ -0,0 +1,107 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            # author: adefossez
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import logging
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from diffq import DiffQuantizer
         | 
| 11 | 
            +
            import torch.hub
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .model import Demucs
         | 
| 14 | 
            +
            from .tasnet import ConvTasNet
         | 
| 15 | 
            +
            from .utils import set_state
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 18 | 
            +
            ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            PRETRAINED_MODELS = {
         | 
| 21 | 
            +
                'demucs': 'e07c671f',
         | 
| 22 | 
            +
                'demucs48_hq': '28a1282c',
         | 
| 23 | 
            +
                'demucs_extra': '3646af93',
         | 
| 24 | 
            +
                'demucs_quantized': '07afea75',
         | 
| 25 | 
            +
                'tasnet': 'beb46fac',
         | 
| 26 | 
            +
                'tasnet_extra': 'df3777b2',
         | 
| 27 | 
            +
                'demucs_unittest': '09ebc15f',
         | 
| 28 | 
            +
            }
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            SOURCES = ["drums", "bass", "other", "vocals"]
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def get_url(name):
         | 
| 34 | 
            +
                sig = PRETRAINED_MODELS[name]
         | 
| 35 | 
            +
                return ROOT + name + "-" + sig[:8] + ".th"
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def is_pretrained(name):
         | 
| 39 | 
            +
                return name in PRETRAINED_MODELS
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def load_pretrained(name):
         | 
| 43 | 
            +
                if name == "demucs":
         | 
| 44 | 
            +
                    return demucs(pretrained=True)
         | 
| 45 | 
            +
                elif name == "demucs48_hq":
         | 
| 46 | 
            +
                    return demucs(pretrained=True, hq=True, channels=48)
         | 
| 47 | 
            +
                elif name == "demucs_extra":
         | 
| 48 | 
            +
                    return demucs(pretrained=True, extra=True)
         | 
| 49 | 
            +
                elif name == "demucs_quantized":
         | 
| 50 | 
            +
                    return demucs(pretrained=True, quantized=True)
         | 
| 51 | 
            +
                elif name == "demucs_unittest":
         | 
| 52 | 
            +
                    return demucs_unittest(pretrained=True)
         | 
| 53 | 
            +
                elif name == "tasnet":
         | 
| 54 | 
            +
                    return tasnet(pretrained=True)
         | 
| 55 | 
            +
                elif name == "tasnet_extra":
         | 
| 56 | 
            +
                    return tasnet(pretrained=True, extra=True)
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    raise ValueError(f"Invalid pretrained name {name}")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def _load_state(name, model, quantizer=None):
         | 
| 62 | 
            +
                url = get_url(name)
         | 
| 63 | 
            +
                state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
         | 
| 64 | 
            +
                set_state(model, quantizer, state)
         | 
| 65 | 
            +
                if quantizer:
         | 
| 66 | 
            +
                    quantizer.detach()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def demucs_unittest(pretrained=True):
         | 
| 70 | 
            +
                model = Demucs(channels=4, sources=SOURCES)
         | 
| 71 | 
            +
                if pretrained:
         | 
| 72 | 
            +
                    _load_state('demucs_unittest', model)
         | 
| 73 | 
            +
                return model
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
         | 
| 77 | 
            +
                if not pretrained and (extra or quantized or hq):
         | 
| 78 | 
            +
                    raise ValueError("if extra or quantized is True, pretrained must be True.")
         | 
| 79 | 
            +
                model = Demucs(sources=SOURCES, channels=channels)
         | 
| 80 | 
            +
                if pretrained:
         | 
| 81 | 
            +
                    name = 'demucs'
         | 
| 82 | 
            +
                    if channels != 64:
         | 
| 83 | 
            +
                        name += str(channels)
         | 
| 84 | 
            +
                    quantizer = None
         | 
| 85 | 
            +
                    if sum([extra, quantized, hq]) > 1:
         | 
| 86 | 
            +
                        raise ValueError("Only one of extra, quantized, hq, can be True.")
         | 
| 87 | 
            +
                    if quantized:
         | 
| 88 | 
            +
                        quantizer = DiffQuantizer(model, group_size=8, min_size=1)
         | 
| 89 | 
            +
                        name += '_quantized'
         | 
| 90 | 
            +
                    if extra:
         | 
| 91 | 
            +
                        name += '_extra'
         | 
| 92 | 
            +
                    if hq:
         | 
| 93 | 
            +
                        name += '_hq'
         | 
| 94 | 
            +
                    _load_state(name, model, quantizer)
         | 
| 95 | 
            +
                return model
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            def tasnet(pretrained=True, extra=False):
         | 
| 99 | 
            +
                if not pretrained and extra:
         | 
| 100 | 
            +
                    raise ValueError("if extra is True, pretrained must be True.")
         | 
| 101 | 
            +
                model = ConvTasNet(X=10, sources=SOURCES)
         | 
| 102 | 
            +
                if pretrained:
         | 
| 103 | 
            +
                    name = 'tasnet'
         | 
| 104 | 
            +
                    if extra:
         | 
| 105 | 
            +
                        name = 'tasnet_extra'
         | 
| 106 | 
            +
                    _load_state(name, model)
         | 
| 107 | 
            +
                return model
         | 
    	
        demucs/raw.py
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            from collections import defaultdict, namedtuple
         | 
| 10 | 
            +
            from pathlib import Path
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import musdb
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch as th
         | 
| 15 | 
            +
            import tqdm
         | 
| 16 | 
            +
            from torch.utils.data import DataLoader
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .audio import AudioFile
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ChunkInfo = namedtuple("ChunkInfo", ["file_index", "offset", "local_index"])
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class Rawset:
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                Dataset of raw, normalized, float32 audio files
         | 
| 26 | 
            +
                """
         | 
| 27 | 
            +
                def __init__(self, path, samples=None, stride=None, channels=2, streams=None):
         | 
| 28 | 
            +
                    self.path = Path(path)
         | 
| 29 | 
            +
                    self.channels = channels
         | 
| 30 | 
            +
                    self.samples = samples
         | 
| 31 | 
            +
                    if stride is None:
         | 
| 32 | 
            +
                        stride = samples if samples is not None else 0
         | 
| 33 | 
            +
                    self.stride = stride
         | 
| 34 | 
            +
                    entries = defaultdict(list)
         | 
| 35 | 
            +
                    for root, folders, files in os.walk(self.path, followlinks=True):
         | 
| 36 | 
            +
                        folders.sort()
         | 
| 37 | 
            +
                        files.sort()
         | 
| 38 | 
            +
                        for file in files:
         | 
| 39 | 
            +
                            if file.endswith(".raw"):
         | 
| 40 | 
            +
                                path = Path(root) / file
         | 
| 41 | 
            +
                                name, stream = path.stem.rsplit('.', 1)
         | 
| 42 | 
            +
                                entries[(path.parent.relative_to(self.path), name)].append(int(stream))
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self._entries = list(entries.keys())
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    sizes = []
         | 
| 47 | 
            +
                    self._lengths = []
         | 
| 48 | 
            +
                    ref_streams = sorted(entries[self._entries[0]])
         | 
| 49 | 
            +
                    assert ref_streams == list(range(len(ref_streams)))
         | 
| 50 | 
            +
                    if streams is None:
         | 
| 51 | 
            +
                        self.streams = ref_streams
         | 
| 52 | 
            +
                    else:
         | 
| 53 | 
            +
                        self.streams = streams
         | 
| 54 | 
            +
                    for entry in sorted(entries.keys()):
         | 
| 55 | 
            +
                        streams = entries[entry]
         | 
| 56 | 
            +
                        assert sorted(streams) == ref_streams
         | 
| 57 | 
            +
                        file = self._path(*entry)
         | 
| 58 | 
            +
                        length = file.stat().st_size // (4 * channels)
         | 
| 59 | 
            +
                        if samples is None:
         | 
| 60 | 
            +
                            sizes.append(1)
         | 
| 61 | 
            +
                        else:
         | 
| 62 | 
            +
                            if length < samples:
         | 
| 63 | 
            +
                                self._entries.remove(entry)
         | 
| 64 | 
            +
                                continue
         | 
| 65 | 
            +
                            sizes.append((length - samples) // stride + 1)
         | 
| 66 | 
            +
                        self._lengths.append(length)
         | 
| 67 | 
            +
                    if not sizes:
         | 
| 68 | 
            +
                        raise ValueError(f"Empty dataset {self.path}")
         | 
| 69 | 
            +
                    self._cumulative_sizes = np.cumsum(sizes)
         | 
| 70 | 
            +
                    self._sizes = sizes
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __len__(self):
         | 
| 73 | 
            +
                    return self._cumulative_sizes[-1]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                @property
         | 
| 76 | 
            +
                def total_length(self):
         | 
| 77 | 
            +
                    return sum(self._lengths)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def chunk_info(self, index):
         | 
| 80 | 
            +
                    file_index = np.searchsorted(self._cumulative_sizes, index, side='right')
         | 
| 81 | 
            +
                    if file_index == 0:
         | 
| 82 | 
            +
                        local_index = index
         | 
| 83 | 
            +
                    else:
         | 
| 84 | 
            +
                        local_index = index - self._cumulative_sizes[file_index - 1]
         | 
| 85 | 
            +
                    return ChunkInfo(offset=local_index * self.stride,
         | 
| 86 | 
            +
                                     file_index=file_index,
         | 
| 87 | 
            +
                                     local_index=local_index)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def _path(self, folder, name, stream=0):
         | 
| 90 | 
            +
                    return self.path / folder / (name + f'.{stream}.raw')
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def __getitem__(self, index):
         | 
| 93 | 
            +
                    chunk = self.chunk_info(index)
         | 
| 94 | 
            +
                    entry = self._entries[chunk.file_index]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    length = self.samples or self._lengths[chunk.file_index]
         | 
| 97 | 
            +
                    streams = []
         | 
| 98 | 
            +
                    to_read = length * self.channels * 4
         | 
| 99 | 
            +
                    for stream_index, stream in enumerate(self.streams):
         | 
| 100 | 
            +
                        offset = chunk.offset * 4 * self.channels
         | 
| 101 | 
            +
                        file = open(self._path(*entry, stream=stream), 'rb')
         | 
| 102 | 
            +
                        file.seek(offset)
         | 
| 103 | 
            +
                        content = file.read(to_read)
         | 
| 104 | 
            +
                        assert len(content) == to_read
         | 
| 105 | 
            +
                        content = np.frombuffer(content, dtype=np.float32)
         | 
| 106 | 
            +
                        content = content.copy()  # make writable
         | 
| 107 | 
            +
                        streams.append(th.from_numpy(content).view(length, self.channels).t())
         | 
| 108 | 
            +
                    return th.stack(streams, dim=0)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def name(self, index):
         | 
| 111 | 
            +
                    chunk = self.chunk_info(index)
         | 
| 112 | 
            +
                    folder, name = self._entries[chunk.file_index]
         | 
| 113 | 
            +
                    return folder / name
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            class MusDBSet:
         | 
| 117 | 
            +
                def __init__(self, mus, streams=slice(None), samplerate=44100, channels=2):
         | 
| 118 | 
            +
                    self.mus = mus
         | 
| 119 | 
            +
                    self.streams = streams
         | 
| 120 | 
            +
                    self.samplerate = samplerate
         | 
| 121 | 
            +
                    self.channels = channels
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def __len__(self):
         | 
| 124 | 
            +
                    return len(self.mus.tracks)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def __getitem__(self, index):
         | 
| 127 | 
            +
                    track = self.mus.tracks[index]
         | 
| 128 | 
            +
                    return (track.name, AudioFile(track.path).read(channels=self.channels,
         | 
| 129 | 
            +
                                                                   seek_time=0,
         | 
| 130 | 
            +
                                                                   streams=self.streams,
         | 
| 131 | 
            +
                                                                   samplerate=self.samplerate))
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def build_raw(mus, destination, normalize, workers, samplerate, channels):
         | 
| 135 | 
            +
                destination.mkdir(parents=True, exist_ok=True)
         | 
| 136 | 
            +
                loader = DataLoader(MusDBSet(mus, channels=channels, samplerate=samplerate),
         | 
| 137 | 
            +
                                    batch_size=1,
         | 
| 138 | 
            +
                                    num_workers=workers,
         | 
| 139 | 
            +
                                    collate_fn=lambda x: x[0])
         | 
| 140 | 
            +
                for name, streams in tqdm.tqdm(loader):
         | 
| 141 | 
            +
                    if normalize:
         | 
| 142 | 
            +
                        ref = streams[0].mean(dim=0)  # use mono mixture as reference
         | 
| 143 | 
            +
                        streams = (streams - ref.mean()) / ref.std()
         | 
| 144 | 
            +
                    for index, stream in enumerate(streams):
         | 
| 145 | 
            +
                        open(destination / (name + f'.{index}.raw'), "wb").write(stream.t().numpy().tobytes())
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def main():
         | 
| 149 | 
            +
                parser = argparse.ArgumentParser('rawset')
         | 
| 150 | 
            +
                parser.add_argument('--workers', type=int, default=10)
         | 
| 151 | 
            +
                parser.add_argument('--samplerate', type=int, default=44100)
         | 
| 152 | 
            +
                parser.add_argument('--channels', type=int, default=2)
         | 
| 153 | 
            +
                parser.add_argument('musdb', type=Path)
         | 
| 154 | 
            +
                parser.add_argument('destination', type=Path)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                args = parser.parse_args()
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="train"),
         | 
| 159 | 
            +
                          args.destination / "train",
         | 
| 160 | 
            +
                          normalize=True,
         | 
| 161 | 
            +
                          channels=args.channels,
         | 
| 162 | 
            +
                          samplerate=args.samplerate,
         | 
| 163 | 
            +
                          workers=args.workers)
         | 
| 164 | 
            +
                build_raw(musdb.DB(root=args.musdb, subsets=["train"], split="valid"),
         | 
| 165 | 
            +
                          args.destination / "valid",
         | 
| 166 | 
            +
                          normalize=True,
         | 
| 167 | 
            +
                          samplerate=args.samplerate,
         | 
| 168 | 
            +
                          channels=args.channels,
         | 
| 169 | 
            +
                          workers=args.workers)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            if __name__ == "__main__":
         | 
| 173 | 
            +
                main()
         | 
    	
        demucs/repitch.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import io
         | 
| 8 | 
            +
            import random
         | 
| 9 | 
            +
            import subprocess as sp
         | 
| 10 | 
            +
            import tempfile
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from scipy.io import wavfile
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def i16_pcm(wav):
         | 
| 18 | 
            +
                if wav.dtype == np.int16:
         | 
| 19 | 
            +
                    return wav
         | 
| 20 | 
            +
                return (wav * 2**15).clamp_(-2**15, 2**15 - 1).short()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def f32_pcm(wav):
         | 
| 24 | 
            +
                if wav.dtype == np.float:
         | 
| 25 | 
            +
                    return wav
         | 
| 26 | 
            +
                return wav.float() / 2**15
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class RepitchedWrapper:
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
                Wrap a dataset to apply online change of pitch / tempo.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, tempo_std=5, vocals=[3]):
         | 
| 34 | 
            +
                    self.dataset = dataset
         | 
| 35 | 
            +
                    self.proba = proba
         | 
| 36 | 
            +
                    self.max_pitch = max_pitch
         | 
| 37 | 
            +
                    self.max_tempo = max_tempo
         | 
| 38 | 
            +
                    self.tempo_std = tempo_std
         | 
| 39 | 
            +
                    self.vocals = vocals
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __len__(self):
         | 
| 42 | 
            +
                    return len(self.dataset)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __getitem__(self, index):
         | 
| 45 | 
            +
                    streams = self.dataset[index]
         | 
| 46 | 
            +
                    in_length = streams.shape[-1]
         | 
| 47 | 
            +
                    out_length = int((1 - 0.01 * self.max_tempo) * in_length)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    if random.random() < self.proba:
         | 
| 50 | 
            +
                        delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
         | 
| 51 | 
            +
                        delta_tempo = random.gauss(0, self.tempo_std)
         | 
| 52 | 
            +
                        delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
         | 
| 53 | 
            +
                        outs = []
         | 
| 54 | 
            +
                        for idx, stream in enumerate(streams):
         | 
| 55 | 
            +
                            stream = repitch(
         | 
| 56 | 
            +
                                stream,
         | 
| 57 | 
            +
                                delta_pitch,
         | 
| 58 | 
            +
                                delta_tempo,
         | 
| 59 | 
            +
                                voice=idx in self.vocals)
         | 
| 60 | 
            +
                            outs.append(stream[:, :out_length])
         | 
| 61 | 
            +
                        streams = torch.stack(outs)
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        streams = streams[..., :out_length]
         | 
| 64 | 
            +
                    return streams
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
         | 
| 68 | 
            +
                """
         | 
| 69 | 
            +
                tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
         | 
| 70 | 
            +
                pitch is in semi tones.
         | 
| 71 | 
            +
                Requires `soundstretch` to be installed, see
         | 
| 72 | 
            +
                https://www.surina.net/soundtouch/soundstretch.html
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                outfile = tempfile.NamedTemporaryFile(suffix=".wav")
         | 
| 75 | 
            +
                in_ = io.BytesIO()
         | 
| 76 | 
            +
                wavfile.write(in_, samplerate, i16_pcm(wav).t().numpy())
         | 
| 77 | 
            +
                command = [
         | 
| 78 | 
            +
                    "soundstretch",
         | 
| 79 | 
            +
                    "stdin",
         | 
| 80 | 
            +
                    outfile.name,
         | 
| 81 | 
            +
                    f"-pitch={pitch}",
         | 
| 82 | 
            +
                    f"-tempo={tempo:.6f}",
         | 
| 83 | 
            +
                ]
         | 
| 84 | 
            +
                if quick:
         | 
| 85 | 
            +
                    command += ["-quick"]
         | 
| 86 | 
            +
                if voice:
         | 
| 87 | 
            +
                    command += ["-speech"]
         | 
| 88 | 
            +
                try:
         | 
| 89 | 
            +
                    sp.run(command, capture_output=True, input=in_.getvalue(), check=True)
         | 
| 90 | 
            +
                except sp.CalledProcessError as error:
         | 
| 91 | 
            +
                    raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
         | 
| 92 | 
            +
                sr, wav = wavfile.read(outfile.name)
         | 
| 93 | 
            +
                wav = wav.copy()
         | 
| 94 | 
            +
                wav = f32_pcm(torch.from_numpy(wav).t())
         | 
| 95 | 
            +
                assert sr == samplerate
         | 
| 96 | 
            +
                return wav
         | 
    	
        demucs/separate.py
    ADDED
    
    | @@ -0,0 +1,185 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import argparse
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            from pathlib import Path
         | 
| 10 | 
            +
            import subprocess
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import julius
         | 
| 13 | 
            +
            import torch as th
         | 
| 14 | 
            +
            import torchaudio as ta
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .audio import AudioFile, convert_audio_channels
         | 
| 17 | 
            +
            from .pretrained import is_pretrained, load_pretrained
         | 
| 18 | 
            +
            from .utils import apply_model, load_model
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def load_track(track, device, audio_channels, samplerate):
         | 
| 22 | 
            +
                errors = {}
         | 
| 23 | 
            +
                wav = None
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                try:
         | 
| 26 | 
            +
                    wav = AudioFile(track).read(
         | 
| 27 | 
            +
                        streams=0,
         | 
| 28 | 
            +
                        samplerate=samplerate,
         | 
| 29 | 
            +
                        channels=audio_channels).to(device)
         | 
| 30 | 
            +
                except FileNotFoundError:
         | 
| 31 | 
            +
                    errors['ffmpeg'] = 'Ffmpeg is not installed.'
         | 
| 32 | 
            +
                except subprocess.CalledProcessError:
         | 
| 33 | 
            +
                    errors['ffmpeg'] = 'FFmpeg could not read the file.'
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                if wav is None:
         | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        wav, sr = ta.load(str(track))
         | 
| 38 | 
            +
                    except RuntimeError as err:
         | 
| 39 | 
            +
                        errors['torchaudio'] = err.args[0]
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        wav = convert_audio_channels(wav, audio_channels)
         | 
| 42 | 
            +
                        wav = wav.to(device)
         | 
| 43 | 
            +
                        wav = julius.resample_frac(wav, sr, samplerate)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                if wav is None:
         | 
| 46 | 
            +
                    print(f"Could not load file {track}. "
         | 
| 47 | 
            +
                          "Maybe it is not a supported file format? ")
         | 
| 48 | 
            +
                    for backend, error in errors.items():
         | 
| 49 | 
            +
                        print(f"When trying to load using {backend}, got the following error: {error}")
         | 
| 50 | 
            +
                    sys.exit(1)
         | 
| 51 | 
            +
                return wav
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False):
         | 
| 55 | 
            +
                try:
         | 
| 56 | 
            +
                    import lameenc
         | 
| 57 | 
            +
                except ImportError:
         | 
| 58 | 
            +
                    print("Failed to call lame encoder. Maybe it is not installed? "
         | 
| 59 | 
            +
                          "On windows, run `python.exe -m pip install -U lameenc`, "
         | 
| 60 | 
            +
                          "on OSX/Linux, run `python3 -m pip install -U lameenc`, "
         | 
| 61 | 
            +
                          "then try again.", file=sys.stderr)
         | 
| 62 | 
            +
                    sys.exit(1)
         | 
| 63 | 
            +
                encoder = lameenc.Encoder()
         | 
| 64 | 
            +
                encoder.set_bit_rate(bitrate)
         | 
| 65 | 
            +
                encoder.set_in_sample_rate(samplerate)
         | 
| 66 | 
            +
                encoder.set_channels(channels)
         | 
| 67 | 
            +
                encoder.set_quality(2)  # 2-highest, 7-fastest
         | 
| 68 | 
            +
                if not verbose:
         | 
| 69 | 
            +
                    encoder.silence()
         | 
| 70 | 
            +
                wav = wav.transpose(0, 1).numpy()
         | 
| 71 | 
            +
                mp3_data = encoder.encode(wav.tobytes())
         | 
| 72 | 
            +
                mp3_data += encoder.flush()
         | 
| 73 | 
            +
                with open(path, "wb") as f:
         | 
| 74 | 
            +
                    f.write(mp3_data)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def main():
         | 
| 78 | 
            +
                parser = argparse.ArgumentParser("demucs.separate",
         | 
| 79 | 
            +
                                                 description="Separate the sources for the given tracks")
         | 
| 80 | 
            +
                parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
         | 
| 81 | 
            +
                parser.add_argument("-n",
         | 
| 82 | 
            +
                                    "--name",
         | 
| 83 | 
            +
                                    default="demucs_quantized",
         | 
| 84 | 
            +
                                    help="Model name. See README.md for the list of pretrained models. "
         | 
| 85 | 
            +
                                         "Default is demucs_quantized.")
         | 
| 86 | 
            +
                parser.add_argument("-v", "--verbose", action="store_true")
         | 
| 87 | 
            +
                parser.add_argument("-o",
         | 
| 88 | 
            +
                                    "--out",
         | 
| 89 | 
            +
                                    type=Path,
         | 
| 90 | 
            +
                                    default=Path("separated"),
         | 
| 91 | 
            +
                                    help="Folder where to put extracted tracks. A subfolder "
         | 
| 92 | 
            +
                                    "with the model name will be created.")
         | 
| 93 | 
            +
                parser.add_argument("--models",
         | 
| 94 | 
            +
                                    type=Path,
         | 
| 95 | 
            +
                                    default=Path("models"),
         | 
| 96 | 
            +
                                    help="Path to trained models. "
         | 
| 97 | 
            +
                                    "Also used to store downloaded pretrained models")
         | 
| 98 | 
            +
                parser.add_argument("-d",
         | 
| 99 | 
            +
                                    "--device",
         | 
| 100 | 
            +
                                    default="cuda" if th.cuda.is_available() else "cpu",
         | 
| 101 | 
            +
                                    help="Device to use, default is cuda if available else cpu")
         | 
| 102 | 
            +
                parser.add_argument("--shifts",
         | 
| 103 | 
            +
                                    default=0,
         | 
| 104 | 
            +
                                    type=int,
         | 
| 105 | 
            +
                                    help="Number of random shifts for equivariant stabilization."
         | 
| 106 | 
            +
                                    "Increase separation time but improves quality for Demucs. 10 was used "
         | 
| 107 | 
            +
                                    "in the original paper.")
         | 
| 108 | 
            +
                parser.add_argument("--overlap",
         | 
| 109 | 
            +
                                    default=0.25,
         | 
| 110 | 
            +
                                    type=float,
         | 
| 111 | 
            +
                                    help="Overlap between the splits.")
         | 
| 112 | 
            +
                parser.add_argument("--no-split",
         | 
| 113 | 
            +
                                    action="store_false",
         | 
| 114 | 
            +
                                    dest="split",
         | 
| 115 | 
            +
                                    default=True,
         | 
| 116 | 
            +
                                    help="Doesn't split audio in chunks. This can use large amounts of memory.")
         | 
| 117 | 
            +
                parser.add_argument("--float32",
         | 
| 118 | 
            +
                                    action="store_true",
         | 
| 119 | 
            +
                                    help="Convert the output wavefile to use pcm f32 format instead of s16. "
         | 
| 120 | 
            +
                                    "This should not make a difference if you just plan on listening to the "
         | 
| 121 | 
            +
                                    "audio but might be needed to compute exactly metrics like SDR etc.")
         | 
| 122 | 
            +
                parser.add_argument("--int16",
         | 
| 123 | 
            +
                                    action="store_false",
         | 
| 124 | 
            +
                                    dest="float32",
         | 
| 125 | 
            +
                                    help="Opposite of --float32, here for compatibility.")
         | 
| 126 | 
            +
                parser.add_argument("--mp3", action="store_true",
         | 
| 127 | 
            +
                                    help="Convert the output wavs to mp3.")
         | 
| 128 | 
            +
                parser.add_argument("--mp3-bitrate",
         | 
| 129 | 
            +
                                    default=320,
         | 
| 130 | 
            +
                                    type=int,
         | 
| 131 | 
            +
                                    help="Bitrate of converted mp3.")
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                args = parser.parse_args()
         | 
| 134 | 
            +
                name = args.name + ".th"
         | 
| 135 | 
            +
                model_path = args.models / name
         | 
| 136 | 
            +
                if model_path.is_file():
         | 
| 137 | 
            +
                    model = load_model(model_path)
         | 
| 138 | 
            +
                else:
         | 
| 139 | 
            +
                    if is_pretrained(args.name):
         | 
| 140 | 
            +
                        model = load_pretrained(args.name)
         | 
| 141 | 
            +
                    else:
         | 
| 142 | 
            +
                        print(f"No pre-trained model {args.name}", file=sys.stderr)
         | 
| 143 | 
            +
                        sys.exit(1)
         | 
| 144 | 
            +
                model.to(args.device)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                out = args.out / args.name
         | 
| 147 | 
            +
                out.mkdir(parents=True, exist_ok=True)
         | 
| 148 | 
            +
                print(f"Separated tracks will be stored in {out.resolve()}")
         | 
| 149 | 
            +
                for track in args.tracks:
         | 
| 150 | 
            +
                    if not track.exists():
         | 
| 151 | 
            +
                        print(
         | 
| 152 | 
            +
                            f"File {track} does not exist. If the path contains spaces, "
         | 
| 153 | 
            +
                            "please try again after surrounding the entire path with quotes \"\".",
         | 
| 154 | 
            +
                            file=sys.stderr)
         | 
| 155 | 
            +
                        continue
         | 
| 156 | 
            +
                    print(f"Separating track {track}")
         | 
| 157 | 
            +
                    wav = load_track(track, args.device, model.audio_channels, model.samplerate)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    ref = wav.mean(0)
         | 
| 160 | 
            +
                    wav = (wav - ref.mean()) / ref.std()
         | 
| 161 | 
            +
                    sources = apply_model(model, wav, shifts=args.shifts, split=args.split,
         | 
| 162 | 
            +
                                          overlap=args.overlap, progress=True)
         | 
| 163 | 
            +
                    sources = sources * ref.std() + ref.mean()
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    track_folder = out / track.name.rsplit(".", 1)[0]
         | 
| 166 | 
            +
                    track_folder.mkdir(exist_ok=True)
         | 
| 167 | 
            +
                    for source, name in zip(sources, model.sources):
         | 
| 168 | 
            +
                        source = source / max(1.01 * source.abs().max(), 1)
         | 
| 169 | 
            +
                        if args.mp3 or not args.float32:
         | 
| 170 | 
            +
                            source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
         | 
| 171 | 
            +
                        source = source.cpu()
         | 
| 172 | 
            +
                        stem = str(track_folder / name)
         | 
| 173 | 
            +
                        if args.mp3:
         | 
| 174 | 
            +
                            encode_mp3(source, stem + ".mp3",
         | 
| 175 | 
            +
                                       bitrate=args.mp3_bitrate,
         | 
| 176 | 
            +
                                       samplerate=model.samplerate,
         | 
| 177 | 
            +
                                       channels=model.audio_channels,
         | 
| 178 | 
            +
                                       verbose=args.verbose)
         | 
| 179 | 
            +
                        else:
         | 
| 180 | 
            +
                            wavname = str(track_folder / f"{name}.wav")
         | 
| 181 | 
            +
                            ta.save(wavname, source, sample_rate=model.samplerate)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            if __name__ == "__main__":
         | 
| 185 | 
            +
                main()
         | 
    	
        demucs/tasnet.py
    ADDED
    
    | @@ -0,0 +1,452 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            # Created on 2018/12
         | 
| 8 | 
            +
            # Author: Kaituo XU
         | 
| 9 | 
            +
            # Modified on 2019/11 by Alexandre Defossez, added support for multiple output channels
         | 
| 10 | 
            +
            # Here is the original license:
         | 
| 11 | 
            +
            # The MIT License (MIT)
         | 
| 12 | 
            +
            #
         | 
| 13 | 
            +
            # Copyright (c) 2018 Kaituo XU
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 16 | 
            +
            # of this software and associated documentation files (the "Software"), to deal
         | 
| 17 | 
            +
            # in the Software without restriction, including without limitation the rights
         | 
| 18 | 
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 19 | 
            +
            # copies of the Software, and to permit persons to whom the Software is
         | 
| 20 | 
            +
            # furnished to do so, subject to the following conditions:
         | 
| 21 | 
            +
            #
         | 
| 22 | 
            +
            # The above copyright notice and this permission notice shall be included in all
         | 
| 23 | 
            +
            # copies or substantial portions of the Software.
         | 
| 24 | 
            +
            #
         | 
| 25 | 
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 26 | 
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 27 | 
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 28 | 
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 29 | 
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 30 | 
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 31 | 
            +
            # SOFTWARE.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            import math
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            import torch
         | 
| 36 | 
            +
            import torch.nn as nn
         | 
| 37 | 
            +
            import torch.nn.functional as F
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            from .utils import capture_init
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            EPS = 1e-8
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def overlap_and_add(signal, frame_step):
         | 
| 45 | 
            +
                outer_dimensions = signal.size()[:-2]
         | 
| 46 | 
            +
                frames, frame_length = signal.size()[-2:]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                subframe_length = math.gcd(frame_length, frame_step)  # gcd=Greatest Common Divisor
         | 
| 49 | 
            +
                subframe_step = frame_step // subframe_length
         | 
| 50 | 
            +
                subframes_per_frame = frame_length // subframe_length
         | 
| 51 | 
            +
                output_size = frame_step * (frames - 1) + frame_length
         | 
| 52 | 
            +
                output_subframes = output_size // subframe_length
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                frame = torch.arange(0, output_subframes,
         | 
| 57 | 
            +
                                     device=signal.device).unfold(0, subframes_per_frame, subframe_step)
         | 
| 58 | 
            +
                frame = frame.long()  # signal may in GPU or CPU
         | 
| 59 | 
            +
                frame = frame.contiguous().view(-1)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
         | 
| 62 | 
            +
                result.index_add_(-2, frame, subframe_signal)
         | 
| 63 | 
            +
                result = result.view(*outer_dimensions, -1)
         | 
| 64 | 
            +
                return result
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class ConvTasNet(nn.Module):
         | 
| 68 | 
            +
                @capture_init
         | 
| 69 | 
            +
                def __init__(self,
         | 
| 70 | 
            +
                             sources,
         | 
| 71 | 
            +
                             N=256,
         | 
| 72 | 
            +
                             L=20,
         | 
| 73 | 
            +
                             B=256,
         | 
| 74 | 
            +
                             H=512,
         | 
| 75 | 
            +
                             P=3,
         | 
| 76 | 
            +
                             X=8,
         | 
| 77 | 
            +
                             R=4,
         | 
| 78 | 
            +
                             audio_channels=2,
         | 
| 79 | 
            +
                             norm_type="gLN",
         | 
| 80 | 
            +
                             causal=False,
         | 
| 81 | 
            +
                             mask_nonlinear='relu',
         | 
| 82 | 
            +
                             samplerate=44100,
         | 
| 83 | 
            +
                             segment_length=44100 * 2 * 4):
         | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    Args:
         | 
| 86 | 
            +
                        sources: list of sources
         | 
| 87 | 
            +
                        N: Number of filters in autoencoder
         | 
| 88 | 
            +
                        L: Length of the filters (in samples)
         | 
| 89 | 
            +
                        B: Number of channels in bottleneck 1 × 1-conv block
         | 
| 90 | 
            +
                        H: Number of channels in convolutional blocks
         | 
| 91 | 
            +
                        P: Kernel size in convolutional blocks
         | 
| 92 | 
            +
                        X: Number of convolutional blocks in each repeat
         | 
| 93 | 
            +
                        R: Number of repeats
         | 
| 94 | 
            +
                        norm_type: BN, gLN, cLN
         | 
| 95 | 
            +
                        causal: causal or non-causal
         | 
| 96 | 
            +
                        mask_nonlinear: use which non-linear function to generate mask
         | 
| 97 | 
            +
                    """
         | 
| 98 | 
            +
                    super(ConvTasNet, self).__init__()
         | 
| 99 | 
            +
                    # Hyper-parameter
         | 
| 100 | 
            +
                    self.sources = sources
         | 
| 101 | 
            +
                    self.C = len(sources)
         | 
| 102 | 
            +
                    self.N, self.L, self.B, self.H, self.P, self.X, self.R = N, L, B, H, P, X, R
         | 
| 103 | 
            +
                    self.norm_type = norm_type
         | 
| 104 | 
            +
                    self.causal = causal
         | 
| 105 | 
            +
                    self.mask_nonlinear = mask_nonlinear
         | 
| 106 | 
            +
                    self.audio_channels = audio_channels
         | 
| 107 | 
            +
                    self.samplerate = samplerate
         | 
| 108 | 
            +
                    self.segment_length = segment_length
         | 
| 109 | 
            +
                    # Components
         | 
| 110 | 
            +
                    self.encoder = Encoder(L, N, audio_channels)
         | 
| 111 | 
            +
                    self.separator = TemporalConvNet(
         | 
| 112 | 
            +
                        N, B, H, P, X, R, self.C, norm_type, causal, mask_nonlinear)
         | 
| 113 | 
            +
                    self.decoder = Decoder(N, L, audio_channels)
         | 
| 114 | 
            +
                    # init
         | 
| 115 | 
            +
                    for p in self.parameters():
         | 
| 116 | 
            +
                        if p.dim() > 1:
         | 
| 117 | 
            +
                            nn.init.xavier_normal_(p)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def valid_length(self, length):
         | 
| 120 | 
            +
                    return length
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def forward(self, mixture):
         | 
| 123 | 
            +
                    """
         | 
| 124 | 
            +
                    Args:
         | 
| 125 | 
            +
                        mixture: [M, T], M is batch size, T is #samples
         | 
| 126 | 
            +
                    Returns:
         | 
| 127 | 
            +
                        est_source: [M, C, T]
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    mixture_w = self.encoder(mixture)
         | 
| 130 | 
            +
                    est_mask = self.separator(mixture_w)
         | 
| 131 | 
            +
                    est_source = self.decoder(mixture_w, est_mask)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # T changed after conv1d in encoder, fix it here
         | 
| 134 | 
            +
                    T_origin = mixture.size(-1)
         | 
| 135 | 
            +
                    T_conv = est_source.size(-1)
         | 
| 136 | 
            +
                    est_source = F.pad(est_source, (0, T_origin - T_conv))
         | 
| 137 | 
            +
                    return est_source
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            class Encoder(nn.Module):
         | 
| 141 | 
            +
                """Estimation of the nonnegative mixture weight by a 1-D conv layer.
         | 
| 142 | 
            +
                """
         | 
| 143 | 
            +
                def __init__(self, L, N, audio_channels):
         | 
| 144 | 
            +
                    super(Encoder, self).__init__()
         | 
| 145 | 
            +
                    # Hyper-parameter
         | 
| 146 | 
            +
                    self.L, self.N = L, N
         | 
| 147 | 
            +
                    # Components
         | 
| 148 | 
            +
                    # 50% overlap
         | 
| 149 | 
            +
                    self.conv1d_U = nn.Conv1d(audio_channels, N, kernel_size=L, stride=L // 2, bias=False)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                def forward(self, mixture):
         | 
| 152 | 
            +
                    """
         | 
| 153 | 
            +
                    Args:
         | 
| 154 | 
            +
                        mixture: [M, T], M is batch size, T is #samples
         | 
| 155 | 
            +
                    Returns:
         | 
| 156 | 
            +
                        mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
         | 
| 157 | 
            +
                    """
         | 
| 158 | 
            +
                    mixture_w = F.relu(self.conv1d_U(mixture))  # [M, N, K]
         | 
| 159 | 
            +
                    return mixture_w
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class Decoder(nn.Module):
         | 
| 163 | 
            +
                def __init__(self, N, L, audio_channels):
         | 
| 164 | 
            +
                    super(Decoder, self).__init__()
         | 
| 165 | 
            +
                    # Hyper-parameter
         | 
| 166 | 
            +
                    self.N, self.L = N, L
         | 
| 167 | 
            +
                    self.audio_channels = audio_channels
         | 
| 168 | 
            +
                    # Components
         | 
| 169 | 
            +
                    self.basis_signals = nn.Linear(N, audio_channels * L, bias=False)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                def forward(self, mixture_w, est_mask):
         | 
| 172 | 
            +
                    """
         | 
| 173 | 
            +
                    Args:
         | 
| 174 | 
            +
                        mixture_w: [M, N, K]
         | 
| 175 | 
            +
                        est_mask: [M, C, N, K]
         | 
| 176 | 
            +
                    Returns:
         | 
| 177 | 
            +
                        est_source: [M, C, T]
         | 
| 178 | 
            +
                    """
         | 
| 179 | 
            +
                    # D = W * M
         | 
| 180 | 
            +
                    source_w = torch.unsqueeze(mixture_w, 1) * est_mask  # [M, C, N, K]
         | 
| 181 | 
            +
                    source_w = torch.transpose(source_w, 2, 3)  # [M, C, K, N]
         | 
| 182 | 
            +
                    # S = DV
         | 
| 183 | 
            +
                    est_source = self.basis_signals(source_w)  # [M, C, K, ac * L]
         | 
| 184 | 
            +
                    m, c, k, _ = est_source.size()
         | 
| 185 | 
            +
                    est_source = est_source.view(m, c, k, self.audio_channels, -1).transpose(2, 3).contiguous()
         | 
| 186 | 
            +
                    est_source = overlap_and_add(est_source, self.L // 2)  # M x C x ac x T
         | 
| 187 | 
            +
                    return est_source
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            class TemporalConvNet(nn.Module):
         | 
| 191 | 
            +
                def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear='relu'):
         | 
| 192 | 
            +
                    """
         | 
| 193 | 
            +
                    Args:
         | 
| 194 | 
            +
                        N: Number of filters in autoencoder
         | 
| 195 | 
            +
                        B: Number of channels in bottleneck 1 × 1-conv block
         | 
| 196 | 
            +
                        H: Number of channels in convolutional blocks
         | 
| 197 | 
            +
                        P: Kernel size in convolutional blocks
         | 
| 198 | 
            +
                        X: Number of convolutional blocks in each repeat
         | 
| 199 | 
            +
                        R: Number of repeats
         | 
| 200 | 
            +
                        C: Number of speakers
         | 
| 201 | 
            +
                        norm_type: BN, gLN, cLN
         | 
| 202 | 
            +
                        causal: causal or non-causal
         | 
| 203 | 
            +
                        mask_nonlinear: use which non-linear function to generate mask
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    super(TemporalConvNet, self).__init__()
         | 
| 206 | 
            +
                    # Hyper-parameter
         | 
| 207 | 
            +
                    self.C = C
         | 
| 208 | 
            +
                    self.mask_nonlinear = mask_nonlinear
         | 
| 209 | 
            +
                    # Components
         | 
| 210 | 
            +
                    # [M, N, K] -> [M, N, K]
         | 
| 211 | 
            +
                    layer_norm = ChannelwiseLayerNorm(N)
         | 
| 212 | 
            +
                    # [M, N, K] -> [M, B, K]
         | 
| 213 | 
            +
                    bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
         | 
| 214 | 
            +
                    # [M, B, K] -> [M, B, K]
         | 
| 215 | 
            +
                    repeats = []
         | 
| 216 | 
            +
                    for r in range(R):
         | 
| 217 | 
            +
                        blocks = []
         | 
| 218 | 
            +
                        for x in range(X):
         | 
| 219 | 
            +
                            dilation = 2**x
         | 
| 220 | 
            +
                            padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
         | 
| 221 | 
            +
                            blocks += [
         | 
| 222 | 
            +
                                TemporalBlock(B,
         | 
| 223 | 
            +
                                              H,
         | 
| 224 | 
            +
                                              P,
         | 
| 225 | 
            +
                                              stride=1,
         | 
| 226 | 
            +
                                              padding=padding,
         | 
| 227 | 
            +
                                              dilation=dilation,
         | 
| 228 | 
            +
                                              norm_type=norm_type,
         | 
| 229 | 
            +
                                              causal=causal)
         | 
| 230 | 
            +
                            ]
         | 
| 231 | 
            +
                        repeats += [nn.Sequential(*blocks)]
         | 
| 232 | 
            +
                    temporal_conv_net = nn.Sequential(*repeats)
         | 
| 233 | 
            +
                    # [M, B, K] -> [M, C*N, K]
         | 
| 234 | 
            +
                    mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False)
         | 
| 235 | 
            +
                    # Put together
         | 
| 236 | 
            +
                    self.network = nn.Sequential(layer_norm, bottleneck_conv1x1, temporal_conv_net,
         | 
| 237 | 
            +
                                                 mask_conv1x1)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def forward(self, mixture_w):
         | 
| 240 | 
            +
                    """
         | 
| 241 | 
            +
                    Keep this API same with TasNet
         | 
| 242 | 
            +
                    Args:
         | 
| 243 | 
            +
                        mixture_w: [M, N, K], M is batch size
         | 
| 244 | 
            +
                    returns:
         | 
| 245 | 
            +
                        est_mask: [M, C, N, K]
         | 
| 246 | 
            +
                    """
         | 
| 247 | 
            +
                    M, N, K = mixture_w.size()
         | 
| 248 | 
            +
                    score = self.network(mixture_w)  # [M, N, K] -> [M, C*N, K]
         | 
| 249 | 
            +
                    score = score.view(M, self.C, N, K)  # [M, C*N, K] -> [M, C, N, K]
         | 
| 250 | 
            +
                    if self.mask_nonlinear == 'softmax':
         | 
| 251 | 
            +
                        est_mask = F.softmax(score, dim=1)
         | 
| 252 | 
            +
                    elif self.mask_nonlinear == 'relu':
         | 
| 253 | 
            +
                        est_mask = F.relu(score)
         | 
| 254 | 
            +
                    else:
         | 
| 255 | 
            +
                        raise ValueError("Unsupported mask non-linear function")
         | 
| 256 | 
            +
                    return est_mask
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            class TemporalBlock(nn.Module):
         | 
| 260 | 
            +
                def __init__(self,
         | 
| 261 | 
            +
                             in_channels,
         | 
| 262 | 
            +
                             out_channels,
         | 
| 263 | 
            +
                             kernel_size,
         | 
| 264 | 
            +
                             stride,
         | 
| 265 | 
            +
                             padding,
         | 
| 266 | 
            +
                             dilation,
         | 
| 267 | 
            +
                             norm_type="gLN",
         | 
| 268 | 
            +
                             causal=False):
         | 
| 269 | 
            +
                    super(TemporalBlock, self).__init__()
         | 
| 270 | 
            +
                    # [M, B, K] -> [M, H, K]
         | 
| 271 | 
            +
                    conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
         | 
| 272 | 
            +
                    prelu = nn.PReLU()
         | 
| 273 | 
            +
                    norm = chose_norm(norm_type, out_channels)
         | 
| 274 | 
            +
                    # [M, H, K] -> [M, B, K]
         | 
| 275 | 
            +
                    dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, stride, padding,
         | 
| 276 | 
            +
                                                    dilation, norm_type, causal)
         | 
| 277 | 
            +
                    # Put together
         | 
| 278 | 
            +
                    self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def forward(self, x):
         | 
| 281 | 
            +
                    """
         | 
| 282 | 
            +
                    Args:
         | 
| 283 | 
            +
                        x: [M, B, K]
         | 
| 284 | 
            +
                    Returns:
         | 
| 285 | 
            +
                        [M, B, K]
         | 
| 286 | 
            +
                    """
         | 
| 287 | 
            +
                    residual = x
         | 
| 288 | 
            +
                    out = self.net(x)
         | 
| 289 | 
            +
                    # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
         | 
| 290 | 
            +
                    return out + residual  # look like w/o F.relu is better than w/ F.relu
         | 
| 291 | 
            +
                    # return F.relu(out + residual)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
             | 
| 294 | 
            +
            class DepthwiseSeparableConv(nn.Module):
         | 
| 295 | 
            +
                def __init__(self,
         | 
| 296 | 
            +
                             in_channels,
         | 
| 297 | 
            +
                             out_channels,
         | 
| 298 | 
            +
                             kernel_size,
         | 
| 299 | 
            +
                             stride,
         | 
| 300 | 
            +
                             padding,
         | 
| 301 | 
            +
                             dilation,
         | 
| 302 | 
            +
                             norm_type="gLN",
         | 
| 303 | 
            +
                             causal=False):
         | 
| 304 | 
            +
                    super(DepthwiseSeparableConv, self).__init__()
         | 
| 305 | 
            +
                    # Use `groups` option to implement depthwise convolution
         | 
| 306 | 
            +
                    # [M, H, K] -> [M, H, K]
         | 
| 307 | 
            +
                    depthwise_conv = nn.Conv1d(in_channels,
         | 
| 308 | 
            +
                                               in_channels,
         | 
| 309 | 
            +
                                               kernel_size,
         | 
| 310 | 
            +
                                               stride=stride,
         | 
| 311 | 
            +
                                               padding=padding,
         | 
| 312 | 
            +
                                               dilation=dilation,
         | 
| 313 | 
            +
                                               groups=in_channels,
         | 
| 314 | 
            +
                                               bias=False)
         | 
| 315 | 
            +
                    if causal:
         | 
| 316 | 
            +
                        chomp = Chomp1d(padding)
         | 
| 317 | 
            +
                    prelu = nn.PReLU()
         | 
| 318 | 
            +
                    norm = chose_norm(norm_type, in_channels)
         | 
| 319 | 
            +
                    # [M, H, K] -> [M, B, K]
         | 
| 320 | 
            +
                    pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
         | 
| 321 | 
            +
                    # Put together
         | 
| 322 | 
            +
                    if causal:
         | 
| 323 | 
            +
                        self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv)
         | 
| 324 | 
            +
                    else:
         | 
| 325 | 
            +
                        self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                def forward(self, x):
         | 
| 328 | 
            +
                    """
         | 
| 329 | 
            +
                    Args:
         | 
| 330 | 
            +
                        x: [M, H, K]
         | 
| 331 | 
            +
                    Returns:
         | 
| 332 | 
            +
                        result: [M, B, K]
         | 
| 333 | 
            +
                    """
         | 
| 334 | 
            +
                    return self.net(x)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
             | 
| 337 | 
            +
            class Chomp1d(nn.Module):
         | 
| 338 | 
            +
                """To ensure the output length is the same as the input.
         | 
| 339 | 
            +
                """
         | 
| 340 | 
            +
                def __init__(self, chomp_size):
         | 
| 341 | 
            +
                    super(Chomp1d, self).__init__()
         | 
| 342 | 
            +
                    self.chomp_size = chomp_size
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                def forward(self, x):
         | 
| 345 | 
            +
                    """
         | 
| 346 | 
            +
                    Args:
         | 
| 347 | 
            +
                        x: [M, H, Kpad]
         | 
| 348 | 
            +
                    Returns:
         | 
| 349 | 
            +
                        [M, H, K]
         | 
| 350 | 
            +
                    """
         | 
| 351 | 
            +
                    return x[:, :, :-self.chomp_size].contiguous()
         | 
| 352 | 
            +
             | 
| 353 | 
            +
             | 
| 354 | 
            +
            def chose_norm(norm_type, channel_size):
         | 
| 355 | 
            +
                """The input of normlization will be (M, C, K), where M is batch size,
         | 
| 356 | 
            +
                   C is channel size and K is sequence length.
         | 
| 357 | 
            +
                """
         | 
| 358 | 
            +
                if norm_type == "gLN":
         | 
| 359 | 
            +
                    return GlobalLayerNorm(channel_size)
         | 
| 360 | 
            +
                elif norm_type == "cLN":
         | 
| 361 | 
            +
                    return ChannelwiseLayerNorm(channel_size)
         | 
| 362 | 
            +
                elif norm_type == "id":
         | 
| 363 | 
            +
                    return nn.Identity()
         | 
| 364 | 
            +
                else:  # norm_type == "BN":
         | 
| 365 | 
            +
                    # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
         | 
| 366 | 
            +
                    # along M and K, so this BN usage is right.
         | 
| 367 | 
            +
                    return nn.BatchNorm1d(channel_size)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
            # TODO: Use nn.LayerNorm to impl cLN to speed up
         | 
| 371 | 
            +
            class ChannelwiseLayerNorm(nn.Module):
         | 
| 372 | 
            +
                """Channel-wise Layer Normalization (cLN)"""
         | 
| 373 | 
            +
                def __init__(self, channel_size):
         | 
| 374 | 
            +
                    super(ChannelwiseLayerNorm, self).__init__()
         | 
| 375 | 
            +
                    self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1))  # [1, N, 1]
         | 
| 376 | 
            +
                    self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1))  # [1, N, 1]
         | 
| 377 | 
            +
                    self.reset_parameters()
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                def reset_parameters(self):
         | 
| 380 | 
            +
                    self.gamma.data.fill_(1)
         | 
| 381 | 
            +
                    self.beta.data.zero_()
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                def forward(self, y):
         | 
| 384 | 
            +
                    """
         | 
| 385 | 
            +
                    Args:
         | 
| 386 | 
            +
                        y: [M, N, K], M is batch size, N is channel size, K is length
         | 
| 387 | 
            +
                    Returns:
         | 
| 388 | 
            +
                        cLN_y: [M, N, K]
         | 
| 389 | 
            +
                    """
         | 
| 390 | 
            +
                    mean = torch.mean(y, dim=1, keepdim=True)  # [M, 1, K]
         | 
| 391 | 
            +
                    var = torch.var(y, dim=1, keepdim=True, unbiased=False)  # [M, 1, K]
         | 
| 392 | 
            +
                    cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
         | 
| 393 | 
            +
                    return cLN_y
         | 
| 394 | 
            +
             | 
| 395 | 
            +
             | 
| 396 | 
            +
            class GlobalLayerNorm(nn.Module):
         | 
| 397 | 
            +
                """Global Layer Normalization (gLN)"""
         | 
| 398 | 
            +
                def __init__(self, channel_size):
         | 
| 399 | 
            +
                    super(GlobalLayerNorm, self).__init__()
         | 
| 400 | 
            +
                    self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1))  # [1, N, 1]
         | 
| 401 | 
            +
                    self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1))  # [1, N, 1]
         | 
| 402 | 
            +
                    self.reset_parameters()
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def reset_parameters(self):
         | 
| 405 | 
            +
                    self.gamma.data.fill_(1)
         | 
| 406 | 
            +
                    self.beta.data.zero_()
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                def forward(self, y):
         | 
| 409 | 
            +
                    """
         | 
| 410 | 
            +
                    Args:
         | 
| 411 | 
            +
                        y: [M, N, K], M is batch size, N is channel size, K is length
         | 
| 412 | 
            +
                    Returns:
         | 
| 413 | 
            +
                        gLN_y: [M, N, K]
         | 
| 414 | 
            +
                    """
         | 
| 415 | 
            +
                    # TODO: in torch 1.0, torch.mean() support dim list
         | 
| 416 | 
            +
                    mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)  # [M, 1, 1]
         | 
| 417 | 
            +
                    var = (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
         | 
| 418 | 
            +
                    gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
         | 
| 419 | 
            +
                    return gLN_y
         | 
| 420 | 
            +
             | 
| 421 | 
            +
             | 
| 422 | 
            +
            if __name__ == "__main__":
         | 
| 423 | 
            +
                torch.manual_seed(123)
         | 
| 424 | 
            +
                M, N, L, T = 2, 3, 4, 12
         | 
| 425 | 
            +
                K = 2 * T // L - 1
         | 
| 426 | 
            +
                B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
         | 
| 427 | 
            +
                mixture = torch.randint(3, (M, T))
         | 
| 428 | 
            +
                # test Encoder
         | 
| 429 | 
            +
                encoder = Encoder(L, N)
         | 
| 430 | 
            +
                encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
         | 
| 431 | 
            +
                mixture_w = encoder(mixture)
         | 
| 432 | 
            +
                print('mixture', mixture)
         | 
| 433 | 
            +
                print('U', encoder.conv1d_U.weight)
         | 
| 434 | 
            +
                print('mixture_w', mixture_w)
         | 
| 435 | 
            +
                print('mixture_w size', mixture_w.size())
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                # test TemporalConvNet
         | 
| 438 | 
            +
                separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
         | 
| 439 | 
            +
                est_mask = separator(mixture_w)
         | 
| 440 | 
            +
                print('est_mask', est_mask)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                # test Decoder
         | 
| 443 | 
            +
                decoder = Decoder(N, L)
         | 
| 444 | 
            +
                est_mask = torch.randint(2, (B, K, C, N))
         | 
| 445 | 
            +
                est_source = decoder(mixture_w, est_mask)
         | 
| 446 | 
            +
                print('est_source', est_source)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                # test Conv-TasNet
         | 
| 449 | 
            +
                conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
         | 
| 450 | 
            +
                est_source = conv_tasnet(mixture)
         | 
| 451 | 
            +
                print('est_source', est_source)
         | 
| 452 | 
            +
                print('est_source size', est_source.size())
         | 
    	
        demucs/test.py
    ADDED
    
    | @@ -0,0 +1,109 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import gzip
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            from concurrent import futures
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import musdb
         | 
| 12 | 
            +
            import museval
         | 
| 13 | 
            +
            import torch as th
         | 
| 14 | 
            +
            import tqdm
         | 
| 15 | 
            +
            from scipy.io import wavfile
         | 
| 16 | 
            +
            from torch import distributed
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .audio import convert_audio
         | 
| 19 | 
            +
            from .utils import apply_model
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def evaluate(model,
         | 
| 23 | 
            +
                         musdb_path,
         | 
| 24 | 
            +
                         eval_folder,
         | 
| 25 | 
            +
                         workers=2,
         | 
| 26 | 
            +
                         device="cpu",
         | 
| 27 | 
            +
                         rank=0,
         | 
| 28 | 
            +
                         save=False,
         | 
| 29 | 
            +
                         shifts=0,
         | 
| 30 | 
            +
                         split=False,
         | 
| 31 | 
            +
                         overlap=0.25,
         | 
| 32 | 
            +
                         is_wav=False,
         | 
| 33 | 
            +
                         world_size=1):
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
                Evaluate model using museval. Run the model
         | 
| 36 | 
            +
                on a single GPU, the bottleneck being the call to museval.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                output_dir = eval_folder / "results"
         | 
| 40 | 
            +
                output_dir.mkdir(exist_ok=True, parents=True)
         | 
| 41 | 
            +
                json_folder = eval_folder / "results/test"
         | 
| 42 | 
            +
                json_folder.mkdir(exist_ok=True, parents=True)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # we load tracks from the original musdb set
         | 
| 45 | 
            +
                test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
         | 
| 46 | 
            +
                src_rate = 44100  # hardcoded for now...
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                for p in model.parameters():
         | 
| 49 | 
            +
                    p.requires_grad = False
         | 
| 50 | 
            +
                    p.grad = None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                pendings = []
         | 
| 53 | 
            +
                with futures.ProcessPoolExecutor(workers or 1) as pool:
         | 
| 54 | 
            +
                    for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
         | 
| 55 | 
            +
                        track = test_set.tracks[index]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        out = json_folder / f"{track.name}.json.gz"
         | 
| 58 | 
            +
                        if out.exists():
         | 
| 59 | 
            +
                            continue
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                        mix = th.from_numpy(track.audio).t().float()
         | 
| 62 | 
            +
                        ref = mix.mean(dim=0)  # mono mixture
         | 
| 63 | 
            +
                        mix = (mix - ref.mean()) / ref.std()
         | 
| 64 | 
            +
                        mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
         | 
| 65 | 
            +
                        estimates = apply_model(model, mix.to(device),
         | 
| 66 | 
            +
                                                shifts=shifts, split=split, overlap=overlap)
         | 
| 67 | 
            +
                        estimates = estimates * ref.std() + ref.mean()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        estimates = estimates.transpose(1, 2)
         | 
| 70 | 
            +
                        references = th.stack(
         | 
| 71 | 
            +
                            [th.from_numpy(track.targets[name].audio).t() for name in model.sources])
         | 
| 72 | 
            +
                        references = convert_audio(references, src_rate,
         | 
| 73 | 
            +
                                                   model.samplerate, model.audio_channels)
         | 
| 74 | 
            +
                        references = references.transpose(1, 2).numpy()
         | 
| 75 | 
            +
                        estimates = estimates.cpu().numpy()
         | 
| 76 | 
            +
                        win = int(1. * model.samplerate)
         | 
| 77 | 
            +
                        hop = int(1. * model.samplerate)
         | 
| 78 | 
            +
                        if save:
         | 
| 79 | 
            +
                            folder = eval_folder / "wav/test" / track.name
         | 
| 80 | 
            +
                            folder.mkdir(exist_ok=True, parents=True)
         | 
| 81 | 
            +
                            for name, estimate in zip(model.sources, estimates):
         | 
| 82 | 
            +
                                wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        if workers:
         | 
| 85 | 
            +
                            pendings.append((track.name, pool.submit(
         | 
| 86 | 
            +
                                museval.evaluate, references, estimates, win=win, hop=hop)))
         | 
| 87 | 
            +
                        else:
         | 
| 88 | 
            +
                            pendings.append((track.name, museval.evaluate(
         | 
| 89 | 
            +
                                references, estimates, win=win, hop=hop)))
         | 
| 90 | 
            +
                        del references, mix, estimates, track
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
         | 
| 93 | 
            +
                        if workers:
         | 
| 94 | 
            +
                            pending = pending.result()
         | 
| 95 | 
            +
                        sdr, isr, sir, sar = pending
         | 
| 96 | 
            +
                        track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
         | 
| 97 | 
            +
                        for idx, target in enumerate(model.sources):
         | 
| 98 | 
            +
                            values = {
         | 
| 99 | 
            +
                                "SDR": sdr[idx].tolist(),
         | 
| 100 | 
            +
                                "SIR": sir[idx].tolist(),
         | 
| 101 | 
            +
                                "ISR": isr[idx].tolist(),
         | 
| 102 | 
            +
                                "SAR": sar[idx].tolist()
         | 
| 103 | 
            +
                            }
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                            track_store.add_target(target_name=target, values=values)
         | 
| 106 | 
            +
                            json_path = json_folder / f"{track_name}.json.gz"
         | 
| 107 | 
            +
                            gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
         | 
| 108 | 
            +
                if world_size > 1:
         | 
| 109 | 
            +
                    distributed.barrier()
         | 
    	
        demucs/train.py
    ADDED
    
    | @@ -0,0 +1,127 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import tqdm
         | 
| 10 | 
            +
            from torch.utils.data import DataLoader
         | 
| 11 | 
            +
            from torch.utils.data.distributed import DistributedSampler
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .utils import apply_model, average_metric, center_trim
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def train_model(epoch,
         | 
| 17 | 
            +
                            dataset,
         | 
| 18 | 
            +
                            model,
         | 
| 19 | 
            +
                            criterion,
         | 
| 20 | 
            +
                            optimizer,
         | 
| 21 | 
            +
                            augment,
         | 
| 22 | 
            +
                            quantizer=None,
         | 
| 23 | 
            +
                            diffq=0,
         | 
| 24 | 
            +
                            repeat=1,
         | 
| 25 | 
            +
                            device="cpu",
         | 
| 26 | 
            +
                            seed=None,
         | 
| 27 | 
            +
                            workers=4,
         | 
| 28 | 
            +
                            world_size=1,
         | 
| 29 | 
            +
                            batch_size=16):
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                if world_size > 1:
         | 
| 32 | 
            +
                    sampler = DistributedSampler(dataset)
         | 
| 33 | 
            +
                    sampler_epoch = epoch * repeat
         | 
| 34 | 
            +
                    if seed is not None:
         | 
| 35 | 
            +
                        sampler_epoch += seed * 1000
         | 
| 36 | 
            +
                    sampler.set_epoch(sampler_epoch)
         | 
| 37 | 
            +
                    batch_size //= world_size
         | 
| 38 | 
            +
                    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=workers)
         | 
| 39 | 
            +
                else:
         | 
| 40 | 
            +
                    loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True)
         | 
| 41 | 
            +
                current_loss = 0
         | 
| 42 | 
            +
                model_size = 0
         | 
| 43 | 
            +
                for repetition in range(repeat):
         | 
| 44 | 
            +
                    tq = tqdm.tqdm(loader,
         | 
| 45 | 
            +
                                   ncols=120,
         | 
| 46 | 
            +
                                   desc=f"[{epoch:03d}] train ({repetition + 1}/{repeat})",
         | 
| 47 | 
            +
                                   leave=False,
         | 
| 48 | 
            +
                                   file=sys.stdout,
         | 
| 49 | 
            +
                                   unit=" batch")
         | 
| 50 | 
            +
                    total_loss = 0
         | 
| 51 | 
            +
                    for idx, sources in enumerate(tq):
         | 
| 52 | 
            +
                        if len(sources) < batch_size:
         | 
| 53 | 
            +
                            # skip uncomplete batch for augment.Remix to work properly
         | 
| 54 | 
            +
                            continue
         | 
| 55 | 
            +
                        sources = sources.to(device)
         | 
| 56 | 
            +
                        sources = augment(sources)
         | 
| 57 | 
            +
                        mix = sources.sum(dim=1)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                        estimates = model(mix)
         | 
| 60 | 
            +
                        sources = center_trim(sources, estimates)
         | 
| 61 | 
            +
                        loss = criterion(estimates, sources)
         | 
| 62 | 
            +
                        model_size = 0
         | 
| 63 | 
            +
                        if quantizer is not None:
         | 
| 64 | 
            +
                            model_size = quantizer.model_size()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        train_loss = loss + diffq * model_size
         | 
| 67 | 
            +
                        train_loss.backward()
         | 
| 68 | 
            +
                        grad_norm = 0
         | 
| 69 | 
            +
                        for p in model.parameters():
         | 
| 70 | 
            +
                            if p.grad is not None:
         | 
| 71 | 
            +
                                grad_norm += p.grad.data.norm()**2
         | 
| 72 | 
            +
                        grad_norm = grad_norm**0.5
         | 
| 73 | 
            +
                        optimizer.step()
         | 
| 74 | 
            +
                        optimizer.zero_grad()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                        if quantizer is not None:
         | 
| 77 | 
            +
                            model_size = model_size.item()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        total_loss += loss.item()
         | 
| 80 | 
            +
                        current_loss = total_loss / (1 + idx)
         | 
| 81 | 
            +
                        tq.set_postfix(loss=f"{current_loss:.4f}", ms=f"{model_size:.2f}",
         | 
| 82 | 
            +
                                       grad=f"{grad_norm:.5f}")
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        # free some space before next round
         | 
| 85 | 
            +
                        del sources, mix, estimates, loss, train_loss
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    if world_size > 1:
         | 
| 88 | 
            +
                        sampler.epoch += 1
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                if world_size > 1:
         | 
| 91 | 
            +
                    current_loss = average_metric(current_loss)
         | 
| 92 | 
            +
                return current_loss, model_size
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def validate_model(epoch,
         | 
| 96 | 
            +
                               dataset,
         | 
| 97 | 
            +
                               model,
         | 
| 98 | 
            +
                               criterion,
         | 
| 99 | 
            +
                               device="cpu",
         | 
| 100 | 
            +
                               rank=0,
         | 
| 101 | 
            +
                               world_size=1,
         | 
| 102 | 
            +
                               shifts=0,
         | 
| 103 | 
            +
                               overlap=0.25,
         | 
| 104 | 
            +
                               split=False):
         | 
| 105 | 
            +
                indexes = range(rank, len(dataset), world_size)
         | 
| 106 | 
            +
                tq = tqdm.tqdm(indexes,
         | 
| 107 | 
            +
                               ncols=120,
         | 
| 108 | 
            +
                               desc=f"[{epoch:03d}] valid",
         | 
| 109 | 
            +
                               leave=False,
         | 
| 110 | 
            +
                               file=sys.stdout,
         | 
| 111 | 
            +
                               unit=" track")
         | 
| 112 | 
            +
                current_loss = 0
         | 
| 113 | 
            +
                for index in tq:
         | 
| 114 | 
            +
                    streams = dataset[index]
         | 
| 115 | 
            +
                    # first five minutes to avoid OOM on --upsample models
         | 
| 116 | 
            +
                    streams = streams[..., :15_000_000]
         | 
| 117 | 
            +
                    streams = streams.to(device)
         | 
| 118 | 
            +
                    sources = streams[1:]
         | 
| 119 | 
            +
                    mix = streams[0]
         | 
| 120 | 
            +
                    estimates = apply_model(model, mix, shifts=shifts, split=split, overlap=overlap)
         | 
| 121 | 
            +
                    loss = criterion(estimates, sources)
         | 
| 122 | 
            +
                    current_loss += loss.item() / len(indexes)
         | 
| 123 | 
            +
                    del estimates, streams, sources
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                if world_size > 1:
         | 
| 126 | 
            +
                    current_loss = average_metric(current_loss, len(indexes))
         | 
| 127 | 
            +
                return current_loss
         | 
 
			
