diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..b05955480ac724da2f81bd0d93455ded07377668 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+gui/Poppins[[:space:]]Bold[[:space:]]700.ttf filter=lfs diff=lfs merge=lfs -text
+gui/Poppins[[:space:]]Regular[[:space:]]400.ttf filter=lfs diff=lfs merge=lfs -text
+input/APT.[[:space:]][[:space:]]YOR[[:space:]]X[[:space:]]LOID[[:space:]][[:space:]]AMV[[:space:]]4K.mp3 filter=lfs diff=lfs merge=lfs -text
+old_output/APT.[[:space:]][[:space:]]YOR[[:space:]]X[[:space:]]LOID[[:space:]][[:space:]]AMV[[:space:]]4K.mp3_Instrumental_Inst_GaboxV7_(by[[:space:]]Gabox)_old.wav filter=lfs diff=lfs merge=lfs -text
+output/APT.[[:space:]][[:space:]]YOR[[:space:]]X[[:space:]]LOID[[:space:]][[:space:]]AMV[[:space:]]4K.mp3_Instrumental_Inst_GaboxV7_(by[[:space:]]Gabox).wav filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..fc3be620893c26bfb77941c6b8cb8e66228abca8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,76 @@
+__pycache__
+.DS_Store
+*.py[cod]
+*$py.class
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+Lib/site-packages/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+share/man/man1/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+share/jupyter
+etc/jupyter
+
+# IPython
+profile_default/
+ipython_config.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+pyvenv.cfg
+Scripts/
+
+*.code-workspace
+
+results/
+wandb/
\ No newline at end of file
diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem
new file mode 100644
index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3
--- /dev/null
+++ b/.gradio/certificate.pem
@@ -0,0 +1,31 @@
+-----BEGIN CERTIFICATE-----
+MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
+TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
+cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
+WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
+ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
+MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
+h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
+0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
+A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
+T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
+B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
+B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
+KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
+OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
+jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
+qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
+rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
+HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
+hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
+ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
+3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
+NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
+ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
+TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
+jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
+oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
+4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
+mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
+emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
+-----END CERTIFICATE-----
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9d7186e88bca9975edd65956cd499fa60bd04251
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Roman Solovyev (ZFTurbo)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 430374a5a8644bb8b5ee12999e38a68920406cb7..fd36742673feb7db78938cb5cb5a2aa5bda86fbd 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,14 @@
+---
+title: Gecekondu Dubbing Production Studio
+emoji: 🎙️
+colorFrom: red
+colorTo: yellow # gold yerine yellow kullanıldı
+sdk: gradio
+sdk_version: "4.44.1"
+app_file: app.py
+pinned: false
+---
+
# Gecekondu Dubbing Production Space
-Bu Space, ses ayrıştırma ve dublaj işlemleri için profesyonel bir arayüz sunar.
+Bu Space, ses ayrıştırma ve dublaj işlemleri için profesyonel bir arayüz sunar. Gecekondu ekibi tarafından geliştirilmiştir.
diff --git a/WebUi2.py b/WebUi2.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f39c1d0fd36b516170791049729159750014a4
--- /dev/null
+++ b/WebUi2.py
@@ -0,0 +1,3244 @@
+import os
+
+os.chdir('/content/Music-Source-Separation-Training')
+import torch
+import yaml
+import gradio as gr
+import subprocess
+import threading
+import random
+import time
+import shutil
+import librosa
+import soundfile as sf
+import numpy as np
+import requests
+import json
+import locale
+import shutil
+from datetime import datetime
+import glob
+import yt_dlp
+import validators
+from pytube import YouTube
+from google.colab import auth
+from googleapiclient.discovery import build
+from googleapiclient.http import MediaIoBaseDownload
+import io
+import math
+import hashlib
+import re
+import gc
+import psutil
+import concurrent.futures
+from tqdm import tqdm
+from google.oauth2.credentials import Credentials
+import tempfile
+from urllib.parse import urlparse
+from urllib.parse import quote
+import gdown
+
+
+os.makedirs('/content/Music-Source-Separation-Training/input', exist_ok=True)
+os.makedirs('/content/Music-Source-Separation-Training/output', exist_ok=True)
+os.makedirs('/content/drive/MyDrive/output', exist_ok=True)
+os.makedirs('/content/drive/MyDrive/ensemble_folder', exist_ok=True)
+os.makedirs('/content/Music-Source-Separation-Training/old_output', exist_ok=True)
+os.makedirs('/content/Music-Source-Separation-Training/auto_ensemble_temp', exist_ok=True)
+os.makedirs('/content/Music-Source-Separation-Training/wav_folder', exist_ok=True)
+shutil.rmtree('/content/Music-Source-Separation-Training/ensemble', ignore_errors=True)
+shutil.rmtree('/content/Music-Source-Separation-Training/auto_ensemble_temp', ignore_errors=True)
+
+
+def clear_old_output():
+ old_output_folder = os.path.join(BASE_PATH, 'old_output')
+ try:
+ if not os.path.exists(old_output_folder):
+ return "❌ Old output folder does not exist"
+
+ # Tüm dosya ve alt klasörleri sil
+ shutil.rmtree(old_output_folder)
+ os.makedirs(old_output_folder, exist_ok=True)
+
+ return "✅ Old outputs successfully cleared!"
+
+ except Exception as e:
+ error_msg = f"🔥 Error: {str(e)}"
+ print(error_msg)
+ return error_msg
+
+ print("All files in old_output have been deleted.")
+
+
+def shorten_filename(filename, max_length=30):
+ """
+ Shortens a filename to a specified maximum length
+
+ Args:
+ filename (str): The filename to be shortened
+ max_length (int): Maximum allowed length for the filename
+
+ Returns:
+ str: Shortened filename
+ """
+ base, ext = os.path.splitext(filename)
+ if len(base) <= max_length:
+ return filename
+
+ # Take first 15 and last 10 characters
+ shortened = base[:15] + "..." + base[-10:] + ext
+ return shortened
+
+def update_progress(progress=gr.Progress()):
+ def track_progress(percent):
+ progress
+ (percent/100)
+ return track_progress
+
+# Özel karakterleri temizlemek için
+def clean_filename(title):
+ return re.sub(r'[^\w\-_\. ]', '', title).strip()
+
+def download_callback(url, download_type='direct', cookie_file=None):
+ try:
+ # 1. TEMİZLİK VE KLASÖR HAZIRLIĞI
+ BASE_PATH = "/content/Music-Source-Separation-Training"
+ INPUT_DIR = os.path.join(BASE_PATH, "input")
+ COOKIE_PATH = os.path.join(BASE_PATH, "cookies.txt")
+
+ # Input klasörünü temizle ve yeniden oluştur
+ clear_temp_folder(
+ "/tmp",
+ exclude_items=["gradio", "config.json"]
+ )
+ clear_directory(INPUT_DIR)
+ os.makedirs(INPUT_DIR, exist_ok=True)
+
+ # 2. URL DOĞRULAMA
+ if not validators.url(url):
+ return None, "❌ Invalid URL", None, None, None, None
+
+ # 3. COOKIE YÖNETİMİ
+ if cookie_file is not None:
+ try:
+ with open(cookie_file.name, "rb") as f:
+ cookie_content = f.read()
+ with open(COOKIE_PATH, "wb") as f:
+ f.write(cookie_content)
+ print("✅ Cookie file updated!")
+ except Exception as e:
+ print(f"⚠️ Cookie installation error: {str(e)}")
+
+ wav_path = None
+ download_success = False
+
+ # 4. İNDİRME TÜRÜNE GÖRE İŞLEM
+ if download_type == 'drive':
+ # GOOGLE DRIVE İNDİRME
+ try:
+ file_id = re.search(r'/d/([^/]+)', url).group(1) if '/d/' in url else url.split('id=')[-1]
+ original_filename = "drive_download.wav"
+
+ # Gdown ile indirme
+ output_path = os.path.join(INPUT_DIR, original_filename)
+ gdown.download(
+ f'https://drive.google.com/uc?id={file_id}',
+ output_path,
+ quiet=True,
+ fuzzy=True
+ )
+
+ if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
+ wav_path = output_path
+ download_success = True
+ print(f"✅ Downloaded from Google Drive: {wav_path}")
+ else:
+ raise Exception("File size zero or file not created")
+
+ except Exception as e:
+ error_msg = f"❌ Google Drive download error: {str(e)}"
+ print(error_msg)
+ return None, error_msg, None, None, None, None
+
+ else:
+ # YOUTUBE/DİREKT LİNK İNDİRME
+ ydl_opts = {
+ 'format': 'bestaudio/best',
+ 'outtmpl': os.path.join(INPUT_DIR, '%(title)s.%(ext)s'),
+ 'postprocessors': [{
+ 'key': 'FFmpegExtractAudio',
+ 'preferredcodec': 'wav',
+ 'preferredquality': '0'
+ }],
+ 'cookiefile': COOKIE_PATH if os.path.exists(COOKIE_PATH) else None,
+ 'nocheckcertificate': True,
+ 'ignoreerrors': True,
+ 'retries': 3
+ }
+
+ try:
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
+ info_dict = ydl.extract_info(url, download=True)
+ temp_path = ydl.prepare_filename(info_dict)
+ wav_path = os.path.splitext(temp_path)[0] + '.wav'
+
+ if os.path.exists(wav_path):
+ download_success = True
+ print(f"✅ Downloaded successfully: {wav_path}")
+ else:
+ raise Exception("WAV conversion failed")
+
+ except Exception as e:
+ error_msg = f"❌ Download error: {str(e)}"
+ print(error_msg)
+ return None, error_msg, None, None, None, None
+
+ # 5. SON KONTROLLER VE TEMİZLİK
+ if download_success and wav_path:
+ # Input klasöründeki gereksiz dosyaları temizle
+ for f in os.listdir(INPUT_DIR):
+ if f != os.path.basename(wav_path):
+ os.remove(os.path.join(INPUT_DIR, f))
+
+ return (
+ gr.File(value=wav_path),
+ "🎉 Downloaded successfully!",
+ gr.File(value=wav_path),
+ gr.File(value=wav_path),
+ gr.Audio(value=wav_path),
+ gr.Audio(value=wav_path)
+ )
+
+ return None, "❌ Download failed", None, None, None, None
+
+ except Exception as e:
+ error_msg = f"🔥 Critical Error: {str(e)}"
+ print(error_msg)
+ return None, error_msg, None, None, None, None
+
+
+# Hook function to track download progress
+def download_progress_hook(d):
+ if d['status'] == 'finished':
+ print('Download complete, conversion in progress...')
+ elif d['status'] == 'downloading':
+ downloaded_bytes = d.get('downloaded_bytes', 0)
+ total_bytes = d.get('total_bytes') or d.get('total_bytes_estimate', 0)
+ if total_bytes > 0:
+ percent = downloaded_bytes * 100. / total_bytes
+ print(f'Downloading: {percent:.1f}%')
+
+# Define the global variable at the top
+INPUT_DIR = "/content/Music-Source-Separation-Training/input"
+
+def download_file(url):
+ # Encode the URL to handle spaces and special characters
+ encoded_url = quote(url, safe=':/')
+
+ path = 'ckpts'
+ os.makedirs(path, exist_ok=True)
+ filename = os.path.basename(encoded_url)
+ file_path = os.path.join(path, filename)
+
+ if os.path.exists(file_path):
+ print(f"File '{filename}' already exists at '{path}'.")
+ return
+
+ try:
+ response = torch.hub.download_url_to_file(encoded_url, file_path)
+ print(f"File '{filename}' downloaded successfully")
+ except Exception as e:
+ print(f"Error downloading file '{filename}' from '{url}': {e}")
+
+
+
+def generate_random_port():
+ return random.randint(1000, 9000)
+
+ clear_memory()
+
+# Markdown annotations
+markdown_intro = """
+# Voice Parsing Tool
+
+This tool is used to parse audio files.
+"""
+
+class IndentDumper(yaml.Dumper):
+ def increase_indent(self, flow=False, indentless=False):
+ return super(IndentDumper, self).increase_indent(flow, False)
+
+
+def tuple_constructor(loader, node):
+ # Load the sequence of values from the YAML node
+ values = loader.construct_sequence(node)
+ # Return a tuple constructed from the sequence
+ return tuple(values)
+
+# Register the constructor with PyYAML
+yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple',
+tuple_constructor)
+
+
+
+def conf_edit(config_path, chunk_size, overlap):
+ with open(config_path, 'r') as f:
+ data = yaml.load(f, Loader=yaml.SafeLoader)
+
+ # handle cases where 'use_amp' is missing from config:
+ if 'use_amp' not in data.keys():
+ data['training']['use_amp'] = True
+
+ data['audio']['chunk_size'] = chunk_size
+ data['inference']['num_overlap'] = overlap
+
+ if data['inference']['batch_size'] == 1:
+ data['inference']['batch_size'] = 2
+
+ print("Using custom overlap and chunk_size values:")
+ print(f"overlap = {data['inference']['num_overlap']}")
+ print(f"chunk_size = {data['audio']['chunk_size']}")
+ print(f"batch_size = {data['inference']['batch_size']}")
+
+
+ with open(config_path, 'w') as f:
+ yaml.dump(data, f, default_flow_style=False, sort_keys=False, Dumper=IndentDumper, allow_unicode=True)
+
+def save_uploaded_file(uploaded_file, is_input=False, target_dir=None):
+ try:
+ # Medya dosya uzantıları
+ media_extensions = ['.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a', '.mp4']
+
+ # Hedef dizini belirle
+ if target_dir is None:
+ target_dir = INPUT_DIR if is_input else VİDEO_TEMP
+
+ # Zaman damgası pattern'leri
+ timestamp_patterns = [
+ r'_\d{8}_\d{6}_\d{6}$', # _20231215_123456_123456
+ r'_\d{14}$', # _20231215123456
+ r'_\d{10}$', # _1702658400
+ r'_\d+$' # Herhangi bir sayı
+ ]
+
+ # Dosya adını al
+ if hasattr(uploaded_file, 'name'):
+ original_filename = os.path.basename(uploaded_file.name)
+ else:
+ original_filename = os.path.basename(str(uploaded_file))
+
+ # Dosya adını temizle (sadece input'lar için)
+ if is_input:
+ base_filename = original_filename
+ # Zaman damgalarını sil
+ for pattern in timestamp_patterns:
+ base_filename = re.sub(pattern, '', base_filename)
+ # Çoklu uzantıları sil
+ for ext in media_extensions:
+ base_filename = base_filename.replace(ext, '')
+
+ # Dosya uzantısını belirle
+ file_ext = next(
+ (ext for ext in media_extensions if original_filename.lower().endswith(ext)),
+ '.wav'
+ )
+ clean_filename = f"{base_filename.strip('_- ')}{file_ext}"
+ else:
+ clean_filename = original_filename
+
+ # Hedef dizini belirle (DÜZELTME BURADA)
+ target_directory = INPUT_DIR if is_input else OUTPUT_DIR
+ target_path = os.path.join(target_dir, clean_filename)
+
+ # Dizini oluştur (yoksa)
+ os.makedirs(target_directory, exist_ok=True)
+
+ # Dizindeki TÜM önceki dosyaları sil
+ for filename in os.listdir(target_directory):
+ file_path = os.path.join(target_directory, filename)
+ try:
+ if os.path.isfile(file_path) or os.path.islink(file_path):
+ os.unlink(file_path)
+ elif os.path.isdir(file_path):
+ shutil.rmtree(file_path)
+ except Exception as e:
+ print(f"{file_path} Not deleted: {e}")
+
+ # Yeni dosyayı kaydet
+ if hasattr(uploaded_file, 'read'):
+ with open(target_path, "wb") as f:
+ f.write(uploaded_file.read())
+ else:
+ shutil.copy(uploaded_file, target_path)
+
+ print(f"File saved successfully: {os.path.basename(target_path)}")
+ return target_path
+
+ except Exception as e:
+ print(f"File save error: {e}")
+ return None
+
+ clear_memory()
+
+def save_uploaded_file(uploaded_file, is_input=False, target_dir=None):
+ try:
+ # Medya dosya uzantıları
+ media_extensions = ['.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a', '.mp4']
+
+ # Hedef dizini belirle
+ if target_dir is None:
+ target_dir = INPUT_DIR if is_input else OUTPUT_DIR
+
+ # Zaman damgası pattern'leri
+ timestamp_patterns = [
+ r'_\d{8}_\d{6}_\d{6}$', # _20231215_123456_123456
+ r'_\d{14}$', # _20231215123456
+ r'_\d{10}$', # _1702658400
+ r'_\d+$' # Herhangi bir sayı
+ ]
+
+ # Dosya adını al
+ if hasattr(uploaded_file, 'name'):
+ original_filename = os.path.basename(uploaded_file.name)
+ else:
+ original_filename = os.path.basename(str(uploaded_file))
+
+ # Dosya adını temizle (sadece input'lar için)
+ if is_input:
+ base_filename = original_filename
+ # Zaman damgalarını sil
+ for pattern in timestamp_patterns:
+ base_filename = re.sub(pattern, '', base_filename)
+ # Çoklu uzantıları sil
+ for ext in media_extensions:
+ base_filename = base_filename.replace(ext, '')
+
+ # Dosya uzantısını belirle
+ file_ext = next(
+ (ext for ext in media_extensions if original_filename.lower().endswith(ext)),
+ '.wav'
+ )
+ clean_filename = f"{base_filename.strip('_- ')}{file_ext}"
+ else:
+ clean_filename = original_filename
+
+ # Hedef dizini belirle
+ target_directory = INPUT_DIR if is_input else OUTPUT_DIR
+ target_path = os.path.join(target_directory, clean_filename)
+
+ # Dizini oluştur (yoksa)
+ os.makedirs(target_directory, exist_ok=True)
+
+ # Eğer dosya zaten varsa, sil
+ if os.path.exists(target_path):
+ os.remove(target_path)
+
+ # Yeni dosyayı kaydet
+ if hasattr(uploaded_file, 'read'):
+ with open(target_path, "wb") as f:
+ f.write(uploaded_file.read())
+ else:
+ shutil.copy(uploaded_file, target_path)
+
+ print(f"File saved successfully: {os.path.basename(target_path)}")
+ return target_path
+
+ except Exception as e:
+ print(f"File save error: {e}")
+ return None
+
+
+def clear_temp_folder(folder_path, exclude_items=None):
+ """
+ Safely clears contents of a directory while preserving specified items
+
+ Args:
+ folder_path (str): Path to directory to clean
+ exclude_items (list): Items to preserve (e.g., ['gradio', 'important.log'])
+
+ Returns:
+ bool: True if successful, False if failed
+ """
+ try:
+ # Validate directory existence
+ if not os.path.exists(folder_path):
+ print(f"⚠️ Directory does not exist: {folder_path}")
+ return False
+
+ if not os.path.isdir(folder_path):
+ print(f"⚠️ Path is not a directory: {folder_path}")
+ return False
+
+ # Initialize exclusion list
+ exclude_items = exclude_items or []
+ preserved_items = []
+
+ # Process directory contents
+ for item_name in os.listdir(folder_path):
+ item_path = os.path.join(folder_path, item_name)
+
+ # Skip excluded items
+ if item_name in exclude_items:
+ preserved_items.append(item_path)
+ continue
+
+ try:
+ # Delete files and symlinks
+ if os.path.isfile(item_path) or os.path.islink(item_path):
+ os.unlink(item_path)
+ print(f"🗑️ File removed: {item_path}")
+
+ # Delete directories
+ elif os.path.isdir(item_path):
+ shutil.rmtree(item_path)
+ print(f"🗂️ Directory removed: {item_path}")
+
+ except PermissionError as pe:
+ print(f"🔒 Permission denied: {item_path} ({str(pe)})")
+ continue
+
+ except Exception as e:
+ print(f"⚠️ Error deleting {item_path}: {str(e)}")
+ continue
+
+ # Print summary
+ print(f"\n✅ Cleaning completed: {folder_path}")
+ print(f"Total preserved items: {len(preserved_items)}")
+ if preserved_items:
+ print("Preserved items:")
+ for item in preserved_items:
+ print(f" - {item}")
+
+ return True
+
+ except Exception as e:
+ print(f"❌ Critical error: {str(e)}")
+ return False
+
+
+def handle_file_upload(file_obj, file_path_input, is_auto_ensemble=False):
+ try:
+ BASE_PATH = "/content/Music-Source-Separation-Training"
+ INPUT_DIR = os.path.join(BASE_PATH, "input")
+
+ # Yeni: Önceki dosyaları kontrol et
+ existing_files = os.listdir(INPUT_DIR)
+ new_file = None
+
+ # Dosya yolu girilmişse
+ if file_path_input and os.path.exists(file_path_input):
+ new_file = file_path_input
+ # Dosya yüklenmişse
+ elif file_obj:
+ new_file = file_obj.name # Gradio'nun geçici dosya yolu
+
+ # Yeni dosya yoksa mevcut dosyayı koru
+ if not new_file and existing_files:
+ kept_file = os.path.join(INPUT_DIR, existing_files[0])
+ return [
+ gr.File(value=kept_file),
+ gr.Audio(value=kept_file)
+ ]
+
+ # Yeni dosya varsa temizle ve yükle
+ if new_file:
+ clear_directory(INPUT_DIR) # Sadece yeni dosya geldiğinde temizle
+ saved_path = save_uploaded_file(new_file, is_input=True)
+ return [
+ gr.File(value=saved_path),
+ gr.Audio(value=saved_path)
+ ]
+
+ return [None, None]
+
+ except Exception as e:
+ print(f"Error: {str(e)}")
+ return [None, None]
+
+def move_old_files(output_folder):
+ old_output_folder = os.path.join(BASE_PATH, 'old_output')
+ os.makedirs(old_output_folder, exist_ok=True)
+
+ # Eski dosyaları taşı ve adlarının sonuna "old" ekle
+ for filename in os.listdir(output_folder):
+ file_path = os.path.join(output_folder, filename)
+ if os.path.isfile(file_path):
+ # Yeni dosya adını oluştur
+ new_filename = f"{os.path.splitext(filename)[0]}_old{os.path.splitext(filename)[1]}"
+ new_file_path = os.path.join(old_output_folder, new_filename)
+ shutil.move(file_path, new_file_path)
+
+def move_wav_files2(INPUT_DIR):
+ ENSEMBLE_DIR = os.path.join(BASE_PATH, 'ensemble')
+ os.makedirs(ENSEMBLE_DIR, exist_ok=True)
+
+ # Eski dosyaları taşı ve adlarının sonuna "old" ekle
+ for filename in os.listdir(INPUT_DIR):
+ file_path = os.path.join(INPUT_DIR, filename)
+ if os.path.isfile(file_path):
+ # Yeni dosya adını oluştur
+ new_filename = f"{os.path.splitext(filename)[0]}_ensemble{os.path.splitext(filename)[1]}"
+ new_file_path = os.path.join(ENSEMBLE_DIR, new_filename)
+ shutil.move(file_path, new_file_path)
+
+
+def extract_model_name(full_model_string):
+ """
+ Function to clear model name
+ """
+ if not full_model_string:
+ return ""
+
+ cleaned = str(full_model_string)
+
+ # Remove the description
+ if ' - ' in cleaned:
+ cleaned = cleaned.split(' - ')[0]
+
+ # Remove emoji prefixes
+ emoji_prefixes = ['✅ ', '👥 ', '🗣️ ', '🏛️ ', '🔇 ', '🔉 ', '🎬 ', '🎼 ', '✅(?) ']
+ for prefix in emoji_prefixes:
+ if cleaned.startswith(prefix):
+ cleaned = cleaned[len(prefix):]
+
+ return cleaned.strip()
+
+COOKIE_PATH = '/content/'
+BASE_PATH = '/content/Music-Source-Separation-Training'
+INPUT_DIR = os.path.join(BASE_PATH, 'input')
+AUTO_ENSEMBLE_TEMP = os.path.join(BASE_PATH, 'auto_ensemble_temp')
+OUTPUT_DIR = '/content/drive/MyDrive/output'
+OLD_OUTPUT_DIR = '/content/drive/MyDrive/old_output'
+AUTO_ENSEMBLE_OUTPUT = '/content/drive/MyDrive/ensemble_folder'
+INFERENCE_SCRIPT_PATH = '/content/Music-Source-Separation-Training/inference.py'
+VİDEO_TEMP = '/content/Music-Source-Separation-Training/video_temp'
+ENSEMBLE_DIR = '/content/Music-Source-Separation-Training/ensemble'
+os.makedirs(VİDEO_TEMP, exist_ok=True) # Klasörü oluşturduğundan emin ol
+Backup = '/content/backup'
+
+def clear_directory(directory):
+ """Deletes all files in the given directory."""
+ files = glob.glob(os.path.join(directory, '*')) # Dizin içindeki tüm dosyaları al
+ for f in files:
+ try:
+ os.remove(f) # remove files
+ except Exception as e:
+ print(f"{f} could not be deleted: {e}")
+
+def create_directory(directory):
+ """Creates the given directory (if it exists, if not)."""
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+ print(f"{directory} directory created.")
+ else:
+ print(f"{directory} directory already exists.")
+
+def convert_to_wav(file_path):
+ """Converts the audio file to WAV format and moves it to the ensemble directory."""
+
+ BASE_DIR = "/content/Music-Source-Separation-Training"
+ ENSEMBLE_DIR = os.path.join(BASE_DIR, "ensemble") # Define the ensemble directory
+ os.makedirs(ENSEMBLE_DIR, exist_ok=True) # Create the ensemble directory if it doesn't exist
+
+ original_filename = os.path.basename(file_path)
+ filename, ext = os.path.splitext(original_filename)
+
+ # If already a WAV file, return its path directly
+ if ext.lower() == '.wav':
+ return file_path # Return the original path if it's already a WAV file
+
+ try:
+ # Prepare for WAV conversion
+ wav_output = os.path.join(ENSEMBLE_DIR, f"{filename}.wav") # Save to ensemble directory
+
+ # Run FFmpeg command to convert to WAV
+ command = [
+ 'ffmpeg', '-y', '-i', file_path,
+ '-acodec', 'pcm_s16le', '-ar', '44100', wav_output
+ ]
+ subprocess.run(command, check=True, capture_output=True)
+
+ return wav_output # Return the path of the converted WAV file
+
+ except subprocess.CalledProcessError as e:
+ error_msg = f"FFmpeg Error ({e.returncode}): {e.stderr.decode()}"
+ print(error_msg)
+ return None
+ except Exception as e:
+ print(f"Error during conversion: {str(e)}")
+ return None
+
+def send_audio_file(file_path):
+ try:
+ if not os.path.exists(file_path):
+ print(f"File not found: {file_path}")
+ return None, "File not found"
+
+ with open(file_path, 'rb') as f:
+ data = f.read()
+ print(f"Sending file: {file_path}, Size: {len(data)} bytes")
+ return data, "Success"
+ except Exception as e:
+ print(f"Error sending file: {e}")
+ return None, str(e)
+
+
+def process_audio(input_audio_file, model, chunk_size, overlap, export_format, use_tta, demud_phaseremix_inst, extract_instrumental, clean_model, *args, **kwargs):
+ # Determine the audio path
+ if input_audio_file is not None:
+ audio_path = input_audio_file.name
+ else:
+ # Check for existing files in INPUT_DIR
+ create_directory(INPUT_DIR) # Ensure the directory exists
+ existing_files = os.listdir(INPUT_DIR)
+ if existing_files:
+ # Use the first existing file
+ audio_path = os.path.join(INPUT_DIR, existing_files[0])
+ else:
+ print("No audio file provided and no existing file in input directory.")
+ return [None] * 14 # Error case
+
+ # Create necessary directories
+ create_directory(OUTPUT_DIR)
+ create_directory(OLD_OUTPUT_DIR)
+
+ # Move old files to the OLD_OUTPUT_DIR
+ move_old_files(OUTPUT_DIR)
+
+ # Clean model name
+ clean_model = extract_model_name(model)
+ print(f"Processing audio from: {audio_path} using model: {clean_model}")
+
+ # Model configuration (remaining code)
+ model_type, config_path, start_check_point = "", "", ""
+
+ if clean_model == 'VOCALS-InstVocHQ':
+ model_type = 'mdx23c'
+ config_path = 'ckpts/config_vocals_mdx23c.yaml'
+ start_check_point = 'ckpts/model_vocals_mdx23c_sdr_10.17.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_mdx23c.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_mdx23c_sdr_10.17.ckpt')
+
+ elif clean_model == 'VOCALS-MelBand-Roformer (by KimberleyJSN)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_vocals_mel_band_roformer_kj.yaml'
+ start_check_point = 'ckpts/MelBandRoformer.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml')
+ download_file('https://huggingface.co/KimberleyJSN/melbandroformer/resolve/main/MelBandRoformer.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-BS-Roformer_1297 (by viperx)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/model_bs_roformer_ep_317_sdr_12.9755.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_317_sdr_12.9755.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml')
+ download_file('https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_317_sdr_12.9755.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-BS-Roformer_1296 (by viperx)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/model_bs_roformer_ep_368_sdr_12.9628.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_368_sdr_12.9628.ckpt'
+ download_file('https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_368_sdr_12.9628.ckpt')
+ download_file('https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_368_sdr_12.9628.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-BS-RoformerLargev1 (by unwa)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/config_bsrofoL.yaml'
+ start_check_point = 'ckpts/BS-Roformer_LargeV1.ckpt'
+ download_file('https://huggingface.co/jarredou/unwa_bs_roformer/resolve/main/BS-Roformer_LargeV1.ckpt')
+ download_file('https://huggingface.co/jarredou/unwa_bs_roformer/raw/main/config_bsrofoL.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-Mel-Roformer big beta 4 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_big_beta4.yaml'
+ start_check_point = 'ckpts/melband_roformer_big_beta4.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta4.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/raw/main/config_melbandroformer_big_beta4.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-Melband-Roformer BigBeta5e (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/big_beta5e.yaml'
+ start_check_point = 'ckpts/big_beta5e.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-Mel-Roformer v1 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_inst.yaml'
+ start_check_point = 'ckpts/melband_roformer_inst_v1.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v1.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/raw/main/config_melbandroformer_inst.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-Mel-Roformer v2 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_inst_v2.yaml'
+ start_check_point = 'ckpts/melband_roformer_inst_v2.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v2.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/raw/main/config_melbandroformer_inst_v2.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_instvoc_duality.yaml'
+ start_check_point = 'ckpts/melband_roformer_instvoc_duality_v1.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvoc_duality_v1.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/raw/main/config_melbandroformer_instvoc_duality.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_instvoc_duality.yaml'
+ start_check_point = 'ckpts/melband_roformer_instvox_duality_v2.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvox_duality_v2.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/raw/main/config_melbandroformer_instvoc_duality.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'KARAOKE-MelBand-Roformer (by aufr33 & viperx)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_mel_band_roformer_karaoke.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'OTHER-BS-Roformer_1053 (by viperx)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/model_bs_roformer_ep_937_sdr_10.5309.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_937_sdr_10.5309.ckpt'
+ download_file('https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt')
+ download_file('https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_937_sdr_10.5309.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'CROWD-REMOVAL-MelBand-Roformer (by aufr33)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer _crowd.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/model_mel_band_roformer_crowd.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-VitLarge23 (by ZFTurbo)':
+ model_type = 'segm_models'
+ config_path = 'ckpts/config_vocals_segm_models.yaml'
+ start_check_point = 'ckpts/model_vocals_segm_models_sdr_9.77.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/refs/heads/main/configs/config_vocals_segm_models.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_segm_models_sdr_9.77.ckpt')
+
+ elif clean_model == 'CINEMATIC-BandIt_Plus (by kwatcharasupat)':
+ model_type = 'bandit'
+ config_path = 'ckpts/config_dnr_bandit_bsrnn_multi_mus64.yaml'
+ start_check_point = 'ckpts/model_bandit_plus_dnr_sdr_11.47.chpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/config_dnr_bandit_bsrnn_multi_mus64.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/model_bandit_plus_dnr_sdr_11.47.chpt')
+
+ elif clean_model == 'DRUMSEP-MDX23C_DrumSep_6stem (by aufr33 & jarredou)':
+ model_type = 'mdx23c'
+ config_path = 'ckpts/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml'
+ start_check_point = 'ckpts/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt'
+ download_file('https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt')
+ download_file('https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml')
+
+ elif clean_model == '4STEMS-SCNet_MUSDB18 (by starrytong)':
+ model_type = 'scnet'
+ config_path = 'ckpts/config_musdb18_scnet.yaml'
+ start_check_point = 'ckpts/scnet_checkpoint_musdb18.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/config_musdb18_scnet.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/scnet_checkpoint_musdb18.ckpt')
+
+ elif clean_model == 'DE-REVERB-MDX23C (by aufr33 & jarredou)':
+ model_type = 'mdx23c'
+ config_path = 'ckpts/config_dereverb_mdx23c.yaml'
+ start_check_point = 'ckpts/dereverb_mdx23c_sdr_6.9096.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/dereverb_mdx23c_sdr_6.9096.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/config_dereverb_mdx23c.yaml')
+
+ elif clean_model == 'DENOISE-MelBand-Roformer-1 (by aufr33)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer_denoise.yaml'
+ start_check_point = 'ckpts/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'DENOISE-MelBand-Roformer-2 (by aufr33)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer_denoise.yaml'
+ start_check_point = 'ckpts/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-MelBand-Roformer Kim FT (by Unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_kimmel_unwa_ft.yaml'
+ start_check_point = 'ckpts/kimmel_unwa_ft.ckpt'
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft.ckpt')
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_v1e (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_inst.yaml'
+ start_check_point = 'ckpts/inst_v1e.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'bleed_suppressor_v1 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_bleed_suppressor_v1.yaml'
+ start_check_point = 'ckpts/bleed_suppressor_v1.ckpt'
+ download_file('https://huggingface.co/ASesYusuf1/MODELS/resolve/main/bleed_suppressor_v1.ckpt')
+ download_file('https://huggingface.co/ASesYusuf1/MODELS/resolve/main/config_bleed_suppressor_v1.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-MelBand-Roformer (by Becruily)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_instrumental_becruily.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_vocals_becruily.ckpt'
+ download_file('https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/config_vocals_becruily.yaml')
+ download_file('https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/mel_band_roformer_vocals_becruily.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-MelBand-Roformer (by Becruily)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_instrumental_becruily.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_instrumental_becruily.ckpt'
+ download_file('https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml')
+ download_file('https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == '4STEMS-SCNet_XL_MUSDB18 (by ZFTurbo)':
+ model_type = 'scnet'
+ config_path = 'ckpts/config_musdb18_scnet_xl.yaml'
+ start_check_point = 'ckpts/model_scnet_ep_54_sdr_9.8051.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/config_musdb18_scnet_xl.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/model_scnet_ep_54_sdr_9.8051.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == '4STEMS-SCNet_Large (by starrytong)':
+ model_type = 'scnet'
+ config_path = 'ckpts/config_musdb18_scnet_large_starrytong.yaml'
+ start_check_point = 'ckpts/SCNet-large_starrytong_fixed.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_scnet_large_starrytong.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/SCNet-large_starrytong_fixed.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == '4STEMS-BS-Roformer_MUSDB18 (by ZFTurbo)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/config_bs_roformer_384_8_2_485100.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_17_sdr_9.6568.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/config_bs_roformer_384_8_2_485100.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'DE-REVERB-MelBand-Roformer aggr./v2/19.1729 (by anvuew)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'DE-REVERB-Echo-MelBand-Roformer (by Sucial)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_dereverb-echo_mel_band_roformer.yaml'
+ start_check_point = 'ckpts/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt'
+ download_file('https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt')
+ download_file('https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb-echo_mel_band_roformer.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'dereverb_mel_band_roformer_less_aggressive_anvuew':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'dereverb_mel_band_roformer_anvuew':
+ model_type = 'mel_band_roformer'
+ config_path = 'dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gabox (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gabox.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_gaboxBV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxBv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gaboxBV2 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxBv2.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gaboxBFV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/gaboxFv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gaboxFV2 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxFv2.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'VOCALS-Male Female-BS-RoFormer Male Female Beta 7_2889 (by aufr33)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/config_chorus_male_female_bs_roformer.yaml'
+ start_check_point = 'ckpts/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt'
+ download_file('https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt')
+ download_file('https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_kimmel_unwa_ft.yaml'
+ start_check_point = 'ckpts/kimmel_unwa_ft2.ckpt'
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml')
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxBSroformer (by Gabox)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/voc_gaboxBSroformer.yaml'
+ start_check_point = 'ckpts/voc_gaboxBSR.ckpt'
+ download_file('https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSroformer.yaml')
+ download_file('https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSR.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxMelReformer (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_gabox.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxMelReformerFV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_gaboxFv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxMelReformerFV2 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_gaboxFv2.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_GaboxFv3 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxFv3.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv3.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'Intrumental_Gabox (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/intrumental_gabox.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/intrumental_gabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_Fv4Noise (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_Fv4Noise.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4Noise.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_V5 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV5.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'SYH99999/MelBandRoformerSYHFTB1_Model1 (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/model.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'SYH99999/MelBandRoformerSYHFTB1_Model2 (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/model2.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'SYH99999/MelBandRoformerSYHFTB1_Model3 (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/model3.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model3.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_kimmel_unwa_ft.yaml'
+ start_check_point = 'ckpts/kimmel_unwa_ft2_bleedless.ckpt'
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml')
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2_bleedless.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_gaboxFV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxFv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_gaboxFV6 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV6.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'denoisedebleed (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer_denoise.yaml'
+ start_check_point = 'ckpts/denoisedebleed.ckpt'
+ download_file('https://huggingface.co/poiqazwsx/melband-roformer-denoise/resolve/main/model_mel_band_roformer_denoise.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/denoisedebleed.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INSTV5N (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV5N.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5N.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'Voc_Fv3 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_Fv3.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_Fv3.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'MelBandRoformer4StemFTLarge (SYH99999)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/MelBandRoformer4StemFTLarge.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/MelBandRoformer4StemFTLarge.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'dereverb_mel_band_roformer_mono (by anvuew)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INSTV6N (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV6N.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6N.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'KaraokeGabox':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_mel_band_roformer_karaoke.yaml'
+ start_check_point = 'ckpts/KaraokeGabox.ckpt'
+ download_file('https://github.com/deton24/Colab-for-new-MDX_UVR_models/releases/download/v1.0.0/config_mel_band_roformer_karaoke.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/blob/main/melbandroformers/experimental/KaraokeGabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'FullnessVocalModel (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/FullnessVocalModel.ckpt'
+ download_file('https://huggingface.co/Aname-Tommy/MelBandRoformers/blob/main/config.yaml')
+ download_file('https://huggingface.co/Aname-Tommy/MelBandRoformers/blob/main/FullnessVocalModel.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'Inst_GaboxV7 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/Inst_GaboxV7.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxV7.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+
+
+
+
+
+
+
+ # Other model options will be added here...
+ # (All the elif blocks you gave in the previous code will go here)
+
+
+ else:
+ print(f"Unsupported model: {clean_model}")
+ return [None] * 14 # Hata durumu
+
+ result = run_command_and_process_files(model_type, config_path, start_check_point, INPUT_DIR, OUTPUT_DIR, extract_instrumental, use_tta, demud_phaseremix_inst, clean_model)
+
+ # İşlem tamamlandıktan sonra giriş dizinini temizle
+
+ return result
+
+
+def clean_model_name(model):
+ """
+ Clean and standardize model names for filename
+ """
+ # Mapping of complex model names to simpler, filename-friendly versions
+ model_name_mapping = {
+ 'VOCALS-InstVocHQ': 'InstVocHQ',
+ 'VOCALS-MelBand-Roformer (by KimberleyJSN)': 'KimberleyJSN',
+ 'VOCALS-BS-Roformer_1297 (by viperx)': 'VOCALS_BS_Roformer1297',
+ 'VOCALS-BS-Roformer_1296 (by viperx)': 'VOCALS-BS-Roformer_1296',
+ 'VOCALS-BS-RoformerLargev1 (by unwa)': 'UnwaLargeV1',
+ 'VOCALS-Mel-Roformer big beta 4 (by unwa)': 'UnwaBigBeta4',
+ 'VOCALS-Melband-Roformer BigBeta5e (by unwa)': 'UnwaBigBeta5e',
+ 'INST-Mel-Roformer v1 (by unwa)': 'UnwaInstV1',
+ 'INST-Mel-Roformer v2 (by unwa)': 'UnwaInstV2',
+ 'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)': 'UnwaDualityV1',
+ 'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)': 'UnwaDualityV2',
+ 'KARAOKE-MelBand-Roformer (by aufr33 & viperx)': 'KaraokeMelBandRoformer',
+ 'VOCALS-VitLarge23 (by ZFTurbo)': 'VitLarge23',
+ 'VOCALS-MelBand-Roformer (by Becruily)': 'BecruilyVocals',
+ 'INST-MelBand-Roformer (by Becruily)': 'BecruilyInst',
+ 'VOCALS-MelBand-Roformer Kim FT (by Unwa)': 'KimFT',
+ 'INST-MelBand-Roformer Kim FT (by Unwa)': 'KimFTInst',
+ 'OTHER-BS-Roformer_1053 (by viperx)': 'OtherViperx1053',
+ 'CROWD-REMOVAL-MelBand-Roformer (by aufr33)': 'CrowdRemovalRoformer',
+ 'CINEMATIC-BandIt_Plus (by kwatcharasupat)': 'CinematicBandItPlus',
+ 'DRUMSEP-MDX23C_DrumSep_6stem (by aufr33 & jarredou)': 'DrumSepMDX23C',
+ '4STEMS-SCNet_MUSDB18 (by starrytong)': 'FourStemsSCNet',
+ 'DE-REVERB-MDX23C (by aufr33 & jarredou)': 'DeReverbMDX23C',
+ 'DENOISE-MelBand-Roformer-1 (by aufr33)': 'DenoiseMelBand1',
+ 'DENOISE-MelBand-Roformer-2 (by aufr33)': 'DenoiseMelBand2',
+ 'INST-MelBand-Roformer (by Becruily)': 'BecruilyInst',
+ '4STEMS-SCNet_XL_MUSDB18 (by ZFTurbo)': 'FourStemsSCNetXL',
+ '4STEMS-SCNet_Large (by starrytong)': 'FourStemsSCNetLarge',
+ '4STEMS-BS-Roformer_MUSDB18 (by ZFTurbo)': 'FourStemsBSRoformer',
+ 'DE-REVERB-MelBand-Roformer aggr./v2/19.1729 (by anvuew)': 'DeReverbMelBandAggr',
+ 'DE-REVERB-Echo-MelBand-Roformer (by Sucial)': 'DeReverbEchoMelBand',
+ 'bleed_suppressor_v1 (by unwa)': 'BleedSuppressorV1',
+ 'inst_v1e (by unwa)': 'InstV1E',
+ 'inst_gabox ( by Gabox)': 'InstGabox',
+ 'inst_gaboxBV1 (by Gabox)': 'InstGaboxBV1',
+ 'inst_gaboxBV2 (by Gabox)': 'InstGaboxBV2',
+ 'inst_gaboxBFV1 (by Gabox)': 'InstGaboxBFV1',
+ 'inst_gaboxFV2 (by Gabox)': 'InstGaboxFV2',
+ 'inst_gaboxFV1 (by Gabox)': 'InstGaboxFV1',
+ 'dereverb_mel_band_roformer_less_aggressive_anvuew': 'DereverbMelBandRoformerLessAggressive',
+ 'dereverb_mel_band_roformer_anvuew': 'DereverbMelBandRoformer',
+ 'VOCALS-Male Female-BS-RoFormer Male Female Beta 7_2889 (by aufr33)': 'MaleFemale-BS-RoFormer-(by aufr33)',
+ 'VOCALS-MelBand-Roformer (by Becruily)': 'Vocals-MelBand-Roformer-(by Becruily)',
+ 'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)': 'Vocals-MelBand-Roformer-KİM-FT-2(by Unwa)',
+ 'voc_gaboxMelRoformer (by Gabox)': 'voc_gaboxMelRoformer',
+ 'voc_gaboxBSroformer (by Gabox)': 'voc_gaboxBSroformer',
+ 'voc_gaboxMelRoformerFV1 (by Gabox)': 'voc_gaboxMelRoformerFV1',
+ 'voc_gaboxMelRoformerFV2 (by Gabox)': 'voc_gaboxMelRoformerFV2',
+ 'SYH99999/MelBandRoformerSYHFTB1(by Amane)': 'MelBandRoformerSYHFTB1',
+ 'inst_V5 (by Gabox)': 'INSTV5-(by Gabox)',
+ 'inst_Fv4Noise (by Gabox)': 'Inst_Fv4Noise-(by Gabox)',
+ 'Intrumental_Gabox (by Gabox)': 'Intrumental_Gabox-(by Gabox)',
+ 'inst_GaboxFv3 (by Gabox)': 'INST_GaboxFv3-(by Gabox)',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model1 (by Amane)': 'MelBandRoformerSYHFTB1_model1',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model2 (by Amane)': 'MelBandRoformerSYHFTB1_model2',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model3 (by Amane)': 'MelBandRoformerSYHFTB1_model3',
+ 'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)': 'VOCALS-MelBand-Roformer-Kim-FT-2-Blendless-(by unwa)',
+ 'inst_gaboxFV6 (by Gabox)': 'inst_gaboxFV6-(by Gabox)',
+ 'denoisedebleed (by Gabox)': 'denoisedebleed-(by Gabox)',
+ 'INSTV5N (by Gabox)': 'INSTV5N_(by Gabox)',
+ 'Voc_Fv3 (by Gabox)': 'Voc_Fv3_(by Gabox)',
+ 'MelBandRoformer4StemFTLarge (SYH99999)': 'MelBandRoformer4StemFTLarge_(SYH99999)',
+ 'dereverb_mel_band_roformer_mono (by anvuew)': 'dereverb_mel_band_roformer_mono_(by anvuew)',
+ 'INSTV6N (by Gabox)': 'INSTV6N_(by Gabox)',
+ 'KaraokeGabox': 'KaraokeGabox',
+ 'FullnessVocalModel (by Amane)': 'FullnessVocalModel',
+ 'Inst_GaboxV7 (by Gabox)': 'Inst_GaboxV7_(by Gabox)',
+
+ # Add more mappings as needed
+ }
+
+ # Use mapping if exists, otherwise clean the model name
+ if model in model_name_mapping:
+ return model_name_mapping[model]
+
+ # General cleaning if not in mapping
+ cleaned = re.sub(r'\s*\(.*?\)', '', model) # Remove parenthetical info
+ cleaned = cleaned.replace('-', '_')
+ cleaned = ''.join(char for char in cleaned if char.isalnum() or char == '_')
+
+ return cleaned
+
+def shorten_filename(filename, max_length=30):
+ """
+ Shortens a filename to a specified maximum length
+
+ Args:
+ filename (str): The filename to be shortened
+ max_length (int): Maximum allowed length for the filename
+
+ Returns:
+ str: Shortened filename
+ """
+ base, ext = os.path.splitext(filename)
+ if len(base) <= max_length:
+ return filename
+
+ # Take first 15 and last 10 characters
+ shortened = base[:15] + "..." + base[-10:] + ext
+ return shortened
+
+def clean_filename(filename):
+ """
+ Temizlenmiş dosya adını döndürür
+ """
+ # Zaman damgası ve gereksiz etiketleri temizleme desenleri
+ cleanup_patterns = [
+ r'_\d{8}_\d{6}_\d{6}$', # _20231215_123456_123456
+ r'_\d{14}$', # _20231215123456
+ r'_\d{10}$', # _1702658400
+ r'_\d+$' # Herhangi bir sayı
+ ]
+
+ # Dosya adını ve uzantısını ayır
+ base, ext = os.path.splitext(filename)
+
+ # Zaman damgalarını temizle
+ for pattern in cleanup_patterns:
+ base = re.sub(pattern, '', base)
+
+ # Dosya türü etiketlerini temizle
+ file_types = ['vocals', 'instrumental', 'drum', 'bass', 'other', 'effects', 'speech', 'music', 'dry', 'male', 'female']
+ for type_keyword in file_types:
+ base = base.replace(f'_{type_keyword}', '')
+
+ # Dosya türünü tespit et
+ detected_type = None
+ for type_keyword in file_types:
+ if type_keyword in base.lower():
+ detected_type = type_keyword
+ break
+
+ # Zaman damgaları ve gereksiz etiketlerden temizlenmiş base
+ clean_base = base.strip('_- ')
+
+ return clean_base, detected_type, ext
+
+def run_command_and_process_files(model_type, config_path, start_check_point, INPUT_DIR, OUTPUT_DIR, extract_instrumental, use_tta, demud_phaseremix_inst, clean_model):
+ try:
+ # Komut parçalarını oluştur
+ cmd_parts = [
+ "python", "inference.py",
+ "--model_type", model_type,
+ "--config_path", config_path,
+ "--start_check_point", start_check_point,
+ "--input_folder", INPUT_DIR,
+ "--store_dir", OUTPUT_DIR, # İşlenecek ses dosyasının yolu
+ ]
+
+ # Opsiyonel parametreleri ekle
+ if extract_instrumental:
+ cmd_parts.append("--extract_instrumental")
+
+ if use_tta:
+ cmd_parts.append("--use_tta")
+
+ if demud_phaseremix_inst:
+ cmd_parts.append("--demud_phaseremix_inst")
+
+ # Komutu çalıştır
+ process = subprocess.Popen(
+ cmd_parts,
+ cwd=BASE_PATH,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ bufsize=1,
+ universal_newlines=True
+ )
+
+ # Çıktıları gerçek zamanlı olarak yazdır
+ for line in process.stdout:
+ print(line.strip())
+
+ for line in process.stderr:
+ print(line.strip())
+
+ process.wait()
+
+ # Model adını temizle
+ filename_model = clean_model_name(clean_model)
+
+ # Çıktı dosyalarını al
+ output_files = os.listdir(OUTPUT_DIR)
+
+ # Dosya yeniden adlandırma fonksiyonu
+ def rename_files_with_model(folder, filename_model):
+ for filename in sorted(os.listdir(folder)):
+ file_path = os.path.join(folder, filename)
+
+ # Medya dosyası değilse atla
+ if not any(filename.lower().endswith(ext) for ext in ['.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a']):
+ continue
+
+ base, ext = os.path.splitext(filename)
+
+ # Temiz base adı
+ clean_base = base.strip('_- ')
+
+ # Yeni dosya adını oluştur
+ new_filename = f"{clean_base}_{filename_model}{ext}"
+
+ new_file_path = os.path.join(folder, new_filename)
+ os.rename(file_path, new_file_path)
+
+ # Dosyaları yeniden adlandır
+ rename_files_with_model(OUTPUT_DIR, filename_model)
+
+ # Güncellenmiş dosya listesini al
+ output_files = os.listdir(OUTPUT_DIR)
+
+ # Dosya bulma fonksiyonu
+ def find_file(keyword):
+ matching_files = [
+ os.path.join(OUTPUT_DIR, f) for f in output_files
+ if keyword in f.lower()
+ ]
+ return matching_files[0] if matching_files else None
+
+ # Farklı dosya türlerini bul
+ vocal_file = find_file('vocals')
+ instrumental_file = find_file('instrumental')
+ phaseremix_file = find_file('phaseremix')
+ drum_file = find_file('drum')
+ bass_file = find_file('bass')
+ other_file = find_file('other')
+ effects_file = find_file('effects')
+ speech_file = find_file('speech')
+ music_file = find_file('music')
+ dry_file = find_file('dry')
+ male_file = find_file('male')
+ female_file = find_file('female')
+ bleed_file = find_file('bleed')
+ karaoke_file = find_file('karaoke')
+
+
+ # Bulunan dosyaları döndür
+ return (
+ vocal_file or None,
+ instrumental_file or None,
+ phaseremix_file or None,
+ drum_file or None,
+ bass_file or None,
+ other_file or None,
+ effects_file or None,
+ speech_file or None,
+ music_file or None,
+ dry_file or None,
+ male_file or None,
+ female_file or None,
+ bleed_file or None,
+ karaoke_file or None
+
+ )
+
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ return (None,) * 14
+
+
+
+def create_interface():
+ # Let's define the model options in advance
+ model_choices = {
+ "Vocal Separation": [
+ 'FullnessVocalModel (by Amane)',
+ 'Voc_Fv3 (by Gabox)',
+ 'VOCALS-BS-Roformer_1297 (by viperx)',
+ 'VOCALS-BS-Roformer_1296 (by viperx)',
+ '✅ VOCALS-Mel-Roformer big beta 4 (by unwa) - Melspectrogram based high performance',
+ 'VOCALS-BS-RoformerLargev1 (by unwa) - Comprehensive model',
+ 'VOCALS-InstVocHQ - General purpose model',
+ 'VOCALS-MelBand-Roformer (by KimberleyJSN) - Alternative model',
+ 'VOCALS-VitLarge23 (by ZFTurbo) - Transformer-based model',
+ 'VOCALS-MelBand-Roformer Kim FT (by Unwa)',
+ 'VOCALS-MelBand-Roformer (by Becruily)',
+ '✅ VOCALS-Melband-Roformer BigBeta5e (by unwa)',
+ 'VOCALS-Male Female-BS-RoFormer Male Female Beta 7_2889 (by aufr33)',
+ 'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)',
+ 'voc_gaboxMelRoforner (by Gabox)',
+ 'voc_gaboxBSroformer (by Gabox)',
+ 'voc_gaboxMelRoformerFV1 (by Gabox)',
+ 'voc_gaboxMelRoformerFV2 (by Gabox)',
+ 'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)'
+ ],
+ "Instrumental Separation": [
+ 'Inst_GaboxV7 (by Gabox)',
+ 'INSTV5N (by Gabox)',
+ 'inst_gaboxFV6 (by Gabox)',
+ '✅ INST-Mel-Roformer v2 (by unwa) - Most recent instrumental separation model',
+ '✅ inst_v1e (by unwa)',
+ '✅ INST-Mel-Roformer v1 (by unwa) - Old instrumental separation model',
+ 'INST-MelBand-Roformer (by Becruily)',
+ 'inst_gaboxFV2 (by Gabox)',
+ 'inst_gaboxFV1 (by Gabox)',
+ 'inst_gaboxBV2 (by Gabox)',
+ 'inst_gaboxBV1 (by Gabox)',
+ 'inst_gabox (by Gabox)',
+ '✅(?) inst_GaboxFv3 (by Gabox)',
+ 'Intrumental_Gabox (by Gabox)',
+ '✅(?) inst_Fv4Noise (by Gabox)',
+ '✅(?) inst_V5 (by Gabox)',
+ 'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa) - Latest version instrumental separation',
+ 'INST-VOC-Mel-Roformer a.k.a. duality (by unwa) - Previous version',
+ 'INST-Separator MDX23C (by aufr33) - Alternative instrumental separation',
+ 'INSTV6N (by Gabox)'
+ ],
+ "Karaoke & Accompaniment": [
+ '✅ KARAOKE-MelBand-Roformer (by aufr33 & viperx) - Advanced karaoke separation',
+ 'KaraokeGabox'
+ ],
+ "Noise & Effect Removal": [
+ 'denoisedebleed (by Gabox)',
+ '🔇 DENOISE-MelBand-Roformer-1 (by aufr33) - Basic noise reduction',
+ '🔉 DENOISE-MelBand-Roformer-2 (by aufr33) - Advanced noise reduction',
+ 'bleed_suppressor_v1 (by unwa) - dont use it if you dont know what youre doing',
+ 'dereverb_mel_band_roformer_mono (by anvuew)',
+ '👥 CROWD-REMOVAL-MelBand-Roformer (by aufr33) - Crowd noise removal',
+ '🏛️ DE-REVERB-MDX23C (by aufr33 & jarredou) - Reverb reduction',
+ '🏛️ DE-REVERB-MelBand-Roformer aggr./v2/19.1729 (by anvuew)',
+ '🗣️ DE-REVERB-Echo-MelBand-Roformer (by Sucial)',
+ 'dereverb_mel_band_roformer_less_aggressive_anvuew',
+ 'dereverb_mel_band_roformer_anvuew'
+
+
+ ],
+ "Drum Separation": [
+ '✅ DRUMSEP-MDX23C_DrumSep_6stem (by aufr33 & jarredou) - Detailed drum separation'
+ ],
+ "Multi-Stem & Other Models": [
+ 'MelBandRoformer4StemFTLarge (SYH99999)',
+ '🎬 4STEMS-SCNet_MUSDB18 (by starrytong) - Multi-stem separation',
+ '🎼 CINEMATIC-BandIt_Plus (by kwatcharasupat) - Cinematic music analysis',
+ 'OTHER-BS-Roformer_1053 (by viperx) - Other special models',
+ '4STEMS-SCNet_XL_MUSDB18 (by ZFTurbo)',
+ '4STEMS-SCNet_Large (by starrytong)',
+ '4STEMS-BS-Roformer_MUSDB18 (by ZFTurbo)',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model1 (by Amane)',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model2 (by Amane)',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model3 (by Amane)'
+ ],
+ }
+
+
+ def update_models(category):
+ models = model_choices.get(category, [])
+ return gr.Dropdown(
+ label="Select Model",
+ choices=models,
+ value=models[0] if models else None
+ )
+
+
+ def ensemble_files(args):
+ """
+ Ensemble audio files using the external script
+
+ Args:
+ args (list): Command-line arguments for ensemble script
+ """
+ try:
+
+ script_path = "/content/Music-Source-Separation-Training/ensemble.py"
+
+
+ full_command = ["python", script_path] + args
+
+
+ result = subprocess.run(
+ full_command,
+ capture_output=True,
+ text=True,
+ check=True
+ )
+
+ print("Ensemble successful:")
+ print(result.stdout)
+ return result.stdout
+
+ except subprocess.CalledProcessError as e:
+ print(f"Ensemble error: {e}")
+ print(f"Error output: {e.stderr}")
+ raise
+ except Exception as e:
+ print(f"Unexpected error during ensemble: {e}")
+ raise
+
+ def refresh_audio_files(directory):
+ """
+ Refreshes and lists audio files in the specified directory and old_output directory.
+
+ Args:
+ directory (str): Path of the directory to be scanned.
+
+ Returns:
+ list: List of discovered audio files.
+ """
+ try:
+ audio_extensions = ['.wav', '.mp3', '.flac', '.ogg']
+ audio_files = [
+ f for f in os.listdir(directory)
+ if os.path.isfile(os.path.join(directory, f))
+ and os.path.splitext(f)[1].lower() in audio_extensions
+ ]
+
+ # Eski dosyaları da kontrol et
+ old_output_directory = os.path.join(BASE_PATH, 'old_output')
+ old_audio_files = [
+ f for f in os.listdir(old_output_directory)
+ if os.path.isfile(os.path.join(old_output_directory, f))
+ and os.path.splitext(f)[1].lower() in audio_extensions
+ ]
+
+ return sorted(audio_files + old_audio_files)
+ except Exception as e:
+ print(f"Audio file listing error: {e}")
+ return []
+
+
+ # Global değişken tanımlamaları
+ BASE_PATH = '/content/Music-Source-Separation-Training'
+ AUTO_ENSEMBLE_TEMP = os.path.join(BASE_PATH, 'auto_ensemble_temp')
+ model_output_dir = os.path.join(BASE_PATH, 'auto_ensemble_temp')
+
+ def auto_ensemble_process(audio_input, selected_models, chunk_size, overlap, export_format2,
+ use_tta, extract_instrumental, ensemble_type,
+ progress=gr.Progress(), *args, **kwargs):
+ try:
+ # Ensure the ensemble directory exists
+ move_wav_files2(INPUT_DIR)
+ create_directory(ENSEMBLE_DIR)
+
+ # Handle audio input
+ if audio_input is not None:
+ temp_path = audio_input.name # Gradio'nun geçici dosya yolu
+ audio_path = os.path.join(ENSEMBLE_DIR, os.path.basename(temp_path))
+ else:
+ existing_files = os.listdir(ENSEMBLE_DIR)
+ if not existing_files:
+ return None, "❌ No audio file found"
+ audio_path = os.path.join(ENSEMBLE_DIR, existing_files[0])
+
+ # Model processing
+ all_outputs = []
+ total_models = len(selected_models)
+
+ for idx, model in enumerate(selected_models):
+ progress((idx + 1) / total_models, f"Processing {model}...")
+
+ clean_model = extract_model_name(model)
+ print(f"Processing using model: {clean_model}")
+
+ # Model output directory
+ model_output_dir = os.path.join(AUTO_ENSEMBLE_TEMP, clean_model)
+ os.makedirs(model_output_dir, exist_ok=True)
+
+ model_type, config_path, start_check_point = "", "", ""
+
+ if clean_model == 'VOCALS-InstVocHQ':
+ model_type = 'mdx23c'
+ config_path = 'ckpts/config_vocals_mdx23c.yaml'
+ start_check_point = 'ckpts/model_vocals_mdx23c_sdr_10.17.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_mdx23c.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_mdx23c_sdr_10.17.ckpt')
+
+ elif clean_model == 'VOCALS-MelBand-Roformer (by KimberleyJSN)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_vocals_mel_band_roformer_kj.yaml'
+ start_check_point = 'ckpts/MelBandRoformer.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml')
+ download_file('https://huggingface.co/KimberleyJSN/melbandroformer/resolve/main/MelBandRoformer.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-BS-Roformer_1297 (by viperx)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/model_bs_roformer_ep_317_sdr_12.9755.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_317_sdr_12.9755.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml')
+ download_file('https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_317_sdr_12.9755.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-BS-Roformer_1296 (by viperx)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/model_bs_roformer_ep_368_sdr_12.9628.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_368_sdr_12.9628.ckpt'
+ download_file('https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_368_sdr_12.9628.ckpt')
+ download_file('https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_368_sdr_12.9628.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-BS-RoformerLargev1 (by unwa)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/config_bsrofoL.yaml'
+ start_check_point = 'ckpts/BS-Roformer_LargeV1.ckpt'
+ download_file('https://huggingface.co/jarredou/unwa_bs_roformer/resolve/main/BS-Roformer_LargeV1.ckpt')
+ download_file('https://huggingface.co/jarredou/unwa_bs_roformer/raw/main/config_bsrofoL.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-Mel-Roformer big beta 4 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_big_beta4.yaml'
+ start_check_point = 'ckpts/melband_roformer_big_beta4.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/melband_roformer_big_beta4.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/raw/main/config_melbandroformer_big_beta4.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-Melband-Roformer BigBeta5e (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/big_beta5e.yaml'
+ start_check_point = 'ckpts/big_beta5e.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-big/resolve/main/big_beta5e.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-Mel-Roformer v1 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_inst.yaml'
+ start_check_point = 'ckpts/melband_roformer_inst_v1.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v1.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/raw/main/config_melbandroformer_inst.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-Mel-Roformer v2 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_inst_v2.yaml'
+ start_check_point = 'ckpts/melband_roformer_inst_v2.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/melband_roformer_inst_v2.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/raw/main/config_melbandroformer_inst_v2.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_instvoc_duality.yaml'
+ start_check_point = 'ckpts/melband_roformer_instvoc_duality_v1.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvoc_duality_v1.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/raw/main/config_melbandroformer_instvoc_duality.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_instvoc_duality.yaml'
+ start_check_point = 'ckpts/melband_roformer_instvox_duality_v2.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/resolve/main/melband_roformer_instvox_duality_v2.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-InstVoc-Duality/raw/main/config_melbandroformer_instvoc_duality.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'KARAOKE-MelBand-Roformer (by aufr33 & viperx)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_mel_band_roformer_karaoke.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/mel_band_roformer_karaoke_aufr33_viperx_sdr_10.1956.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33-viperx-karaoke-melroformer-model/resolve/main/config_mel_band_roformer_karaoke.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'OTHER-BS-Roformer_1053 (by viperx)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/model_bs_roformer_ep_937_sdr_10.5309.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_937_sdr_10.5309.ckpt'
+ download_file('https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt')
+ download_file('https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/mdx_c_configs/model_bs_roformer_ep_937_sdr_10.5309.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'CROWD-REMOVAL-MelBand-Roformer (by aufr33)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer _crowd.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/model_mel_band_roformer_crowd.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-VitLarge23 (by ZFTurbo)':
+ model_type = 'segm_models'
+ config_path = 'ckpts/config_vocals_segm_models.yaml'
+ start_check_point = 'ckpts/model_vocals_segm_models_sdr_9.77.ckpt'
+ download_file('https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/refs/heads/main/configs/config_vocals_segm_models.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_segm_models_sdr_9.77.ckpt')
+
+ elif clean_model == 'CINEMATIC-BandIt_Plus (by kwatcharasupat)':
+ model_type = 'bandit'
+ config_path = 'ckpts/config_dnr_bandit_bsrnn_multi_mus64.yaml'
+ start_check_point = 'ckpts/model_bandit_plus_dnr_sdr_11.47.chpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/config_dnr_bandit_bsrnn_multi_mus64.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/model_bandit_plus_dnr_sdr_11.47.chpt')
+
+ elif clean_model == 'DRUMSEP-MDX23C_DrumSep_6stem (by aufr33 & jarredou)':
+ model_type = 'mdx23c'
+ config_path = 'ckpts/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml'
+ start_check_point = 'ckpts/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt'
+ download_file('https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt')
+ download_file('https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml')
+
+ elif clean_model == '4STEMS-SCNet_MUSDB18 (by starrytong)':
+ model_type = 'scnet'
+ config_path = 'ckpts/config_musdb18_scnet.yaml'
+ start_check_point = 'ckpts/scnet_checkpoint_musdb18.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/config_musdb18_scnet.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/scnet_checkpoint_musdb18.ckpt')
+
+ elif clean_model == 'DE-REVERB-MDX23C (by aufr33 & jarredou)':
+ model_type = 'mdx23c'
+ config_path = 'ckpts/config_dereverb_mdx23c.yaml'
+ start_check_point = 'ckpts/dereverb_mdx23c_sdr_6.9096.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/dereverb_mdx23c_sdr_6.9096.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33_jarredou_MDXv3_DeReverb/resolve/main/config_dereverb_mdx23c.yaml')
+
+ elif clean_model == 'DENOISE-MelBand-Roformer-1 (by aufr33)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer_denoise.yaml'
+ start_check_point = 'ckpts/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'DENOISE-MelBand-Roformer-2 (by aufr33)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer_denoise.yaml'
+ start_check_point = 'ckpts/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt'
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt')
+ download_file('https://huggingface.co/jarredou/aufr33_MelBand_Denoise/resolve/main/model_mel_band_roformer_denoise.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-MelBand-Roformer Kim FT (by Unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_kimmel_unwa_ft.yaml'
+ start_check_point = 'ckpts/kimmel_unwa_ft.ckpt'
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft.ckpt')
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_v1e (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_melbandroformer_inst.yaml'
+ start_check_point = 'ckpts/inst_v1e.ckpt'
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/inst_v1e.ckpt')
+ download_file('https://huggingface.co/pcunwa/Mel-Band-Roformer-Inst/resolve/main/config_melbandroformer_inst.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'bleed_suppressor_v1 (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_bleed_suppressor_v1.yaml'
+ start_check_point = 'ckpts/bleed_suppressor_v1.ckpt'
+ download_file('https://huggingface.co/ASesYusuf1/MODELS/resolve/main/bleed_suppressor_v1.ckpt')
+ download_file('https://huggingface.co/ASesYusuf1/MODELS/resolve/main/config_bleed_suppressor_v1.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-MelBand-Roformer (by Becruily)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_instrumental_becruily.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_vocals_becruily.ckpt'
+ download_file('https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/config_vocals_becruily.yaml')
+ download_file('https://huggingface.co/becruily/mel-band-roformer-vocals/resolve/main/mel_band_roformer_vocals_becruily.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INST-MelBand-Roformer (by Becruily)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_instrumental_becruily.yaml'
+ start_check_point = 'ckpts/mel_band_roformer_instrumental_becruily.ckpt'
+ download_file('https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/config_instrumental_becruily.yaml')
+ download_file('https://huggingface.co/becruily/mel-band-roformer-instrumental/resolve/main/mel_band_roformer_instrumental_becruily.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == '4STEMS-SCNet_XL_MUSDB18 (by ZFTurbo)':
+ model_type = 'scnet'
+ config_path = 'ckpts/config_musdb18_scnet_xl.yaml'
+ start_check_point = 'ckpts/model_scnet_ep_54_sdr_9.8051.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/config_musdb18_scnet_xl.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/model_scnet_ep_54_sdr_9.8051.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == '4STEMS-SCNet_Large (by starrytong)':
+ model_type = 'scnet'
+ config_path = 'ckpts/config_musdb18_scnet_large_starrytong.yaml'
+ start_check_point = 'ckpts/SCNet-large_starrytong_fixed.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_scnet_large_starrytong.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/SCNet-large_starrytong_fixed.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == '4STEMS-BS-Roformer_MUSDB18 (by ZFTurbo)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/config_bs_roformer_384_8_2_485100.yaml'
+ start_check_point = 'ckpts/model_bs_roformer_ep_17_sdr_9.6568.ckpt'
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/config_bs_roformer_384_8_2_485100.yaml')
+ download_file('https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'DE-REVERB-MelBand-Roformer aggr./v2/19.1729 (by anvuew)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'DE-REVERB-Echo-MelBand-Roformer (by Sucial)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_dereverb-echo_mel_band_roformer.yaml'
+ start_check_point = 'ckpts/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt'
+ download_file('https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt')
+ download_file('https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb-echo_mel_band_roformer.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'dereverb_mel_band_roformer_less_aggressive_anvuew':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_less_aggressive_anvuew_sdr_18.8050.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'dereverb_mel_band_roformer_anvuew':
+ model_type = 'mel_band_roformer'
+ config_path = 'dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gabox (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gabox.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_gaboxBV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxBv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gaboxBV2 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxBv2.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxBv2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gaboxBFV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/gaboxFv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'inst_gaboxFV2 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxFv2.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'VOCALS-Male Female-BS-RoFormer Male Female Beta 7_2889 (by aufr33)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/config_chorus_male_female_bs_roformer.yaml'
+ start_check_point = 'ckpts/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt'
+ download_file('https://huggingface.co/RareSirMix/AIModelRehosting/resolve/main/bs_roformer_male_female_by_aufr33_sdr_7.2889.ckpt')
+ download_file('https://huggingface.co/Sucial/Chorus_Male_Female_BS_Roformer/resolve/main/config_chorus_male_female_bs_roformer.yaml')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+ elif clean_model == 'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_kimmel_unwa_ft.yaml'
+ start_check_point = 'ckpts/kimmel_unwa_ft2.ckpt'
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml')
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxBSroformer (by Gabox)':
+ model_type = 'bs_roformer'
+ config_path = 'ckpts/voc_gaboxBSroformer.yaml'
+ start_check_point = 'ckpts/voc_gaboxBSR.ckpt'
+ download_file('https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSroformer.yaml')
+ download_file('https://huggingface.co/GaboxR67/BSRoformerVocTest/resolve/main/voc_gaboxBSR.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxMelReformer (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_gabox.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxMelReformerFV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_gaboxFv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'voc_gaboxMelReformerFV2 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_gaboxFv2.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gaboxFv2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_GaboxFv3 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxFv3.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv3.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'Intrumental_Gabox (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/intrumental_gabox.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/intrumental_gabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_Fv4Noise (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_Fv4Noise.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_Fv4Noise.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_V5 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV5.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'SYH99999/MelBandRoformerSYHFTB1_Model1 (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/model.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'SYH99999/MelBandRoformerSYHFTB1_Model2 (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/model2.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model2.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'SYH99999/MelBandRoformerSYHFTB1_Model3 (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/model3.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformerSYHFTB1/resolve/main/model3.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_kimmel_unwa_ft.yaml'
+ start_check_point = 'ckpts/kimmel_unwa_ft2_bleedless.ckpt'
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/config_kimmel_unwa_ft.yaml')
+ download_file('https://huggingface.co/pcunwa/Kim-Mel-Band-Roformer-FT/resolve/main/kimmel_unwa_ft2_bleedless.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_gaboxFV1 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/inst_gaboxFv1.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gaboxFv1.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'inst_gaboxFV6 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV6.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'denoisedebleed (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/model_mel_band_roformer_denoise.yaml'
+ start_check_point = 'ckpts/denoisedebleed.ckpt'
+ download_file('https://huggingface.co/poiqazwsx/melband-roformer-denoise/resolve/main/model_mel_band_roformer_denoise.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/denoisedebleed.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INSTV5N (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV5N.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV5N.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'Voc_Fv3 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/voc_gabox.yaml'
+ start_check_point = 'ckpts/voc_Fv3.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/vocals/voc_Fv3.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'MelBandRoformer4StemFTLarge (SYH99999)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/MelBandRoformer4StemFTLarge.ckpt'
+ download_file('https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/config.yaml')
+ download_file('https://huggingface.co/SYH99999/MelBandRoformer4StemFTLarge/resolve/main/MelBandRoformer4StemFTLarge.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'dereverb_mel_band_roformer_mono (by anvuew)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/dereverb_mel_band_roformer_anvuew.yaml'
+ start_check_point = 'ckpts/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt'
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml')
+ download_file('https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_mono_anvuew_sdr_20.4029.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'INSTV6N (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/INSTV6N.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/INSTV6N.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'KaraokeGabox':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config_mel_band_roformer_karaoke.yaml'
+ start_check_point = 'ckpts/KaraokeGabox.ckpt'
+ download_file('https://github.com/deton24/Colab-for-new-MDX_UVR_models/releases/download/v1.0.0/config_mel_band_roformer_karaoke.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/blob/main/melbandroformers/experimental/KaraokeGabox.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'FullnessVocalModel (by Amane)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/config.yaml'
+ start_check_point = 'ckpts/FullnessVocalModel.ckpt'
+ download_file('https://huggingface.co/Aname-Tommy/MelBandRoformers/blob/main/config.yaml')
+ download_file('https://huggingface.co/Aname-Tommy/MelBandRoformers/blob/main/FullnessVocalModel.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+ elif clean_model == 'Inst_GaboxV7 (by Gabox)':
+ model_type = 'mel_band_roformer'
+ config_path = 'ckpts/inst_gabox.yaml'
+ start_check_point = 'ckpts/Inst_GaboxV7.ckpt'
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/inst_gabox.yaml')
+ download_file('https://huggingface.co/GaboxR67/MelBandRoformers/resolve/main/melbandroformers/instrumental/Inst_GaboxV7.ckpt')
+ conf_edit(config_path, chunk_size, overlap)
+
+
+
+ # Ana sekme komut yapısını kullan
+ cmd = [
+ "python",
+ "inference.py",
+ "--model_type", model_type,
+ "--config_path", config_path,
+ "--start_check_point", start_check_point,
+ "--input_folder", ENSEMBLE_DIR,
+ "--store_dir", model_output_dir,
+ ]
+
+ if use_tta:
+ cmd.append("--use_tta")
+ if extract_instrumental:
+ cmd.append("--extract_instrumental")
+
+ print(f"Running command: {' '.join(cmd)}")
+
+ # Hata yakalama ile çalıştırma
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ print(result.stdout)
+ if result.returncode != 0:
+ print(f"Error: {result.stderr}")
+ return None, f"Model {model} failed: {result.stderr}"
+ except Exception as e:
+ return None, f"Critical error with {model}: {str(e)}"
+
+ # Çıktı dosyalarını topla
+ model_outputs = glob.glob(os.path.join(model_output_dir, "*.wav"))
+ all_outputs.extend(model_outputs)
+
+ # 3. Çıktı dosyalarını kontrol et
+ output_files = glob.glob(os.path.join(model_output_dir, "*.wav"))
+ if not output_files:
+ raise FileNotFoundError(f"{model} failed to produce output")
+
+ model_outputs.extend(output_files)
+
+
+ # 4. Dosya bekletme ve kontrol
+ def wait_for_files(files, timeout=300):
+ start = time.time()
+ while time.time() - start < timeout:
+ missing = [f for f in files if not os.path.exists(f)]
+ if not missing: return True
+ time.sleep(5)
+ raise TimeoutError(f"Missing files: {missing[:3]}...")
+
+ wait_for_files(model_outputs)
+
+ # 5. Ensemble komutunu güvenli oluştur
+ quoted_files = [f'"{f}"' for f in model_outputs]
+ timestamp = str(int(time.time()))
+ output_path = os.path.join(AUTO_ENSEMBLE_OUTPUT, f"ensemble_{timestamp}.wav")
+
+ ensemble_cmd = [
+ "python", "ensemble.py",
+ "--files", *quoted_files,
+ "--type", ensemble_type,
+ "--output", f'"{output_path}"'
+ ]
+
+ # 6. Komutu çalıştır
+ result = subprocess.run(
+ " ".join(ensemble_cmd),
+ shell=True,
+ capture_output=True,
+ text=True,
+ check=True
+ )
+
+ # 7. Son kontrol
+ if not os.path.exists(output_path):
+ raise RuntimeError("Ensemble dosyası oluşturulamadı")
+
+ return output_path, "✅ Success!"
+
+ except Exception as e:
+ return None, f"❌ Error: {str(e)}"
+
+ finally:
+ # Temizlik
+ shutil.rmtree('/content/Music-Source-Separation-Training/ensemble', ignore_errors=True)
+ shutil.rmtree('/content/Music-Source-Separation-Training/ensemble', ignore_errors=True)
+ clear_directory(VİDEO_TEMP)
+ clear_directory(ENSEMBLE_DIR)
+ gc.collect()
+
+
+ main_input_key = "shared_audio_input"
+ # Global components
+ input_audio_file = gr.File(visible=True)
+ auto_input_audio_file = gr.File(visible=True)
+ original_audio = gr.Audio(visible=True)
+
+
+ css = """
+ /* Genel Tema */
+ body {
+ background: url('/content/logo.jpg') no-repeat center center fixed;
+ background-size: cover;
+ background-color: #2d0b0b; /* Koyu kırmızı, dublaj stüdyosuna uygun */
+ min-height: 100vh;
+ margin: 0;
+ padding: 1rem;
+ font-family: 'Poppins', sans-serif;
+ color: #C0C0C0; /* Metalik gümüş metin, profesyonel görünüm */
+ }
+
+ body::after {
+ content: '';
+ position: fixed;
+ top: 0;
+ left: 0;
+ width: 100%;
+ height: 100%;
+ background: rgba(45, 11, 11, 0.9); /* Daha koyu kırmızı overlay */
+ z-index: -1;
+ }
+
+ /* Logo Stilleri */
+ .logo-container {
+ position: absolute;
+ top: 1rem;
+ left: 50%;
+ transform: translateX(-50%);
+ display: flex;
+ align-items: center;
+ z-index: 2000; /* Diğer öğelerden üstte, mutlaka görünür */
+ }
+
+ .logo-img {
+ width: 120px;
+ height: auto;
+ }
+
+ /* Başlık Stilleri */
+ .header-text {
+
+ text-align: center;
+ padding: 80px 20px 20px; /* Logo için alan bırak */
+ color: #ff4040; /* Kırmızı, dublaj temasına uygun */
+ font-size: 2.5rem; /* Daha etkileyici ve büyük başlık */
+ font-weight: 900; /* Daha kalın ve dramatik */
+ text-shadow: 0 0 10px rgba(255, 64, 64, 0.5); /* Kırmızı gölge efekti */
+ z-index: 1500; /* Tablerden üstte, logonun altında */
+ }
+
+ /* Metalik kırmızı parlama animasyonu */
+ @keyframes metallic-red-shine {
+ 0% { filter: brightness(1) saturate(1) drop-shadow(0 0 5px #ff4040); }
+ 50% { filter: brightness(1.3) saturate(1.7) drop-shadow(0 0 15px #ff6b6b); }
+ 100% { filter: brightness(1) saturate(1) drop-shadow(0 0 5px #ff4040); }
+ }
+
+ /* Dublaj temalı stil */
+ .dubbing-theme {
+ background: linear-gradient(to bottom, #800000, #2d0b0b); /* Koyu kırmızı gradyan */
+ border-radius: 15px;
+ padding: 1rem;
+ box-shadow: 0 10px 20px rgba(255, 64, 64, 0.3); /* Kırmızı gölge */
+ }
+
+ /* Footer Stilleri (Tablerin Üstünde, Şeffaf) */
+ .footer {
+ text-align: center;
+ padding: 10px;
+ color: #ff4040; /* Kırmızı metin, dublaj temasına uygun */
+ font-size: 14px;
+ margin-top: 20px;
+ position: relative;
+ z-index: 1001; /* Tablerden üstte, logodan düşük */
+ }
+
+ /* Düğme ve Yükleme Alanı Stilleri */
+ button {
+ transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1) !important;
+ background: #800000 !important; /* Koyu kırmızı, dublaj temasına uygun */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ border-radius: 8px !important;
+ padding: 8px 16px !important;
+ position: relative;
+ overflow: hidden !important;
+ font-size: 0.9rem !important;
+ }
+
+ button:hover {
+ transform: scale(1.05) !important;
+ box-shadow: 0 10px 40px rgba(255, 64, 64, 0.7) !important; /* Daha belirgin kırmızı gölge */
+ background: #ff4040 !important; /* Daha açık kırmızı hover efekti */
+ }
+
+ button::before {
+ content: '';
+ position: absolute;
+ top: -50%;
+ left: -50%;
+ width: 200%;
+ height: 200%;
+ background: linear-gradient(45deg,
+ transparent 20%,
+ rgba(192, 192, 192, 0.3) 50%, /* Metalik gümüş ton */
+ transparent 80%);
+ animation: button-shine 3s infinite linear;
+ }
+
+ /* Resim ve Ses Yükleme Alanı Stili */
+ .compact-upload.horizontal {
+ display: inline-flex !important;
+ align-items: center !important;
+ gap: 8px !important;
+ max-width: 400px !important;
+ height: 40px !important;
+ padding: 0 12px !important;
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ background: rgba(128, 0, 0, 0.5) !important; /* Koyu kırmızı, şeffaf */
+ border-radius: 8px !important;
+ transition: all 0.2s ease !important;
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ }
+
+ .compact-upload.horizontal:hover {
+ border-color: #ff6b6b !important; /* Daha açık kırmızı */
+ background: rgba(128, 0, 0, 0.7) !important; /* Daha koyu kırmızı hover */
+ }
+
+ .compact-upload.horizontal .w-full {
+ flex: 1 1 auto !important;
+ min-width: 120px !important;
+ margin: 0 !important;
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ }
+
+ .compact-upload.horizontal button {
+ padding: 4px 12px !important;
+ font-size: 0.75em !important;
+ height: 28px !important;
+ min-width: 80px !important;
+ border-radius: 4px !important;
+ background: #800000 !important; /* Koyu kırmızı */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ }
+
+ .compact-upload.horizontal .text-gray-500 {
+ font-size: 0.7em !important;
+ color: rgba(192, 192, 192, 0.6) !important; /* Şeffaf metalik gümüş */
+ white-space: nowrap !important;
+ overflow: hidden !important;
+ text-overflow: ellipsis !important;
+ max-width: 180px !important;
+ }
+
+ /* Ekstra Dar Versiyon */
+ .compact-upload.horizontal.x-narrow {
+ max-width: 320px !important;
+ height: 36px !important;
+ padding: 0 10px !important;
+ gap: 6px !important;
+ }
+
+ .compact-upload.horizontal.x-narrow button {
+ padding: 3px 10px !important;
+ font-size: 0.7em !important;
+ height: 26px !important;
+ min-width: 70px !important;
+ }
+
+ .compact-upload.horizontal.x-narrow .text-gray-500 {
+ font-size: 0.65em !important;
+ max-width: 140px !important;
+ }
+
+ /* Sekmeler İçin Ortak Stiller */
+ .gr-tab {
+ background: rgba(128, 0, 0, 0.5) !important; /* Koyu kırmızı, şeffaf */
+ border-radius: 12px 12px 0 0 !important;
+ margin: 0 5px !important;
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ z-index: 1500; /* Logo’nun altında, diğer öğelerden üstte */
+ }
+
+ .gr-tab-selected {
+ background: #800000 !important; /* Koyu kırmızı */
+ box-shadow: 0 4px 12px rgba(255, 64, 64, 0.7) !important; /* Daha belirgin kırmızı gölge */
+ color: #ffffff !important; /* Beyaz metin (seçili sekme için kontrast) */
+ border: 1px solid #ff6b6b !important; /* Daha açık kırmızı */
+ }
+
+ /* Manuel Ensemble Özel Stilleri */
+ .compact-header {
+ font-size: 0.95em !important;
+ margin: 0.8rem 0 0.5rem 0 !important;
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ }
+
+ .compact-grid {
+ gap: 0.4rem !important;
+ max-height: 50vh;
+ overflow-y: auto;
+ padding: 10px;
+ background: rgba(128, 0, 0, 0.3) !important; /* Koyu kırmızı, şeffaf */
+ border-radius: 12px;
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ }
+
+ .compact-dropdown {
+ --padding: 8px 12px !important;
+ --radius: 10px !important;
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ background: rgba(128, 0, 0, 0.5) !important; /* Koyu kırmızı, şeffaf */
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ }
+
+ .tooltip-icon {
+ font-size: 1.4em !important;
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ cursor: help;
+ margin-left: 0.5rem !important;
+ }
+
+ .log-box {
+ font-family: 'Fira Code', monospace !important;
+ font-size: 0.85em !important;
+ background-color: rgba(128, 0, 0, 0.3) !important; /* Koyu kırmızı, şeffaf */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ border-radius: 8px;
+ padding: 1rem !important;
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ }
+
+ /* Animasyonlar */
+ @keyframes text-glow {
+ 0% { text-shadow: 0 0 5px rgba(192, 192, 192, 0); }
+ 50% { text-shadow: 0 0 15px rgba(192, 192, 192, 1); }
+ 100% { text-shadow: 0 0 5px rgba(192, 192, 192, 0); }
+ }
+
+ @keyframes button-shine {
+ 0% { transform: rotate(0deg) translateX(-50%); }
+ 100% { transform: rotate(360deg) translateX(-50%); }
+ }
+
+ /* Responsive Ayarlar */
+ @media (max-width: 768px) {
+ .compact-grid {
+ max-height: 40vh;
+ }
+
+ .compact-upload.horizontal {
+ max-width: 100% !important;
+ width: 100% !important;
+ }
+
+ .compact-upload.horizontal .text-gray-500 {
+ max-width: 100px !important;
+ }
+
+ .compact-upload.horizontal.x-narrow {
+ height: 40px !important;
+ padding: 0 8px !important;
+ }
+
+ .logo-container {
+ width: 80px; /* Mobil cihazlarda daha küçük logo */
+ top: 1rem;
+ left: 50%;
+ transform: translateX(-50%);
+ }
+
+ .header-text {
+ padding: 60px 20px 20px; /* Mobil için daha az boşluk */
+ font-size: 1.8rem; /* Mobil için biraz daha küçük başlık */
+ }
+ }
+ """
+
+ # Arayüz tasarımı
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
+ with gr.Column():
+ # Logo (PNG olarak, dublaj temasına uygun)
+ logo_html = """
+
+

+
+ """
+ gr.HTML(logo_html)
+
+ # Başlık (Etkileyici ve dublaj temalı)
+ gr.HTML("""
+
+ """)
+
+ with gr.Tabs():
+ with gr.Tab("Audio Separation", elem_id="separation_tab"):
+ with gr.Row(equal_height=True):
+ # Sol Panel - Kontroller
+ with gr.Column(scale=1, min_width=380):
+ with gr.Accordion("📥 Input & Model", open=True):
+ with gr.Tabs():
+ with gr.Tab("🖥 Upload"):
+ input_audio_file = gr.File(
+ file_types=[".wav", ".mp3", ".m4a", ".mp4", ".mkv", ".flac"],
+ elem_classes=["compact-upload", "horizontal", "x-narrow"],
+ label="",
+ scale=1
+ )
+
+ with gr.Tab("📂 Path"):
+ file_path_input = gr.Textbox(placeholder="/path/to/audio.wav")
+
+
+ with gr.Row():
+ model_category = gr.Dropdown(
+ label="Category",
+ choices=list(model_choices.keys()),
+ value="Vocal Separation"
+ )
+ model_dropdown = gr.Dropdown(label="Model")
+
+ with gr.Accordion("⚙ Settings", open=False):
+ with gr.Row():
+ export_format = gr.Dropdown(
+ label="Format",
+ choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'],
+ value='wav FLOAT'
+ )
+ chunk_size = gr.Dropdown(
+ label="Chunk Size",
+ choices=[352800, 485100],
+ value=352800,
+ info="Don't change unless you have specific requirements"
+ )
+
+ with gr.Row():
+ overlap = gr.Slider(2, 50, step=1, label="Overlap")
+ info="Recommended: 2-10 (Higher values increase quality but require more VRAM)"
+ use_tta = gr.Checkbox(label="TTA Boost")
+ info="Improves quality but increases processing time"
+
+ with gr.Row():
+ use_demud_phaseremix_inst = gr.Checkbox(label="Phase Fix")
+ info="Advanced phase correction for instrumental tracks"
+ extract_instrumental = gr.Checkbox(label="Instrumental")
+
+ with gr.Row():
+ process_btn = gr.Button("🚀 Process", variant="primary")
+ clear_old_output_btn = gr.Button("🧹 Reset", variant="secondary")
+ clear_old_output_status = gr.Textbox(label="Status", interactive=False)
+
+ # Sağ Panel - Sonuçlar
+ with gr.Column(scale=2, min_width=800):
+ with gr.Tabs():
+ with gr.Tab("🎧 Main"):
+ with gr.Column():
+ original_audio = gr.Audio(label="Original", interactive=False)
+ with gr.Row():
+ vocals_audio = gr.Audio(label="Vocals", show_download_button=True)
+ instrumental_audio = gr.Audio(label="Instrumental", show_download_button=True)
+
+ with gr.Tab("🔍 Details"):
+ with gr.Column():
+ with gr.Row():
+ male_audio = gr.Audio(label="Male")
+ female_audio = gr.Audio(label="Female")
+ speech_audio = gr.Audio(label="Speech")
+ with gr.Row():
+ drum_audio = gr.Audio(label="Drums")
+ bass_audio = gr.Audio(label="Bass")
+ with gr.Row():
+ other_audio = gr.Audio(label="Other")
+ effects_audio = gr.Audio(label="Effects")
+
+ with gr.Tab("⚙ Advanced"):
+ with gr.Column():
+ with gr.Row():
+ phaseremix_audio = gr.Audio(label="Phase Remix")
+ dry_audio = gr.Audio(label="Dry")
+ with gr.Row():
+ music_audio = gr.Audio(label="Music")
+ karaoke_audio = gr.Audio(label="Karaoke")
+ bleed_audio = gr.Audio(label="Bleed")
+
+ with gr.Row():
+
+ gr.Markdown("""
+
+ 🔈 Processing Tip: For noisy results, use bleed_suppressor_v1
+ or denoisedebleed
models in the "Denoise & Effect Removal"
+ category to clean the output
+
+ """)
+
+
+
+
+ # Oto Ensemble Sekmesi
+ with gr.Tab("Auto Ensemble"):
+ with gr.Row():
+ with gr.Column():
+ with gr.Group():
+ auto_input_audio_file = gr.File(label="Upload file")
+ auto_file_path_input = gr.Textbox(
+ label="Or enter file path",
+ placeholder="Enter full path to audio file",
+ interactive=True
+ )
+
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
+ with gr.Row():
+ auto_use_tta = gr.Checkbox(label="Use TTA", value=False)
+ auto_extract_instrumental = gr.Checkbox(label="Instrumental Only")
+
+ with gr.Row():
+ auto_overlap = gr.Slider(
+ label="Overlap",
+ minimum=2,
+ maximum=50,
+ value=2,
+ step=1
+ )
+ auto_chunk_size = gr.Dropdown(
+ label="Chunk Size",
+ choices=[352800, 485100],
+ value=352800
+ )
+ export_format2 = gr.Dropdown(
+ label="Output Format",
+ choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'],
+ value='wav FLOAT'
+ )
+
+ # Model Seçim Bölümü
+ with gr.Group():
+ gr.Markdown("### 🧠 Model Selection")
+ with gr.Row():
+ auto_category_dropdown = gr.Dropdown(
+ label="Model Category",
+ choices=list(model_choices.keys()),
+ value="Vocal Separation"
+ )
+
+ # Model seçimi (tek seferde)
+ auto_model_dropdown = gr.Dropdown(
+ label="Select Models from Category",
+ choices=model_choices["Vocal Separation"],
+ multiselect=True,
+ max_choices=50,
+ interactive=True
+ )
+
+ # Seçilen modellerin listesi (ayrı kutucuk)
+ selected_models = gr.Dropdown(
+ label="Selected Models",
+ choices=[],
+ multiselect=True,
+ interactive=False # Kullanıcı buraya direkt seçim yapamaz
+ )
+
+
+ with gr.Row():
+ add_btn = gr.Button("➕ Add Selected", variant="secondary")
+ clear_btn = gr.Button("🗑️ Clear All", variant="stop")
+
+ # Ensemble Ayarları
+ with gr.Group():
+ gr.Markdown("### ⚡ Ensemble Settings")
+ with gr.Row():
+ auto_ensemble_type = gr.Dropdown(
+ label="Method",
+ choices=['avg_wave', 'median_wave', 'min_wave', 'max_wave',
+ 'avg_fft', 'median_fft', 'min_fft', 'max_fft'],
+ value='avg_wave'
+ )
+
+ gr.Markdown("**Recommendation:** avg_wave and max_fft best results")
+
+ auto_process_btn = gr.Button("🚀 Start Processing", variant="primary")
+
+ with gr.Column():
+ with gr.Tabs():
+ with gr.Tab("🔊 Original Audio"):
+ original_audio2 = gr.Audio(
+ label=" Original Audio",
+ interactive=False,
+ every=1, # Her 1 saniyede bir güncelle
+ elem_id="original_audio_player"
+ )
+ with gr.Tab("🎚️ Ensemble Result"):
+ auto_output_audio = gr.Audio(
+ label="Output Preview",
+ show_download_button=True,
+ interactive=False
+ )
+
+ auto_status = gr.Textbox(
+ label="Processing Status",
+ interactive=False,
+ placeholder="Waiting for processing...",
+ elem_classes="status-box"
+ )
+
+ gr.Markdown("""
+
+
+
⚠️
+
+
+ Model Selection Guidelines
+
+
+ - Avoid cross-category mixing: Combining vocal and instrumental models may create unwanted blends
+ - Special model notes:
+
+ - Duality models (v1/v2) - Output both stems
+ - MDX23C Separator - Hybrid results
+
+
+ - Best practice: Use 3-5 similar models from same category
+
+
+ 💡 Pro Tip: Start with "VOCALS-MelBand-Roformer BigBeta5e" + "VOCALS-BS-Roformer_1297" combination
+
+
+
+
+ """)
+
+ # Kategori değişim fonksiyonunu güncelleyelim
+ def update_models(category):
+ return gr.Dropdown(choices=model_choices[category])
+
+ def add_models(new_models, existing_models):
+ updated = list(set(existing_models + new_models))
+ return gr.Dropdown(choices=updated, value=updated)
+
+ def clear_models():
+ return gr.Dropdown(choices=[], value=[])
+
+ # Etkileşimler
+ def update_category(target):
+ category_map = {
+ "Only Vocals": "Vocal Separation",
+ "Only Instrumental": "Instrumental Separation"
+ }
+ return category_map.get(target, "Vocal Separation")
+
+ # Otomatik yenileme için olayı bağla
+ input_audio_file.upload(
+ fn=lambda x, y: handle_file_upload(x, y, is_auto_ensemble=False),
+ inputs=[input_audio_file, file_path_input],
+ outputs=[input_audio_file, original_audio]
+ )
+
+ file_path_input.change(
+ fn=lambda x, y: handle_file_upload(x, y, is_auto_ensemble=False),
+ inputs=[input_audio_file, file_path_input],
+ outputs=[input_audio_file, original_audio]
+ )
+
+ auto_input_audio_file.upload(
+ fn=lambda x, y: handle_file_upload(x, y, is_auto_ensemble=True),
+ inputs=[auto_input_audio_file, auto_file_path_input],
+ outputs=[auto_input_audio_file, original_audio2]
+ )
+
+ auto_file_path_input.change(
+ fn=lambda x, y: handle_file_upload(x, y, is_auto_ensemble=True),
+ inputs=[auto_input_audio_file, auto_file_path_input],
+ outputs=[auto_input_audio_file, original_audio2]
+ )
+
+ auto_category_dropdown.change(
+ fn=update_models,
+ inputs=auto_category_dropdown,
+ outputs=auto_model_dropdown
+ )
+
+ add_btn.click(
+ fn=add_models,
+ inputs=[auto_model_dropdown, selected_models],
+ outputs=selected_models
+ )
+
+ clear_btn.click(
+ fn=clear_models,
+ inputs=[],
+ outputs=selected_models
+ )
+
+ auto_process_btn.click(
+ fn=auto_ensemble_process,
+ inputs=[
+ auto_input_audio_file,
+ selected_models,
+ auto_chunk_size,
+ auto_overlap,
+ export_format2,
+ auto_use_tta,
+ auto_extract_instrumental,
+ auto_ensemble_type,
+ gr.State(None)
+ ],
+ outputs=[auto_output_audio, auto_status]
+ )
+
+ # İndirme Sekmesi
+ with gr.Tab("Download Sources"):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("### 🗂️ Cloud Storage")
+ drive_url_input = gr.Textbox(label="Google Drive Shareable Link")
+ drive_download_btn = gr.Button("⬇️ Download from Drive", variant="secondary")
+ drive_download_status = gr.Textbox(label="Download Status")
+ drive_download_output = gr.File(label="Downloaded File", interactive=False)
+
+ with gr.Column():
+ gr.Markdown("### 🌐 Direct Links")
+ direct_url_input = gr.Textbox(label="Audio File URL")
+ direct_download_btn = gr.Button("⬇️ Download from URL", variant="secondary")
+ direct_download_status = gr.Textbox(label="Download Status")
+ direct_download_output = gr.File(label="Downloaded File", interactive=False)
+
+ with gr.Column():
+ gr.Markdown("### 🍪 Cookie Management")
+ cookie_file = gr.File(
+ label="Upload Cookies.txt",
+ file_types=[".txt"],
+ interactive=True,
+ elem_id="cookie_upload"
+ )
+ gr.Markdown("""
+
+
+ **📌 Why Needed?**
+ - Access age-restricted content
+ - Download private/unlisted videos
+ - Bypass regional restrictions
+ - Avoid YouTube download limits
+
+ **⚠️ Important Notes**
+ - NEVER share your cookie files!
+ - Refresh cookies when:
+ • Getting "403 Forbidden" errors
+ • Downloads suddenly stop
+ • Seeing "Session expired" messages
+
+ **🔄 Renewal Steps**
+ 1. Install this
Chrome extension
+ 2. Login to YouTube in Chrome
+ 3. Click extension icon → "Export"
+ 4. Upload the downloaded file here
+
+ **⏳ Cookie Lifespan**
+ - Normal sessions: 24 hours
+ - Sensitive operations: 1 hour
+ - Password changes: Immediate invalidation
+
+
+ """)
+
+
+ # Event handlers
+ model_category.change(
+ fn=update_models,
+ inputs=model_category,
+ outputs=model_dropdown
+ )
+
+ clear_old_output_btn.click(
+ fn=clear_old_output,
+ outputs=clear_old_output_status
+ )
+
+ process_btn.click(
+ fn=process_audio,
+ inputs=[
+ input_audio_file,
+ model_dropdown,
+ chunk_size,
+ overlap,
+ export_format,
+ use_tta,
+ use_demud_phaseremix_inst,
+ extract_instrumental,
+ gr.State(None),
+ gr.State(None)
+ ],
+ outputs=[
+ vocals_audio, instrumental_audio, phaseremix_audio,
+ drum_audio, karaoke_audio, bass_audio, other_audio, effects_audio,
+ speech_audio, bleed_audio, music_audio, dry_audio, male_audio, female_audio
+ ]
+ )
+
+ drive_download_btn.click(
+ fn=download_callback,
+ inputs=[drive_url_input, gr.State('drive')],
+ outputs=[
+ drive_download_output, # 0. Dosya çıktısı
+ drive_download_status, # 1. Durum mesajı
+ input_audio_file, # 2. Ana ses dosyası girişi
+ auto_input_audio_file, # 3. Oto ensemble girişi
+ original_audio, # 4. Orijinal ses çıktısı
+ original_audio2
+ ]
+ )
+
+ direct_download_btn.click(
+ fn=download_callback,
+ inputs=[direct_url_input, gr.State('direct'), cookie_file],
+ outputs=[
+ direct_download_output, # 0. Dosya çıktısı
+ direct_download_status, # 1. Durum mesajı
+ input_audio_file, # 2. Ana ses dosyası girişi
+ auto_input_audio_file, # 3. Oto ensemble girişi
+ original_audio, # 4. Orijinal ses çıktısı
+ original_audio2
+ ]
+ )
+
+
+ with gr.Tab("🎚️ Manuel Ensemble"):
+ with gr.Row(equal_height=True):
+ # Sol Panel - Giriş ve Ayarlar
+ with gr.Column(scale=1, min_width=400):
+ with gr.Accordion("📂 Input Sources", open=True):
+ with gr.Row():
+ refresh_btn = gr.Button("🔄 Refresh", variant="secondary", size="sm")
+ ensemble_type = gr.Dropdown(
+ label="Ensemble Algorithm",
+ choices=[
+ 'avg_wave',
+ 'median_wave',
+ 'min_wave',
+ 'max_wave',
+ 'avg_fft',
+ 'median_fft',
+ 'min_fft',
+ 'max_fft'
+ ],
+ value='avg_wave'
+ )
+
+ # Dosya listesini belirli bir yoldan al
+ file_path = "/content/drive/MyDrive/output" # Sabit yol
+ initial_files = glob.glob(f"{file_path}/*.wav") + glob.glob("/content/Music-Source-Separation-Training/old_output/*.wav")
+
+ gr.Markdown("### Select Audio Files")
+ file_dropdown = gr.Dropdown(
+ choices=initial_files,
+ label="Available Files",
+ multiselect=True,
+ interactive=True,
+ elem_id="file-dropdown"
+ )
+
+ weights_input = gr.Textbox(
+ label="Custom Weights (comma separated)",
+ placeholder="Example: 0.8, 1.2, 1.0, ...",
+ info="Leave empty for equal weights"
+ )
+
+ # Sağ Panel - Sonuçlar
+ with gr.Column(scale=2, min_width=800):
+ with gr.Tabs():
+ with gr.Tab("🎧 Result Preview"):
+ ensemble_output_audio = gr.Audio(
+ label="Ensembled Output",
+ interactive=False,
+ show_download_button=True,
+ elem_id="output-audio"
+ )
+
+ with gr.Tab("📋 Processing Log"):
+ ensemble_status = gr.Textbox(
+ label="Processing Details",
+ interactive=False,
+ elem_id="log-box"
+ )
+
+ with gr.Row():
+
+ ensemble_process_btn = gr.Button(
+ "⚡ Process Ensemble",
+ variant="primary",
+ size="sm", # Boyutu küçülttüm
+ elem_id="process-btn"
+ )
+
+ # Etkileşimler
+ def update_file_list():
+ files = glob.glob(f"{file_path}/*.wav") + glob.glob("/content/Music-Source-Separation-Training/old_output/*.wav")
+ return gr.Dropdown(choices=files)
+
+ refresh_btn.click(
+ fn=update_file_list,
+ outputs=file_dropdown
+ )
+
+ def ensemble_audio_fn(files, method, weights):
+ try:
+ if len(files) < 2:
+ return None, "⚠️ Minimum 2 files required"
+
+ # Dosya yollarını kontrol et
+ valid_files = [f for f in files if os.path.exists(f)]
+
+ if len(valid_files) < 2:
+ return None, "❌ Valid files not found"
+
+ # Create output directory if needed
+ output_dir = "/content/drive/MyDrive/ensembles"
+ os.makedirs(output_dir, exist_ok=True) # This line fixes the error
+
+ # Create output path
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_path = f"{output_dir}/ensemble_{timestamp}.wav"
+
+ # Ensemble işlemi
+ ensemble_args = [
+ "--files", *valid_files,
+ "--type", method.lower().replace(' ', '_'),
+ "--output", output_path
+ ]
+
+ if weights and weights.strip():
+ weights_list = [str(w) for w in map(float, weights.split(','))]
+ ensemble_args += ["--weights", *weights_list]
+
+ result = subprocess.run(
+ ["python", "ensemble.py"] + ensemble_args,
+ capture_output=True,
+ text=True
+ )
+
+ log = f"✅ Success!\n{result.stdout}" if not result.stderr else f"❌ Error!\n{result.stderr}"
+ return output_path, log
+
+ except Exception as e:
+ return None, f"⛔ Critical Error: {str(e)}"
+
+ ensemble_process_btn.click(
+ fn=ensemble_audio_fn,
+ inputs=[file_dropdown, ensemble_type, weights_input],
+ outputs=[ensemble_output_audio, ensemble_status]
+ )
+
+ gr.HTML("""
+
+ """)
+
+ return demo
+
+def launch_with_share():
+ try:
+ port = generate_random_port()
+ demo = create_interface()
+
+ share_link = demo.launch(
+ share=True,
+ server_port=port,
+ server_name='0.0.0.0',
+ inline=False,
+ allowed_paths=[
+ '/content',
+ '/content/drive/MyDrive/output',
+ '/tmp'
+ '/model_output_dir',
+ 'model_output_dir'
+ ]
+ )
+
+ print(f"🌐 Gradio Share Link: {share_link}")
+ print(f"🔌 Local Server Port: {port}")
+
+ while True:
+ time.sleep(1)
+
+ except KeyboardInterrupt:
+ print("🛑 Server stopped by user.")
+ except Exception as e:
+ print(f"❌ Error during server launch: {e}")
+ finally:
+ try:
+ demo.close()
+ except:
+ pass
+
+if __name__ == "__main__":
+ launch_with_share()
diff --git a/ckpts/inst_gabox.yaml b/ckpts/inst_gabox.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e71de76d0ad743bc1807aeb8a2ae6f851e86af8f
--- /dev/null
+++ b/ckpts/inst_gabox.yaml
@@ -0,0 +1,48 @@
+audio:
+ chunk_size: 352800
+ dim_f: 1024
+ dim_t: 1101
+ hop_length: 441
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.0
+model:
+ dim: 384
+ depth: 6
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ num_bands: 60
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0
+ ff_dropout: 0
+ flash_attn: true
+ dim_freqs_in: 1025
+ sample_rate: 44100
+ stft_n_fft: 2048
+ stft_hop_length: 441
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: false
+training:
+ instruments:
+ - Instrumental
+ - Vocals
+ target_instrument: Instrumental
+ use_amp: true
+inference:
+ batch_size: 2
+ dim_t: 1101
+ num_overlap: 2
diff --git a/clean_model.py b/clean_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee1be091f8cfeb4155654fa3421c07d592c1a3b
--- /dev/null
+++ b/clean_model.py
@@ -0,0 +1,157 @@
+import os
+import glob
+import subprocess
+import time
+import gc
+import shutil
+import sys
+from datetime import datetime
+import torch
+import yaml
+import gradio as gr
+import threading
+import random
+import librosa
+import soundfile as sf
+import numpy as np
+import requests
+import json
+import locale
+import re
+import psutil
+import concurrent.futures
+from tqdm import tqdm
+from google.oauth2.credentials import Credentials
+import tempfile
+from urllib.parse import urlparse, quote
+import gdown
+
+import warnings
+warnings.filterwarnings("ignore")
+
+# BASE_DIR'i dinamik olarak güncel dizine ayarla
+BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # processing.py'nin bulunduğu dizin
+INFERENCE_PATH = os.path.join(BASE_DIR, "inference.py") # inference.py'nin tam yolu
+OUTPUT_DIR = os.path.join(BASE_DIR, "output") # Çıkış dizini BASE_DIR/output olarak güncellendi
+AUTO_ENSEMBLE_OUTPUT = os.path.join(BASE_DIR, "ensemble_output") # Ensemble çıkış dizini
+
+def clean_model_name(model):
+ """
+ Clean and standardize model names for filename
+ """
+ model_name_mapping = {
+ 'VOCALS-InstVocHQ': 'InstVocHQ',
+ 'VOCALS-MelBand-Roformer (by KimberleyJSN)': 'KimberleyJSN',
+ 'VOCALS-BS-Roformer_1297 (by viperx)': 'VOCALS_BS_Roformer1297',
+ 'VOCALS-BS-Roformer_1296 (by viperx)': 'VOCALS-BS-Roformer_1296',
+ 'VOCALS-BS-RoformerLargev1 (by unwa)': 'UnwaLargeV1',
+ 'VOCALS-Mel-Roformer big beta 4 (by unwa)': 'UnwaBigBeta4',
+ 'VOCALS-Melband-Roformer BigBeta5e (by unwa)': 'UnwaBigBeta5e',
+ 'INST-Mel-Roformer v1 (by unwa)': 'UnwaInstV1',
+ 'INST-Mel-Roformer v2 (by unwa)': 'UnwaInstV2',
+ 'INST-VOC-Mel-Roformer a.k.a. duality (by unwa)': 'UnwaDualityV1',
+ 'INST-VOC-Mel-Roformer a.k.a. duality v2 (by unwa)': 'UnwaDualityV2',
+ 'KARAOKE-MelBand-Roformer (by aufr33 & viperx)': 'KaraokeMelBandRoformer',
+ 'VOCALS-VitLarge23 (by ZFTurbo)': 'VitLarge23',
+ 'VOCALS-MelBand-Roformer (by Becruily)': 'BecruilyVocals',
+ 'INST-MelBand-Roformer (by Becruily)': 'BecruilyInst',
+ 'VOCALS-MelBand-Roformer Kim FT (by Unwa)': 'KimFT',
+ 'INST-MelBand-Roformer Kim FT (by Unwa)': 'KimFTInst',
+ 'OTHER-BS-Roformer_1053 (by viperx)': 'OtherViperx1053',
+ 'CROWD-REMOVAL-MelBand-Roformer (by aufr33)': 'CrowdRemovalRoformer',
+ 'CINEMATIC-BandIt_Plus (by kwatcharasupat)': 'CinematicBandItPlus',
+ 'DRUMSEP-MDX23C_DrumSep_6stem (by aufr33 & jarredou)': 'DrumSepMDX23C',
+ '4STEMS-SCNet_MUSDB18 (by starrytong)': 'FourStemsSCNet',
+ 'DE-REVERB-MDX23C (by aufr33 & jarredou)': 'DeReverbMDX23C',
+ 'DENOISE-MelBand-Roformer-1 (by aufr33)': 'DenoiseMelBand1',
+ 'DENOISE-MelBand-Roformer-2 (by aufr33)': 'DenoiseMelBand2',
+ 'INST-MelBand-Roformer (by Becruily)': 'BecruilyInst',
+ '4STEMS-SCNet_XL_MUSDB18 (by ZFTurbo)': 'FourStemsSCNetXL',
+ '4STEMS-SCNet_Large (by starrytong)': 'FourStemsSCNetLarge',
+ '4STEMS-BS-Roformer_MUSDB18 (by ZFTurbo)': 'FourStemsBSRoformer',
+ 'DE-REVERB-MelBand-Roformer aggr./v2/19.1729 (by anvuew)': 'DeReverbMelBandAggr',
+ 'DE-REVERB-Echo-MelBand-Roformer (by Sucial)': 'DeReverbEchoMelBand',
+ 'bleed_suppressor_v1 (by unwa)': 'BleedSuppressorV1',
+ 'inst_v1e (by unwa)': 'InstV1E',
+ 'inst_gabox (by Gabox)': 'InstGabox',
+ 'inst_gaboxBV1 (by Gabox)': 'InstGaboxBV1',
+ 'inst_gaboxBV2 (by Gabox)': 'InstGaboxBV2',
+ 'inst_gaboxBFV1 (by Gabox)': 'InstGaboxBFV1',
+ 'inst_gaboxFV2 (by Gabox)': 'InstGaboxFV2',
+ 'inst_gaboxFV1 (by Gabox)': 'InstGaboxFV1',
+ 'dereverb_mel_band_roformer_less_aggressive_anvuew': 'DereverbMelBandRoformerLessAggressive',
+ 'dereverb_mel_band_roformer_anvuew': 'DereverbMelBandRoformer',
+ 'VOCALS-Male Female-BS-RoFormer Male Female Beta 7_2889 (by aufr33)': 'MaleFemale-BS-RoFormer-(by aufr33)',
+ 'VOCALS-MelBand-Roformer (by Becruily)': 'Vocals-MelBand-Roformer-(by Becruily)',
+ 'VOCALS-MelBand-Roformer Kim FT 2 (by Unwa)': 'Vocals-MelBand-Roformer-KİM-FT-2(by Unwa)',
+ 'voc_gaboxMelRoformer (by Gabox)': 'voc_gaboxMelRoformer',
+ 'voc_gaboxBSroformer (by Gabox)': 'voc_gaboxBSroformer',
+ 'voc_gaboxMelRoformerFV1 (by Gabox)': 'voc_gaboxMelRoformerFV1',
+ 'voc_gaboxMelRoformerFV2 (by Gabox)': 'voc_gaboxMelRoformerFV2',
+ 'SYH99999/MelBandRoformerSYHFTB1(by Amane)': 'MelBandRoformerSYHFTB1',
+ 'inst_V5 (by Gabox)': 'INSTV5-(by Gabox)',
+ 'inst_Fv4Noise (by Gabox)': 'Inst_Fv4Noise-(by Gabox)',
+ 'Intrumental_Gabox (by Gabox)': 'Intrumental_Gabox-(by Gabox)',
+ 'inst_GaboxFv3 (by Gabox)': 'INST_GaboxFv3-(by Gabox)',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model1 (by Amane)': 'MelBandRoformerSYHFTB1_model1',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model2 (by Amane)': 'MelBandRoformerSYHFTB1_model2',
+ 'SYH99999/MelBandRoformerSYHFTB1_Model3 (by Amane)': 'MelBandRoformerSYHFTB1_model3',
+ 'VOCALS-MelBand-Roformer Kim FT 2 Blendless (by unwa)': 'VOCALS-MelBand-Roformer-Kim-FT-2-Blendless-(by unwa)',
+ 'inst_gaboxFV6 (by Gabox)': 'inst_gaboxFV6-(by Gabox)',
+ 'denoisedebleed (by Gabox)': 'denoisedebleed-(by Gabox)',
+ 'INSTV5N (by Gabox)': 'INSTV5N_(by Gabox)',
+ 'Voc_Fv3 (by Gabox)': 'Voc_Fv3_(by Gabox)',
+ 'MelBandRoformer4StemFTLarge (SYH99999)': 'MelBandRoformer4StemFTLarge_(SYH99999)',
+ 'dereverb_mel_band_roformer_mono (by anvuew)': 'dereverb_mel_band_roformer_mono_(by anvuew)',
+ 'INSTV6N (by Gabox)': 'INSTV6N_(by Gabox)',
+ 'KaraokeGabox': 'KaraokeGabox',
+ 'FullnessVocalModel (by Amane)': 'FullnessVocalModel',
+ 'Inst_GaboxV7 (by Gabox)': 'Inst_GaboxV7_(by Gabox)',
+ }
+
+ if model in model_name_mapping:
+ return model_name_mapping[model]
+
+ cleaned = re.sub(r'\s*\(.*?\)', '', model) # Remove parenthetical info
+ cleaned = cleaned.replace('-', '_')
+ cleaned = ''.join(char for char in cleaned if char.isalnum() or char == '_')
+
+ return cleaned
+
+def shorten_filename(filename, max_length=30):
+ """
+ Shortens a filename to a specified maximum length
+ """
+ base, ext = os.path.splitext(filename)
+ if len(base) <= max_length:
+ return filename
+ shortened = base[:15] + "..." + base[-10:] + ext
+ return shortened
+
+def clean_filename(filename):
+ """
+ Temizlenmiş dosya adını döndürür
+ """
+ cleanup_patterns = [
+ r'_\d{8}_\d{6}_\d{6}$', # _20231215_123456_123456
+ r'_\d{14}$', # _20231215123456
+ r'_\d{10}$', # _1702658400
+ r'_\d+$' # Herhangi bir sayı
+ ]
+
+ base, ext = os.path.splitext(filename)
+ for pattern in cleanup_patterns:
+ base = re.sub(pattern, '', base)
+
+ file_types = ['vocals', 'instrumental', 'drum', 'bass', 'other', 'effects', 'speech', 'music', 'dry', 'male', 'female']
+ for type_keyword in file_types:
+ base = base.replace(f'_{type_keyword}', '')
+
+ detected_type = None
+ for type_keyword in file_types:
+ if type_keyword in base.lower():
+ detected_type = type_keyword
+ break
+
+ clean_base = base.strip('_- ')
+ return clean_base, detected_type, ext
diff --git a/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml b/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f7f511025573acaa7a53f75c957b878c0aaa8205
--- /dev/null
+++ b/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml
@@ -0,0 +1,72 @@
+audio:
+ chunk_size: 352800
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 441
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ dim: 384
+ depth: 6
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ num_bands: 60
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0
+ ff_dropout: 0
+ flash_attn: True
+ dim_freqs_in: 1025
+ sample_rate: 44100 # needed for mel filter bank from librosa
+ stft_n_fft: 2048
+ stft_hop_length: 441
+ stft_win_length: 2048
+ stft_normalized: False
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+
+training:
+ batch_size: 4
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 1.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ augmentation: false # enable augmentations by audiomentations and pedalboard
+ augmentation_type: null
+ use_mp3_compress: false # Deprecated
+ augmentation_mix: false # Mix several stems of the same type with some probability
+ augmentation_loudness: false # randomly change loudness of each stem
+ augmentation_loudness_type: 1 # Type 1 or 2
+ augmentation_loudness_min: 0
+ augmentation_loudness_max: 0
+ q: 0.95
+ coarse_loss_clip: false
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+inference:
+ batch_size: 4
+ dim_t: 256
+ num_overlap: 2
\ No newline at end of file
diff --git a/configs/config_apollo.yaml b/configs/config_apollo.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..290547b44d0c5d62d3d3315e7c2e444728139f22
--- /dev/null
+++ b/configs/config_apollo.yaml
@@ -0,0 +1,33 @@
+audio:
+ chunk_size: 132300
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.0
+
+model:
+ sr: 44100
+ win: 20
+ feature_dim: 256
+ layer: 6
+
+training:
+ instruments: ['restored', 'addition']
+ target_instrument: 'restored'
+ batch_size: 2
+ num_steps: 1000
+ num_epochs: 1000
+ optimizer: 'prodigy'
+ lr: 1.0
+ patience: 2
+ reduce_factor: 0.95
+ coarse_loss_clip: true
+ grad_clip: 0
+ q: 0.95
+ use_amp: true
+
+augmentations:
+ enable: false # enable or disable all augmentations (to fast disable if needed)
+
+inference:
+ batch_size: 4
+ num_overlap: 4
diff --git a/configs/config_dnr_bandit_bsrnn_multi_mus64.yaml b/configs/config_dnr_bandit_bsrnn_multi_mus64.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2392ca496e498e57a99d70f1f28f73fe3dd7c432
--- /dev/null
+++ b/configs/config_dnr_bandit_bsrnn_multi_mus64.yaml
@@ -0,0 +1,78 @@
+name: "MultiMaskMultiSourceBandSplitRNN"
+audio:
+ chunk_size: 264600
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ in_channel: 1
+ stems: ['speech', 'music', 'effects']
+ band_specs: "musical"
+ n_bands: 64
+ fs: 44100
+ require_no_overlap: false
+ require_no_gap: true
+ normalize_channel_independently: false
+ treat_channel_as_feature: true
+ n_sqm_modules: 8
+ emb_dim: 128
+ rnn_dim: 256
+ bidirectional: true
+ rnn_type: "GRU"
+ mlp_dim: 512
+ hidden_activation: "Tanh"
+ hidden_activation_kwargs: null
+ complex_mask: true
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+ window_fn: "hann_window"
+ wkwargs: null
+ power: null
+ center: true
+ normalized: true
+ pad_mode: "constant"
+ onesided: true
+
+training:
+ batch_size: 4
+ gradient_accumulation_steps: 4
+ grad_clip: 0
+ instruments:
+ - speech
+ - music
+ - effects
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_dnr_bandit_v2_mus64.yaml b/configs/config_dnr_bandit_v2_mus64.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..db74fee27426b6e2204d459070603abcf846e3a6
--- /dev/null
+++ b/configs/config_dnr_bandit_v2_mus64.yaml
@@ -0,0 +1,77 @@
+cls: Bandit
+
+audio:
+ chunk_size: 384000
+ num_channels: 2
+ sample_rate: 48000
+ min_mean_abs: 0.000
+
+kwargs:
+ in_channels: 1
+ stems: ['speech', 'music', 'sfx']
+ band_type: musical
+ n_bands: 64
+ normalize_channel_independently: false
+ treat_channel_as_feature: true
+ n_sqm_modules: 8
+ emb_dim: 128
+ rnn_dim: 256
+ bidirectional: true
+ rnn_type: "GRU"
+ mlp_dim: 512
+ hidden_activation: "Tanh"
+ hidden_activation_kwargs: null
+ complex_mask: true
+ use_freq_weights: true
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+ window_fn: "hann_window"
+ wkwargs: null
+ power: null
+ center: true
+ normalized: true
+ pad_mode: "reflect"
+ onesided: true
+
+training:
+ batch_size: 4
+ gradient_accumulation_steps: 4
+ grad_clip: 0
+ instruments:
+ - speech
+ - music
+ - sfx
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+inference:
+ batch_size: 8
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_drumsep.yaml b/configs/config_drumsep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..687b5ff0639a476a57a6e24552964759c8ce1ff5
--- /dev/null
+++ b/configs/config_drumsep.yaml
@@ -0,0 +1,72 @@
+audio:
+ chunk_size: 1764000 # samplerate * segment
+ min_mean_abs: 0.000
+ hop_length: 1024
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 40
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['kick', 'snare', 'cymbals', 'toms']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: adam
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: false # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+
+inference:
+ num_overlap: 4
+ batch_size: 8
+
+model: hdemucs
+
+hdemucs: # see demucs/hdemucs.py for a detailed description
+ channels: 48
+ channels_time: null
+ growth: 2
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: False
+ cac: True
+ depth: 6
+ rewrite: True
+ hybrid: True
+ hybrid_old: False
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: True
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ norm_starts: 4
+ norm_groups: 4
+ dconv_mode: 1
+ dconv_depth: 2
+ dconv_comp: 4
+ dconv_attn: 4
+ dconv_lstm: 4
+ dconv_init: 0.001
+ rescale: 0.1
diff --git a/configs/config_htdemucs_6stems.yaml b/configs/config_htdemucs_6stems.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d06a489ec66794414dedd4c143f6e937b26ce666
--- /dev/null
+++ b/configs/config_htdemucs_6stems.yaml
@@ -0,0 +1,127 @@
+audio:
+ chunk_size: 485100 # samplerate * segment
+ min_mean_abs: 0.001
+ hop_length: 1024
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 11
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['drums', 'bass', 'other', 'vocals', 'guitar', 'piano']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: adam
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: [0.2, 0.02]
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+inference:
+ num_overlap: 4
+ batch_size: 8
+
+model: htdemucs
+
+htdemucs: # see demucs/htdemucs.py for a detailed description
+ # Channels
+ channels: 48
+ channels_time:
+ growth: 2
+ # STFT
+ num_subbands: 1
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: false
+ cac: true
+ # Main structure
+ depth: 4
+ rewrite: true
+ # Frequency Branch
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: true
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ # normalization
+ norm_starts: 4
+ norm_groups: 4
+ # DConv residual branch
+ dconv_mode: 3
+ dconv_depth: 2
+ dconv_comp: 8
+ dconv_init: 1e-3
+ # Before the Transformer
+ bottom_channels: 0
+ # CrossTransformer
+ # ------ Common to all
+ # Regular parameters
+ t_layers: 5
+ t_hidden_scale: 4.0
+ t_heads: 8
+ t_dropout: 0.0
+ t_layer_scale: True
+ t_gelu: True
+ # ------------- Positional Embedding
+ t_emb: sin
+ t_max_positions: 10000 # for the scaled embedding
+ t_max_period: 10000.0
+ t_weight_pos_embed: 1.0
+ t_cape_mean_normalize: True
+ t_cape_augment: True
+ t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
+ t_sin_random_shift: 0
+ # ------------- norm before a transformer encoder
+ t_norm_in: True
+ t_norm_in_group: False
+ # ------------- norm inside the encoder
+ t_group_norm: False
+ t_norm_first: True
+ t_norm_out: True
+ # ------------- optim
+ t_weight_decay: 0.0
+ t_lr:
+ # ------------- sparsity
+ t_sparse_self_attn: False
+ t_sparse_cross_attn: False
+ t_mask_type: diag
+ t_mask_random_seed: 42
+ t_sparse_attn_window: 400
+ t_global_window: 100
+ t_sparsity: 0.95
+ t_auto_sparsity: False
+ # Cross Encoder First (False)
+ t_cross_first: False
+ # Weight init
+ rescale: 0.1
+
diff --git a/configs/config_musdb18_bs_mamba2.yaml b/configs/config_musdb18_bs_mamba2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..89154451ca213849592ff4aa4a2d21e132b13cf2
--- /dev/null
+++ b/configs/config_musdb18_bs_mamba2.yaml
@@ -0,0 +1,58 @@
+audio:
+ chunk_size: 132300 # samplerate * segment
+ hop_length: 1024
+ min_mean_abs: 0.0
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 11
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['drums', 'bass', 'other', 'vocals']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: prodigy
+ lr: 1.0
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ read_metadata_procs: 8
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+model:
+ sr: 44100
+ win: 2048
+ stride: 512
+ feature_dim: 128
+ num_repeat_mask: 8
+ num_repeat_map: 4
+ num_output: 4
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+inference:
+ num_overlap: 2
+ batch_size: 8
\ No newline at end of file
diff --git a/configs/config_musdb18_bs_roformer.yaml b/configs/config_musdb18_bs_roformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ff17998d201e7f3d894cdc80671b9ac330023541
--- /dev/null
+++ b/configs/config_musdb18_bs_roformer.yaml
@@ -0,0 +1,137 @@
+audio:
+ chunk_size: 131584
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 512
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ dim: 192
+ depth: 6
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ freqs_per_bands: !!python/tuple
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 128
+ - 129
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: true
+ dim_freqs_in: 1025
+ stft_n_fft: 2048
+ stft_hop_length: 512
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+ mlp_expansion_factor: 4 # Probably too big (requires a lot of memory for weights)
+ use_torch_checkpoint: False # it allows to greatly reduce GPU memory consumption during training (not fully tested)
+ skip_connection: False # Enable skip connection between transformer blocks - can solve problem with gradients and probably faster training
+
+training:
+ batch_size: 10
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - bass
+ - drums
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_musdb18_bs_roformer_with_lora.yaml b/configs/config_musdb18_bs_roformer_with_lora.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6c6dcdbcf7ca2a1c820d6649e3be056c270b8788
--- /dev/null
+++ b/configs/config_musdb18_bs_roformer_with_lora.yaml
@@ -0,0 +1,205 @@
+audio:
+ chunk_size: 485100
+ dim_f: 1024
+ dim_t: 801 # don't work (use in model)
+ hop_length: 441 # don't work (use in model)
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+lora:
+ r: 8
+ lora_alpha: 16 # alpha / rank > 1
+ lora_dropout: 0.05
+ merge_weights: False
+ fan_in_fan_out: False
+ enable_lora: [True, False, True] # This for QKV
+ # enable_lora: [True] # For non-Roformers architectures
+
+model:
+ dim: 384
+ depth: 8
+ stereo: true
+ num_stems: 4
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ freqs_per_bands: !!python/tuple
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 128
+ - 129
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: true
+ dim_freqs_in: 1025
+ stft_n_fft: 2048
+ stft_hop_length: 441
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+ mlp_expansion_factor: 2
+ use_torch_checkpoint: False # it allows to greatly reduce GPU memory consumption during training (not fully tested)
+ skip_connection: False # Enable skip connection between transformer blocks - can solve problem with gradients and probably faster training
+
+training:
+ batch_size: 1
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments: ['drums', 'bass', 'other', 'vocals']
+ patience: 3
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ augmentation: false # enable augmentations by audiomentations and pedalboard
+ augmentation_type: simple1
+ use_mp3_compress: false # Deprecated
+ augmentation_mix: true # Mix several stems of the same type with some probability
+ augmentation_loudness: true # randomly change loudness of each stem
+ augmentation_loudness_type: 1 # Type 1 or 2
+ augmentation_loudness_min: 0.5
+ augmentation_loudness_max: 1.5
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ # optimizer: prodigy
+ optimizer: adam
+ # lr: 1.0
+ lr: 1.0e-5
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.1
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ bass:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -2
+ pitch_shift_max_semitones: 2
+ seven_band_parametric_eq: 0.1
+ seven_band_parametric_eq_min_gain_db: -3
+ seven_band_parametric_eq_max_gain_db: 6
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.5
+ drums:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.1
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.6
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.1
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+
+
+inference:
+ batch_size: 2
+ dim_t: 1101
+ num_overlap: 2
\ No newline at end of file
diff --git a/configs/config_musdb18_demucs3_mmi.yaml b/configs/config_musdb18_demucs3_mmi.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..08c25c50f8f747d0e4af7acae68b1e47a01f3d0c
--- /dev/null
+++ b/configs/config_musdb18_demucs3_mmi.yaml
@@ -0,0 +1,72 @@
+audio:
+ chunk_size: 485100 # samplerate * segment
+ min_mean_abs: 0.000
+ hop_length: 1024
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 11
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['drums', 'bass', 'other', 'vocals']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: adam
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: false # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+
+inference:
+ num_overlap: 4
+ batch_size: 8
+
+model: hdemucs
+
+hdemucs: # see demucs/hdemucs.py for a detailed description
+ channels: 48
+ channels_time: null
+ growth: 2
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: False
+ cac: True
+ depth: 6
+ rewrite: True
+ hybrid: True
+ hybrid_old: False
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: True
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ norm_starts: 4
+ norm_groups: 4
+ dconv_mode: 1
+ dconv_depth: 2
+ dconv_comp: 4
+ dconv_attn: 4
+ dconv_lstm: 4
+ dconv_init: 0.001
+ rescale: 0.1
diff --git a/configs/config_musdb18_htdemucs.yaml b/configs/config_musdb18_htdemucs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba635367baca0b58a977fa4bb38a1cec99579ca9
--- /dev/null
+++ b/configs/config_musdb18_htdemucs.yaml
@@ -0,0 +1,119 @@
+audio:
+ chunk_size: 485100 # samplerate * segment
+ min_mean_abs: 0.001
+ hop_length: 1024
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 11
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['drums', 'bass', 'other', 'vocals']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: adam
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+
+inference:
+ num_overlap: 4
+ batch_size: 8
+
+model: htdemucs
+
+htdemucs: # see demucs/htdemucs.py for a detailed description
+ # Channels
+ channels: 48
+ channels_time:
+ growth: 2
+ # STFT
+ num_subbands: 1
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: false
+ cac: true
+ # Main structure
+ depth: 4
+ rewrite: true
+ # Frequency Branch
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: true
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ # normalization
+ norm_starts: 4
+ norm_groups: 4
+ # DConv residual branch
+ dconv_mode: 3
+ dconv_depth: 2
+ dconv_comp: 8
+ dconv_init: 1e-3
+ # Before the Transformer
+ bottom_channels: 512
+ # CrossTransformer
+ # ------ Common to all
+ # Regular parameters
+ t_layers: 5
+ t_hidden_scale: 4.0
+ t_heads: 8
+ t_dropout: 0.0
+ t_layer_scale: True
+ t_gelu: True
+ # ------------- Positional Embedding
+ t_emb: sin
+ t_max_positions: 10000 # for the scaled embedding
+ t_max_period: 10000.0
+ t_weight_pos_embed: 1.0
+ t_cape_mean_normalize: True
+ t_cape_augment: True
+ t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
+ t_sin_random_shift: 0
+ # ------------- norm before a transformer encoder
+ t_norm_in: True
+ t_norm_in_group: False
+ # ------------- norm inside the encoder
+ t_group_norm: False
+ t_norm_first: True
+ t_norm_out: True
+ # ------------- optim
+ t_weight_decay: 0.0
+ t_lr:
+ # ------------- sparsity
+ t_sparse_self_attn: False
+ t_sparse_cross_attn: False
+ t_mask_type: diag
+ t_mask_random_seed: 42
+ t_sparse_attn_window: 400
+ t_global_window: 100
+ t_sparsity: 0.95
+ t_auto_sparsity: False
+ # Cross Encoder First (False)
+ t_cross_first: False
+ # Weight init
+ rescale: 0.1
+
diff --git a/configs/config_musdb18_mdx23c.yaml b/configs/config_musdb18_mdx23c.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..73631f7293c8db94c55c7e1db9fdfc79c712d6e0
--- /dev/null
+++ b/configs/config_musdb18_mdx23c.yaml
@@ -0,0 +1,182 @@
+audio:
+ chunk_size: 261120
+ dim_f: 4096
+ dim_t: 256
+ hop_length: 1024
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ act: gelu
+ bottleneck_factor: 4
+ growth: 128
+ norm: InstanceNorm
+ num_blocks_per_scale: 2
+ num_channels: 128
+ num_scales: 5
+ num_subbands: 4
+ scale:
+ - 2
+ - 2
+
+training:
+ batch_size: 6
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - bass
+ - drums
+ - other
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ # apply mp3 compression to mixture only (emulate downloading mp3 from internet)
+ mp3_compression_on_mixture: 0.01
+ mp3_compression_on_mixture_bitrate_min: 32
+ mp3_compression_on_mixture_bitrate_max: 320
+ mp3_compression_on_mixture_backend: "lameenc"
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+ mp3_compression: 0.01
+ mp3_compression_min_bitrate: 32
+ mp3_compression_max_bitrate: 320
+ mp3_compression_backend: "lameenc"
+
+ # pedalboard reverb block
+ pedalboard_reverb: 0.01
+ pedalboard_reverb_room_size_min: 0.1
+ pedalboard_reverb_room_size_max: 0.9
+ pedalboard_reverb_damping_min: 0.1
+ pedalboard_reverb_damping_max: 0.9
+ pedalboard_reverb_wet_level_min: 0.1
+ pedalboard_reverb_wet_level_max: 0.9
+ pedalboard_reverb_dry_level_min: 0.1
+ pedalboard_reverb_dry_level_max: 0.9
+ pedalboard_reverb_width_min: 0.9
+ pedalboard_reverb_width_max: 1.0
+
+ # pedalboard chorus block
+ pedalboard_chorus: 0.01
+ pedalboard_chorus_rate_hz_min: 1.0
+ pedalboard_chorus_rate_hz_max: 7.0
+ pedalboard_chorus_depth_min: 0.25
+ pedalboard_chorus_depth_max: 0.95
+ pedalboard_chorus_centre_delay_ms_min: 3
+ pedalboard_chorus_centre_delay_ms_max: 10
+ pedalboard_chorus_feedback_min: 0.0
+ pedalboard_chorus_feedback_max: 0.5
+ pedalboard_chorus_mix_min: 0.1
+ pedalboard_chorus_mix_max: 0.9
+
+ # pedalboard phazer block
+ pedalboard_phazer: 0.01
+ pedalboard_phazer_rate_hz_min: 1.0
+ pedalboard_phazer_rate_hz_max: 10.0
+ pedalboard_phazer_depth_min: 0.25
+ pedalboard_phazer_depth_max: 0.95
+ pedalboard_phazer_centre_frequency_hz_min: 200
+ pedalboard_phazer_centre_frequency_hz_max: 12000
+ pedalboard_phazer_feedback_min: 0.0
+ pedalboard_phazer_feedback_max: 0.5
+ pedalboard_phazer_mix_min: 0.1
+ pedalboard_phazer_mix_max: 0.9
+
+ # pedalboard distortion block
+ pedalboard_distortion: 0.01
+ pedalboard_distortion_drive_db_min: 1.0
+ pedalboard_distortion_drive_db_max: 25.0
+
+ # pedalboard pitch shift block
+ pedalboard_pitch_shift: 0.01
+ pedalboard_pitch_shift_semitones_min: -7
+ pedalboard_pitch_shift_semitones_max: 7
+
+ # pedalboard resample block
+ pedalboard_resample: 0.01
+ pedalboard_resample_target_sample_rate_min: 4000
+ pedalboard_resample_target_sample_rate_max: 44100
+
+ # pedalboard bitcrash block
+ pedalboard_bitcrash: 0.01
+ pedalboard_bitcrash_bit_depth_min: 4
+ pedalboard_bitcrash_bit_depth_max: 16
+
+ # pedalboard mp3 compressor block
+ pedalboard_mp3_compressor: 0.01
+ pedalboard_mp3_compressor_pedalboard_mp3_compressor_min: 0
+ pedalboard_mp3_compressor_pedalboard_mp3_compressor_max: 9.999
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ bass:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -2
+ pitch_shift_max_semitones: 2
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -3
+ seven_band_parametric_eq_max_gain_db: 6
+ tanh_distortion: 0.2
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.5
+ drums:
+ pitch_shift: 0.33
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.33
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.6
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.01
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_musdb18_mdx23c_stht.yaml b/configs/config_musdb18_mdx23c_stht.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..73631f7293c8db94c55c7e1db9fdfc79c712d6e0
--- /dev/null
+++ b/configs/config_musdb18_mdx23c_stht.yaml
@@ -0,0 +1,182 @@
+audio:
+ chunk_size: 261120
+ dim_f: 4096
+ dim_t: 256
+ hop_length: 1024
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ act: gelu
+ bottleneck_factor: 4
+ growth: 128
+ norm: InstanceNorm
+ num_blocks_per_scale: 2
+ num_channels: 128
+ num_scales: 5
+ num_subbands: 4
+ scale:
+ - 2
+ - 2
+
+training:
+ batch_size: 6
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - bass
+ - drums
+ - other
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ # apply mp3 compression to mixture only (emulate downloading mp3 from internet)
+ mp3_compression_on_mixture: 0.01
+ mp3_compression_on_mixture_bitrate_min: 32
+ mp3_compression_on_mixture_bitrate_max: 320
+ mp3_compression_on_mixture_backend: "lameenc"
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+ mp3_compression: 0.01
+ mp3_compression_min_bitrate: 32
+ mp3_compression_max_bitrate: 320
+ mp3_compression_backend: "lameenc"
+
+ # pedalboard reverb block
+ pedalboard_reverb: 0.01
+ pedalboard_reverb_room_size_min: 0.1
+ pedalboard_reverb_room_size_max: 0.9
+ pedalboard_reverb_damping_min: 0.1
+ pedalboard_reverb_damping_max: 0.9
+ pedalboard_reverb_wet_level_min: 0.1
+ pedalboard_reverb_wet_level_max: 0.9
+ pedalboard_reverb_dry_level_min: 0.1
+ pedalboard_reverb_dry_level_max: 0.9
+ pedalboard_reverb_width_min: 0.9
+ pedalboard_reverb_width_max: 1.0
+
+ # pedalboard chorus block
+ pedalboard_chorus: 0.01
+ pedalboard_chorus_rate_hz_min: 1.0
+ pedalboard_chorus_rate_hz_max: 7.0
+ pedalboard_chorus_depth_min: 0.25
+ pedalboard_chorus_depth_max: 0.95
+ pedalboard_chorus_centre_delay_ms_min: 3
+ pedalboard_chorus_centre_delay_ms_max: 10
+ pedalboard_chorus_feedback_min: 0.0
+ pedalboard_chorus_feedback_max: 0.5
+ pedalboard_chorus_mix_min: 0.1
+ pedalboard_chorus_mix_max: 0.9
+
+ # pedalboard phazer block
+ pedalboard_phazer: 0.01
+ pedalboard_phazer_rate_hz_min: 1.0
+ pedalboard_phazer_rate_hz_max: 10.0
+ pedalboard_phazer_depth_min: 0.25
+ pedalboard_phazer_depth_max: 0.95
+ pedalboard_phazer_centre_frequency_hz_min: 200
+ pedalboard_phazer_centre_frequency_hz_max: 12000
+ pedalboard_phazer_feedback_min: 0.0
+ pedalboard_phazer_feedback_max: 0.5
+ pedalboard_phazer_mix_min: 0.1
+ pedalboard_phazer_mix_max: 0.9
+
+ # pedalboard distortion block
+ pedalboard_distortion: 0.01
+ pedalboard_distortion_drive_db_min: 1.0
+ pedalboard_distortion_drive_db_max: 25.0
+
+ # pedalboard pitch shift block
+ pedalboard_pitch_shift: 0.01
+ pedalboard_pitch_shift_semitones_min: -7
+ pedalboard_pitch_shift_semitones_max: 7
+
+ # pedalboard resample block
+ pedalboard_resample: 0.01
+ pedalboard_resample_target_sample_rate_min: 4000
+ pedalboard_resample_target_sample_rate_max: 44100
+
+ # pedalboard bitcrash block
+ pedalboard_bitcrash: 0.01
+ pedalboard_bitcrash_bit_depth_min: 4
+ pedalboard_bitcrash_bit_depth_max: 16
+
+ # pedalboard mp3 compressor block
+ pedalboard_mp3_compressor: 0.01
+ pedalboard_mp3_compressor_pedalboard_mp3_compressor_min: 0
+ pedalboard_mp3_compressor_pedalboard_mp3_compressor_max: 9.999
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ bass:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -2
+ pitch_shift_max_semitones: 2
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -3
+ seven_band_parametric_eq_max_gain_db: 6
+ tanh_distortion: 0.2
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.5
+ drums:
+ pitch_shift: 0.33
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.33
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.6
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.01
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_musdb18_mel_band_roformer.yaml b/configs/config_musdb18_mel_band_roformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f5b9e05544f74b53a61dd8256b29d91704ca4fc
--- /dev/null
+++ b/configs/config_musdb18_mel_band_roformer.yaml
@@ -0,0 +1,76 @@
+audio:
+ chunk_size: 131584
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 512
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ dim: 192
+ depth: 8
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ num_bands: 60
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: True
+ dim_freqs_in: 1025
+ sample_rate: 44100 # needed for mel filter bank from librosa
+ stft_n_fft: 2048
+ stft_hop_length: 512
+ stft_win_length: 2048
+ stft_normalized: False
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+ mlp_expansion_factor: 4 # Probably too big (requires a lot of memory for weights)
+ use_torch_checkpoint: False # it allows to greatly reduce GPU memory consumption during training (not fully tested)
+ skip_connection: False # Enable skip connection between transformer blocks - can solve problem with gradients and probably faster training
+
+training:
+ batch_size: 7
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - bass
+ - drums
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
diff --git a/configs/config_musdb18_mel_band_roformer_all_stems.yaml b/configs/config_musdb18_mel_band_roformer_all_stems.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9f7f323d40ca96c8359089d3258745078a6da2a9
--- /dev/null
+++ b/configs/config_musdb18_mel_band_roformer_all_stems.yaml
@@ -0,0 +1,97 @@
+audio:
+ chunk_size: 352800
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 441
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ dim: 384
+ depth: 6
+ stereo: true
+ num_stems: 4
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ num_bands: 60
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0
+ ff_dropout: 0
+ flash_attn: True
+ dim_freqs_in: 1025
+ sample_rate: 44100 # needed for mel filter bank from librosa
+ stft_n_fft: 2048
+ stft_hop_length: 441
+ stft_win_length: 2048
+ stft_normalized: False
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+ mlp_expansion_factor: 4 # Probably too big (requires a lot of memory for weights)
+ use_torch_checkpoint: False # it allows to greatly reduce GPU memory consumption during training (not fully tested)
+ skip_connection: False # Enable skip connection between transformer blocks - can solve problem with gradients and probably faster training
+
+training:
+ batch_size: 1
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - drums
+ - bass
+ - other
+ - vocals
+ lr: 1.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ augmentation: false # enable augmentations by audiomentations and pedalboard
+ augmentation_type: null
+ use_mp3_compress: false # Deprecated
+ augmentation_mix: false # Mix several stems of the same type with some probability
+ augmentation_loudness: false # randomly change loudness of each stem
+ augmentation_loudness_type: 1 # Type 1 or 2
+ augmentation_loudness_min: 0
+ augmentation_loudness_max: 0
+ q: 0.95
+ coarse_loss_clip: false
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+
+inference:
+ batch_size: 4
+ dim_t: 256
+ num_overlap: 2
\ No newline at end of file
diff --git a/configs/config_musdb18_scnet.yaml b/configs/config_musdb18_scnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e7dcafdd8d023938f3f680c8e107a18dba6c892b
--- /dev/null
+++ b/configs/config_musdb18_scnet.yaml
@@ -0,0 +1,83 @@
+audio:
+ chunk_size: 485100 # 44100 * 11
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ sources:
+ - drums
+ - bass
+ - other
+ - vocals
+ audio_channels: 2
+ dims:
+ - 4
+ - 32
+ - 64
+ - 128
+ nfft: 4096
+ hop_size: 1024
+ win_size: 4096
+ normalized: True
+ band_SR:
+ - 0.175
+ - 0.392
+ - 0.433
+ band_stride:
+ - 1
+ - 4
+ - 16
+ band_kernel:
+ - 3
+ - 4
+ - 16
+ conv_depths:
+ - 3
+ - 2
+ - 1
+ compress: 4
+ conv_kernel: 3
+ num_dplayer: 6
+ expand: 1
+
+training:
+ batch_size: 10
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - drums
+ - bass
+ - other
+ - vocals
+ lr: 5.0e-04
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 8
+ dim_t: 256
+ num_overlap: 4
+ normalize: true
diff --git a/configs/config_musdb18_scnet_large.yaml b/configs/config_musdb18_scnet_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..939ba190c8bb18ea782326c5c90b1c26f460cd36
--- /dev/null
+++ b/configs/config_musdb18_scnet_large.yaml
@@ -0,0 +1,83 @@
+audio:
+ chunk_size: 485100 # 44100 * 11
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ sources:
+ - drums
+ - bass
+ - other
+ - vocals
+ audio_channels: 2
+ dims:
+ - 4
+ - 64
+ - 128
+ - 256
+ nfft: 4096
+ hop_size: 1024
+ win_size: 4096
+ normalized: True
+ band_SR:
+ - 0.225
+ - 0.372
+ - 0.403
+ band_stride:
+ - 1
+ - 4
+ - 16
+ band_kernel:
+ - 3
+ - 4
+ - 16
+ conv_depths:
+ - 3
+ - 2
+ - 1
+ compress: 4
+ conv_kernel: 3
+ num_dplayer: 6
+ expand: 1
+
+training:
+ batch_size: 6
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - drums
+ - bass
+ - other
+ - vocals
+ lr: 5.0e-04
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 8
+ dim_t: 256
+ num_overlap: 4
+ normalize: false
diff --git a/configs/config_musdb18_segm_models.yaml b/configs/config_musdb18_segm_models.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cbec03910a628bd83c6f42f3984f5d9ba732a9fd
--- /dev/null
+++ b/configs/config_musdb18_segm_models.yaml
@@ -0,0 +1,92 @@
+audio:
+ chunk_size: 261632
+ dim_f: 4096
+ dim_t: 512
+ hop_length: 512
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ encoder_name: tu-maxvit_large_tf_512 # look here for possibilities: https://github.com/qubvel/segmentation_models.pytorch#encoders-
+ decoder_type: unet # unet, fpn
+ act: gelu
+ num_channels: 128
+ num_subbands: 8
+
+training:
+ batch_size: 7
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - bass
+ - drums
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 2000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adamw
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ # apply mp3 compression to mixture only (emulate downloading mp3 from internet)
+ mp3_compression_on_mixture: 0.01
+ mp3_compression_on_mixture_bitrate_min: 32
+ mp3_compression_on_mixture_bitrate_max: 320
+ mp3_compression_on_mixture_backend: "lameenc"
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+ mp3_compression: 0.01
+ mp3_compression_min_bitrate: 32
+ mp3_compression_max_bitrate: 320
+ mp3_compression_backend: "lameenc"
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.01
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+
+
+inference:
+ batch_size: 1
+ dim_t: 512
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_musdb18_torchseg.yaml b/configs/config_musdb18_torchseg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8de81fccd55e4946d6180a1382603cad27e2a7c0
--- /dev/null
+++ b/configs/config_musdb18_torchseg.yaml
@@ -0,0 +1,92 @@
+audio:
+ chunk_size: 261632
+ dim_f: 4096
+ dim_t: 512
+ hop_length: 512
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ encoder_name: maxvit_tiny_tf_512 # look with torchseg.list_encoders(). Currently 858 available
+ decoder_type: unet # unet, fpn
+ act: gelu
+ num_channels: 128
+ num_subbands: 8
+
+training:
+ batch_size: 18
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - bass
+ - drums
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 2000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adamw
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ # apply mp3 compression to mixture only (emulate downloading mp3 from internet)
+ mp3_compression_on_mixture: 0.01
+ mp3_compression_on_mixture_bitrate_min: 32
+ mp3_compression_on_mixture_bitrate_max: 320
+ mp3_compression_on_mixture_backend: "lameenc"
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+ mp3_compression: 0.01
+ mp3_compression_min_bitrate: 32
+ mp3_compression_max_bitrate: 320
+ mp3_compression_backend: "lameenc"
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.01
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+
+
+inference:
+ batch_size: 1
+ dim_t: 512
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_bandit_bsrnn_multi_mus64.yaml b/configs/config_vocals_bandit_bsrnn_multi_mus64.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..432ae32c19e6136806a718ca882afc516f2aa1f4
--- /dev/null
+++ b/configs/config_vocals_bandit_bsrnn_multi_mus64.yaml
@@ -0,0 +1,73 @@
+name: "MultiMaskMultiSourceBandSplitRNN"
+audio:
+ chunk_size: 264600
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ in_channel: 1
+ stems: ['vocals', 'other']
+ band_specs: "musical"
+ n_bands: 64
+ fs: 44100
+ require_no_overlap: false
+ require_no_gap: true
+ normalize_channel_independently: false
+ treat_channel_as_feature: true
+ n_sqm_modules: 8
+ emb_dim: 128
+ rnn_dim: 256
+ bidirectional: true
+ rnn_type: "GRU"
+ mlp_dim: 512
+ hidden_activation: "Tanh"
+ hidden_activation_kwargs: null
+ complex_mask: true
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+ window_fn: "hann_window"
+ wkwargs: null
+ power: null
+ center: true
+ normalized: true
+ pad_mode: "constant"
+ onesided: true
+
+training:
+ batch_size: 4
+ gradient_accumulation_steps: 4
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_bs_mamba2.yaml b/configs/config_vocals_bs_mamba2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..967e01172c1f99306f61053238100f8e18d34963
--- /dev/null
+++ b/configs/config_vocals_bs_mamba2.yaml
@@ -0,0 +1,51 @@
+audio:
+ chunk_size: 132300 # samplerate * segment
+ hop_length: 1024
+ min_mean_abs: 0.0
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 11
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['vocals', 'other']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: prodigy
+ lr: 1.0
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ read_metadata_procs: 8
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+model:
+ sr: 44100
+ win: 2048
+ stride: 512
+ feature_dim: 128
+ num_repeat_mask: 8
+ num_repeat_map: 4
+ num_output: 2
+
+augmentations:
+ enable: false # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: [0.2, 0.02]
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ num_overlap: 2
+ batch_size: 4
\ No newline at end of file
diff --git a/configs/config_vocals_bs_roformer.yaml b/configs/config_vocals_bs_roformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..98fcb290f86ff37f81f7e45b98c2ed1c14c02c2d
--- /dev/null
+++ b/configs/config_vocals_bs_roformer.yaml
@@ -0,0 +1,141 @@
+audio:
+ chunk_size: 131584
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 512
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ dim: 192
+ depth: 6
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ freqs_per_bands: !!python/tuple
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 128
+ - 129
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: true
+ dim_freqs_in: 1025
+ stft_n_fft: 2048
+ stft_hop_length: 512
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+ mlp_expansion_factor: 4 # Probably too big (requires a lot of memory for weights)
+ use_torch_checkpoint: False # it allows to greatly reduce GPU memory consumption during training (not fully tested)
+ skip_connection: False # Enable skip connection between transformer blocks - can solve problem with gradients and probably faster training
+
+training:
+ batch_size: 10
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_htdemucs.yaml b/configs/config_vocals_htdemucs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..796004a5cbd8a841963b5b41616ffd5cf8b247ea
--- /dev/null
+++ b/configs/config_vocals_htdemucs.yaml
@@ -0,0 +1,123 @@
+audio:
+ chunk_size: 485100 # samplerate * segment
+ min_mean_abs: 0.001
+ hop_length: 1024
+
+training:
+ batch_size: 10
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ segment: 11
+ shift: 1
+ samplerate: 44100
+ channels: 2
+ normalize: true
+ instruments: ['vocals', 'other']
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ optimizer: adam
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: [0.2, 0.02]
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ num_overlap: 2
+ batch_size: 8
+
+model: htdemucs
+
+htdemucs: # see demucs/htdemucs.py for a detailed description
+ # Channels
+ channels: 48
+ channels_time:
+ growth: 2
+ # STFT
+ num_subbands: 1
+ nfft: 4096
+ wiener_iters: 0
+ end_iters: 0
+ wiener_residual: false
+ cac: true
+ # Main structure
+ depth: 4
+ rewrite: true
+ # Frequency Branch
+ multi_freqs: []
+ multi_freqs_depth: 3
+ freq_emb: 0.2
+ emb_scale: 10
+ emb_smooth: true
+ # Convolutions
+ kernel_size: 8
+ stride: 4
+ time_stride: 2
+ context: 1
+ context_enc: 0
+ # normalization
+ norm_starts: 4
+ norm_groups: 4
+ # DConv residual branch
+ dconv_mode: 3
+ dconv_depth: 2
+ dconv_comp: 8
+ dconv_init: 1e-3
+ # Before the Transformer
+ bottom_channels: 512
+ # CrossTransformer
+ # ------ Common to all
+ # Regular parameters
+ t_layers: 5
+ t_hidden_scale: 4.0
+ t_heads: 8
+ t_dropout: 0.0
+ t_layer_scale: True
+ t_gelu: True
+ # ------------- Positional Embedding
+ t_emb: sin
+ t_max_positions: 10000 # for the scaled embedding
+ t_max_period: 10000.0
+ t_weight_pos_embed: 1.0
+ t_cape_mean_normalize: True
+ t_cape_augment: True
+ t_cape_glob_loc_scale: [5000.0, 1.0, 1.4]
+ t_sin_random_shift: 0
+ # ------------- norm before a transformer encoder
+ t_norm_in: True
+ t_norm_in_group: False
+ # ------------- norm inside the encoder
+ t_group_norm: False
+ t_norm_first: True
+ t_norm_out: True
+ # ------------- optim
+ t_weight_decay: 0.0
+ t_lr:
+ # ------------- sparsity
+ t_sparse_self_attn: False
+ t_sparse_cross_attn: False
+ t_mask_type: diag
+ t_mask_random_seed: 42
+ t_sparse_attn_window: 400
+ t_global_window: 100
+ t_sparsity: 0.95
+ t_auto_sparsity: False
+ # Cross Encoder First (False)
+ t_cross_first: False
+ # Weight init
+ rescale: 0.1
+
diff --git a/configs/config_vocals_mdx23c.yaml b/configs/config_vocals_mdx23c.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b241ad5da8e4e4cdeca43a3f09ec64961321ce13
--- /dev/null
+++ b/configs/config_vocals_mdx23c.yaml
@@ -0,0 +1,96 @@
+audio:
+ chunk_size: 261120
+ dim_f: 4096
+ dim_t: 256
+ hop_length: 1024
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ act: gelu
+ bottleneck_factor: 4
+ growth: 128
+ norm: InstanceNorm
+ num_blocks_per_scale: 2
+ num_channels: 128
+ num_scales: 5
+ num_subbands: 4
+ scale:
+ - 2
+ - 2
+
+training:
+ batch_size: 6
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 9.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ read_metadata_procs: 8 # Number of processes to use during metadata reading for dataset. Can speed up metadata generation
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ # apply mp3 compression to mixture only (emulate downloading mp3 from internet)
+ mp3_compression_on_mixture: 0.01
+ mp3_compression_on_mixture_bitrate_min: 32
+ mp3_compression_on_mixture_bitrate_max: 320
+ mp3_compression_on_mixture_backend: "lameenc"
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+ mp3_compression: 0.01
+ mp3_compression_min_bitrate: 32
+ mp3_compression_max_bitrate: 320
+ mp3_compression_backend: "lameenc"
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.01
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_mel_band_roformer.yaml b/configs/config_vocals_mel_band_roformer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4fa4385aab64f0cb2c6a424bb6c63bd954f780e0
--- /dev/null
+++ b/configs/config_vocals_mel_band_roformer.yaml
@@ -0,0 +1,80 @@
+audio:
+ chunk_size: 131584
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 512
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ dim: 192
+ depth: 8
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ num_bands: 60
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: True
+ dim_freqs_in: 1025
+ sample_rate: 44100 # needed for mel filter bank from librosa
+ stft_n_fft: 2048
+ stft_hop_length: 512
+ stft_win_length: 2048
+ stft_normalized: False
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+ mlp_expansion_factor: 4 # Probably too big (requires a lot of memory for weights)
+ use_torch_checkpoint: False # it allows to greatly reduce GPU memory consumption during training (not fully tested)
+ skip_connection: False # Enable skip connection between transformer blocks - can solve problem with gradients and probably faster training
+
+training:
+ batch_size: 7
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 256
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_scnet.yaml b/configs/config_vocals_scnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1210f9654a6e09b0965937b31e9dcaeaaf2257a0
--- /dev/null
+++ b/configs/config_vocals_scnet.yaml
@@ -0,0 +1,79 @@
+audio:
+ chunk_size: 485100 # 44100 * 11
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ sources:
+ - vocals
+ - other
+ audio_channels: 2
+ dims:
+ - 4
+ - 32
+ - 64
+ - 128
+ nfft: 4096
+ hop_size: 1024
+ win_size: 4096
+ normalized: True
+ band_SR:
+ - 0.175
+ - 0.392
+ - 0.433
+ band_stride:
+ - 1
+ - 4
+ - 16
+ band_kernel:
+ - 3
+ - 4
+ - 16
+ conv_depths:
+ - 3
+ - 2
+ - 1
+ compress: 4
+ conv_kernel: 3
+ num_dplayer: 6
+ expand: 1
+
+training:
+ batch_size: 10
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 5.0e-04
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 10
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 8
+ dim_t: 256
+ num_overlap: 4
+ normalize: false
diff --git a/configs/config_vocals_scnet_large.yaml b/configs/config_vocals_scnet_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f34450eb84d2d2b072577504edbf2948efb94158
--- /dev/null
+++ b/configs/config_vocals_scnet_large.yaml
@@ -0,0 +1,79 @@
+audio:
+ chunk_size: 485100 # 44100 * 11
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ sources:
+ - vocals
+ - other
+ audio_channels: 2
+ dims:
+ - 4
+ - 64
+ - 128
+ - 256
+ nfft: 4096
+ hop_size: 1024
+ win_size: 4096
+ normalized: True
+ band_SR:
+ - 0.225
+ - 0.372
+ - 0.403
+ band_stride:
+ - 1
+ - 4
+ - 16
+ band_kernel:
+ - 3
+ - 4
+ - 16
+ conv_depths:
+ - 3
+ - 2
+ - 1
+ compress: 4
+ conv_kernel: 3
+ num_dplayer: 6
+ expand: 1
+
+training:
+ batch_size: 6
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 1.0e-04
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: false # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 8
+ dim_t: 256
+ num_overlap: 4
+ normalize: false
diff --git a/configs/config_vocals_scnet_unofficial.yaml b/configs/config_vocals_scnet_unofficial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d3e604e4992e1f4090da91227c9ecc5e66e9117
--- /dev/null
+++ b/configs/config_vocals_scnet_unofficial.yaml
@@ -0,0 +1,62 @@
+audio:
+ chunk_size: 264600
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ dims: [4, 32, 64, 128]
+ bandsplit_ratios: [.175, .392, .433]
+ downsample_strides: [1, 4, 16]
+ n_conv_modules: [3, 2, 1]
+ n_rnn_layers: 6
+ rnn_hidden_dim: 128
+ n_sources: 2
+
+ n_fft: 4096
+ hop_length: 1024
+ win_length: 4096
+ stft_normalized: false
+
+ use_mamba: false
+ d_state: 16
+ d_conv: 4
+ d_expand: 2
+
+training:
+ batch_size: 10
+ gradient_accumulation_steps: 2
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 5.0e-04
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs:
+ !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 8
+ dim_t: 256
+ num_overlap: 4
diff --git a/configs/config_vocals_segm_models.yaml b/configs/config_vocals_segm_models.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..44711a0658a95289c8d3745a6d78114b937df1fa
--- /dev/null
+++ b/configs/config_vocals_segm_models.yaml
@@ -0,0 +1,78 @@
+audio:
+ chunk_size: 261632
+ dim_f: 4096
+ dim_t: 512
+ hop_length: 512
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ encoder_name: tu-maxvit_large_tf_512 # look here for possibilities: https://github.com/qubvel/segmentation_models.pytorch#encoders-
+ decoder_type: unet # unet, fpn
+ act: gelu
+ num_channels: 128
+ num_subbands: 8
+
+loss_multistft:
+ fft_sizes:
+ - 1024
+ - 2048
+ - 4096
+ hop_sizes:
+ - 512
+ - 1024
+ - 2048
+ win_lengths:
+ - 1024
+ - 2048
+ - 4096
+ window: "hann_window"
+ scale: "mel"
+ n_bins: 128
+ sample_rate: 44100
+ perceptual_weighting: true
+ w_sc: 1.0
+ w_log_mag: 1.0
+ w_lin_mag: 0.0
+ w_phs: 0.0
+ mag_distance: "L1"
+
+
+training:
+ batch_size: 8
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 2000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adamw
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 512
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_swin_upernet.yaml b/configs/config_vocals_swin_upernet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..28a837346ea70fa02a1df5245c5048a4676b63c7
--- /dev/null
+++ b/configs/config_vocals_swin_upernet.yaml
@@ -0,0 +1,51 @@
+audio:
+ chunk_size: 261632
+ dim_f: 4096
+ dim_t: 512
+ hop_length: 512
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ act: gelu
+ num_channels: 16
+ num_subbands: 8
+
+training:
+ batch_size: 14
+ gradient_accumulation_steps: 4
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 3.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adamw
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 1
+ dim_t: 512
+ num_overlap: 4
\ No newline at end of file
diff --git a/configs/config_vocals_torchseg.yaml b/configs/config_vocals_torchseg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ebbeae2770d68c4b2dbdcfd6125a5ca387e6d9b
--- /dev/null
+++ b/configs/config_vocals_torchseg.yaml
@@ -0,0 +1,58 @@
+audio:
+ chunk_size: 261632
+ dim_f: 4096
+ dim_t: 512
+ hop_length: 512
+ n_fft: 8192
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ encoder_name: maxvit_tiny_tf_512 # look with torchseg.list_encoders(). Currently 858 available
+ decoder_type: unet # unet, fpn
+ act: gelu
+ num_channels: 128
+ num_subbands: 8
+
+training:
+ batch_size: 18
+ gradient_accumulation_steps: 1
+ grad_clip: 1.0
+ instruments:
+ - vocals
+ - other
+ lr: 1.0e-04
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: null
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: radam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: false # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+
+inference:
+ batch_size: 8
+ dim_t: 512
+ num_overlap: 2
\ No newline at end of file
diff --git a/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml b/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..135a051897dee27285ac46ee350afe1e1ec02011
--- /dev/null
+++ b/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml
@@ -0,0 +1,126 @@
+audio:
+ chunk_size: 352800
+ dim_f: 1024
+ dim_t: 801 # don't work (use in model)
+ hop_length: 441 # don't work (use in model)
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ dim: 512
+ depth: 12
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ freqs_per_bands: !!python/tuple
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 128
+ - 129
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: true
+ dim_freqs_in: 1025
+ stft_n_fft: 2048
+ stft_hop_length: 441
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+
+training:
+ batch_size: 2
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 1.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+inference:
+ batch_size: 4
+ dim_t: 801
+ num_overlap: 2
\ No newline at end of file
diff --git a/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml b/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d5e9a0b670759dd378af60e09e0a5e3c650cbf7c
--- /dev/null
+++ b/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml
@@ -0,0 +1,138 @@
+audio:
+ chunk_size: 131584
+ dim_f: 1024
+ dim_t: 256
+ hop_length: 512
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.001
+
+model:
+ dim: 384
+ depth: 12
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ freqs_per_bands: !!python/tuple
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 128
+ - 129
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: true
+ dim_freqs_in: 1025
+ stft_n_fft: 2048
+ stft_hop_length: 512
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+
+training:
+ batch_size: 4
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 5.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: other
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+inference:
+ batch_size: 8
+ dim_t: 512
+ num_overlap: 2
\ No newline at end of file
diff --git a/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml b/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7cb922c9c06076e382826decc017ca9d760b9623
--- /dev/null
+++ b/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml
@@ -0,0 +1,65 @@
+audio:
+ chunk_size: 352800
+ dim_f: 1024
+ dim_t: 801 # don't work (use in model)
+ hop_length: 441 # don't work (use in model)
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ dim: 384
+ depth: 12
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ num_bands: 60
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: True
+ dim_freqs_in: 1025
+ sample_rate: 44100 # needed for mel filter bank from librosa
+ stft_n_fft: 2048
+ stft_hop_length: 441
+ stft_win_length: 2048
+ stft_normalized: False
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+
+training:
+ batch_size: 1
+ gradient_accumulation_steps: 8
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 4.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+
+inference:
+ batch_size: 4
+ dim_t: 801
+ num_overlap: 2
\ No newline at end of file
diff --git a/cookies.txt b/cookies.txt
new file mode 100644
index 0000000000000000000000000000000000000000..98d46a84e6470fde4e12aba797faf50b57431ce1
--- /dev/null
+++ b/cookies.txt
@@ -0,0 +1,28 @@
+# Netscape HTTP Cookie File
+# This file is generated by yt-dlp. Do not edit.
+
+.youtube.com TRUE / FALSE 1756513080 HSID AkpR1gV80KfDyBAeq
+.youtube.com TRUE / TRUE 1756513080 SSID AUkQz9BsAZ9dihvT7
+.youtube.com TRUE / FALSE 1756513080 APISID FPGwyoC5hlxA_ztn/ADM7Q4t2tMF9LolFH
+.youtube.com TRUE / TRUE 1756513080 SAPISID 4yc1vubX-H2x2gTg/A4eb_29p67eyBKNwo
+.youtube.com TRUE / TRUE 1756513080 __Secure-1PAPISID 4yc1vubX-H2x2gTg/A4eb_29p67eyBKNwo
+.youtube.com TRUE / TRUE 1756513080 __Secure-3PAPISID 4yc1vubX-H2x2gTg/A4eb_29p67eyBKNwo
+.youtube.com TRUE / FALSE 0 PREF f4=4000000&tz=UTC&f7=100&f6=40000000&hl=en
+.youtube.com TRUE / FALSE 1756513080 SID g.a000uQhpMC7F759FwOg4eAHYr_VFV7qLJJzrVdnrbB1Gg1ruHpzr7Q7JXHasofNz_IFpc8N2LgACgYKAfYSARISFQHGX2Miyesv7_oABGm-5jwErW1A3BoVAUF8yKrc9rgHmp5qJT6VRm79tW1A0076
+.youtube.com TRUE / TRUE 1756513080 __Secure-1PSID g.a000uQhpMC7F759FwOg4eAHYr_VFV7qLJJzrVdnrbB1Gg1ruHpzrr6x28jzJl8SymGUnS601CAACgYKAVESARISFQHGX2MiMrjFi53JrIU9q__AtJkTHhoVAUF8yKq9HTb9EMf-IuIKrE24vlao0076
+.youtube.com TRUE / TRUE 1756513080 __Secure-3PSID g.a000uQhpMC7F759FwOg4eAHYr_VFV7qLJJzrVdnrbB1Gg1ruHpzrggFtS3EfdibObQagLMZPwgACgYKAS0SARISFQHGX2MiVtkRAB_snp0m6Ci8U8_KdxoVAUF8yKpL1TslRsnn1zHR9IM89xyI0076
+.youtube.com TRUE / TRUE 0 wide 0
+.youtube.com TRUE / TRUE 1756638167 __Secure-1PSIDTS sidts-CjIBEJ3XV-avYWfDaATyg0Nhkmwux6CKyFaF1gYPa-AjJzR_e3PPij4K2ft8TRk6khgu2xAA
+.youtube.com TRUE / TRUE 1756638167 __Secure-3PSIDTS sidts-CjIBEJ3XV-avYWfDaATyg0Nhkmwux6CKyFaF1gYPa-AjJzR_e3PPij4K2ft8TRk6khgu2xAA
+.youtube.com TRUE / TRUE 1756638341 LOGIN_INFO AFmmF2swRgIhAKDOVmKULP27JwVcR_zerOJpO9GmXntRWiR4zWAazwz_AiEAwvt5os697PYAjWwVLGwA5oN3mFBrA1kh_4AlSuvoE-M:QUQ3MjNmem5mN3p3NVNiM2hzMEJ0R1EwUzI2SFNDXzhlQlNXemF2Z2IyZER1cmt1VXZSbk5EbkpaekFySGw4a09MQ3Z6c3RhSFhoXzFUNE9mdHB4ZEdPTFgzNEZrMTB2SWd2azlTdi1SUTdZWGczMEpRb3otemhHZ08wZDlzc0dhRE1sM2tJUTBfSkNiSmpuOTBXdG13eDhsR2JTVVlVQWZB
+.youtube.com TRUE / TRUE 1741086952 CONSISTENCY AKreu9tDDatJYzjErs5c0WuYZjTQFRMZu7GKaDYzvFqROdgjqcrkvrsWqoTI1zZioac6yVWq7BCSzc1y0Pk0j8ikhC_l9YEyMmQs14Kg3IHcli61swZK3uWn
+.youtube.com TRUE / FALSE 1756638353 SIDCC AKEyXzWXY14_kKbTxyT3AyRMPsKsEGUsuIHbGutwC42o1YlZS06ch-ug7SyZAYQ7jEDVx5EDfw
+.youtube.com TRUE / TRUE 1756638353 __Secure-1PSIDCC AKEyXzX0UJz8MZ6u_9s7hOlSPGjbu-JwY0Q1l77e5oO5CJTXNIDO95oxQyCdFaP2D-4qbJrCI1I
+.youtube.com TRUE / TRUE 1756638353 __Secure-3PSIDCC AKEyXzU5TOnpQc0o7qGCur58CMCcshJ2tsoLi9rVwsER2dK2P22VqU3jYG0yMz0LsNkMxjbXeg
+.youtube.com TRUE / TRUE 0 YSC tBB8nN6HoE0
+.youtube.com TRUE / TRUE 1756638392 __Secure-ROLLOUT_TOKEN CMe15b6p2vn47QEQsKafn6TwiwMYru_Yn6TwiwM%3D
+.youtube.com TRUE / TRUE 1756638392 VISITOR_INFO1_LIVE 54yW_8GrQNM
+.youtube.com TRUE / TRUE 1756638392 VISITOR_PRIVACY_METADATA CgJVUxIEGgAgDA%3D%3D
+.youtube.com TRUE / TRUE 1756638392 YT_DEVICE_MEASUREMENT_ID 3rgPq1s=
+.youtube.com TRUE / TRUE 1804158392 __Secure-YT_TVFAS t=483635&s=2
+.youtube.com TRUE / TRUE 1756638392 DEVICE_INFO ChxOelEzTnprd09URXhNemMzTXpRNE1qZ3lNQT09ELi9m74GGLi9m74G
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..444a97ad1b3069931cd6c76ad775723c50ffbb89
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,669 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+
+
+import os
+import random
+import numpy as np
+import torch
+import soundfile as sf
+import pickle
+import time
+import itertools
+import multiprocessing
+from tqdm.auto import tqdm
+from glob import glob
+import audiomentations as AU
+import pedalboard as PB
+import warnings
+warnings.filterwarnings("ignore")
+
+
+def load_chunk(path, length, chunk_size, offset=None):
+ if chunk_size <= length:
+ if offset is None:
+ offset = np.random.randint(length - chunk_size + 1)
+ x = sf.read(path, dtype='float32', start=offset, frames=chunk_size)[0]
+ else:
+ x = sf.read(path, dtype='float32')[0]
+ if len(x.shape) == 1:
+ # Mono case
+ pad = np.zeros((chunk_size - length))
+ else:
+ pad = np.zeros([chunk_size - length, x.shape[-1]])
+ x = np.concatenate([x, pad], axis=0)
+ # Mono fix
+ if len(x.shape) == 1:
+ x = np.expand_dims(x, axis=1)
+ return x.T
+
+
+def get_track_set_length(params):
+ path, instruments, file_types = params
+ # Check lengths of all instruments (it can be different in some cases)
+ lengths_arr = []
+ for instr in instruments:
+ length = -1
+ for extension in file_types:
+ path_to_audio_file = path + '/{}.{}'.format(instr, extension)
+ if os.path.isfile(path_to_audio_file):
+ length = len(sf.read(path_to_audio_file)[0])
+ break
+ if length == -1:
+ print('Cant find file "{}" in folder {}'.format(instr, path))
+ continue
+ lengths_arr.append(length)
+ lengths_arr = np.array(lengths_arr)
+ if lengths_arr.min() != lengths_arr.max():
+ print('Warning: lengths of stems are different for path: {}. ({} != {})'.format(
+ path,
+ lengths_arr.min(),
+ lengths_arr.max())
+ )
+ # We use minimum to allow overflow for soundfile read in non-equal length cases
+ return path, lengths_arr.min()
+
+
+# For multiprocessing
+def get_track_length(params):
+ path = params
+ length = len(sf.read(path)[0])
+ return (path, length)
+
+
+class MSSDataset(torch.utils.data.Dataset):
+ def __init__(self, config, data_path, metadata_path="metadata.pkl", dataset_type=1, batch_size=None, verbose=True):
+ self.verbose = verbose
+ self.config = config
+ self.dataset_type = dataset_type # 1, 2, 3 or 4
+ self.data_path = data_path
+ self.instruments = instruments = config.training.instruments
+ if batch_size is None:
+ batch_size = config.training.batch_size
+ self.batch_size = batch_size
+ self.file_types = ['wav', 'flac']
+ self.metadata_path = metadata_path
+
+ # Augmentation block
+ self.aug = False
+ if 'augmentations' in config:
+ if config['augmentations'].enable is True:
+ if self.verbose:
+ print('Use augmentation for training')
+ self.aug = True
+ else:
+ if self.verbose:
+ print('There is no augmentations block in config. Augmentations disabled for training...')
+
+ metadata = self.get_metadata()
+
+ if self.dataset_type in [1, 4]:
+ if len(metadata) > 0:
+ if self.verbose:
+ print('Found tracks in dataset: {}'.format(len(metadata)))
+ else:
+ print('No tracks found for training. Check paths you provided!')
+ exit()
+ else:
+ for instr in self.instruments:
+ if self.verbose:
+ print('Found tracks for {} in dataset: {}'.format(instr, len(metadata[instr])))
+ self.metadata = metadata
+ self.chunk_size = config.audio.chunk_size
+ self.min_mean_abs = config.audio.min_mean_abs
+
+ def __len__(self):
+ return self.config.training.num_steps * self.batch_size
+
+ def read_from_metadata_cache(self, track_paths, instr=None):
+ metadata = []
+ if os.path.isfile(self.metadata_path):
+ if self.verbose:
+ print('Found metadata cache file: {}'.format(self.metadata_path))
+ old_metadata = pickle.load(open(self.metadata_path, 'rb'))
+ else:
+ return track_paths, metadata
+
+ if instr:
+ old_metadata = old_metadata[instr]
+
+ # We will not re-read tracks existed in old metadata file
+ track_paths_set = set(track_paths)
+ for old_path, file_size in old_metadata:
+ if old_path in track_paths_set:
+ metadata.append([old_path, file_size])
+ track_paths_set.remove(old_path)
+ track_paths = list(track_paths_set)
+ if len(metadata) > 0:
+ print('Old metadata was used for {} tracks.'.format(len(metadata)))
+ return track_paths, metadata
+
+
+ def get_metadata(self):
+ read_metadata_procs = multiprocessing.cpu_count()
+ if 'read_metadata_procs' in self.config['training']:
+ read_metadata_procs = int(self.config['training']['read_metadata_procs'])
+
+ if self.verbose:
+ print(
+ 'Dataset type:', self.dataset_type,
+ 'Processes to use:', read_metadata_procs,
+ '\nCollecting metadata for', str(self.data_path),
+ )
+
+ if self.dataset_type in [1, 4]:
+ track_paths = []
+ if type(self.data_path) == list:
+ for tp in self.data_path:
+ tracks_for_folder = sorted(glob(tp + '/*'))
+ if len(tracks_for_folder) == 0:
+ print('Warning: no tracks found in folder \'{}\'. Please check it!'.format(tp))
+ track_paths += tracks_for_folder
+ else:
+ track_paths += sorted(glob(self.data_path + '/*'))
+
+ track_paths = [path for path in track_paths if os.path.basename(path)[0] != '.' and os.path.isdir(path)]
+ track_paths, metadata = self.read_from_metadata_cache(track_paths, None)
+
+ if read_metadata_procs <= 1:
+ for path in tqdm(track_paths):
+ track_path, track_length = get_track_set_length((path, self.instruments, self.file_types))
+ metadata.append((track_path, track_length))
+ else:
+ p = multiprocessing.Pool(processes=read_metadata_procs)
+ with tqdm(total=len(track_paths)) as pbar:
+ track_iter = p.imap(
+ get_track_set_length,
+ zip(track_paths, itertools.repeat(self.instruments), itertools.repeat(self.file_types))
+ )
+ for track_path, track_length in track_iter:
+ metadata.append((track_path, track_length))
+ pbar.update()
+ p.close()
+
+ elif self.dataset_type == 2:
+ metadata = dict()
+ for instr in self.instruments:
+ metadata[instr] = []
+ track_paths = []
+ if type(self.data_path) == list:
+ for tp in self.data_path:
+ track_paths += sorted(glob(tp + '/{}/*.wav'.format(instr)))
+ track_paths += sorted(glob(tp + '/{}/*.flac'.format(instr)))
+ else:
+ track_paths += sorted(glob(self.data_path + '/{}/*.wav'.format(instr)))
+ track_paths += sorted(glob(self.data_path + '/{}/*.flac'.format(instr)))
+
+ track_paths, metadata[instr] = self.read_from_metadata_cache(track_paths, instr)
+
+ if read_metadata_procs <= 1:
+ for path in tqdm(track_paths):
+ length = len(sf.read(path)[0])
+ metadata[instr].append((path, length))
+ else:
+ p = multiprocessing.Pool(processes=read_metadata_procs)
+ for out in tqdm(p.imap(get_track_length, track_paths), total=len(track_paths)):
+ metadata[instr].append(out)
+
+ elif self.dataset_type == 3:
+ import pandas as pd
+ if type(self.data_path) != list:
+ data_path = [self.data_path]
+
+ metadata = dict()
+ for i in range(len(self.data_path)):
+ if self.verbose:
+ print('Reading tracks from: {}'.format(self.data_path[i]))
+ df = pd.read_csv(self.data_path[i])
+
+ skipped = 0
+ for instr in self.instruments:
+ part = df[df['instrum'] == instr].copy()
+ print('Tracks found for {}: {}'.format(instr, len(part)))
+ for instr in self.instruments:
+ part = df[df['instrum'] == instr].copy()
+ metadata[instr] = []
+ track_paths = list(part['path'].values)
+ track_paths, metadata[instr] = self.read_from_metadata_cache(track_paths, instr)
+
+ for path in tqdm(track_paths):
+ if not os.path.isfile(path):
+ print('Cant find track: {}'.format(path))
+ skipped += 1
+ continue
+ # print(path)
+ try:
+ length = len(sf.read(path)[0])
+ except:
+ print('Problem with path: {}'.format(path))
+ skipped += 1
+ continue
+ metadata[instr].append((path, length))
+ if skipped > 0:
+ print('Missing tracks: {} from {}'.format(skipped, len(df)))
+ else:
+ print('Unknown dataset type: {}. Must be 1, 2, 3 or 4'.format(self.dataset_type))
+ exit()
+
+ # Save metadata
+ pickle.dump(metadata, open(self.metadata_path, 'wb'))
+ return metadata
+
+ def load_source(self, metadata, instr):
+ while True:
+ if self.dataset_type in [1, 4]:
+ track_path, track_length = random.choice(metadata)
+ for extension in self.file_types:
+ path_to_audio_file = track_path + '/{}.{}'.format(instr, extension)
+ if os.path.isfile(path_to_audio_file):
+ try:
+ source = load_chunk(path_to_audio_file, track_length, self.chunk_size)
+ except Exception as e:
+ # Sometimes error during FLAC reading, catch it and use zero stem
+ print('Error: {} Path: {}'.format(e, path_to_audio_file))
+ source = np.zeros((2, self.chunk_size), dtype=np.float32)
+ break
+ else:
+ track_path, track_length = random.choice(metadata[instr])
+ try:
+ source = load_chunk(track_path, track_length, self.chunk_size)
+ except Exception as e:
+ # Sometimes error during FLAC reading, catch it and use zero stem
+ print('Error: {} Path: {}'.format(e, track_path))
+ source = np.zeros((2, self.chunk_size), dtype=np.float32)
+
+ if np.abs(source).mean() >= self.min_mean_abs: # remove quiet chunks
+ break
+ if self.aug:
+ source = self.augm_data(source, instr)
+ return torch.tensor(source, dtype=torch.float32)
+
+ def load_random_mix(self):
+ res = []
+ for instr in self.instruments:
+ s1 = self.load_source(self.metadata, instr)
+ # Mixup augmentation. Multiple mix of same type of stems
+ if self.aug:
+ if 'mixup' in self.config['augmentations']:
+ if self.config['augmentations'].mixup:
+ mixup = [s1]
+ for prob in self.config.augmentations.mixup_probs:
+ if random.uniform(0, 1) < prob:
+ s2 = self.load_source(self.metadata, instr)
+ mixup.append(s2)
+ mixup = torch.stack(mixup, dim=0)
+ loud_values = np.random.uniform(
+ low=self.config.augmentations.loudness_min,
+ high=self.config.augmentations.loudness_max,
+ size=(len(mixup),)
+ )
+ loud_values = torch.tensor(loud_values, dtype=torch.float32)
+ mixup *= loud_values[:, None, None]
+ s1 = mixup.mean(dim=0, dtype=torch.float32)
+ res.append(s1)
+ res = torch.stack(res)
+ return res
+
+ def load_aligned_data(self):
+ track_path, track_length = random.choice(self.metadata)
+ attempts = 10
+ while attempts:
+ if track_length >= self.chunk_size:
+ common_offset = np.random.randint(track_length - self.chunk_size + 1)
+ else:
+ common_offset = None
+ res = []
+ silent_chunks = 0
+ for i in self.instruments:
+ for extension in self.file_types:
+ path_to_audio_file = track_path + '/{}.{}'.format(i, extension)
+ if os.path.isfile(path_to_audio_file):
+ try:
+ source = load_chunk(path_to_audio_file, track_length, self.chunk_size, offset=common_offset)
+ except Exception as e:
+ # Sometimes error during FLAC reading, catch it and use zero stem
+ print('Error: {} Path: {}'.format(e, path_to_audio_file))
+ source = np.zeros((2, self.chunk_size), dtype=np.float32)
+ break
+ res.append(source)
+ if np.abs(source).mean() < self.min_mean_abs: # remove quiet chunks
+ silent_chunks += 1
+ if silent_chunks == 0:
+ break
+
+ attempts -= 1
+ if attempts <= 0:
+ print('Attempts max!', track_path)
+ if common_offset is None:
+ # If track is too small break immediately
+ break
+
+ res = np.stack(res, axis=0)
+ if self.aug:
+ for i, instr in enumerate(self.instruments):
+ res[i] = self.augm_data(res[i], instr)
+ return torch.tensor(res, dtype=torch.float32)
+
+ def augm_data(self, source, instr):
+ # source.shape = (2, 261120) - first channels, second length
+ source_shape = source.shape
+ applied_augs = []
+ if 'all' in self.config['augmentations']:
+ augs = self.config['augmentations']['all']
+ else:
+ augs = dict()
+
+ # We need to add to all augmentations specific augs for stem. And rewrite values if needed
+ if instr in self.config['augmentations']:
+ for el in self.config['augmentations'][instr]:
+ augs[el] = self.config['augmentations'][instr][el]
+
+ # Channel shuffle
+ if 'channel_shuffle' in augs:
+ if augs['channel_shuffle'] > 0:
+ if random.uniform(0, 1) < augs['channel_shuffle']:
+ source = source[::-1].copy()
+ applied_augs.append('channel_shuffle')
+ # Random inverse
+ if 'random_inverse' in augs:
+ if augs['random_inverse'] > 0:
+ if random.uniform(0, 1) < augs['random_inverse']:
+ source = source[:, ::-1].copy()
+ applied_augs.append('random_inverse')
+ # Random polarity (multiply -1)
+ if 'random_polarity' in augs:
+ if augs['random_polarity'] > 0:
+ if random.uniform(0, 1) < augs['random_polarity']:
+ source = -source.copy()
+ applied_augs.append('random_polarity')
+ # Random pitch shift
+ if 'pitch_shift' in augs:
+ if augs['pitch_shift'] > 0:
+ if random.uniform(0, 1) < augs['pitch_shift']:
+ apply_aug = AU.PitchShift(
+ min_semitones=augs['pitch_shift_min_semitones'],
+ max_semitones=augs['pitch_shift_max_semitones'],
+ p=1.0
+ )
+ source = apply_aug(samples=source, sample_rate=44100)
+ applied_augs.append('pitch_shift')
+ # Random seven band parametric eq
+ if 'seven_band_parametric_eq' in augs:
+ if augs['seven_band_parametric_eq'] > 0:
+ if random.uniform(0, 1) < augs['seven_band_parametric_eq']:
+ apply_aug = AU.SevenBandParametricEQ(
+ min_gain_db=augs['seven_band_parametric_eq_min_gain_db'],
+ max_gain_db=augs['seven_band_parametric_eq_max_gain_db'],
+ p=1.0
+ )
+ source = apply_aug(samples=source, sample_rate=44100)
+ applied_augs.append('seven_band_parametric_eq')
+ # Random tanh distortion
+ if 'tanh_distortion' in augs:
+ if augs['tanh_distortion'] > 0:
+ if random.uniform(0, 1) < augs['tanh_distortion']:
+ apply_aug = AU.TanhDistortion(
+ min_distortion=augs['tanh_distortion_min'],
+ max_distortion=augs['tanh_distortion_max'],
+ p=1.0
+ )
+ source = apply_aug(samples=source, sample_rate=44100)
+ applied_augs.append('tanh_distortion')
+ # Random MP3 Compression
+ if 'mp3_compression' in augs:
+ if augs['mp3_compression'] > 0:
+ if random.uniform(0, 1) < augs['mp3_compression']:
+ apply_aug = AU.Mp3Compression(
+ min_bitrate=augs['mp3_compression_min_bitrate'],
+ max_bitrate=augs['mp3_compression_max_bitrate'],
+ backend=augs['mp3_compression_backend'],
+ p=1.0
+ )
+ source = apply_aug(samples=source, sample_rate=44100)
+ applied_augs.append('mp3_compression')
+ # Random AddGaussianNoise
+ if 'gaussian_noise' in augs:
+ if augs['gaussian_noise'] > 0:
+ if random.uniform(0, 1) < augs['gaussian_noise']:
+ apply_aug = AU.AddGaussianNoise(
+ min_amplitude=augs['gaussian_noise_min_amplitude'],
+ max_amplitude=augs['gaussian_noise_max_amplitude'],
+ p=1.0
+ )
+ source = apply_aug(samples=source, sample_rate=44100)
+ applied_augs.append('gaussian_noise')
+ # Random TimeStretch
+ if 'time_stretch' in augs:
+ if augs['time_stretch'] > 0:
+ if random.uniform(0, 1) < augs['time_stretch']:
+ apply_aug = AU.TimeStretch(
+ min_rate=augs['time_stretch_min_rate'],
+ max_rate=augs['time_stretch_max_rate'],
+ leave_length_unchanged=True,
+ p=1.0
+ )
+ source = apply_aug(samples=source, sample_rate=44100)
+ applied_augs.append('time_stretch')
+
+ # Possible fix of shape
+ if source_shape != source.shape:
+ source = source[..., :source_shape[-1]]
+
+ # Random Reverb
+ if 'pedalboard_reverb' in augs:
+ if augs['pedalboard_reverb'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_reverb']:
+ room_size = random.uniform(
+ augs['pedalboard_reverb_room_size_min'],
+ augs['pedalboard_reverb_room_size_max'],
+ )
+ damping = random.uniform(
+ augs['pedalboard_reverb_damping_min'],
+ augs['pedalboard_reverb_damping_max'],
+ )
+ wet_level = random.uniform(
+ augs['pedalboard_reverb_wet_level_min'],
+ augs['pedalboard_reverb_wet_level_max'],
+ )
+ dry_level = random.uniform(
+ augs['pedalboard_reverb_dry_level_min'],
+ augs['pedalboard_reverb_dry_level_max'],
+ )
+ width = random.uniform(
+ augs['pedalboard_reverb_width_min'],
+ augs['pedalboard_reverb_width_max'],
+ )
+ board = PB.Pedalboard([PB.Reverb(
+ room_size=room_size, # 0.1 - 0.9
+ damping=damping, # 0.1 - 0.9
+ wet_level=wet_level, # 0.1 - 0.9
+ dry_level=dry_level, # 0.1 - 0.9
+ width=width, # 0.9 - 1.0
+ freeze_mode=0.0,
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_reverb')
+
+ # Random Chorus
+ if 'pedalboard_chorus' in augs:
+ if augs['pedalboard_chorus'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_chorus']:
+ rate_hz = random.uniform(
+ augs['pedalboard_chorus_rate_hz_min'],
+ augs['pedalboard_chorus_rate_hz_max'],
+ )
+ depth = random.uniform(
+ augs['pedalboard_chorus_depth_min'],
+ augs['pedalboard_chorus_depth_max'],
+ )
+ centre_delay_ms = random.uniform(
+ augs['pedalboard_chorus_centre_delay_ms_min'],
+ augs['pedalboard_chorus_centre_delay_ms_max'],
+ )
+ feedback = random.uniform(
+ augs['pedalboard_chorus_feedback_min'],
+ augs['pedalboard_chorus_feedback_max'],
+ )
+ mix = random.uniform(
+ augs['pedalboard_chorus_mix_min'],
+ augs['pedalboard_chorus_mix_max'],
+ )
+ board = PB.Pedalboard([PB.Chorus(
+ rate_hz=rate_hz,
+ depth=depth,
+ centre_delay_ms=centre_delay_ms,
+ feedback=feedback,
+ mix=mix,
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_chorus')
+
+ # Random Phazer
+ if 'pedalboard_phazer' in augs:
+ if augs['pedalboard_phazer'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_phazer']:
+ rate_hz = random.uniform(
+ augs['pedalboard_phazer_rate_hz_min'],
+ augs['pedalboard_phazer_rate_hz_max'],
+ )
+ depth = random.uniform(
+ augs['pedalboard_phazer_depth_min'],
+ augs['pedalboard_phazer_depth_max'],
+ )
+ centre_frequency_hz = random.uniform(
+ augs['pedalboard_phazer_centre_frequency_hz_min'],
+ augs['pedalboard_phazer_centre_frequency_hz_max'],
+ )
+ feedback = random.uniform(
+ augs['pedalboard_phazer_feedback_min'],
+ augs['pedalboard_phazer_feedback_max'],
+ )
+ mix = random.uniform(
+ augs['pedalboard_phazer_mix_min'],
+ augs['pedalboard_phazer_mix_max'],
+ )
+ board = PB.Pedalboard([PB.Phaser(
+ rate_hz=rate_hz,
+ depth=depth,
+ centre_frequency_hz=centre_frequency_hz,
+ feedback=feedback,
+ mix=mix,
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_phazer')
+
+ # Random Distortion
+ if 'pedalboard_distortion' in augs:
+ if augs['pedalboard_distortion'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_distortion']:
+ drive_db = random.uniform(
+ augs['pedalboard_distortion_drive_db_min'],
+ augs['pedalboard_distortion_drive_db_max'],
+ )
+ board = PB.Pedalboard([PB.Distortion(
+ drive_db=drive_db,
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_distortion')
+
+ # Random PitchShift
+ if 'pedalboard_pitch_shift' in augs:
+ if augs['pedalboard_pitch_shift'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_pitch_shift']:
+ semitones = random.uniform(
+ augs['pedalboard_pitch_shift_semitones_min'],
+ augs['pedalboard_pitch_shift_semitones_max'],
+ )
+ board = PB.Pedalboard([PB.PitchShift(
+ semitones=semitones
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_pitch_shift')
+
+ # Random Resample
+ if 'pedalboard_resample' in augs:
+ if augs['pedalboard_resample'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_resample']:
+ target_sample_rate = random.uniform(
+ augs['pedalboard_resample_target_sample_rate_min'],
+ augs['pedalboard_resample_target_sample_rate_max'],
+ )
+ board = PB.Pedalboard([PB.Resample(
+ target_sample_rate=target_sample_rate
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_resample')
+
+ # Random Bitcrash
+ if 'pedalboard_bitcrash' in augs:
+ if augs['pedalboard_bitcrash'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_bitcrash']:
+ bit_depth = random.uniform(
+ augs['pedalboard_bitcrash_bit_depth_min'],
+ augs['pedalboard_bitcrash_bit_depth_max'],
+ )
+ board = PB.Pedalboard([PB.Bitcrush(
+ bit_depth=bit_depth
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_bitcrash')
+
+ # Random MP3Compressor
+ if 'pedalboard_mp3_compressor' in augs:
+ if augs['pedalboard_mp3_compressor'] > 0:
+ if random.uniform(0, 1) < augs['pedalboard_mp3_compressor']:
+ vbr_quality = random.uniform(
+ augs['pedalboard_mp3_compressor_pedalboard_mp3_compressor_min'],
+ augs['pedalboard_mp3_compressor_pedalboard_mp3_compressor_max'],
+ )
+ board = PB.Pedalboard([PB.MP3Compressor(
+ vbr_quality=vbr_quality
+ )])
+ source = board(source, 44100)
+ applied_augs.append('pedalboard_mp3_compressor')
+
+ # print(applied_augs)
+ return source
+
+ def __getitem__(self, index):
+ if self.dataset_type in [1, 2, 3]:
+ res = self.load_random_mix()
+ else:
+ res = self.load_aligned_data()
+
+ # Randomly change loudness of each stem
+ if self.aug:
+ if 'loudness' in self.config['augmentations']:
+ if self.config['augmentations']['loudness']:
+ loud_values = np.random.uniform(
+ low=self.config['augmentations']['loudness_min'],
+ high=self.config['augmentations']['loudness_max'],
+ size=(len(res),)
+ )
+ loud_values = torch.tensor(loud_values, dtype=torch.float32)
+ res *= loud_values[:, None, None]
+
+ mix = res.sum(0)
+
+ if self.aug:
+ if 'mp3_compression_on_mixture' in self.config['augmentations']:
+ apply_aug = AU.Mp3Compression(
+ min_bitrate=self.config['augmentations']['mp3_compression_on_mixture_bitrate_min'],
+ max_bitrate=self.config['augmentations']['mp3_compression_on_mixture_bitrate_max'],
+ backend=self.config['augmentations']['mp3_compression_on_mixture_backend'],
+ p=self.config['augmentations']['mp3_compression_on_mixture']
+ )
+ mix_conv = mix.cpu().numpy().astype(np.float32)
+ required_shape = mix_conv.shape
+ mix = apply_aug(samples=mix_conv, sample_rate=44100)
+ # Sometimes it gives longer audio (so we cut)
+ if mix.shape != required_shape:
+ mix = mix[..., :required_shape[-1]]
+ mix = torch.tensor(mix, dtype=torch.float32)
+
+ # If we need to optimize only given stem
+ if self.config.training.target_instrument is not None:
+ index = self.config.training.instruments.index(self.config.training.target_instrument)
+ return res[index:index+1], mix
+
+ return res, mix
diff --git a/docs/LoRA.md b/docs/LoRA.md
new file mode 100644
index 0000000000000000000000000000000000000000..52c2bc86606323d98c7530acb1a2b2b58ff6cfe4
--- /dev/null
+++ b/docs/LoRA.md
@@ -0,0 +1,114 @@
+## Training with LoRA
+
+### What is LoRA?
+
+LoRA (Low-Rank Adaptation) is a technique designed to reduce the computational and memory cost of fine-tuning large-scale neural networks. Instead of fine-tuning all the model parameters, LoRA introduces small trainable low-rank matrices that are injected into the network. This allows significant reductions in the number of trainable parameters, making it more efficient to adapt pre-trained models to new tasks. For more details, you can refer to the original paper.
+
+### Enabling LoRA in Training
+
+To include LoRA in your training pipeline, you need to:
+
+Add the `--train_lora` flag to the training command.
+
+Add the following configuration for LoRA in your config file:
+
+Example:
+```
+lora:
+ r: 8
+ lora_alpha: 16 # alpha / rank > 1
+ lora_dropout: 0.05
+ merge_weights: False
+ fan_in_fan_out: False
+ enable_lora: [True]
+```
+
+Configuration Parameters Explained:
+
+* `r` (Rank): This determines the rank of the low-rank adaptation matrices. A smaller rank reduces memory usage and file size but may limit the model's adaptability to new tasks. Common values are 4, 8, or 16.
+
+* `lora_alpha`: Scaling factor for the LoRA weights. The ratio lora_alpha / r should generally be greater than 1 to ensure sufficient expressive power. For example, with r=8 and lora_alpha=16, the scaling factor is 2.
+
+* `lora_dropout`: Dropout rate applied to LoRA layers. It helps regularize the model and prevent overfitting, especially for smaller datasets. Typical values are in the range [0.0, 0.1].
+
+* `merge_weights`: Whether to merge the LoRA weights into the original model weights during inference. Set this to True only if you want to save the final model with merged weights for deployment.
+
+* `fan_in_fan_out`: Defines the weight initialization convention. Leave this as False for most scenarios unless your model uses a specific convention requiring it.
+
+* `enable_lora`: A list of booleans specifying whether LoRA should be applied to certain layers.
+ * For example, `[True, False, True]` enables LoRA for the 1st and 3rd layers but not the 2nd.
+ * The number of output neurons in the layer must be divisible by the length of enable_lora to ensure proper distribution of LoRA parameters across layers.
+ * For transformer architectures such as GPT models, `enable_lora = [True, False, True]` is typically used to apply LoRA to the Query (Q) and Value (V) projection matrices while skipping the Key (K) projection matrix. This setup allows efficient fine-tuning of the attention mechanism while maintaining computational efficiency.
+
+### Benefits of Using LoRA
+
+* File Size Reduction: With LoRA, only the LoRA layer weights are saved, which significantly reduces the size of the saved model.
+
+* Flexible Fine-Tuning: You can fine-tune the LoRA layers while keeping the base model frozen, preserving the original model's generalization capabilities.
+
+* Using Pretrained Weights with LoRA
+
+### To train a model using both pretrained weights and LoRA weights, you need to:
+
+1. Include the `--lora_checkpoint` parameter in the training command.
+
+2. Specify the path to the LoRA checkpoint file.
+
+### Validating and Inferencing with LoRA
+
+When using a model fine-tuned with LoRA for validation or inference, you must provide the LoRA checkpoint using the `--lora_checkpoint` parameter.
+
+### Example Commands
+
+* Training with LoRA
+
+```
+python train.py --model_type scnet \
+ --config_path configs/config_musdb18_scnet_large_starrytong.yaml \
+ --start_check_point weights/last_scnet.ckpt \
+ --results_path results/ \
+ --data_path datasets/moisesdb/train_tracks \
+ --valid_path datasets/moisesdb/valid \
+ --device_ids 0 \
+ --metrics neg_log_wmse l1_freq sdr \
+ --metric_for_scheduler neg_log_wmse \
+ --train_lora
+```
+
+* Validating with LoRA
+```
+python valid.py --model_type scnet \
+ --config_path configs/config_musdb18_scnet_large_starrytong.yaml \
+ --start_check_point weights/last_scnet.ckpt \
+ --store_dir results_store/ \
+ --valid_path datasets/moisesdb/valid \
+ --device_ids 0 \
+ --metrics neg_log_wmse l1_freq si_sdr sdr aura_stft aura_mrstft bleedless fullness
+```
+
+* Inference with LoRA
+```
+python inference.py --lora_checkpoint weights/lora_last_scnet.ckpt \
+ --model_type scnet \
+ --config_path configs/config_musdb18_scnet_large_starrytong.yaml \
+ --start_check_point weights/last_scnet.ckpt \
+ --store_dir inference_results/ \
+ --input_folder datasets/moisesdb/mixtures_for_inference \
+ --device_ids 0
+```
+
+### Train example with BSRoformer and LoRA
+
+You can use this [config](../configs/config_musdb18_bs_roformer_with_lora.yaml) and this [weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt) to finetune BSRoformer on your dataset.
+
+```
+python train.py --model_type bs_roformer \
+ --config_path configs/config_musdb18_bs_roformer_with_lora.yaml \
+ --start_check_point weights/model_bs_roformer_ep_17_sdr_9.6568.ckpt \
+ --results_path results/ \
+ --data_path musdb18hq/train \
+ --valid_path musdb18hq/test \
+ --device_ids 0 \
+ --metrics sdr \
+ --train_lora
+```
diff --git a/docs/augmentations.md b/docs/augmentations.md
new file mode 100644
index 0000000000000000000000000000000000000000..41d03585111340cc2c05dfc603753e5af1348a7e
--- /dev/null
+++ b/docs/augmentations.md
@@ -0,0 +1,146 @@
+### Augmentations
+
+Augmentations allows to change stems on the fly increasing the size of dataset by creating new samples from old samples.
+Now control for augmentations is done from config file. Below you can find the example of full config,
+which includes all available augmentations:
+
+```config
+augmentations:
+ enable: true # enable or disable all augmentations (to fast disable if needed)
+ loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max)
+ loudness_min: 0.5
+ loudness_max: 1.5
+ mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3)
+ mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02)
+ - 0.2
+ - 0.02
+ mixup_loudness_min: 0.5
+ mixup_loudness_max: 1.5
+
+ # apply mp3 compression to mixture only (emulate downloading mp3 from internet)
+ mp3_compression_on_mixture: 0.01
+ mp3_compression_on_mixture_bitrate_min: 32
+ mp3_compression_on_mixture_bitrate_max: 320
+ mp3_compression_on_mixture_backend: "lameenc"
+
+ all:
+ channel_shuffle: 0.5 # Set 0 or lower to disable
+ random_inverse: 0.1 # inverse track (better lower probability)
+ random_polarity: 0.5 # polarity change (multiply waveform to -1)
+ mp3_compression: 0.01
+ mp3_compression_min_bitrate: 32
+ mp3_compression_max_bitrate: 320
+ mp3_compression_backend: "lameenc"
+
+ # pedalboard reverb block
+ pedalboard_reverb: 0.01
+ pedalboard_reverb_room_size_min: 0.1
+ pedalboard_reverb_room_size_max: 0.9
+ pedalboard_reverb_damping_min: 0.1
+ pedalboard_reverb_damping_max: 0.9
+ pedalboard_reverb_wet_level_min: 0.1
+ pedalboard_reverb_wet_level_max: 0.9
+ pedalboard_reverb_dry_level_min: 0.1
+ pedalboard_reverb_dry_level_max: 0.9
+ pedalboard_reverb_width_min: 0.9
+ pedalboard_reverb_width_max: 1.0
+
+ # pedalboard chorus block
+ pedalboard_chorus: 0.01
+ pedalboard_chorus_rate_hz_min: 1.0
+ pedalboard_chorus_rate_hz_max: 7.0
+ pedalboard_chorus_depth_min: 0.25
+ pedalboard_chorus_depth_max: 0.95
+ pedalboard_chorus_centre_delay_ms_min: 3
+ pedalboard_chorus_centre_delay_ms_max: 10
+ pedalboard_chorus_feedback_min: 0.0
+ pedalboard_chorus_feedback_max: 0.5
+ pedalboard_chorus_mix_min: 0.1
+ pedalboard_chorus_mix_max: 0.9
+
+ # pedalboard phazer block
+ pedalboard_phazer: 0.01
+ pedalboard_phazer_rate_hz_min: 1.0
+ pedalboard_phazer_rate_hz_max: 10.0
+ pedalboard_phazer_depth_min: 0.25
+ pedalboard_phazer_depth_max: 0.95
+ pedalboard_phazer_centre_frequency_hz_min: 200
+ pedalboard_phazer_centre_frequency_hz_max: 12000
+ pedalboard_phazer_feedback_min: 0.0
+ pedalboard_phazer_feedback_max: 0.5
+ pedalboard_phazer_mix_min: 0.1
+ pedalboard_phazer_mix_max: 0.9
+
+ # pedalboard distortion block
+ pedalboard_distortion: 0.01
+ pedalboard_distortion_drive_db_min: 1.0
+ pedalboard_distortion_drive_db_max: 25.0
+
+ # pedalboard pitch shift block
+ pedalboard_pitch_shift: 0.01
+ pedalboard_pitch_shift_semitones_min: -7
+ pedalboard_pitch_shift_semitones_max: 7
+
+ # pedalboard resample block
+ pedalboard_resample: 0.01
+ pedalboard_resample_target_sample_rate_min: 4000
+ pedalboard_resample_target_sample_rate_max: 44100
+
+ # pedalboard bitcrash block
+ pedalboard_bitcrash: 0.01
+ pedalboard_bitcrash_bit_depth_min: 4
+ pedalboard_bitcrash_bit_depth_max: 16
+
+ # pedalboard mp3 compressor block
+ pedalboard_mp3_compressor: 0.01
+ pedalboard_mp3_compressor_pedalboard_mp3_compressor_min: 0
+ pedalboard_mp3_compressor_pedalboard_mp3_compressor_max: 9.999
+
+ vocals:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.1
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.7
+ bass:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -2
+ pitch_shift_max_semitones: 2
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -3
+ seven_band_parametric_eq_max_gain_db: 6
+ tanh_distortion: 0.2
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.5
+ drums:
+ pitch_shift: 0.33
+ pitch_shift_min_semitones: -5
+ pitch_shift_max_semitones: 5
+ seven_band_parametric_eq: 0.25
+ seven_band_parametric_eq_min_gain_db: -9
+ seven_band_parametric_eq_max_gain_db: 9
+ tanh_distortion: 0.33
+ tanh_distortion_min: 0.1
+ tanh_distortion_max: 0.6
+ other:
+ pitch_shift: 0.1
+ pitch_shift_min_semitones: -4
+ pitch_shift_max_semitones: 4
+ gaussian_noise: 0.1
+ gaussian_noise_min_amplitude: 0.001
+ gaussian_noise_max_amplitude: 0.015
+ time_stretch: 0.01
+ time_stretch_min_rate: 0.8
+ time_stretch_max_rate: 1.25
+```
+
+You can copypaste it into your config to use augmentations.
+Notes:
+* To completely disable all augmentations you can either remove `augmentations` section from config or set `enable` to `false`.
+* If you want to disable some augmentation, just set it to zero.
+* Augmentations in `all` subsections applied to all stems
+* Augmentations in `vocals`, `bass` etc subsections applied only to corresponding stems. You can create such subsections for all stems which are given in `training.instruments`.
\ No newline at end of file
diff --git a/docs/bs_roformer_info.md b/docs/bs_roformer_info.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad7bfc9f8f57e54de1be42cdcdb14775811ebe36
--- /dev/null
+++ b/docs/bs_roformer_info.md
@@ -0,0 +1,145 @@
+### Batch sizes for BSRoformer
+
+You can use table below to choose BS Roformer `batch_size` parameter for training based on your GPUs. Batch size values provided for single GPU. If you have several GPUs you need to multiply value on number of GPUs.
+
+| chunk_size | dim | depth | batch_size (A6000 48GB) | batch_size (3090/4090 24GB) | batch_size (16GB) |
+|:----------:|:---:|:-----:|:-----------------------:|:---------------------------:|:-----------------:|
+| 131584 | 128 | 6 | 10 | 5 | 3 |
+| 131584 | 256 | 6 | 8 | 4 | 2 |
+| 131584 | 384 | 6 | 7 | 3 | 2 |
+| 131584 | 512 | 6 | 6 | 3 | 2 |
+| 131584 | 256 | 8 | 6 | 3 | 2 |
+| 131584 | 256 | 12 | 4 | 2 | 1 |
+| 263168 | 128 | 6 | 4 | 2 | 1 |
+| 263168 | 256 | 6 | 3 | 1 | 1 |
+| 352800 | 128 | 6 | 2 | 1 | - |
+| 352800 | 256 | 6 | 2 | 1 | - |
+| 352800 | 384 | 12 | 1 | - | - |
+| 352800 | 512 | 12 | - | - | - |
+
+
+Parameters obtained with initial config:
+
+```
+audio:
+ chunk_size: 131584
+ dim_f: 1024
+ dim_t: 515
+ hop_length: 512
+ n_fft: 2048
+ num_channels: 2
+ sample_rate: 44100
+ min_mean_abs: 0.000
+
+model:
+ dim: 384
+ depth: 12
+ stereo: true
+ num_stems: 1
+ time_transformer_depth: 1
+ freq_transformer_depth: 1
+ linear_transformer_depth: 0
+ freqs_per_bands: !!python/tuple
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 2
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 4
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 12
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 24
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 48
+ - 128
+ - 129
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0.1
+ ff_dropout: 0.1
+ flash_attn: false
+ dim_freqs_in: 1025
+ stft_n_fft: 2048
+ stft_hop_length: 512
+ stft_win_length: 2048
+ stft_normalized: false
+ mask_estimator_depth: 2
+ multi_stft_resolution_loss_weight: 1.0
+ multi_stft_resolutions_window_sizes: !!python/tuple
+ - 4096
+ - 2048
+ - 1024
+ - 512
+ - 256
+ multi_stft_hop_size: 147
+ multi_stft_normalized: False
+
+training:
+ batch_size: 1
+ gradient_accumulation_steps: 1
+ grad_clip: 0
+ instruments:
+ - vocals
+ - other
+ lr: 3.0e-05
+ patience: 2
+ reduce_factor: 0.95
+ target_instrument: vocals
+ num_epochs: 1000
+ num_steps: 1000
+ q: 0.95
+ coarse_loss_clip: true
+ ema_momentum: 0.999
+ optimizer: adam
+ other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental
+ use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true
+```
diff --git a/docs/changes.md b/docs/changes.md
new file mode 100644
index 0000000000000000000000000000000000000000..9aeba78f94f0b569849ac6560c40f68c82f6206d
--- /dev/null
+++ b/docs/changes.md
@@ -0,0 +1,20 @@
+### Changes
+
+#### v1.0.2
+
+* Added multi GPU validation (earlier validation was performed on single GPU)
+* `training.batch_size` in config now must be set for single GPU (if you use multiple GPUs it will be automatically multiplied by number of GPUs)
+
+#### v1.0.3
+
+* Added "spawn" fix for multiprocessing
+* Function `get_model_from_config` now takes path of config as input.
+* On latest version of pytorch some problems with torch.backends.cudnn.benchmark = True - big slow down. Fixed version 2.0.1 in requirements.txt
+* `--valid_path` parameter for train.py now can accept several validation folders instead of one. Added warning if validation folder is empty.
+* Small fix for AMP usage in Demucs models taken from config
+* Support for Demucs3 mmi model was added
+* GPU memory consumption was reduced during inference and validation.
+* Some changes to repair click problems on the edges of segment.
+* Added support to train on FLAC files. Some more error checks added.
+* viperx's Roformer weights and configs added
+* `--extract_instrumental` argument added to inference.py
\ No newline at end of file
diff --git a/docs/dataset_types.md b/docs/dataset_types.md
new file mode 100644
index 0000000000000000000000000000000000000000..345faf7e105aa113deaa173972e5ba630c5c317a
--- /dev/null
+++ b/docs/dataset_types.md
@@ -0,0 +1,75 @@
+### Dataset types for training
+
+* **Type 1 (MUSDB)**: different folders. Each folder contains all needed stems in format _< stem name >.wav_. The same as in MUSDBHQ18 dataset. In latest code releases it's possible to use `flac` instead of `wav`.
+
+Example:
+```
+--- Song 1:
+------ vocals.wav
+------ bass.wav
+------ drums.wav
+------ other.wav
+--- Song 2:
+------ vocals.wav
+------ bass.wav
+------ drums.wav
+------ other.wav
+--- Song 3:
+...........
+```
+
+* **Type 2 (Stems)**: each folder is "stem name". Folder contains wav files which consists only of required stem.
+```
+--- vocals:
+------ vocals_1.wav
+------ vocals_2.wav
+------ vocals_3.wav
+------ vocals_4.wav
+------ ...
+--- bass:
+------ bass_1.wav
+------ bass_2.wav
+------ bass_3.wav
+------ bass_4.wav
+------ ...
+...........
+```
+
+* **Type 3 (CSV file)**:
+
+You can provide CSV-file (or list of CSV-files) with following structure:
+```
+instrum,path
+vocals,/path/to/dataset/vocals_1.wav
+vocals,/path/to/dataset2/vocals_v2.wav
+vocals,/path/to/dataset3/vocals_some.wav
+...
+drums,/path/to/dataset/drums_good.wav
+...
+```
+
+* **Type 4 (MUSDB Aligned)**:
+
+The same as Type 1, but during training all instruments will be from the same position of song.
+
+### Dataset for validation
+
+* The validation dataset must be the same structure as type 1 datasets (regardless of what type of dataset you're using for training), but also each folder must include `mixture.wav` for each song. `mixture.wav` - is the sum of all stems for song.
+
+Example:
+```
+--- Song 1:
+------ vocals.wav
+------ bass.wav
+------ drums.wav
+------ other.wav
+------ mixture.wav
+--- Song 2:
+------ vocals.wav
+------ bass.wav
+------ drums.wav
+------ other.wav
+------ mixture.wav
+--- Song 3:
+...........
+```
diff --git a/docs/ensemble.md b/docs/ensemble.md
new file mode 100644
index 0000000000000000000000000000000000000000..bac03d5edd222ce71138c766f066e232dfdda8e9
--- /dev/null
+++ b/docs/ensemble.md
@@ -0,0 +1,30 @@
+### Ensemble usage
+
+Repository contains `ensemble.py` script which can be used to ensemble results of different algorithms.
+
+Arguments:
+* `--files` - Path to all audio-files to ensemble
+* `--type` - Method to do ensemble. One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft. Default: avg_wave.
+* `--weights` - Weights to create ensemble. Number of weights must be equal to number of files
+* `--output` - Path to wav file where ensemble result will be stored (Default: res.wav)
+
+Example:
+```
+ensemble.py --files ./results_tracks/vocals1.wav ./results_tracks/vocals2.wav --weights 2 1 --type max_fft --output out.wav
+```
+
+### Ensemble types:
+
+* `avg_wave` - ensemble on 1D variant, find average for every sample of waveform independently
+* `median_wave` - ensemble on 1D variant, find median value for every sample of waveform independently
+* `min_wave` - ensemble on 1D variant, find minimum absolute value for every sample of waveform independently
+* `max_wave` - ensemble on 1D variant, find maximum absolute value for every sample of waveform independently
+* `avg_fft` - ensemble on spectrogram (Short-time Fourier transform (STFT), 2D variant), find average for every pixel of spectrogram independently. After averaging use inverse STFT to obtain original 1D-waveform back.
+* `median_fft` - the same as avg_fft but use median instead of mean (only useful for ensembling of 3 or more sources).
+* `min_fft` - the same as avg_fft but use minimum function instead of mean (reduce aggressiveness).
+* `max_fft` - the same as avg_fft but use maximum function instead of mean (the most aggressive).
+
+### Notes
+* `min_fft` can be used to do more conservative ensemble - it will reduce influence of more aggressive models.
+* It's better to ensemble models which are of equal quality - in this case it will give gain. If one of model is bad - it will reduce overall quality.
+* In my experiments `avg_wave` was always better or equal in SDR score comparing with other methods.
diff --git a/docs/gui.md b/docs/gui.md
new file mode 100644
index 0000000000000000000000000000000000000000..04252c3b19ae92974c47e04f54c756698d3bbddd
--- /dev/null
+++ b/docs/gui.md
@@ -0,0 +1,31 @@
+## GUI for MSST code
+
+GUI was prepared by **Bas Curtiz** and is based on [wxpython](https://en.wikipedia.org/wiki/WxPython) module.
+
+
+
+### How to
+
+How to use GUI with ZFTurbo's Music Source Separation Universal Training Code:
+
+1. Install Python: https://www.python.org/ftp/python/3.11.6/python-3.11.6-amd64.exe
+2. Install Microsoft Visual C++ 2015-2022 (x64): https://aka.ms/vs/17/release/vc_redist.x64.exe
+3. Install Microsoft C++ Build Tools: https://visualstudio.microsoft.com/visual-cpp-build-tools/
+Select Desktop development with C++
+4. Install PyTorch: https://pytorch.org/get-started/locally/
+5. Download and unzip Music Source Separation Universal Training Code:
+https://github.com/ZFTurbo/Music-Source-Separation-Training
+6. Open up CMD inside the folder and enter: `pip install -r requirements.txt`
+7. Enter: `python gui-wx.py`
+8. Download models - assign the config (.yaml) and checkpoint (.bin, .ckpt, or .th)
+
+Video guide on [Youtube](https://youtu.be/M8JKFeN7HfU) (~6.5 minutes).
+
+[](https://youtu.be/M8JKFeN7HfU)
+
+Also you can use GUI as EXE-file on Windows: [Link](https://mega.nz/file/xAAzTCzR#2IapG3RJ9Vew3oC8l9H2zrw1vwUtZSqsUdJAjmARmPs). Put it inside the root folder. This way you can make a shortcut to the exe on your desktop to run it with a double-click.
+
+### Other links
+
+* You can try [non-official GUI to MSST](https://github.com/SUC-DriverOld/MSST-WebUI).
+* The one more version [by SiftedSand](https://github.com/SiftedSand/MusicSepGUI)
\ No newline at end of file
diff --git a/docs/mel_roformer_experiments.md b/docs/mel_roformer_experiments.md
new file mode 100644
index 0000000000000000000000000000000000000000..3d68b92d4c8725bcc9ce724922e5802d0a60e36e
--- /dev/null
+++ b/docs/mel_roformer_experiments.md
@@ -0,0 +1,21 @@
+## Mel Roformer models
+
+All experiments were made using MUSDB18HQ dataset. All metrics were measured using 'test' set. Training was made using 'train' set.
+
+### Experiments table
+
+| Average SDR Score | Chunk size | Depth | Dim | mlp expansion factor | Skip connection | Hop size | FFT Size | Dropout | Batch Size | DL Checkpoint | Comment |
+|:-----------------:|:-------------:|:-----------------:|:---:|:--------------------:|:-----:|:-----:|:-----:|:-----:|:----------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------:|
+| 5.1235 | 88200 | 2 | 64 | 1 | No | 441 | 2048 | 0/0 | 32 (48 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_53_sdr_5.1235_config_mel_64_2_1_88200_experimental.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_53_sdr_5.1235.ckpt) | |
+| 6.4698 | 88200 | 4 | 128 | 1 | No | 441 | 2048 | 0.1/0.1 | 28 (80 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_38_sdr_6.4698.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_38_sdr_6.4698.ckpt) | |
+| 6.7022 | 88200 | 4 | 128 | 1 | No | 882 | 4096 | 0/0 | 20 (80 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_166_sdr_6.7022_config_mel_128_4_1_88200_big_fft_4096.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_166_sdr_6.7022.ckpt) | |
+| 7.8127 | 88200 | 6 | 256 | 1 | Yes | 441 | 2048 | 0.1/0.1 | 16 (80 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_168_sdr_7.8127_config_mel_256_6_1_88200.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_168_sdr_7.8127.ckpt) | |
+| 6.4908 | 176400 | 4 | 128 | 1 | Yes | 441 | 2048 | 0.1/0.1 | 8 (48 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_15_sdr_6.4908.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_15_sdr_6.4908.ckpt) | |
+| 6.5224 | 176400 | 4 | 128 | 2 | Yes | 441 | 2048 | 0.1/0.1 | 8 (48 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_9_sdr_6.5254.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_9_sdr_6.5254.ckpt) | |
+| 7.0412 | 352800 | 4 | 128 | 1 | No | 882 | 4096 | 0/0 | 5 (80 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_48_sdr_7.0412_config_mel_128_4_1_352800_big_fft_4096.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_experimental_ep_48_sdr_7.0412.ckpt) | |
+| 8.2175 | 352800 | 4 | 256 | 1 | No | 441 | 2048 | 0/0 | 5 (80 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_1_sdr_8.2175.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_1_sdr_8.2175.ckpt) | Trained longer on different strategies. Looks like it a bit overfit in the end |
+| 1.0557 | 352800 | 4 | 128 | 1 | No | 882 | 2048 | 0/0 | 6 (48 GB) | --- | Looks like big hop size is not great |
+| 6.8652 | 485100 | 4 | 128 | 1 | No | 441 | 2048 | 0.1/0.1 | 5 (48 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_7_sdr_6.8652.yaml) / [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_7_sdr_6.8652.ckpt) | |
+| 8.9400* | 485100 | 8 | 384 | 4 | Yes | 882 | 4096 | 0/0 | 2 (80 GB) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_5_sdr_8.9443_config_mel_384_8_4_485100_big_fft_4096_skip_connect.yaml) / Weights ([part 1](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_5_sdr_8.9443.zip.001), [part2](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.11/model_mel_band_roformer_ep_5_sdr_8.9443.zip.002)) | Very big file with weights > 3GB. Continue to increase metrics |
+
+* Note 1: Some models probably undertrained
\ No newline at end of file
diff --git a/docs/pretrained_models.md b/docs/pretrained_models.md
new file mode 100644
index 0000000000000000000000000000000000000000..a1595ab1b46bf5b48b33a100605b600dfb2d0d68
--- /dev/null
+++ b/docs/pretrained_models.md
@@ -0,0 +1,67 @@
+## Pre-trained models
+
+If you trained some good models, please, share them. You can post config and model weights [in this issue](https://github.com/ZFTurbo/Music-Source-Separation-Training/issues/1).
+
+### Vocal models
+
+| Model Type | Instruments | Metrics (SDR) | Config | Checkpoint |
+|:--------------------------------------------------------------------------------:|:-------------:|:-----------------:|:-----:|:-----:|
+| MDX23C | vocals / other | SDR vocals: 10.17 | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/config_vocals_mdx23c.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_mdx23c_sdr_10.17.ckpt) |
+| HTDemucs4 (MVSep finetuned) | vocals / other | SDR vocals: 8.78 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_vocals_htdemucs.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_htdemucs_sdr_8.78.ckpt) |
+| Segm Models (VitLarge23) | vocals / other | SDR vocals: 9.77 | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/config_vocals_segm_models.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/model_vocals_segm_models_sdr_9.77.ckpt) |
+| Swin Upernet | vocals / other | SDR vocals: 7.57 | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.2/config_vocals_swin_upernet.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.2/model_swin_upernet_ep_56_sdr_10.6703.ckpt) |
+| BS Roformer ([viperx](https://github.com/playdasegunda) edition) | vocals / other | SDR vocals: 10.87 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml) | [Weights](https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_317_sdr_12.9755.ckpt) |
+| MelBand Roformer ([viperx](https://github.com/playdasegunda) edition) | vocals / other | SDR vocals: 9.67 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml) | [Weights](https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt) |
+| MelBand Roformer ([KimberleyJensen](https://github.com/KimberleyJensen/) edition) | vocals / other | SDR vocals: 10.98 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/KimberleyJensen/config_vocals_mel_band_roformer_kj.yaml) | [Weights](https://huggingface.co/KimberleyJSN/melbandroformer/resolve/main/MelBandRoformer.ckpt) |
+
+**Note**: Metrics measured on [Multisong Dataset](https://mvsep.com/en/quality_checker).
+
+### Single stem models
+
+| Model Type | Instruments | Metrics (SDR) | Config | Checkpoint |
+|:-------------------------------------------------------------------------------------------------------------:|:-----------:|:----------------:|:-----:|:-------------------------------------------------------------------------------------------------------------------------------------------------------:|
+| HTDemucs4 FT Drums | drums | SDR drums: 11.13 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th) |
+| HTDemucs4 FT Bass | bass | SDR bass: 11.96 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th) |
+| HTDemucs4 FT Other | other | SDR other: 5.85 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th) |
+| HTDemucs4 FT Vocals (Official repository) | vocals | SDR vocals: 8.38 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th) |
+| BS Roformer ([viperx](https://github.com/playdasegunda) edition) | other | SDR other: 6.85 | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml) | [Weights](https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_937_sdr_10.5309.ckpt) |
+| MelBand Roformer ([aufr33](https://github.com/aufr33) and [viperx](https://github.com/playdasegunda) edition) | crowd | SDR crowd: 5.99 | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/model_mel_band_roformer_crowd.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.4/mel_band_roformer_crowd_aufr33_viperx_sdr_8.7144.ckpt) |
+| MelBand Roformer ([anvuew](https://github.com/anvuew) edition) | dereverb | --- | [Config](https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew.yaml) | [Weights](https://huggingface.co/anvuew/dereverb_mel_band_roformer/resolve/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt) |
+| MelBand Roformer Denoise (by [aufr33](https://github.com/aufr33)) | denoise | --- | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.7/model_mel_band_roformer_denoise.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.7/denoise_mel_band_roformer_aufr33_sdr_27.9959.ckpt) |
+| MelBand Roformer Denoise Aggressive (by [aufr33](https://github.com/aufr33)) | denoise | --- | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.7/model_mel_band_roformer_denoise.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.7/denoise_mel_band_roformer_aufr33_aggr_sdr_27.9768.ckpt) |
+| Apollo LQ MP3 restoration (by [JusperLee](https://github.com/JusperLee)) | restored | --- | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/configs/config_apollo.yaml) | [Weights](https://huggingface.co/JusperLee/Apollo/resolve/main/pytorch_model.bin) |
+| MelBand Roformer Aspiration (by [SUC-DriverOld](https://github.com/SUC-DriverOld)) | aspiration | SDR: 9.85 | [Config](https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/blob/main/config_aspiration_mel_band_roformer.yaml) | [Weights](https://huggingface.co/Sucial/Aspiration_Mel_Band_Roformer/blob/main/aspiration_mel_band_roformer_sdr_18.9845.ckpt) |
+| MDX23C Phantom Centre extraction (by [wesleyr36](https://github.com/wesleyr36)) | similarity | L1Freq: 72.23 | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/config_mdx23c_similarity.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.10/model_mdx23c_ep_271_l1_freq_72.2383.ckpt) |
+| MelBand Roformer Vocals DeReverb/DeEcho (by [SUC-DriverOld](https://github.com/SUC-DriverOld)) | dry | SDR: 10.01 | [Config](https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/config_dereverb-echo_mel_band_roformer.yaml) | [Weights](https://huggingface.co/Sucial/Dereverb-Echo_Mel_Band_Roformer/resolve/main/dereverb-echo_mel_band_roformer_sdr_10.0169.ckpt) |
+
+**Note**: All HTDemucs4 FT models output 4 stems, but quality is best only on target stem (all other stems are dummy).
+
+### Multi-stem models
+
+| Model Type | Instruments | Metrics (SDR) | Config | Checkpoint |
+|:---------------------------------------------------------------------------------------------------:|:----------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:-----:|:-----:|
+| BandIt Plus | speech / music / effects | DnR test avg: 11.50 (speech: 15.64, music: 9.18 effects: 9.69) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/config_dnr_bandit_bsrnn_multi_mus64.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.3/model_bandit_plus_dnr_sdr_11.47.chpt) |
+| HTDemucs4 | bass / drums / vocals / other | Multisong avg: 9.16 (bass: 11.76, drums: 10.88 vocals: 8.24 other: 5.74) | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_htdemucs.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/955717e8-8726e21a.th) |
+| HTDemucs4 (6 stems) | bass / drums / vocals / other / piano / guitar | Multisong (bass: 11.22, drums: 10.22 vocals: 8.05 other: --- piano: --- guitar: ---) | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_htdemucs_6stems.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/5c90dfd2-34c22ccb.th) |
+| Demucs3 mmi | bass / drums / vocals / other | Multisong avg: 8.88 (bass: 11.17, drums: 10.70 vocals: 8.22 other: 5.42) | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_musdb18_demucs3_mmi.yaml) | [Weights](https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/75fc33f5-1941ce65.th) |
+| DrumSep htdemucs (by [inagoy](https://github.com/inagoy)) | kick / snare / cymbals / toms | --- | [Config](https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/config_drumsep.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.5/model_drumsep.th) |
+| DrumSep mdx23c (by [aufr33](https://github.com/aufr33) and [jarredou](https://github.com/jarredou)) | kick / snare / toms / hh / ride / crash | --- | [Config](https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.yaml) | [Weights](https://github.com/jarredou/models/releases/download/aufr33-jarredou_MDX23C_DrumSep_model_v0.1/aufr33-jarredou_DrumSep_model_mdx23c_ep_141_sdr_10.8059.ckpt) |
+
+### Multi-stem models (MUSDB18HQ)
+
+* Models in this list were trained only on MUSDB18HQ dataset (100 songs train data). These weights are useful for fine-tuning.
+* Instruments: bass / drums / vocals / other
+
+| Model Type | Metrics (SDR) | Config | Checkpoint |
+|:------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----:|:-----:|
+| MDX23C | MUSDB test avg: 7.15 (bass: 5.77, drums: 7.93 vocals: 9.23 other: 5.68)
Multisong avg: 7.02 (bass: 8.40, drums: 7.73 vocals: 7.36 other: 4.57) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/config_musdb18_mdx23c.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.1/model_mdx23c_ep_168_sdr_7.0207.ckpt) |
+| TS BS Mamba2 | MUSDB test avg: 6.87 (bass: 5.82, drums: 8.14 vocals: 8.35 other: 5.16)
Multisong avg: 6.66 (bass: 7.87, drums: 7.92 vocals: 7.01 other: 3.85) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_bs_mamba2.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/model_bs_mamba2_ep_11_sdr_6.8723.ckpt) |
+| SCNet (by [starrytong](https://github.com/starrytong)) | Multisong avg: 8.87 (bass: 11.07, drums: 10.79 vocals: 8.27 other: 5.34) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/config_musdb18_scnet.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v.1.0.6/scnet_checkpoint_musdb18.ckpt) |
+| SCNet Large | MUSDB test avg: 9.32 (bass: 8.63, drums: 10.89 vocals: 10.69 other: 7.06)
Multisong avg: 9.19 (bass: 11.15, drums: 11.04 vocals: 8.94 other: 5.62) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.8/config_musdb18_scnet_large.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.8/model_scnet_sdr_9.3244.ckpt) |
+| SCNet Large (by [starrytong](https://github.com/starrytong)) | MUSDB test avg: 9.70 (bass: 9.38, drums: 11.15 vocals: 10.94 other: 7.31)
Multisong avg: 9.28 (bass: 11.27, drums: 11.23 vocals: 9.05 other: 5.57) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/config_musdb18_scnet_large_starrytong.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.9/SCNet-large_starrytong_fixed.ckpt) |
+| SCNet XL | MUSDB test avg: 9.80 (bass: 9.23, drums: 11.51 vocals: 11.05 other: 7.41)
Multisong avg: 9.72 (bass: 11.87, drums: 11.49 vocals: 9.32 other: 6.19) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/config_musdb18_scnet_xl.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.13/model_scnet_ep_54_sdr_9.8051.ckpt) |
+| BS Roformer | MUSDB test avg: 9.65 (bass: 8.48, drums: 11.61 vocals: 11.08 other: 7.44)
Multisong avg: 9.38 (bass: 11.08, drums: 11.29 vocals: 9.19 other: 5.96) | [Config](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/config_bs_roformer_384_8_2_485100.yaml) | [Weights](https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.12/model_bs_roformer_ep_17_sdr_9.6568.ckpt) |
+
+### MelRoformer models
+
+[Table of Mel Band Roformers with different paramers](mel_roformer_experiments.md)
\ No newline at end of file
diff --git a/docs/test.md b/docs/test.md
new file mode 100644
index 0000000000000000000000000000000000000000..0fb5ea259e079087e923c14d9bd8cd31aa960baa
--- /dev/null
+++ b/docs/test.md
@@ -0,0 +1,150 @@
+`tests` Documentation
+========================
+
+Overview
+--------
+
+The `tests.py` script is designed to verify the functionality of a specific configuration, model weights, and dataset before proceeding with training, validation, or inference. Additionally, it allows the specification of other parameters, which can be passed either through the command line or via the `base_args` variable in the script.
+
+Usage
+-----
+
+To use `tests.py`, provide the desired arguments via the command line using the `--` prefix. It is mandatory to specify the following arguments:
+
+* `--model_type`
+
+* `--config_path`
+
+* `--start_check_point`
+
+* `--data_path`
+
+* `--valid_path`
+
+
+For example:
+
+```
+python tests.py --check_train \
+--config_path config.yaml \
+--model_type scnet \
+--data_path /path/to/data \
+--valid_path /path/to/valid
+```
+
+Alternatively, you can define default arguments in the `base_args` variable directly in the script.
+
+Arguments
+---------
+
+The script accepts the following arguments:
+
+* `--check_train`: Check training functionality.
+
+* `--check_valid`: Check validation functionality.
+
+* `--check_inference`: Check inference functionality.
+
+* `--device_ids`: Specify device IDs for training or inference.
+
+* `--model_type`: Specify the type of model to use.
+
+* `--start_check_point`: Path to the checkpoint to start from.
+
+* `--config_path`: Path to the configuration file.
+
+* `--data_path`: Path to the training data.
+
+* `--valid_path`: Path to the validation data.
+
+* `--results_path`: Path to save training results.
+
+* `--store_dir`: Path to store validation or inference results.
+
+* `--input_folder`: Path to the input folder for inference.
+
+* `--metrics`: List of metrics to evaluate, provided as space-separated values.
+
+* `--max_folders`: Maximum number of folders to process.
+
+* `--dataset_type`: Dataset type. Must be one of: 1, 2, 3, or 4. Default is 1.
+
+* `--num_workers`: Number of workers for the dataloader. Default is 0.
+
+* `--pin_memory`: Use pinned memory in the dataloader.
+
+* `--seed`: Random seed for reproducibility. Default is 0.
+
+* `--use_multistft_loss`: Use MultiSTFT Loss from the auraloss package.
+
+* `--use_mse_loss`: Use Mean Squared Error (MSE) loss.
+
+* `--use_l1_loss`: Use L1 loss.
+
+* `--wandb_key`: API Key for Weights and Biases (wandb). Default is an empty string.
+
+* `--pre_valid`: Run validation before training.
+
+* `--metric_for_scheduler`: Metric to be used for the learning rate scheduler. Choices are `sdr`, `l1_freq`, `si_sdr`, `neg_log_wmse`, `aura_stft`, `aura_mrstft`, `bleedless`, or `fullness`. Default is `sdr`.
+
+* `--train_lora`: Enable training with LoRA.
+
+* `--lora_checkpoint`: Path to the initial LoRA weights checkpoint. Default is an empty string.
+
+* `--extension`: File extension for validation. Default is `wav`.
+
+* `--use_tta`: Enable test-time augmentation during inference. This triples runtime but improves prediction quality.
+
+* `--extract_instrumental`: Invert vocals to obtain instrumental output if available.
+
+* `--disable_detailed_pbar`: Disable the detailed progress bar.
+
+* `--force_cpu`: Force the use of the CPU, even if CUDA is available.
+
+* `--flac_file`: Output FLAC files instead of WAV.
+
+* `--pcm_type`: PCM type for FLAC files. Choices are `PCM_16` or `PCM_24`. Default is `PCM_24`.
+
+* `--draw_spectro`: Generate spectrograms for the resulting stems. Specify the value in seconds of the track. Requires `--store_dir` to be set. Default is 0.
+
+
+Example
+-------
+
+To check train, validate and inference with a configuration file with a specific dataset and checkpoint we can use:
+
+```
+python tests/test.py \
+--check_train \
+--check_valid \
+--check_inference \
+--model_type scnet \
+--config_path configs/config_musdb18_scnet_large_starrytong.yaml \
+--start_check_point weights/model_scnet_ep_30_neg_log_wmse_-11.8688.ckpt \
+--data_path datasets/moisesdb/train_tracks \
+--valid_path datasets/moisesdb/valid \
+--use_tta \
+--use_mse_loss
+```
+
+This command validates the setup by:
+
+* Specifying `scnet` as the model type.
+
+* Loading the configuration from `configs/config_musdb18_scnet_large_starrytong.yaml`.
+
+* Using the dataset located at `datasets/moisesdb/train_tracks` for training.
+
+* Using `datasets/moisesdb/valid` for validation.
+
+* Starting from the checkpoint at `weights/model_scnet_ep_30_neg_log_wmse_-11.8688.ckpt`.
+
+* Enabling test-time augmentation and using MSE loss.
+
+
+Additional Script: `admin_test.py`
+----------------------------------
+
+The `admin_test.py` script provides a way to verify the functionality of all configurations and models without specifying model weights or datasets. By default, it performs validation and inference. The configurations and corresponding parameters can be modified using the `MODEL_CONFIGS` variable in the script.
+
+This script is useful for bulk testing and ensuring that multiple configurations are correctly set up. It can help identify potential issues with configurations or models before proceeding to detailed testing with `tests.py` or full-scale training.
diff --git a/ensemble.py b/ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae8abf28254dde7ef1b1b06422294340f112e7fd
--- /dev/null
+++ b/ensemble.py
@@ -0,0 +1,164 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+
+import os
+import librosa
+import soundfile as sf
+import numpy as np
+import argparse
+
+
+def stft(wave, nfft, hl):
+ wave_left = np.asfortranarray(wave[0])
+ wave_right = np.asfortranarray(wave[1])
+ spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
+ spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
+ spec = np.asfortranarray([spec_left, spec_right])
+ return spec
+
+
+def istft(spec, hl, length):
+ spec_left = np.asfortranarray(spec[0])
+ spec_right = np.asfortranarray(spec[1])
+ wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
+ wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
+ wave = np.asfortranarray([wave_left, wave_right])
+ return wave
+
+
+def absmax(a, *, axis):
+ dims = list(a.shape)
+ dims.pop(axis)
+ indices = list(np.ogrid[tuple(slice(0, d) for d in dims)]) # Tuple yerine list
+ argmax = np.abs(a).argmax(axis=axis)
+ insert_pos = (len(a.shape) + axis) % len(a.shape)
+ indices.insert(insert_pos, argmax)
+ return a[tuple(indices)]
+
+
+def absmin(a, *, axis):
+ dims = list(a.shape)
+ dims.pop(axis)
+ indices = list(np.ogrid[tuple(slice(0, d) for d in dims)]) # Tuple yerine list
+ argmax = np.abs(a).argmin(axis=axis)
+ insert_pos = (len(a.shape) + axis) % len(a.shape)
+ indices.insert(insert_pos, argmax)
+ return a[tuple(indices)]
+
+
+def lambda_max(arr, axis=None, key=None, keepdims=False):
+ idxs = np.argmax(key(arr), axis)
+ if axis is not None:
+ idxs = np.expand_dims(idxs, axis)
+ result = np.take_along_axis(arr, idxs, axis)
+ if not keepdims:
+ result = np.squeeze(result, axis=axis)
+ return result
+ else:
+ return arr.flatten()[idxs]
+
+
+def lambda_min(arr, axis=None, key=None, keepdims=False):
+ idxs = np.argmin(key(arr), axis)
+ if axis is not None:
+ idxs = np.expand_dims(idxs, axis)
+ result = np.take_along_axis(arr, idxs, axis)
+ if not keepdims:
+ result = np.squeeze(result, axis=axis)
+ return result
+ else:
+ return arr.flatten()[idxs]
+
+
+def average_waveforms(pred_track, weights, algorithm):
+ """
+ :param pred_track: shape = (num, channels, length)
+ :param weights: shape = (num, )
+ :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
+ :return: averaged waveform in shape (channels, length)
+ """
+
+ pred_track = np.array(pred_track)
+ final_length = pred_track.shape[-1]
+
+ mod_track = []
+ for i in range(pred_track.shape[0]):
+ if algorithm == 'avg_wave':
+ mod_track.append(pred_track[i] * weights[i])
+ elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
+ mod_track.append(pred_track[i])
+ elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
+ spec = stft(pred_track[i], nfft=2048, hl=1024)
+ if algorithm in ['avg_fft']:
+ mod_track.append(spec * weights[i])
+ else:
+ mod_track.append(spec)
+ pred_track = np.array(mod_track)
+
+ if algorithm in ['avg_wave']:
+ pred_track = pred_track.sum(axis=0)
+ pred_track /= np.array(weights).sum().T
+ elif algorithm in ['median_wave']:
+ pred_track = np.median(pred_track, axis=0)
+ elif algorithm in ['min_wave']:
+ pred_track = np.array(pred_track)
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
+ elif algorithm in ['max_wave']:
+ pred_track = np.array(pred_track)
+ pred_track = lambda_max(pred_track, axis=0, key=np.abs)
+ elif algorithm in ['avg_fft']:
+ pred_track = pred_track.sum(axis=0)
+ pred_track /= np.array(weights).sum()
+ pred_track = istft(pred_track, 1024, final_length)
+ elif algorithm in ['min_fft']:
+ pred_track = np.array(pred_track)
+ pred_track = lambda_min(pred_track, axis=0, key=np.abs)
+ pred_track = istft(pred_track, 1024, final_length)
+ elif algorithm in ['max_fft']:
+ pred_track = np.array(pred_track)
+ pred_track = absmax(pred_track, axis=0)
+ pred_track = istft(pred_track, 1024, final_length)
+ elif algorithm in ['median_fft']:
+ pred_track = np.array(pred_track)
+ pred_track = np.median(pred_track, axis=0)
+ pred_track = istft(pred_track, 1024, final_length)
+ return pred_track
+
+
+def ensemble_files(args):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble")
+ parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft")
+ parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files")
+ parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored")
+ if args is None:
+ args = parser.parse_args()
+ else:
+ args = parser.parse_args(args)
+
+ print('Ensemble type: {}'.format(args.type))
+ print('Number of input files: {}'.format(len(args.files)))
+ if args.weights is not None:
+ weights = args.weights
+ else:
+ weights = np.ones(len(args.files))
+ print('Weights: {}'.format(weights))
+ print('Output file: {}'.format(args.output))
+ data = []
+ for f in args.files:
+ if not os.path.isfile(f):
+ print('Error. Can\'t find file: {}. Check paths.'.format(f))
+ exit()
+ print('Reading file: {}'.format(f))
+ wav, sr = librosa.load(f, sr=None, mono=False)
+ # wav, sr = sf.read(f)
+ print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
+ data.append(wav)
+ data = np.array(data)
+ res = average_waveforms(data, weights, args.type)
+ print('Result shape: {}'.format(res.shape))
+ sf.write(args.output, res.T, sr, 'FLOAT')
+
+
+if __name__ == "__main__":
+ ensemble_files(None)
diff --git a/gui-wx.py b/gui-wx.py
new file mode 100644
index 0000000000000000000000000000000000000000..799a9f047b8f8de0176efac9f0f2db779a39e2f0
--- /dev/null
+++ b/gui-wx.py
@@ -0,0 +1,635 @@
+import wx
+import wx.adv
+import wx.html
+import wx.html2
+import subprocess
+import os
+import threading
+import queue
+import json
+import webbrowser
+import requests
+import sys
+
+def run_subprocess(cmd, output_queue):
+ try:
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
+ for line in process.stdout:
+ output_queue.put(line)
+ process.wait()
+ if process.returncode == 0:
+ output_queue.put("Process completed successfully!")
+ else:
+ output_queue.put(f"Process failed with return code {process.returncode}")
+ except Exception as e:
+ output_queue.put(f"An error occurred: {str(e)}")
+
+def update_output(output_text, output_queue):
+ try:
+ while True:
+ line = output_queue.get_nowait()
+ if wx.Window.FindWindowById(output_text.GetId()):
+ wx.CallAfter(output_text.AppendText, line)
+ else:
+ return # Exit if the text control no longer exists
+ except queue.Empty:
+ pass
+ except RuntimeError:
+ return # Exit if a RuntimeError occurs (e.g., window closed)
+ wx.CallLater(100, update_output, output_text, output_queue)
+
+def open_store_folder(folder_path):
+ if os.path.exists(folder_path):
+ os.startfile(folder_path)
+ else:
+ wx.MessageBox(f"The folder {folder_path} does not exist.", "Error", wx.OK | wx.ICON_ERROR)
+
+class DarkThemedTextCtrl(wx.TextCtrl):
+ def __init__(self, parent, id=wx.ID_ANY, value="", style=0):
+ super().__init__(parent, id, value, style=style | wx.NO_BORDER)
+ self.SetBackgroundColour(wx.Colour(0, 0, 0))
+ self.SetForegroundColour(wx.WHITE)
+
+class CollapsiblePanel(wx.Panel):
+ def __init__(self, parent, title, *args, **kwargs):
+ wx.Panel.__init__(self, parent, *args, **kwargs)
+ self.SetBackgroundColour(parent.GetBackgroundColour())
+
+ self.toggle_button = wx.Button(self, label=title, style=wx.NO_BORDER)
+ self.toggle_button.SetBackgroundColour(self.GetBackgroundColour())
+ self.toggle_button.Bind(wx.EVT_BUTTON, self.on_toggle)
+
+ self.content_panel = wx.Panel(self)
+ self.content_panel.SetBackgroundColour(self.GetBackgroundColour())
+
+ self.main_sizer = wx.BoxSizer(wx.VERTICAL)
+ self.main_sizer.Add(self.toggle_button, 0, wx.EXPAND | wx.ALL, 5)
+ self.main_sizer.Add(self.content_panel, 0, wx.EXPAND | wx.ALL, 5)
+
+ self.SetSizer(self.main_sizer)
+ self.collapsed = True
+ self.toggle_button.SetLabel(f"▶ {title}")
+ self.content_panel.Hide()
+
+ def on_toggle(self, event):
+ self.collapsed = not self.collapsed
+ self.toggle_button.SetLabel(f"{'▶' if self.collapsed else '▼'} {self.toggle_button.GetLabel()[2:]}")
+ self.content_panel.Show(not self.collapsed)
+ self.Layout()
+ self.GetParent().Layout()
+
+ def get_content_panel(self):
+ return self.content_panel
+
+class CustomToolTip(wx.PopupWindow):
+ def __init__(self, parent, text):
+ wx.PopupWindow.__init__(self, parent)
+
+ # Main panel for tooltip
+ panel = wx.Panel(self)
+ self.st = wx.StaticText(panel, 1, text, pos=(10, 10))
+
+ font = wx.Font(8, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, False, "Poppins")
+ self.st.SetFont(font)
+
+ size = self.st.GetBestSize()
+ self.SetSize((size.width + 20, size.height + 20))
+
+ # Adjust the panel size
+ panel.SetSize(self.GetSize())
+ panel.SetBackgroundColour(wx.Colour(255, 255, 255))
+
+ # Bind paint event to draw border
+ panel.Bind(wx.EVT_PAINT, self.on_paint)
+
+ def on_paint(self, event):
+ # Get the device context for the panel (not self)
+ panel = event.GetEventObject() # Get the panel triggering the paint event
+ dc = wx.PaintDC(panel) # Use panel as the target of the PaintDC
+ dc.SetPen(wx.Pen(wx.Colour(210, 210, 210), 1)) # Border color
+ dc.SetBrush(wx.Brush(wx.Colour(255, 255, 255))) # Fill with white
+
+ size = panel.GetSize()
+ dc.DrawRectangle(0, 0, size.width, size.height) # Draw border around panel
+
+class MainFrame(wx.Frame):
+ def __init__(self):
+ super().__init__(parent=None, title="Music Source Separation Training & Inference GUI")
+ self.SetSize(994, 670)
+ self.SetBackgroundColour(wx.Colour(247, 248, 250)) # #F7F8FA
+
+ icon = wx.Icon("gui/favicon.ico", wx.BITMAP_TYPE_ICO)
+ self.SetIcon(icon)
+
+ self.saved_combinations = {}
+
+ # Center the window on the screen
+ self.Center()
+
+ # Set Poppins font for the entire application
+ font_path = "gui/Poppins Regular 400.ttf"
+ bold_font_path = "gui/Poppins Bold 700.ttf"
+ wx.Font.AddPrivateFont(font_path)
+ wx.Font.AddPrivateFont(bold_font_path)
+ self.font = wx.Font(9, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, False, "Poppins")
+ self.bold_font = wx.Font(10, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_BOLD, False, "Poppins")
+ self.SetFont(self.font)
+
+ panel = wx.Panel(self)
+ main_sizer = wx.BoxSizer(wx.VERTICAL)
+
+ # Add image (with error handling)
+ try:
+ img = wx.Image("gui/mvsep.png", wx.BITMAP_TYPE_PNG)
+ img_bitmap = wx.Bitmap(img)
+ img_ctrl = wx.StaticBitmap(panel, -1, img_bitmap)
+ main_sizer.Add(img_ctrl, 0, wx.ALIGN_CENTER | wx.TOP, 20)
+ except:
+ print("Failed to load image: gui/mvsep.png")
+
+ # Add title text
+ title_text = wx.StaticText(panel, label="Music Source Separation Training && Inference GUI")
+ title_text.SetFont(self.bold_font)
+ title_text.SetForegroundColour(wx.BLACK)
+ main_sizer.Add(title_text, 0, wx.ALIGN_CENTER | wx.TOP, 10)
+
+ # Add subtitle text
+ subtitle_text = wx.StaticText(panel, label="Code by ZFTurbo / GUI by Bas Curtiz")
+ subtitle_text.SetForegroundColour(wx.BLACK)
+ main_sizer.Add(subtitle_text, 0, wx.ALIGN_CENTER | wx.TOP, 5)
+
+ # Add GitHub link
+ github_link = wx.adv.HyperlinkCtrl(panel, -1, "GitHub Repository", "https://github.com/ZFTurbo/Music-Source-Separation-Training")
+ github_link.SetNormalColour(wx.Colour(1, 118, 179)) # #0176B3
+ github_link.SetHoverColour(wx.Colour(86, 91, 123)) # #565B7B
+ main_sizer.Add(github_link, 0, wx.ALIGN_CENTER | wx.TOP, 10)
+
+ # Add Download models button on a new line with 10px bottom margin
+ download_models_btn = self.create_styled_button(panel, "Download Models", self.on_download_models)
+ main_sizer.Add(download_models_btn, 0, wx.ALIGN_CENTER | wx.TOP | wx.BOTTOM, 10)
+
+ # Training Configuration
+ self.training_panel = CollapsiblePanel(panel, "Training Configuration")
+ self.training_panel.toggle_button.SetFont(self.bold_font)
+ self.create_training_controls(self.training_panel.get_content_panel())
+ main_sizer.Add(self.training_panel, 0, wx.EXPAND | wx.ALL, 10)
+
+ # Inference Configuration
+ self.inference_panel = CollapsiblePanel(panel, "Inference Configuration")
+ self.inference_panel.toggle_button.SetFont(self.bold_font)
+ self.create_inference_controls(self.inference_panel.get_content_panel())
+ main_sizer.Add(self.inference_panel, 0, wx.EXPAND | wx.ALL, 10)
+
+ panel.SetSizer(main_sizer)
+ self.load_settings()
+
+ def create_styled_button(self, parent, label, handler):
+ btn = wx.Button(parent, label=label, style=wx.BORDER_NONE)
+ btn.SetBackgroundColour(wx.Colour(1, 118, 179)) # #0176B3
+ btn.SetForegroundColour(wx.WHITE)
+ btn.SetFont(self.bold_font)
+
+ def on_enter(event):
+ btn.SetBackgroundColour(wx.Colour(86, 91, 123))
+ event.Skip()
+
+ def on_leave(event):
+ btn.SetBackgroundColour(wx.Colour(1, 118, 179))
+ event.Skip()
+
+ def on_click(event):
+ btn.SetBackgroundColour(wx.Colour(86, 91, 123))
+ handler(event)
+ wx.CallLater(100, lambda: btn.SetBackgroundColour(wx.Colour(1, 118, 179)))
+
+ btn.Bind(wx.EVT_ENTER_WINDOW, on_enter)
+ btn.Bind(wx.EVT_LEAVE_WINDOW, on_leave)
+ btn.Bind(wx.EVT_BUTTON, on_click)
+
+ return btn
+
+ def create_training_controls(self, panel):
+ sizer = wx.BoxSizer(wx.VERTICAL)
+
+ # Model Type
+ model_type_sizer = wx.BoxSizer(wx.HORIZONTAL)
+ model_type_sizer.Add(wx.StaticText(panel, label="Model Type:"), 0, wx.ALIGN_CENTER_VERTICAL)
+ self.model_type = wx.Choice(panel, choices=["apollo", "bandit", "bandit_v2", "bs_roformer", "htdemucs", "mdx23c", "mel_band_roformer", "scnet", "scnet_unofficial", "segm_models", "swin_upernet", "torchseg"])
+ self.model_type.SetFont(self.font)
+ model_type_sizer.Add(self.model_type, 0, wx.LEFT, 5)
+ sizer.Add(model_type_sizer, 0, wx.EXPAND | wx.ALL, 5)
+
+ # Config File
+ self.config_entry = self.add_browse_control(panel, sizer, "Config File:", is_folder=False, is_config=True)
+
+ # Start Checkpoint
+ self.checkpoint_entry = self.add_browse_control(panel, sizer, "Checkpoint:", is_folder=False, is_checkpoint=True)
+
+ # Results Path
+ self.result_path_entry = self.add_browse_control(panel, sizer, "Results Path:", is_folder=True)
+
+ # Data Paths
+ self.data_entry = self.add_browse_control(panel, sizer, "Data Paths (separated by ';'):", is_folder=True)
+
+ # Validation Paths
+ self.valid_entry = self.add_browse_control(panel, sizer, "Validation Paths (separated by ';'):", is_folder=True)
+
+ # Number of Workers and Device IDs
+ workers_device_sizer = wx.BoxSizer(wx.HORIZONTAL)
+
+ workers_sizer = wx.BoxSizer(wx.HORIZONTAL)
+ workers_sizer.Add(wx.StaticText(panel, label="Number of Workers:"), 0, wx.ALIGN_CENTER_VERTICAL)
+ self.workers_entry = wx.TextCtrl(panel, value="4")
+ self.workers_entry.SetFont(self.font)
+ workers_sizer.Add(self.workers_entry, 0, wx.LEFT, 5)
+ workers_device_sizer.Add(workers_sizer, 0, wx.EXPAND)
+
+ device_sizer = wx.BoxSizer(wx.HORIZONTAL)
+ device_sizer.Add(wx.StaticText(panel, label="Device IDs (comma-separated):"), 0, wx.ALIGN_CENTER_VERTICAL | wx.LEFT, 20)
+ self.device_entry = wx.TextCtrl(panel, value="0")
+ self.device_entry.SetFont(self.font)
+ device_sizer.Add(self.device_entry, 0, wx.LEFT, 5)
+ workers_device_sizer.Add(device_sizer, 0, wx.EXPAND)
+
+ sizer.Add(workers_device_sizer, 0, wx.EXPAND | wx.ALL, 5)
+
+ # Run Training Button
+ self.run_button = self.create_styled_button(panel, "Run Training", self.run_training)
+ sizer.Add(self.run_button, 0, wx.ALIGN_CENTER | wx.ALL, 10)
+
+ panel.SetSizer(sizer)
+
+ def create_inference_controls(self, panel):
+ sizer = wx.BoxSizer(wx.VERTICAL)
+
+ # Model Type and Saved Combinations
+ infer_model_type_sizer = wx.BoxSizer(wx.HORIZONTAL)
+ infer_model_type_sizer.Add(wx.StaticText(panel, label="Model Type:"), 0, wx.ALIGN_CENTER_VERTICAL)
+ self.infer_model_type = wx.Choice(panel, choices=["apollo", "bandit", "bandit_v2", "bs_roformer", "htdemucs", "mdx23c", "mel_band_roformer", "scnet", "scnet_unofficial", "segm_models", "swin_upernet", "torchseg"])
+ self.infer_model_type.SetFont(self.font)
+ infer_model_type_sizer.Add(self.infer_model_type, 0, wx.LEFT, 5)
+
+ # Add "Preset:" label
+ infer_model_type_sizer.Add(wx.StaticText(panel, label="Preset:"), 0, wx.ALIGN_CENTER_VERTICAL | wx.LEFT, 20)
+
+ # Add dropdown for saved combinations
+ self.saved_combinations_dropdown = wx.Choice(panel, choices=[])
+ self.saved_combinations_dropdown.SetFont(self.font)
+ self.saved_combinations_dropdown.Bind(wx.EVT_CHOICE, self.on_combination_selected)
+
+ # Set the width to 200px and an appropriate height
+ self.saved_combinations_dropdown.SetMinSize((358, -1)) # -1 keeps the height unchanged
+
+ # Add to sizer
+ infer_model_type_sizer.Add(self.saved_combinations_dropdown, 0, wx.LEFT, 5)
+
+ # Add plus button
+ plus_button = self.create_styled_button(panel, "+", self.on_save_combination)
+ plus_button.SetMinSize((30, 30))
+ infer_model_type_sizer.Add(plus_button, 0, wx.LEFT, 5)
+
+ # Add help button with custom tooltip
+ help_button = wx.StaticText(panel, label="?")
+ help_button.SetFont(self.bold_font)
+ help_button.SetForegroundColour(wx.Colour(1, 118, 179)) #0176B3
+ tooltip_text = ("How to add a preset?\n\n"
+ "1. Click Download Models\n"
+ "2. Right-click a model's Config && Checkpoint\n"
+ "3. Save link as && select a proper destination\n"
+ "4. Copy the Model name\n"
+ "5. Close Download Models\n\n"
+ "6. Browse for the Config file\n"
+ "7. Browse for the Checkpoint\n"
+ "8. Select the Model Type\n"
+ "9. Click the + button\n"
+ "10. Paste the Model name && click OK\n\n"
+ "On next use, just select it from the Preset dropdown.")
+
+ self.tooltip = CustomToolTip(self, tooltip_text)
+ self.tooltip.Hide()
+
+ def on_help_enter(event):
+ self.tooltip.Position(help_button.ClientToScreen((0, help_button.GetSize().height)), (0, 0))
+ self.tooltip.Show()
+
+ def on_help_leave(event):
+ self.tooltip.Hide()
+
+ help_button.Bind(wx.EVT_ENTER_WINDOW, on_help_enter)
+ help_button.Bind(wx.EVT_LEAVE_WINDOW, on_help_leave)
+
+ infer_model_type_sizer.Add(help_button, 0, wx.LEFT | wx.ALIGN_CENTER_VERTICAL, 5)
+
+ sizer.Add(infer_model_type_sizer, 0, wx.EXPAND | wx.ALL, 5)
+
+ # Config File
+ self.infer_config_entry = self.add_browse_control(panel, sizer, "Config File:", is_folder=False, is_config=True)
+
+ # Start Checkpoint
+ self.infer_checkpoint_entry = self.add_browse_control(panel, sizer, "Checkpoint:", is_folder=False, is_checkpoint=True)
+
+ # Input Folder
+ self.infer_input_entry = self.add_browse_control(panel, sizer, "Input Folder:", is_folder=True)
+
+ # Store Directory
+ self.infer_store_entry = self.add_browse_control(panel, sizer, "Output Folder:", is_folder=True)
+
+ # Extract Instrumental Checkbox
+ self.extract_instrumental_checkbox = wx.CheckBox(panel, label="Extract Instrumental")
+ self.extract_instrumental_checkbox.SetFont(self.font)
+ sizer.Add(self.extract_instrumental_checkbox, 0, wx.EXPAND | wx.ALL, 5)
+
+ # Run Inference Button
+ self.run_infer_button = self.create_styled_button(panel, "Run Inference", self.run_inference)
+ sizer.Add(self.run_infer_button, 0, wx.ALIGN_CENTER | wx.ALL, 10)
+
+ panel.SetSizer(sizer)
+
+ def add_browse_control(self, panel, sizer, label, is_folder=False, is_config=False, is_checkpoint=False):
+ browse_sizer = wx.BoxSizer(wx.HORIZONTAL)
+ browse_sizer.Add(wx.StaticText(panel, label=label), 0, wx.ALIGN_CENTER_VERTICAL)
+ entry = wx.TextCtrl(panel)
+ entry.SetFont(self.font)
+ browse_sizer.Add(entry, 1, wx.EXPAND | wx.LEFT, 5)
+ browse_button = self.create_styled_button(panel, "Browse", lambda event, entry=entry, is_folder=is_folder, is_config=is_config, is_checkpoint=is_checkpoint: self.browse(event, entry, is_folder, is_config, is_checkpoint))
+ browse_sizer.Add(browse_button, 0, wx.LEFT, 5)
+ sizer.Add(browse_sizer, 0, wx.EXPAND | wx.ALL, 5)
+ return entry
+
+ def browse(self, event, entry, is_folder=False, is_config=False, is_checkpoint=False):
+ if is_folder:
+ dialog = wx.DirDialog(self, "Choose a directory", style=wx.DD_DEFAULT_STYLE | wx.DD_DIR_MUST_EXIST)
+ else:
+ wildcard = "All files (*.*)|*.*"
+ if is_config:
+ wildcard = "YAML files (*.yaml)|*.yaml"
+ elif is_checkpoint:
+ wildcard = "Checkpoint files (*.bin;*.chpt;*.ckpt;*.th)|*.bin;*.chpt;*.ckpt;*.th"
+
+ dialog = wx.FileDialog(self, "Choose a file", style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST, wildcard=wildcard)
+
+ dialog.SetFont(self.font)
+ if dialog.ShowModal() == wx.ID_OK:
+ entry.SetValue(dialog.GetPath())
+ dialog.Destroy()
+
+ def create_output_window(self, title, folder_path):
+ output_frame = wx.Frame(self, title=title, style=wx.DEFAULT_FRAME_STYLE | wx.STAY_ON_TOP)
+ output_frame.SetIcon(self.GetIcon())
+ output_frame.SetSize(994, 670)
+ output_frame.SetBackgroundColour(wx.Colour(0, 0, 0))
+ output_frame.SetFont(self.font)
+
+ # Set the position of the output frame to match the main frame
+ output_frame.SetPosition(self.GetPosition())
+
+ output_title = wx.StaticText(output_frame, label=title)
+ output_title.SetFont(self.bold_font)
+ output_title.SetForegroundColour(wx.WHITE)
+
+ output_text = DarkThemedTextCtrl(output_frame, style=wx.TE_MULTILINE | wx.TE_READONLY)
+ output_text.SetFont(self.font)
+
+ open_folder_button = self.create_styled_button(output_frame, f"Open Output Folder", lambda event: open_store_folder(folder_path))
+
+ sizer = wx.BoxSizer(wx.VERTICAL)
+ sizer.Add(output_title, 0, wx.ALIGN_CENTER | wx.TOP, 10)
+ sizer.Add(output_text, 1, wx.EXPAND | wx.ALL, 10)
+ sizer.Add(open_folder_button, 0, wx.ALIGN_CENTER | wx.BOTTOM, 10)
+ output_frame.SetSizer(sizer)
+
+ return output_frame, output_text
+
+ def run_training(self, event):
+ model_type = self.model_type.GetStringSelection()
+ config_path = self.config_entry.GetValue()
+ start_checkpoint = self.checkpoint_entry.GetValue()
+ results_path = self.result_path_entry.GetValue()
+ data_paths = self.data_entry.GetValue()
+ valid_paths = self.valid_entry.GetValue()
+ num_workers = self.workers_entry.GetValue()
+ device_ids = self.device_entry.GetValue()
+
+ if not model_type:
+ wx.MessageBox("Please select a model type.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not config_path:
+ wx.MessageBox("Please select a config file.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not results_path:
+ wx.MessageBox("Please specify a results path.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not data_paths:
+ wx.MessageBox("Please specify data paths.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not valid_paths:
+ wx.MessageBox("Please specify validation paths.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+
+ cmd = [
+ sys.executable, "train.py",
+ "--model_type", model_type,
+ "--config_path", config_path,
+ "--results_path", results_path,
+ "--data_path", *data_paths.split(';'),
+ "--valid_path", *valid_paths.split(';'),
+ "--num_workers", num_workers,
+ "--device_ids", device_ids
+ ]
+
+ if start_checkpoint:
+ cmd += ["--start_check_point", start_checkpoint]
+
+ output_queue = queue.Queue()
+ threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
+
+ output_frame, output_text = self.create_output_window("Training Output", results_path)
+ output_frame.Show()
+ update_output(output_text, output_queue)
+
+ self.save_settings()
+
+ def run_inference(self, event):
+ model_type = self.infer_model_type.GetStringSelection()
+ config_path = self.infer_config_entry.GetValue()
+ start_checkpoint = self.infer_checkpoint_entry.GetValue()
+ input_folder = self.infer_input_entry.GetValue()
+ store_dir = self.infer_store_entry.GetValue()
+ extract_instrumental = self.extract_instrumental_checkbox.GetValue()
+
+ if not model_type:
+ wx.MessageBox("Please select a model type.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not config_path:
+ wx.MessageBox("Please select a config file.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not input_folder:
+ wx.MessageBox("Please specify an input folder.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+ if not store_dir:
+ wx.MessageBox("Please specify an output folder.", "Input Error", wx.OK | wx.ICON_ERROR)
+ return
+
+ cmd = [
+ sys.executable, "inference.py",
+ "--model_type", model_type,
+ "--config_path", config_path,
+ "--input_folder", input_folder,
+ "--store_dir", store_dir
+ ]
+
+ if start_checkpoint:
+ cmd += ["--start_check_point", start_checkpoint]
+
+ if extract_instrumental:
+ cmd += ["--extract_instrumental"]
+
+ output_queue = queue.Queue()
+ threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
+
+ output_frame, output_text = self.create_output_window("Inference Output", store_dir)
+ output_frame.Show()
+ update_output(output_text, output_queue)
+
+ self.save_settings()
+
+ def save_settings(self):
+ settings = {
+ "model_type": self.model_type.GetStringSelection(),
+ "config_path": self.config_entry.GetValue(),
+ "start_checkpoint": self.checkpoint_entry.GetValue(),
+ "results_path": self.result_path_entry.GetValue(),
+ "data_paths": self.data_entry.GetValue(),
+ "valid_paths": self.valid_entry.GetValue(),
+ "num_workers": self.workers_entry.GetValue(),
+ "device_ids": self.device_entry.GetValue(),
+ "infer_model_type": self.infer_model_type.GetStringSelection(),
+ "infer_config_path": self.infer_config_entry.GetValue(),
+ "infer_start_checkpoint": self.infer_checkpoint_entry.GetValue(),
+ "infer_input_folder": self.infer_input_entry.GetValue(),
+ "infer_store_dir": self.infer_store_entry.GetValue(),
+ "extract_instrumental": self.extract_instrumental_checkbox.GetValue(),
+ "saved_combinations": self.saved_combinations
+ }
+ with open("settings.json", "w") as f:
+ json.dump(settings, f, indent=2, ensure_ascii=False)
+
+ def load_settings(self):
+ try:
+ with open("settings.json", "r") as f:
+ settings = json.load(f)
+
+ self.model_type.SetStringSelection(settings.get("model_type", ""))
+ self.config_entry.SetValue(settings.get("config_path", ""))
+ self.checkpoint_entry.SetValue(settings.get("start_checkpoint", ""))
+ self.result_path_entry.SetValue(settings.get("results_path", ""))
+ self.data_entry.SetValue(settings.get("data_paths", ""))
+ self.valid_entry.SetValue(settings.get("valid_paths", ""))
+ self.workers_entry.SetValue(settings.get("num_workers", "4"))
+ self.device_entry.SetValue(settings.get("device_ids", "0"))
+
+ self.infer_model_type.SetStringSelection(settings.get("infer_model_type", ""))
+ self.infer_config_entry.SetValue(settings.get("infer_config_path", ""))
+ self.infer_checkpoint_entry.SetValue(settings.get("infer_start_checkpoint", ""))
+ self.infer_input_entry.SetValue(settings.get("infer_input_folder", ""))
+ self.infer_store_entry.SetValue(settings.get("infer_store_dir", ""))
+ self.extract_instrumental_checkbox.SetValue(settings.get("extract_instrumental", False))
+ self.saved_combinations = settings.get("saved_combinations", {})
+
+ self.update_saved_combinations()
+ except FileNotFoundError:
+ pass # If the settings file doesn't exist, use default values
+
+ def on_download_models(self, event):
+ DownloadModelsFrame(self).Show()
+
+ def on_save_combination(self, event):
+ dialog = wx.TextEntryDialog(self, "Enter a name for this preset:", "Save Preset")
+ if dialog.ShowModal() == wx.ID_OK:
+ name = dialog.GetValue()
+ if name:
+ combination = {
+ "model_type": self.infer_model_type.GetStringSelection(),
+ "config_path": self.infer_config_entry.GetValue(),
+ "checkpoint": self.infer_checkpoint_entry.GetValue()
+ }
+ self.saved_combinations[name] = combination
+ self.update_saved_combinations()
+ self.save_settings()
+ dialog.Destroy()
+
+ def on_combination_selected(self, event):
+ name = self.saved_combinations_dropdown.GetStringSelection()
+ if name:
+ combination = self.saved_combinations.get(name)
+ if combination:
+ self.infer_model_type.SetStringSelection(combination["model_type"])
+ self.infer_config_entry.SetValue(combination["config_path"])
+ self.infer_checkpoint_entry.SetValue(combination["checkpoint"])
+
+ def update_saved_combinations(self):
+ self.saved_combinations_dropdown.Clear()
+ for name in self.saved_combinations.keys():
+ self.saved_combinations_dropdown.Append(name)
+
+class DownloadModelsFrame(wx.Frame):
+ def __init__(self, parent):
+ super().__init__(parent, title="Download Models", size=(994, 670), style=wx.DEFAULT_FRAME_STYLE & ~(wx.RESIZE_BORDER | wx.MAXIMIZE_BOX))
+ self.SetBackgroundColour(wx.Colour(247, 248, 250)) # #F7F8FA
+ self.SetFont(wx.Font(9, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL, False, "Poppins"))
+
+ # Set the position of the Download Models frame to match the main frame
+ self.SetPosition(parent.GetPosition())
+
+ # Set the icon for the Download Models frame
+ icon = wx.Icon("gui/favicon.ico", wx.BITMAP_TYPE_ICO)
+ self.SetIcon(icon)
+
+ panel = wx.Panel(self)
+ sizer = wx.BoxSizer(wx.VERTICAL)
+
+ # Add WebView
+ self.web_view = wx.html2.WebView.New(panel)
+ self.web_view.LoadURL("https://bascurtiz.x10.mx/models-checkpoint-config-urls.html")
+ self.web_view.Bind(wx.html2.EVT_WEBVIEW_NAVIGATING, self.on_link_click)
+ self.web_view.Bind(wx.html2.EVT_WEBVIEW_NAVIGATED, self.on_page_load)
+ sizer.Add(self.web_view, 1, wx.EXPAND)
+
+ panel.SetSizer(sizer)
+
+ def on_link_click(self, event):
+ url = event.GetURL()
+ if not url.startswith("https://bascurtiz.x10.mx"):
+ event.Veto() # Prevent the WebView from navigating
+ webbrowser.open(url) # Open the link in the default browser
+
+ def on_page_load(self, event):
+ self.inject_custom_css()
+
+ def inject_custom_css(self):
+ css = """
+ body {
+ margin: 0;
+ padding: 0;
+ }
+ ::-webkit-scrollbar {
+ width: 12px;
+ }
+ ::-webkit-scrollbar-track {
+ background: #f1f1f1;
+ }
+ ::-webkit-scrollbar-thumb {
+ background: #888;
+ }
+ ::-webkit-scrollbar-thumb:hover {
+ background: #555;
+ }
+ """
+ js = f"var style = document.createElement('style'); style.textContent = `{css}`; document.head.appendChild(style);"
+ self.web_view.RunScript(js)
+
+if __name__ == "__main__":
+ app = wx.App()
+ frame = MainFrame()
+ frame.Show()
+ app.MainLoop()
diff --git a/gui.py b/gui.py
index 142fa988508a013ec4980bfdc27827e61edc9a67..f52d98d23ea0b1cab159957e4e8f6e6028189e3a 100644
--- a/gui.py
+++ b/gui.py
@@ -14,180 +14,306 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Arayüz oluşturma fonksiyonu
def create_interface():
- # Gelişmiş ve profesyonel CSS
+ # CSS tanımı
css = """
/* Genel Tema */
body {
- background: url('/content/studio_bg.jpg') no-repeat center center fixed;
+ background: url('/content/logo.jpg') no-repeat center center fixed;
background-size: cover;
- background-color: #1a0d0d; /* Daha koyu ve sofistike kırmızı ton */
+ background-color: #2d0b0b; /* Koyu kırmızı, dublaj stüdyosuna uygun */
min-height: 100vh;
margin: 0;
- padding: 2rem;
- font-family: 'Montserrat', sans-serif; /* Daha modern ve profesyonel font */
- color: #D4A017; /* Altın tonu metin, lüks hissi */
- overflow-x: hidden;
+ padding: 1rem;
+ font-family: 'Poppins', sans-serif;
+ color: #C0C0C0; /* Metalik gümüş metin, profesyonel görünüm */
}
- body::before {
+ body::after {
content: '';
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
- background: linear-gradient(135deg, rgba(26, 13, 13, 0.9), rgba(45, 11, 11, 0.85));
+ background: rgba(45, 11, 11, 0.9); /* Daha koyu kırmızı overlay */
z-index: -1;
}
/* Logo Stilleri */
.logo-container {
- position: fixed;
- top: 1.5rem;
- left: 2rem;
+ position: absolute;
+ top: 1rem;
+ left: 50%;
+ transform: translateX(-50%);
display: flex;
align-items: center;
- z-index: 3000;
- background: rgba(0, 0, 0, 0.7);
- padding: 0.5rem 1rem;
- border-radius: 10px;
- box-shadow: 0 4px 15px rgba(212, 160, 23, 0.3);
+ z-index: 2000; /* Diğer öğelerden üstte, mutlaka görünür */
}
.logo-img {
- width: 100px;
+ width: 120px;
height: auto;
- filter: drop-shadow(0 0 5px rgba(212, 160, 23, 0.5));
}
/* Başlık Stilleri */
.header-text {
text-align: center;
- padding: 4rem 0 2rem;
- color: #D4A017; /* Altın tonu, profesyonel ve dikkat çekici */
- font-size: 3rem;
- font-weight: 800;
- text-transform: uppercase;
- letter-spacing: 2px;
- text-shadow: 0 0 15px rgba(212, 160, 23, 0.7), 0 0 5px rgba(255, 64, 64, 0.5);
- z-index: 1500;
- animation: text-glow 4s infinite ease-in-out;
+ padding: 80px 20px 20px; /* Logo için alan bırak */
+ color: #ff4040; /* Kırmızı, dublaj temasına uygun */
+ font-size: 2.5rem; /* Daha etkileyici ve büyük başlık */
+ font-weight: 900; /* Daha kalın ve dramatik */
+ text-shadow: 0 0 10px rgba(255, 64, 64, 0.5); /* Kırmızı gölge efekti */
+ z-index: 1500; /* Tablerden üstte, logonun altında */
}
- /* Profesyonel Panel Stili */
- .dubbing-panel {
- background: rgba(26, 13, 13, 0.95);
- border: 1px solid #D4A017;
- border-radius: 12px;
- padding: 1.5rem;
- box-shadow: 0 8px 25px rgba(212, 160, 23, 0.2);
- transition: transform 0.3s ease;
+ /* Metalik kırmızı parlama animasyonu */
+ @keyframes metallic-red-shine {
+ 0% { filter: brightness(1) saturate(1) drop-shadow(0 0 5px #ff4040); }
+ 50% { filter: brightness(1.3) saturate(1.7) drop-shadow(0 0 15px #ff6b6b); }
+ 100% { filter: brightness(1) saturate(1) drop-shadow(0 0 5px #ff4040); }
+ }
+
+ /* Dublaj temalı stil */
+ .dubbing-theme {
+ background: linear-gradient(to bottom, #800000, #2d0b0b); /* Koyu kırmızı gradyan */
+ border-radius: 15px;
+ padding: 1rem;
+ box-shadow: 0 10px 20px rgba(255, 64, 64, 0.3); /* Kırmızı gölge */
}
- .dubbing-panel:hover {
- transform: translateY(-5px);
+ /* Footer Stilleri (Tablerin Üstünde, Şeffaf) */
+ .footer {
+ text-align: center;
+ padding: 10px;
+ color: #ff4040; /* Kırmızı metin, dublaj temasına uygun */
+ font-size: 14px;
+ margin-top: 20px;
+ position: relative;
+ z-index: 1001; /* Tablerden üstte, logodan düşük */
}
- /* Düğme Stilleri */
+ /* Düğme ve Yükleme Alanı Stilleri */
button {
- background: linear-gradient(45deg, #800000, #A31818); /* Koyu kırmızıdan parlak kırmızıya */
- border: 1px solid #D4A017 !important;
- color: #FFF !important;
+ transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1) !important;
+ background: #800000 !important; /* Koyu kırmızı, dublaj temasına uygun */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
border-radius: 8px !important;
- padding: 10px 20px !important;
- font-weight: 600;
- text-transform: uppercase;
- letter-spacing: 1px;
- transition: all 0.3s ease !important;
- box-shadow: 0 4px 15px rgba(212, 160, 23, 0.3);
+ padding: 8px 16px !important;
+ position: relative;
+ overflow: hidden !important;
+ font-size: 0.9rem !important;
}
button:hover {
- background: linear-gradient(45deg, #A31818, #D42F2F) !important;
transform: scale(1.05) !important;
- box-shadow: 0 6px 20px rgba(212, 160, 23, 0.5) !important;
+ box-shadow: 0 10px 40px rgba(255, 64, 64, 0.7) !important; /* Daha belirgin kırmızı gölge */
+ background: #ff4040 !important; /* Daha açık kırmızı hover efekti */
}
- /* Yükleme Alanı */
+ button::before {
+ content: '';
+ position: absolute;
+ top: -50%;
+ left: -50%;
+ width: 200%;
+ height: 200%;
+ background: linear-gradient(45deg,
+ transparent 20%,
+ rgba(192, 192, 192, 0.3) 50%, /* Metalik gümüş ton */
+ transparent 80%);
+ animation: button-shine 3s infinite linear;
+ }
+
+ /* Resim ve Ses Yükleme Alanı Stili */
.compact-upload.horizontal {
- background: rgba(26, 13, 13, 0.9);
- border: 1px solid #D4A017;
- border-radius: 8px;
- padding: 0.5rem 1rem;
- display: flex;
- align-items: center;
- gap: 10px;
- max-width: 450px;
- transition: all 0.3s ease;
+ display: inline-flex !important;
+ align-items: center !important;
+ gap: 8px !important;
+ max-width: 400px !important;
+ height: 40px !important;
+ padding: 0 12px !important;
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ background: rgba(128, 0, 0, 0.5) !important; /* Koyu kırmızı, şeffaf */
+ border-radius: 8px !important;
+ transition: all 0.2s ease !important;
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
}
.compact-upload.horizontal:hover {
- border-color: #FFD700; /* Parlak altın hover efekti */
- background: rgba(45, 11, 11, 0.95);
+ border-color: #ff6b6b !important; /* Daha açık kırmızı */
+ background: rgba(128, 0, 0, 0.7) !important; /* Daha koyu kırmızı hover */
+ }
+
+ .compact-upload.horizontal .w-full {
+ flex: 1 1 auto !important;
+ min-width: 120px !important;
+ margin: 0 !important;
+ color: #C0C0C0 !important; /* Metalik gümüş */
}
.compact-upload.horizontal button {
- background: #800000;
- border: 1px solid #D4A017;
- padding: 6px 12px;
- font-size: 0.8rem;
+ padding: 4px 12px !important;
+ font-size: 0.75em !important;
+ height: 28px !important;
+ min-width: 80px !important;
+ border-radius: 4px !important;
+ background: #800000 !important; /* Koyu kırmızı */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ }
+
+ .compact-upload.horizontal .text-gray-500 {
+ font-size: 0.7em !important;
+ color: rgba(192, 192, 192, 0.6) !important; /* Şeffaf metalik gümüş */
+ white-space: nowrap !important;
+ overflow: hidden !important;
+ text-overflow: ellipsis !important;
+ max-width: 180px !important;
+ }
+
+ /* Ekstra Dar Versiyon */
+ .compact-upload.horizontal.x-narrow {
+ max-width: 320px !important;
+ height: 36px !important;
+ padding: 0 10px !important;
+ gap: 6px !important;
}
- /* Sekme Stilleri */
+ .compact-upload.horizontal.x-narrow button {
+ padding: 3px 10px !important;
+ font-size: 0.7em !important;
+ height: 26px !important;
+ min-width: 70px !important;
+ }
+
+ .compact-upload.horizontal.x-narrow .text-gray-500 {
+ font-size: 0.65em !important;
+ max-width: 140px !important;
+ }
+
+ /* Sekmeler İçin Ortak Stiller */
.gr-tab {
- background: rgba(26, 13, 13, 0.9);
- border: 1px solid #D4A017;
- border-radius: 12px 12px 0 0;
- padding: 0.75rem 1.5rem;
- color: #D4A017;
- font-weight: 600;
- text-transform: uppercase;
- transition: all 0.3s ease;
+ background: rgba(128, 0, 0, 0.5) !important; /* Koyu kırmızı, şeffaf */
+ border-radius: 12px 12px 0 0 !important;
+ margin: 0 5px !important;
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ z-index: 1500; /* Logo’nun altında, diğer öğelerden üstte */
}
.gr-tab-selected {
- background: #800000;
- color: #FFF;
- border-bottom: none;
- box-shadow: 0 4px 15px rgba(212, 160, 23, 0.4);
+ background: #800000 !important; /* Koyu kırmızı */
+ box-shadow: 0 4px 12px rgba(255, 64, 64, 0.7) !important; /* Daha belirgin kırmızı gölge */
+ color: #ffffff !important; /* Beyaz metin (seçili sekme için kontrast) */
+ border: 1px solid #ff6b6b !important; /* Daha açık kırmızı */
}
- /* Altbilgi */
- .footer {
- text-align: center;
- padding: 1rem;
- color: #D4A017;
- font-size: 1rem;
- font-weight: 500;
- text-shadow: 0 0 5px rgba(212, 160, 23, 0.3);
- position: relative;
- z-index: 1001;
+ /* Manuel Ensemble Özel Stilleri */
+ .compact-header {
+ font-size: 0.95em !important;
+ margin: 0.8rem 0 0.5rem 0 !important;
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ }
+
+ .compact-grid {
+ gap: 0.4rem !important;
+ max-height: 50vh;
+ overflow-y: auto;
+ padding: 10px;
+ background: rgba(128, 0, 0, 0.3) !important; /* Koyu kırmızı, şeffaf */
+ border-radius: 12px;
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ }
+
+ .compact-dropdown {
+ --padding: 8px 12px !important;
+ --radius: 10px !important;
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ background: rgba(128, 0, 0, 0.5) !important; /* Koyu kırmızı, şeffaf */
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
+ }
+
+ .tooltip-icon {
+ font-size: 1.4em !important;
+ color: #C0C0C0 !important; /* Metalik gümüş */
+ cursor: help;
+ margin-left: 0.5rem !important;
+ }
+
+ .log-box {
+ font-family: 'Fira Code', monospace !important;
+ font-size: 0.85em !important;
+ background-color: rgba(128, 0, 0, 0.3) !important; /* Koyu kırmızı, şeffaf */
+ border: 1px solid #ff4040 !important; /* Kırmızı sınır */
+ border-radius: 8px;
+ padding: 1rem !important;
+ color: #C0C0C0 !important; /* Metalik gümüş metin */
}
/* Animasyonlar */
@keyframes text-glow {
- 0% { text-shadow: 0 0 5px rgba(212, 160, 23, 0.5); }
- 50% { text-shadow: 0 0 20px rgba(212, 160, 23, 1), 0 0 10px rgba(255, 64, 64, 0.7); }
- 100% { text-shadow: 0 0 5px rgba(212, 160, 23, 0.5); }
+ 0% { text-shadow: 0 0 5px rgba(192, 192, 192, 0); }
+ 50% { text-shadow: 0 0 15px rgba(192, 192, 192, 1); }
+ 100% { text-shadow: 0 0 5px rgba(192, 192, 192, 0); }
+ }
+
+ @keyframes button-shine {
+ 0% { transform: rotate(0deg) translateX(-50%); }
+ 100% { transform: rotate(360deg) translateX(-50%); }
+ }
+
+ /* Responsive Ayarlar */
+ @media (max-width: 768px) {
+ .compact-grid {
+ max-height: 40vh;
+ }
+
+ .compact-upload.horizontal {
+ max-width: 100% !important;
+ width: 100% !important;
+ }
+
+ .compact-upload.horizontal .text-gray-500 {
+ max-width: 100px !important;
+ }
+
+ .compact-upload.horizontal.x-narrow {
+ height: 40px !important;
+ padding: 0 8px !important;
+ }
+
+ .logo-container {
+ width: 80px; /* Mobil cihazlarda daha küçük logo */
+ top: 1rem;
+ left: 50%;
+ transform: translateX(-50%);
+ }
+
+ .header-text {
+ padding: 60px 20px 20px; /* Mobil için daha az boşluk */
+ font-size: 1.8rem; /* Mobil için biraz daha küçük başlık */
+ }
}
"""
# Arayüz tasarımı
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
- # Üst Panel: Logo ve Başlık
- with gr.Row():
- with gr.Column(scale=1, min_width=200):
- gr.HTML("""
-
-

-
- """)
- with gr.Column(scale=4):
- gr.HTML("""
-
- """)
+ with gr.Column():
+ # Logo (PNG olarak, dublaj temasına uygun)
+ logo_html = """
+
+

+
+ """
+ gr.HTML(logo_html)
+
+ # Başlık (Etkileyici ve dublaj temalı)
+ gr.HTML("""
+
+ """)
with gr.Tabs():
with gr.Tab("Audio Separation", elem_id="separation_tab"):
diff --git a/gui/favicon.ico b/gui/favicon.ico
new file mode 100644
index 0000000000000000000000000000000000000000..18b86d2d3713162a811233760a6debee0647d1b9
Binary files /dev/null and b/gui/favicon.ico differ
diff --git a/gui/mvsep.png b/gui/mvsep.png
new file mode 100644
index 0000000000000000000000000000000000000000..079229fafb6bab82c3529f49a55cd1b0b61206cf
Binary files /dev/null and b/gui/mvsep.png differ
diff --git a/gui/tutorial_screenshot.jpg b/gui/tutorial_screenshot.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f85eb51141196168b57f13629526c4bd44dbcbe9
Binary files /dev/null and b/gui/tutorial_screenshot.jpg differ
diff --git a/gui/wx_msst_screen.png b/gui/wx_msst_screen.png
new file mode 100644
index 0000000000000000000000000000000000000000..a8075e8ccbbf1f1a75c12e5c6aee5007eb57a263
Binary files /dev/null and b/gui/wx_msst_screen.png differ
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc70e2c39852d5af384b963d5bb0a6c9dc29e94
--- /dev/null
+++ b/main.py
@@ -0,0 +1,112 @@
+import os
+import threading
+import urllib.request
+import time
+import sys
+import random
+import argparse
+import time
+import librosa
+from tqdm.auto import tqdm
+import sys
+import os
+import glob
+import torch
+import soundfile as sf
+import torch.nn as nn
+from datetime import datetime
+import numpy as np
+import librosa
+import shutil
+from gui import create_interface
+from pyngrok import ngrok
+
+import warnings
+warnings.filterwarnings("ignore")
+
+def generate_random_port():
+ """Generates a random port between 1000 and 9000."""
+ return random.randint(1000, 9000)
+
+def start_gradio(port, share=False):
+ """Starts the Gradio interface with optional sharing."""
+ demo = create_interface()
+ demo.launch(
+ server_port=port,
+ server_name='0.0.0.0',
+ share=share,
+ allowed_paths=[os.path.join(os.path.expanduser("~"), "Music-Source-Separation", "input"), "/tmp", "/content"],
+ inline=False
+ )
+
+def start_localtunnel(port):
+ """Starts the Gradio interface with localtunnel sharing."""
+ print(f"Starting Localtunnel on port {port}...")
+ os.system('npm install -g localtunnel &>/dev/null')
+
+ with open('url.txt', 'w') as file:
+ file.write('')
+ os.system(f'lt --port {port} >> url.txt 2>&1 &')
+ time.sleep(2)
+
+ endpoint_ip = urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip("\n")
+ with open('url.txt', 'r') as file:
+ tunnel_url = file.read().replace("your url is: ", "").strip()
+
+ print(f"Share Link: {tunnel_url}")
+ print(f"Password IP: {endpoint_ip}")
+
+ start_gradio(port, share=False)
+
+def start_ngrok(port, ngrok_token):
+ """Starts the Gradio interface with ngrok sharing."""
+ print(f"Starting Ngrok on port {port}...")
+ try:
+ ngrok.set_auth_token(ngrok_token)
+ ngrok.kill()
+ tunnel = ngrok.connect(port)
+ print(f"Ngrok URL: {tunnel.public_url}")
+
+ start_gradio(port, share=False)
+ except Exception as e:
+ print(f"Error starting ngrok: {e}")
+ sys.exit(1)
+
+def main(method="gradio", port=None, ngrok_token=""):
+ """Main entry point for the application."""
+ # Portu otomatik belirle veya kullanıcıdan geleni kullan
+ port = port or generate_random_port()
+ print(f"Selected port: {port}")
+
+ # Paylaşım yöntemine göre işlem yap
+ if method == "gradio":
+ print("Starting Gradio with built-in sharing...")
+ start_gradio(port, share=True)
+ elif method == "localtunnel":
+ start_localtunnel(port)
+ elif method == "ngrok":
+ if not ngrok_token:
+ print("Error: Ngrok token is required for ngrok method!")
+ sys.exit(1)
+ start_ngrok(port, ngrok_token)
+ else:
+ print("Error: Invalid method! Use 'gradio', 'localtunnel', or 'ngrok'.")
+ sys.exit(1)
+
+ # Sürekli çalışır durumda tut (gerekirse)
+ try:
+ while True:
+ time.sleep(5)
+ except KeyboardInterrupt:
+ print("\n🛑 Process stopped by user")
+ sys.exit(0)
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description="Music Source Separation Web UI")
+ parser.add_argument("--method", type=str, default="gradio", choices=["gradio", "localtunnel", "ngrok"], help="Sharing method (default: gradio)")
+ parser.add_argument("--port", type=int, default=None, help="Server port (default: random between 1000-9000)")
+ parser.add_argument("--ngrok-token", type=str, default="", help="Ngrok authentication token (required for ngrok)")
+ args = parser.parse_args()
+
+ main(method=args.method, port=args.port, ngrok_token=args.ngrok_token)
diff --git a/metrics.py b/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..245e8f225bff2cc4472a76f43fbda9ec2a6f78da
--- /dev/null
+++ b/metrics.py
@@ -0,0 +1,421 @@
+import numpy as np
+import torch
+import librosa
+import torch.nn.functional as F
+from typing import Dict, List, Tuple
+
+def sdr(references: np.ndarray, estimates: np.ndarray) -> np.ndarray:
+ """
+ Compute Signal-to-Distortion Ratio (SDR) for one or more audio tracks.
+
+ SDR is a measure of how well the predicted source (estimate) matches the reference source.
+ It is calculated as the ratio of the energy of the reference signal to the energy of the error (difference between reference and estimate).
+ Return SDR in decibels (dB)
+ Parameters:
+ ----------
+ references : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples), where num_sources is the number of sources,
+ num_channels is the number of channels (e.g., 1 for mono, 2 for stereo), and num_samples is the length of the audio signal.
+
+ estimates : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples) representing the estimated sources.
+
+ Returns:
+ -------
+ np.ndarray
+ A 1D numpy array containing the SDR values for each source.
+ """
+ eps = 1e-8 # to avoid numerical errors
+ num = np.sum(np.square(references), axis=(1, 2))
+ den = np.sum(np.square(references - estimates), axis=(1, 2))
+ num += eps
+ den += eps
+ return 10 * np.log10(num / den)
+
+
+def si_sdr(reference: np.ndarray, estimate: np.ndarray) -> float:
+ """
+ Compute Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) for one or more audio tracks.
+
+ SI-SDR is a variant of the SDR metric that is invariant to the scaling of the estimate relative to the reference.
+ It is calculated by scaling the estimate to match the reference signal and then computing the SDR.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples), where num_sources is the number of sources,
+ num_channels is the number of channels (e.g., 1 for mono, 2 for stereo), and num_samples is the length of the audio signal.
+
+ estimate : np.ndarray
+ A 3D numpy array of shape (num_sources, num_channels, num_samples) representing the estimated sources.
+
+ Returns:
+ -------
+ float
+ The SI-SDR value for the source. It is a scalar representing the Signal-to-Distortion Ratio in decibels (dB).
+ """
+ eps = 1e-8 # To avoid numerical errors
+ scale = np.sum(estimate * reference + eps, axis=(0, 1)) / np.sum(reference ** 2 + eps, axis=(0, 1))
+ scale = np.expand_dims(scale, axis=(0, 1)) # Reshape to [num_sources, 1]
+
+ reference = reference * scale
+ si_sdr = np.mean(10 * np.log10(
+ np.sum(reference ** 2, axis=(0, 1)) / (np.sum((reference - estimate) ** 2, axis=(0, 1)) + eps) + eps))
+
+ return si_sdr
+
+
+def L1Freq_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ fft_size: int = 2048,
+ hop_size: int = 1024,
+ device: str = 'cpu'
+) -> float:
+ """
+ Compute the L1 Frequency Metric between the reference and estimated audio signals.
+
+ This metric compares the magnitude spectrograms of the reference and estimated audio signals
+ using the Short-Time Fourier Transform (STFT) and calculates the L1 loss between them. The result
+ is scaled to the range [0, 100] where a higher value indicates better performance.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ A 2D numpy array of shape (num_channels, num_samples) representing the reference (ground truth) audio signal.
+
+ estimate : np.ndarray
+ A 2D numpy array of shape (num_channels, num_samples) representing the estimated (predicted) audio signal.
+
+ fft_size : int, optional
+ The size of the FFT (Short-Time Fourier Transform). Default is 2048.
+
+ hop_size : int, optional
+ The hop size between STFT frames. Default is 1024.
+
+ device : str, optional
+ The device to run the computation on ('cpu' or 'cuda'). Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The L1 Frequency Metric in the range [0, 100], where higher values indicate better performance.
+ """
+
+ reference = torch.from_numpy(reference).to(device)
+ estimate = torch.from_numpy(estimate).to(device)
+
+ reference_stft = torch.stft(reference, fft_size, hop_size, return_complex=True)
+ estimated_stft = torch.stft(estimate, fft_size, hop_size, return_complex=True)
+
+ reference_mag = torch.abs(reference_stft)
+ estimate_mag = torch.abs(estimated_stft)
+
+ loss = 10 * F.l1_loss(estimate_mag, reference_mag)
+
+ ret = 100 / (1. + float(loss.cpu().numpy()))
+
+ return ret
+
+
+def LogWMSE_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ mixture: np.ndarray,
+ device: str = 'cpu',
+) -> float:
+ """
+ Calculate the Log-WMSE (Logarithmic Weighted Mean Squared Error) between the reference, estimate, and mixture signals.
+
+ This metric evaluates the quality of the estimated signal compared to the reference signal in the
+ context of audio source separation. The result is given in logarithmic scale, which helps in evaluating
+ signals with large amplitude differences.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The ground truth audio signal of shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal of shape (channels, time).
+
+ mixture : np.ndarray
+ The mixed audio signal of shape (channels, time).
+
+ device : str, optional
+ The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The Log-WMSE value, which quantifies the difference between the reference and estimated signal on a logarithmic scale.
+ """
+ from torch_log_wmse import LogWMSE
+ log_wmse = LogWMSE(
+ audio_length=reference.shape[-1] / 44100, # audio length in seconds
+ sample_rate=44100, # sample rate of 44100 Hz
+ return_as_loss=False, # return as loss (False means return as metric)
+ bypass_filter=False, # bypass frequency filtering (False means apply filter)
+ )
+
+ reference = torch.from_numpy(reference).unsqueeze(0).unsqueeze(0).to(device)
+ estimate = torch.from_numpy(estimate).unsqueeze(0).unsqueeze(0).to(device)
+ mixture = torch.from_numpy(mixture).unsqueeze(0).to(device)
+
+ res = log_wmse(mixture, reference, estimate)
+ return float(res.cpu().numpy())
+
+
+def AuraSTFT_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ device: str = 'cpu',
+) -> float:
+ """
+ Calculate the AuraSTFT metric, which evaluates the spectral difference between the reference and estimated
+ audio signals using Short-Time Fourier Transform (STFT) loss.
+
+ The AuraSTFT metric computes the STFT loss in both logarithmic and linear magnitudes, and it is commonly used
+ to assess the quality of audio separation tasks. The result is returned as a value scaled to the range [0, 100].
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The ground truth audio signal of shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal of shape (channels, time).
+
+ device : str, optional
+ The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The AuraSTFT metric value, scaled to the range [0, 100], which quantifies the difference between
+ the reference and estimated signal in the spectral domain.
+ """
+
+ from auraloss.freq import STFTLoss
+
+ stft_loss = STFTLoss(
+ w_log_mag=1.0, # weight for log magnitude
+ w_lin_mag=0.0, # weight for linear magnitude
+ w_sc=1.0, # weight for spectral centroid
+ device=device,
+ )
+
+ reference = torch.from_numpy(reference).unsqueeze(0).to(device)
+ estimate = torch.from_numpy(estimate).unsqueeze(0).to(device)
+
+ res = 100 / (1. + 10 * stft_loss(reference, estimate))
+ return float(res.cpu().numpy())
+
+
+def AuraMRSTFT_metric(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ device: str = 'cpu',
+) -> float:
+ """
+ Calculate the AuraMRSTFT metric, which evaluates the spectral difference between the reference and estimated
+ audio signals using Multi-Resolution Short-Time Fourier Transform (STFT) loss.
+
+ The AuraMRSTFT metric uses multi-resolution STFT analysis, which allows better representation of both
+ low- and high-frequency components in the audio signals. The result is returned as a value scaled to the range [0, 100].
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The ground truth audio signal of shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal of shape (channels, time).
+
+ device : str, optional
+ The device to run the computation on, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ float
+ The AuraMRSTFT metric value, scaled to the range [0, 100], which quantifies the difference between
+ the reference and estimated signal in the multi-resolution spectral domain.
+ """
+
+ from auraloss.freq import MultiResolutionSTFTLoss
+
+ mrstft_loss = MultiResolutionSTFTLoss(
+ fft_sizes=[1024, 2048, 4096],
+ hop_sizes=[256, 512, 1024],
+ win_lengths=[1024, 2048, 4096],
+ scale="mel", # mel scale for frequency resolution
+ n_bins=128, # number of bins for mel scale
+ sample_rate=44100,
+ perceptual_weighting=True, # apply perceptual weighting
+ device=device
+ )
+
+ reference = torch.from_numpy(reference).unsqueeze(0).float().to(device)
+ estimate = torch.from_numpy(estimate).unsqueeze(0).float().to(device)
+
+ res = 100 / (1. + 10 * mrstft_loss(reference, estimate))
+ return float(res.cpu().numpy())
+
+
+def bleed_full(
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ sr: int = 44100,
+ n_fft: int = 4096,
+ hop_length: int = 1024,
+ n_mels: int = 512,
+ device: str = 'cpu',
+) -> Tuple[float, float]:
+ """
+ Calculate the 'bleed' and 'fullness' metrics between a reference and an estimated audio signal.
+
+ The 'bleed' metric measures how much the estimated signal bleeds into the reference signal,
+ while the 'fullness' metric measures how much the estimated signal retains its distinctiveness
+ in relation to the reference signal, both using mel spectrograms and decibel scaling.
+
+ Parameters:
+ ----------
+ reference : np.ndarray
+ The reference audio signal, shape (channels, time), where channels is the number of audio channels
+ (e.g., 1 for mono, 2 for stereo) and time is the length of the audio in samples.
+
+ estimate : np.ndarray
+ The estimated audio signal, shape (channels, time).
+
+ sr : int, optional
+ The sample rate of the audio signals. Default is 44100 Hz.
+
+ n_fft : int, optional
+ The FFT size used to compute the STFT. Default is 4096.
+
+ hop_length : int, optional
+ The hop length for STFT computation. Default is 1024.
+
+ n_mels : int, optional
+ The number of mel frequency bins. Default is 512.
+
+ device : str, optional
+ The device for computation, either 'cpu' or 'cuda'. Default is 'cpu'.
+
+ Returns:
+ -------
+ tuple
+ A tuple containing two values:
+ - `bleedless` (float): A score indicating how much 'bleeding' the estimated signal has (higher is better).
+ - `fullness` (float): A score indicating how 'full' the estimated signal is (higher is better).
+ """
+
+ from torchaudio.transforms import AmplitudeToDB
+
+ reference = torch.from_numpy(reference).float().to(device)
+ estimate = torch.from_numpy(estimate).float().to(device)
+
+ window = torch.hann_window(n_fft).to(device)
+
+ # Compute STFTs with the Hann window
+ D1 = torch.abs(torch.stft(reference, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True,
+ pad_mode="constant"))
+ D2 = torch.abs(torch.stft(estimate, n_fft=n_fft, hop_length=hop_length, window=window, return_complex=True,
+ pad_mode="constant"))
+
+ mel_basis = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
+ mel_filter_bank = torch.from_numpy(mel_basis).to(device)
+
+ S1_mel = torch.matmul(mel_filter_bank, D1)
+ S2_mel = torch.matmul(mel_filter_bank, D2)
+
+ S1_db = AmplitudeToDB(stype="magnitude", top_db=80)(S1_mel)
+ S2_db = AmplitudeToDB(stype="magnitude", top_db=80)(S2_mel)
+
+ diff = S2_db - S1_db
+
+ positive_diff = diff[diff > 0]
+ negative_diff = diff[diff < 0]
+
+ average_positive = torch.mean(positive_diff) if positive_diff.numel() > 0 else torch.tensor(0.0).to(device)
+ average_negative = torch.mean(negative_diff) if negative_diff.numel() > 0 else torch.tensor(0.0).to(device)
+
+ bleedless = 100 * 1 / (average_positive + 1)
+ fullness = 100 * 1 / (-average_negative + 1)
+
+ return bleedless.cpu().numpy(), fullness.cpu().numpy()
+
+
+def get_metrics(
+ metrics: List[str],
+ reference: np.ndarray,
+ estimate: np.ndarray,
+ mix: np.ndarray,
+ device: str = 'cpu',
+) -> Dict[str, float]:
+ """
+ Calculate a list of metrics to evaluate the performance of audio source separation models.
+
+ The function computes the specified metrics based on the reference, estimate, and mixture.
+
+ Parameters:
+ ----------
+ metrics : List[str]
+ A list of metric names to compute (e.g., ['sdr', 'si_sdr', 'l1_freq']).
+
+ reference : np.ndarray
+ The reference audio (true signal) with shape (channels, length).
+
+ estimate : np.ndarray
+ The estimated audio (predicted signal) with shape (channels, length).
+
+ mix : np.ndarray
+ The mixed audio signal with shape (channels, length).
+
+ device : str, optional, default='cpu'
+ The device ('cpu' or 'cuda') to perform the calculations on.
+
+ Returns:
+ -------
+ Dict[str, float]
+ A dictionary containing the computed metric values.
+ """
+ result = dict()
+
+ # Adjust the length to be the same across all inputs
+ min_length = min(reference.shape[1], estimate.shape[1])
+ reference = reference[..., :min_length]
+ estimate = estimate[..., :min_length]
+ mix = mix[..., :min_length]
+
+ if 'sdr' in metrics:
+ references = np.expand_dims(reference, axis=0)
+ estimates = np.expand_dims(estimate, axis=0)
+ result['sdr'] = sdr(references, estimates)[0]
+
+ if 'si_sdr' in metrics:
+ result['si_sdr'] = si_sdr(reference, estimate)
+
+ if 'l1_freq' in metrics:
+ result['l1_freq'] = L1Freq_metric(reference, estimate, device=device)
+
+ if 'log_wmse' in metrics:
+ result['log_wmse'] = LogWMSE_metric(reference, estimate, mix, device)
+
+ if 'aura_stft' in metrics:
+ result['aura_stft'] = AuraSTFT_metric(reference, estimate, device)
+
+ if 'aura_mrstft' in metrics:
+ result['aura_mrstft'] = AuraMRSTFT_metric(reference, estimate, device)
+
+ if 'bleedless' in metrics or 'fullness' in metrics:
+ bleedless, fullness = bleed_full(reference, estimate, device=device)
+ if 'bleedless' in metrics:
+ result['bleedless'] = bleedless
+ if 'fullness' in metrics:
+ result['fullness'] = fullness
+
+ return result
diff --git a/models/bandit/core/__init__.py b/models/bandit/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4d6d7953709c2f86a6b484e49c7715b58bbe86a
--- /dev/null
+++ b/models/bandit/core/__init__.py
@@ -0,0 +1,744 @@
+import os.path
+from collections import defaultdict
+from itertools import chain, combinations
+from typing import (
+ Any,
+ Dict,
+ Iterator,
+ Mapping, Optional,
+ Tuple, Type,
+ TypedDict
+)
+
+import pytorch_lightning as pl
+import torch
+import torchaudio as ta
+import torchmetrics as tm
+from asteroid import losses as asteroid_losses
+# from deepspeed.ops.adam import DeepSpeedCPUAdam
+# from geoopt import optim as gooptim
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+from torch import nn, optim
+from torch.optim import lr_scheduler
+from torch.optim.lr_scheduler import LRScheduler
+
+from models.bandit.core import loss, metrics as metrics_, model
+from models.bandit.core.data._types import BatchedDataDict
+from models.bandit.core.data.augmentation import BaseAugmentor, StemAugmentor
+from models.bandit.core.utils import audio as audio_
+from models.bandit.core.utils.audio import BaseFader
+
+# from pandas.io.json._normalize import nested_to_record
+
+ConfigDict = TypedDict('ConfigDict', {'name': str, 'kwargs': Dict[str, Any]})
+
+
+class SchedulerConfigDict(ConfigDict):
+ monitor: str
+
+
+OptimizerSchedulerConfigDict = TypedDict(
+ 'OptimizerSchedulerConfigDict',
+ {"optimizer": ConfigDict, "scheduler": SchedulerConfigDict},
+ total=False
+)
+
+
+class LRSchedulerReturnDict(TypedDict, total=False):
+ scheduler: LRScheduler
+ monitor: str
+
+
+class ConfigureOptimizerReturnDict(TypedDict, total=False):
+ optimizer: torch.optim.Optimizer
+ lr_scheduler: LRSchedulerReturnDict
+
+
+OutputType = Dict[str, Any]
+MetricsType = Dict[str, torch.Tensor]
+
+
+def get_optimizer_class(name: str) -> Type[optim.Optimizer]:
+
+ if name == "DeepSpeedCPUAdam":
+ return DeepSpeedCPUAdam
+
+ for module in [optim, gooptim]:
+ if name in module.__dict__:
+ return module.__dict__[name]
+
+ raise NameError
+
+
+def parse_optimizer_config(
+ config: OptimizerSchedulerConfigDict,
+ parameters: Iterator[nn.Parameter]
+) -> ConfigureOptimizerReturnDict:
+ optim_class = get_optimizer_class(config["optimizer"]["name"])
+ optimizer = optim_class(parameters, **config["optimizer"]["kwargs"])
+
+ optim_dict: ConfigureOptimizerReturnDict = {
+ "optimizer": optimizer,
+ }
+
+ if "scheduler" in config:
+
+ lr_scheduler_class_ = config["scheduler"]["name"]
+ lr_scheduler_class = lr_scheduler.__dict__[lr_scheduler_class_]
+ lr_scheduler_dict: LRSchedulerReturnDict = {
+ "scheduler": lr_scheduler_class(
+ optimizer,
+ **config["scheduler"]["kwargs"]
+ )
+ }
+
+ if lr_scheduler_class_ == "ReduceLROnPlateau":
+ lr_scheduler_dict["monitor"] = config["scheduler"]["monitor"]
+
+ optim_dict["lr_scheduler"] = lr_scheduler_dict
+
+ return optim_dict
+
+
+def parse_model_config(config: ConfigDict) -> Any:
+ name = config["name"]
+
+ for module in [model]:
+ if name in module.__dict__:
+ return module.__dict__[name](**config["kwargs"])
+
+ raise NameError
+
+
+_LEGACY_LOSS_NAMES = ["HybridL1Loss"]
+
+
+def _parse_legacy_loss_config(config: ConfigDict) -> nn.Module:
+ name = config["name"]
+
+ if name == "HybridL1Loss":
+ return loss.TimeFreqL1Loss(**config["kwargs"])
+
+ raise NameError
+
+
+def parse_loss_config(config: ConfigDict) -> nn.Module:
+ name = config["name"]
+
+ if name in _LEGACY_LOSS_NAMES:
+ return _parse_legacy_loss_config(config)
+
+ for module in [loss, nn.modules.loss, asteroid_losses]:
+ if name in module.__dict__:
+ # print(config["kwargs"])
+ return module.__dict__[name](**config["kwargs"])
+
+ raise NameError
+
+
+def get_metric(config: ConfigDict) -> tm.Metric:
+ name = config["name"]
+
+ for module in [tm, metrics_]:
+ if name in module.__dict__:
+ return module.__dict__[name](**config["kwargs"])
+ raise NameError
+
+
+def parse_metric_config(config: Dict[str, ConfigDict]) -> tm.MetricCollection:
+ metrics = {}
+
+ for metric in config:
+ metrics[metric] = get_metric(config[metric])
+
+ return tm.MetricCollection(metrics)
+
+
+def parse_fader_config(config: ConfigDict) -> BaseFader:
+ name = config["name"]
+
+ for module in [audio_]:
+ if name in module.__dict__:
+ return module.__dict__[name](**config["kwargs"])
+
+ raise NameError
+
+
+class LightningSystem(pl.LightningModule):
+ _VOX_STEMS = ["speech", "vocals"]
+ _BG_STEMS = ["background", "effects", "mne"]
+
+ def __init__(
+ self,
+ config: Dict,
+ loss_adjustment: float = 1.0,
+ attach_fader: bool = False
+ ) -> None:
+ super().__init__()
+ self.optimizer_config = config["optimizer"]
+ self.model = parse_model_config(config["model"])
+ self.loss = parse_loss_config(config["loss"])
+ self.metrics = nn.ModuleDict(
+ {
+ stem: parse_metric_config(config["metrics"]["dev"])
+ for stem in self.model.stems
+ }
+ )
+
+ self.metrics.disallow_fsdp = True
+
+ self.test_metrics = nn.ModuleDict(
+ {
+ stem: parse_metric_config(config["metrics"]["test"])
+ for stem in self.model.stems
+ }
+ )
+
+ self.test_metrics.disallow_fsdp = True
+
+ self.fs = config["model"]["kwargs"]["fs"]
+
+ self.fader_config = config["inference"]["fader"]
+ if attach_fader:
+ self.fader = parse_fader_config(config["inference"]["fader"])
+ else:
+ self.fader = None
+
+ self.augmentation: Optional[BaseAugmentor]
+ if config.get("augmentation", None) is not None:
+ self.augmentation = StemAugmentor(**config["augmentation"])
+ else:
+ self.augmentation = None
+
+ self.predict_output_path: Optional[str] = None
+ self.loss_adjustment = loss_adjustment
+
+ self.val_prefix = None
+ self.test_prefix = None
+
+
+ def configure_optimizers(self) -> Any:
+ return parse_optimizer_config(
+ self.optimizer_config,
+ self.trainer.model.parameters()
+ )
+
+ def compute_loss(self, batch: BatchedDataDict, output: OutputType) -> Dict[
+ str, torch.Tensor]:
+ return {"loss": self.loss(output, batch)}
+
+ def update_metrics(
+ self,
+ batch: BatchedDataDict,
+ output: OutputType,
+ mode: str
+ ) -> None:
+
+ if mode == "test":
+ metrics = self.test_metrics
+ else:
+ metrics = self.metrics
+
+ for stem, metric in metrics.items():
+
+ if stem == "mne:+":
+ stem = "mne"
+
+ # print(f"matching for {stem}")
+ if mode == "train":
+ metric.update(
+ output["audio"][stem],#.cpu(),
+ batch["audio"][stem],#.cpu()
+ )
+ else:
+ if stem not in batch["audio"]:
+ matched = False
+ if stem in self._VOX_STEMS:
+ for bstem in self._VOX_STEMS:
+ if bstem in batch["audio"]:
+ batch["audio"][stem] = batch["audio"][bstem]
+ matched = True
+ break
+ elif stem in self._BG_STEMS:
+ for bstem in self._BG_STEMS:
+ if bstem in batch["audio"]:
+ batch["audio"][stem] = batch["audio"][bstem]
+ matched = True
+ break
+ else:
+ matched = True
+
+ # print(batch["audio"].keys())
+
+ if matched:
+ # print(f"matched {stem}!")
+ if stem == "mne" and "mne" not in output["audio"]:
+ output["audio"]["mne"] = output["audio"]["music"] + output["audio"]["effects"]
+
+ metric.update(
+ output["audio"][stem],#.cpu(),
+ batch["audio"][stem],#.cpu(),
+ )
+
+ # print(metric.compute())
+ def compute_metrics(self, mode: str="dev") -> Dict[
+ str, torch.Tensor]:
+
+ if mode == "test":
+ metrics = self.test_metrics
+ else:
+ metrics = self.metrics
+
+ metric_dict = {}
+
+ for stem, metric in metrics.items():
+ md = metric.compute()
+ metric_dict.update(
+ {f"{stem}/{k}": v for k, v in md.items()}
+ )
+
+ self.log_dict(metric_dict, prog_bar=True, logger=False)
+
+ return metric_dict
+
+ def reset_metrics(self, test_mode: bool = False) -> None:
+
+ if test_mode:
+ metrics = self.test_metrics
+ else:
+ metrics = self.metrics
+
+ for _, metric in metrics.items():
+ metric.reset()
+
+
+ def forward(self, batch: BatchedDataDict) -> Any:
+ batch, output = self.model(batch)
+
+
+ return batch, output
+
+ def common_step(self, batch: BatchedDataDict, mode: str) -> Any:
+ batch, output = self.forward(batch)
+ # print(batch)
+ # print(output)
+ loss_dict = self.compute_loss(batch, output)
+
+ with torch.no_grad():
+ self.update_metrics(batch, output, mode=mode)
+
+ if mode == "train":
+ self.log("loss", loss_dict["loss"], prog_bar=True)
+
+ return output, loss_dict
+
+
+ def training_step(self, batch: BatchedDataDict) -> Dict[str, Any]:
+
+ if self.augmentation is not None:
+ with torch.no_grad():
+ batch = self.augmentation(batch)
+
+ _, loss_dict = self.common_step(batch, mode="train")
+
+ with torch.inference_mode():
+ self.log_dict_with_prefix(
+ loss_dict,
+ "train",
+ batch_size=batch["audio"]["mixture"].shape[0]
+ )
+
+ loss_dict["loss"] *= self.loss_adjustment
+
+ return loss_dict
+
+ def on_train_batch_end(
+ self, outputs: STEP_OUTPUT, batch: BatchedDataDict, batch_idx: int
+ ) -> None:
+
+ metric_dict = self.compute_metrics()
+ self.log_dict_with_prefix(metric_dict, "train")
+ self.reset_metrics()
+
+ def validation_step(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int,
+ dataloader_idx: int = 0
+ ) -> Dict[str, Any]:
+
+ with torch.inference_mode():
+ curr_val_prefix = f"val{dataloader_idx}" if dataloader_idx > 0 else "val"
+
+ if curr_val_prefix != self.val_prefix:
+ # print(f"Switching to validation dataloader {dataloader_idx}")
+ if self.val_prefix is not None:
+ self._on_validation_epoch_end()
+ self.val_prefix = curr_val_prefix
+ _, loss_dict = self.common_step(batch, mode="val")
+
+ self.log_dict_with_prefix(
+ loss_dict,
+ self.val_prefix,
+ batch_size=batch["audio"]["mixture"].shape[0],
+ prog_bar=True,
+ add_dataloader_idx=False
+ )
+
+ return loss_dict
+
+ def on_validation_epoch_end(self) -> None:
+ self._on_validation_epoch_end()
+
+ def _on_validation_epoch_end(self) -> None:
+ metric_dict = self.compute_metrics()
+ self.log_dict_with_prefix(metric_dict, self.val_prefix, prog_bar=True,
+ add_dataloader_idx=False)
+ # self.logger.save()
+ # print(self.val_prefix, "Validation metrics:", metric_dict)
+ self.reset_metrics()
+
+
+ def old_predtest_step(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int,
+ dataloader_idx: int = 0
+ ) -> Tuple[BatchedDataDict, OutputType]:
+
+ audio_batch = batch["audio"]["mixture"]
+ track_batch = batch.get("track", ["" for _ in range(len(audio_batch))])
+
+ output_list_of_dicts = [
+ self.fader(
+ audio[None, ...],
+ lambda a: self.test_forward(a, track)
+ )
+ for audio, track in zip(audio_batch, track_batch)
+ ]
+
+ output_dict_of_lists = defaultdict(list)
+
+ for output_dict in output_list_of_dicts:
+ for stem, audio in output_dict.items():
+ output_dict_of_lists[stem].append(audio)
+
+ output = {
+ "audio": {
+ stem: torch.concat(output_list, dim=0)
+ for stem, output_list in output_dict_of_lists.items()
+ }
+ }
+
+ return batch, output
+
+ def predtest_step(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = -1,
+ dataloader_idx: int = 0
+ ) -> Tuple[BatchedDataDict, OutputType]:
+
+ if getattr(self.model, "bypass_fader", False):
+ batch, output = self.model(batch)
+ else:
+ audio_batch = batch["audio"]["mixture"]
+ output = self.fader(
+ audio_batch,
+ lambda a: self.test_forward(a, "", batch=batch)
+ )
+
+ return batch, output
+
+ def test_forward(
+ self,
+ audio: torch.Tensor,
+ track: str = "",
+ batch: BatchedDataDict = None
+ ) -> torch.Tensor:
+
+ if self.fader is None:
+ self.attach_fader()
+
+ cond = batch.get("condition", None)
+
+ if cond is not None and cond.shape[0] == 1:
+ cond = cond.repeat(audio.shape[0], 1)
+
+ _, output = self.forward(
+ {"audio": {"mixture": audio},
+ "track": track,
+ "condition": cond,
+ }
+ ) # TODO: support track properly
+
+ return output["audio"]
+
+ def on_test_epoch_start(self) -> None:
+ self.attach_fader(force_reattach=True)
+
+ def test_step(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int,
+ dataloader_idx: int = 0
+ ) -> Any:
+ curr_test_prefix = f"test{dataloader_idx}"
+
+ # print(batch["audio"].keys())
+
+ if curr_test_prefix != self.test_prefix:
+ # print(f"Switching to test dataloader {dataloader_idx}")
+ if self.test_prefix is not None:
+ self._on_test_epoch_end()
+ self.test_prefix = curr_test_prefix
+
+ with torch.inference_mode():
+ _, output = self.predtest_step(batch, batch_idx, dataloader_idx)
+ # print(output)
+ self.update_metrics(batch, output, mode="test")
+
+ return output
+
+ def on_test_epoch_end(self) -> None:
+ self._on_test_epoch_end()
+
+ def _on_test_epoch_end(self) -> None:
+ metric_dict = self.compute_metrics(mode="test")
+ self.log_dict_with_prefix(metric_dict, self.test_prefix, prog_bar=True,
+ add_dataloader_idx=False)
+ # self.logger.save()
+ # print(self.test_prefix, "Test metrics:", metric_dict)
+ self.reset_metrics()
+
+ def predict_step(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = 0,
+ dataloader_idx: int = 0,
+ include_track_name: Optional[bool] = None,
+ get_no_vox_combinations: bool = True,
+ get_residual: bool = False,
+ treat_batch_as_channels: bool = False,
+ fs: Optional[int] = None,
+ ) -> Any:
+ assert self.predict_output_path is not None
+
+ batch_size = batch["audio"]["mixture"].shape[0]
+
+ if include_track_name is None:
+ include_track_name = batch_size > 1
+
+ with torch.inference_mode():
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
+ print('Pred test finished...')
+ torch.cuda.empty_cache()
+ metric_dict = {}
+
+ if get_residual:
+ mixture = batch["audio"]["mixture"]
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
+ residual = mixture - extracted
+ print(extracted.shape, mixture.shape, residual.shape)
+
+ output["audio"]["residual"] = residual
+
+ if get_no_vox_combinations:
+ no_vox_stems = [
+ stem for stem in output["audio"] if
+ stem not in self._VOX_STEMS
+ ]
+ no_vox_combinations = chain.from_iterable(
+ combinations(no_vox_stems, r) for r in
+ range(2, len(no_vox_stems) + 1)
+ )
+
+ for combination in no_vox_combinations:
+ combination_ = list(combination)
+ output["audio"]["+".join(combination_)] = sum(
+ [output["audio"][stem] for stem in combination_]
+ )
+
+ if treat_batch_as_channels:
+ for stem in output["audio"]:
+ output["audio"][stem] = output["audio"][stem].reshape(
+ 1, -1, output["audio"][stem].shape[-1]
+ )
+ batch_size = 1
+
+ for b in range(batch_size):
+ print("!!", b)
+ for stem in output["audio"]:
+ print(f"Saving audio for {stem} to {self.predict_output_path}")
+ track_name = batch["track"][b].split("/")[-1]
+
+ if batch.get("audio", {}).get(stem, None) is not None:
+ self.test_metrics[stem].reset()
+ metrics = self.test_metrics[stem](
+ batch["audio"][stem][[b], ...],
+ output["audio"][stem][[b], ...]
+ )
+ snr = metrics["snr"]
+ sisnr = metrics["sisnr"]
+ sdr = metrics["sdr"]
+ metric_dict[stem] = metrics
+ print(
+ track_name,
+ f"snr={snr:2.2f} dB",
+ f"sisnr={sisnr:2.2f}",
+ f"sdr={sdr:2.2f} dB",
+ )
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
+ else:
+ filename = f"{stem}.wav"
+
+ if include_track_name:
+ output_dir = os.path.join(
+ self.predict_output_path,
+ track_name
+ )
+ else:
+ output_dir = self.predict_output_path
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ if fs is None:
+ fs = self.fs
+
+ ta.save(
+ os.path.join(output_dir, filename),
+ output["audio"][stem][b, ...].cpu(),
+ fs,
+ )
+
+ return metric_dict
+
+ def get_stems(
+ self,
+ batch: BatchedDataDict,
+ batch_idx: int = 0,
+ dataloader_idx: int = 0,
+ include_track_name: Optional[bool] = None,
+ get_no_vox_combinations: bool = True,
+ get_residual: bool = False,
+ treat_batch_as_channels: bool = False,
+ fs: Optional[int] = None,
+ ) -> Any:
+ assert self.predict_output_path is not None
+
+ batch_size = batch["audio"]["mixture"].shape[0]
+
+ if include_track_name is None:
+ include_track_name = batch_size > 1
+
+ with torch.inference_mode():
+ batch, output = self.predtest_step(batch, batch_idx, dataloader_idx)
+ torch.cuda.empty_cache()
+ metric_dict = {}
+
+ if get_residual:
+ mixture = batch["audio"]["mixture"]
+ extracted = sum([output["audio"][stem] for stem in output["audio"]])
+ residual = mixture - extracted
+ # print(extracted.shape, mixture.shape, residual.shape)
+
+ output["audio"]["residual"] = residual
+
+ if get_no_vox_combinations:
+ no_vox_stems = [
+ stem for stem in output["audio"] if
+ stem not in self._VOX_STEMS
+ ]
+ no_vox_combinations = chain.from_iterable(
+ combinations(no_vox_stems, r) for r in
+ range(2, len(no_vox_stems) + 1)
+ )
+
+ for combination in no_vox_combinations:
+ combination_ = list(combination)
+ output["audio"]["+".join(combination_)] = sum(
+ [output["audio"][stem] for stem in combination_]
+ )
+
+ if treat_batch_as_channels:
+ for stem in output["audio"]:
+ output["audio"][stem] = output["audio"][stem].reshape(
+ 1, -1, output["audio"][stem].shape[-1]
+ )
+ batch_size = 1
+
+ result = {}
+ for b in range(batch_size):
+ for stem in output["audio"]:
+ track_name = batch["track"][b].split("/")[-1]
+
+ if batch.get("audio", {}).get(stem, None) is not None:
+ self.test_metrics[stem].reset()
+ metrics = self.test_metrics[stem](
+ batch["audio"][stem][[b], ...],
+ output["audio"][stem][[b], ...]
+ )
+ snr = metrics["snr"]
+ sisnr = metrics["sisnr"]
+ sdr = metrics["sdr"]
+ metric_dict[stem] = metrics
+ print(
+ track_name,
+ f"snr={snr:2.2f} dB",
+ f"sisnr={sisnr:2.2f}",
+ f"sdr={sdr:2.2f} dB",
+ )
+ filename = f"{stem} - snr={snr:2.2f}dB - sdr={sdr:2.2f}dB.wav"
+ else:
+ filename = f"{stem}.wav"
+
+ if include_track_name:
+ output_dir = os.path.join(
+ self.predict_output_path,
+ track_name
+ )
+ else:
+ output_dir = self.predict_output_path
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ if fs is None:
+ fs = self.fs
+
+ result[stem] = output["audio"][stem][b, ...].cpu().numpy()
+
+ return result
+
+ def load_state_dict(
+ self, state_dict: Mapping[str, Any], strict: bool = False
+ ) -> Any:
+
+ return super().load_state_dict(state_dict, strict=False)
+
+
+ def set_predict_output_path(self, path: str) -> None:
+ self.predict_output_path = path
+ os.makedirs(self.predict_output_path, exist_ok=True)
+
+ self.attach_fader()
+
+ def attach_fader(self, force_reattach=False) -> None:
+ if self.fader is None or force_reattach:
+ self.fader = parse_fader_config(self.fader_config)
+ self.fader.to(self.device)
+
+
+ def log_dict_with_prefix(
+ self,
+ dict_: Dict[str, torch.Tensor],
+ prefix: str,
+ batch_size: Optional[int] = None,
+ **kwargs: Any
+ ) -> None:
+ self.log_dict(
+ {f"{prefix}/{k}": v for k, v in dict_.items()},
+ batch_size=batch_size,
+ logger=True,
+ sync_dist=True,
+ **kwargs,
+ )
\ No newline at end of file
diff --git a/models/bandit/core/data/__init__.py b/models/bandit/core/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1087fe2c4d7d3048295cdf73c0725a015bc0d129
--- /dev/null
+++ b/models/bandit/core/data/__init__.py
@@ -0,0 +1,2 @@
+from .dnr.datamodule import DivideAndRemasterDataModule
+from .musdb.datamodule import MUSDB18DataModule
\ No newline at end of file
diff --git a/models/bandit/core/data/_types.py b/models/bandit/core/data/_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..9499f9a80b5dec7b5b0e7882849e4f7b2c801ccf
--- /dev/null
+++ b/models/bandit/core/data/_types.py
@@ -0,0 +1,18 @@
+from typing import Dict, Sequence, TypedDict
+
+import torch
+
+AudioDict = Dict[str, torch.Tensor]
+
+DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str})
+
+BatchedDataDict = TypedDict(
+ 'BatchedDataDict',
+ {'audio': AudioDict, 'track': Sequence[str]}
+)
+
+
+class DataDictWithLanguage(TypedDict):
+ audio: AudioDict
+ track: str
+ language: str
diff --git a/models/bandit/core/data/augmentation.py b/models/bandit/core/data/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..238214bf17a69e71f48e8761e1ead05b17d0fa5a
--- /dev/null
+++ b/models/bandit/core/data/augmentation.py
@@ -0,0 +1,107 @@
+from abc import ABC
+from typing import Any, Dict, Union
+
+import torch
+import torch_audiomentations as tam
+from torch import nn
+
+from models.bandit.core.data._types import BatchedDataDict, DataDict
+
+
+class BaseAugmentor(nn.Module, ABC):
+ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
+ DataDict, BatchedDataDict]:
+ raise NotImplementedError
+
+
+class StemAugmentor(BaseAugmentor):
+ def __init__(
+ self,
+ audiomentations: Dict[str, Dict[str, Any]],
+ fix_clipping: bool = True,
+ scaler_margin: float = 0.5,
+ apply_both_default_and_common: bool = False,
+ ) -> None:
+ super().__init__()
+
+ augmentations = {}
+
+ self.has_default = "[default]" in audiomentations
+ self.has_common = "[common]" in audiomentations
+ self.apply_both_default_and_common = apply_both_default_and_common
+
+ for stem in audiomentations:
+ if audiomentations[stem]["name"] == "Compose":
+ augmentations[stem] = getattr(
+ tam,
+ audiomentations[stem]["name"]
+ )(
+ [
+ getattr(tam, aug["name"])(**aug["kwargs"])
+ for aug in
+ audiomentations[stem]["kwargs"]["transforms"]
+ ],
+ **audiomentations[stem]["kwargs"]["kwargs"],
+ )
+ else:
+ augmentations[stem] = getattr(
+ tam,
+ audiomentations[stem]["name"]
+ )(
+ **audiomentations[stem]["kwargs"]
+ )
+
+ self.augmentations = nn.ModuleDict(augmentations)
+ self.fix_clipping = fix_clipping
+ self.scaler_margin = scaler_margin
+
+ def check_and_fix_clipping(
+ self, item: Union[DataDict, BatchedDataDict]
+ ) -> Union[DataDict, BatchedDataDict]:
+ max_abs = []
+
+ for stem in item["audio"]:
+ max_abs.append(item["audio"][stem].abs().max().item())
+
+ if max(max_abs) > 1.0:
+ scaler = 1.0 / (max(max_abs) + torch.rand(
+ (1,),
+ device=item["audio"]["mixture"].device
+ ) * self.scaler_margin)
+
+ for stem in item["audio"]:
+ item["audio"][stem] *= scaler
+
+ return item
+
+ def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[
+ DataDict, BatchedDataDict]:
+
+ for stem in item["audio"]:
+ if stem == "mixture":
+ continue
+
+ if self.has_common:
+ item["audio"][stem] = self.augmentations["[common]"](
+ item["audio"][stem]
+ ).samples
+
+ if stem in self.augmentations:
+ item["audio"][stem] = self.augmentations[stem](
+ item["audio"][stem]
+ ).samples
+ elif self.has_default:
+ if not self.has_common or self.apply_both_default_and_common:
+ item["audio"][stem] = self.augmentations["[default]"](
+ item["audio"][stem]
+ ).samples
+
+ item["audio"]["mixture"] = sum(
+ [item["audio"][stem] for stem in item["audio"]
+ if stem != "mixture"]
+ ) # type: ignore[call-overload, assignment]
+
+ if self.fix_clipping:
+ item = self.check_and_fix_clipping(item)
+
+ return item
diff --git a/models/bandit/core/data/augmented.py b/models/bandit/core/data/augmented.py
new file mode 100644
index 0000000000000000000000000000000000000000..84d19599a6579eb5afd304ef6da76a6cbca49045
--- /dev/null
+++ b/models/bandit/core/data/augmented.py
@@ -0,0 +1,35 @@
+import warnings
+from typing import Dict, Optional, Union
+
+import torch
+from torch import nn
+from torch.utils import data
+
+
+class AugmentedDataset(data.Dataset):
+ def __init__(
+ self,
+ dataset: data.Dataset,
+ augmentation: nn.Module = nn.Identity(),
+ target_length: Optional[int] = None,
+ ) -> None:
+ warnings.warn(
+ "This class is no longer used. Attach augmentation to "
+ "the LightningSystem instead.",
+ DeprecationWarning,
+ )
+
+ self.dataset = dataset
+ self.augmentation = augmentation
+
+ self.ds_length: int = len(dataset) # type: ignore[arg-type]
+ self.length = target_length if target_length is not None else self.ds_length
+
+ def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str,
+ torch.Tensor]]]:
+ item = self.dataset[index % self.ds_length]
+ item = self.augmentation(item)
+ return item
+
+ def __len__(self) -> int:
+ return self.length
diff --git a/models/bandit/core/data/base.py b/models/bandit/core/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7b6c33a85b93c32209138e3d21bfc8e0f270cac
--- /dev/null
+++ b/models/bandit/core/data/base.py
@@ -0,0 +1,69 @@
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import pedalboard as pb
+import torch
+import torchaudio as ta
+from torch.utils import data
+
+from models.bandit.core.data._types import AudioDict, DataDict
+
+
+class BaseSourceSeparationDataset(data.Dataset, ABC):
+ def __init__(
+ self, split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int,
+ npy_memmap: bool,
+ recompute_mixture: bool
+ ):
+ self.split = split
+ self.stems = stems
+ self.stems_no_mixture = [s for s in stems if s != "mixture"]
+ self.files = files
+ self.data_path = data_path
+ self.fs = fs
+ self.npy_memmap = npy_memmap
+ self.recompute_mixture = recompute_mixture
+
+ @abstractmethod
+ def get_stem(
+ self,
+ *,
+ stem: str,
+ identifier: Dict[str, Any]
+ ) -> torch.Tensor:
+ raise NotImplementedError
+
+ def _get_audio(self, stems, identifier: Dict[str, Any]):
+ audio = {}
+ for stem in stems:
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier)
+
+ return audio
+
+ def get_audio(self, identifier: Dict[str, Any]) -> AudioDict:
+
+ if self.recompute_mixture:
+ audio = self._get_audio(
+ self.stems_no_mixture,
+ identifier=identifier
+ )
+ audio["mixture"] = self.compute_mixture(audio)
+ return audio
+ else:
+ return self._get_audio(self.stems, identifier=identifier)
+
+ @abstractmethod
+ def get_identifier(self, index: int) -> Dict[str, Any]:
+ pass
+
+ def compute_mixture(self, audio: AudioDict) -> torch.Tensor:
+
+ return sum(
+ audio[stem] for stem in audio if stem != "mixture"
+ )
diff --git a/models/bandit/core/data/dnr/__init__.py b/models/bandit/core/data/dnr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/bandit/core/data/dnr/datamodule.py b/models/bandit/core/data/dnr/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5550608aabf460eb1781576112ed60185dd318
--- /dev/null
+++ b/models/bandit/core/data/dnr/datamodule.py
@@ -0,0 +1,74 @@
+import os
+from typing import Mapping, Optional
+
+import pytorch_lightning as pl
+
+from .dataset import (
+ DivideAndRemasterDataset,
+ DivideAndRemasterDeterministicChunkDataset,
+ DivideAndRemasterRandomChunkDataset,
+ DivideAndRemasterRandomChunkDatasetWithSpeechReverb
+)
+
+
+def DivideAndRemasterDataModule(
+ data_root: str = "$DATA_ROOT/DnR/v2",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ train_kwargs: Optional[Mapping] = None,
+ val_kwargs: Optional[Mapping] = None,
+ test_kwargs: Optional[Mapping] = None,
+ datamodule_kwargs: Optional[Mapping] = None,
+ use_speech_reverb: bool = False
+ # augmentor=None
+) -> pl.LightningDataModule:
+ if train_kwargs is None:
+ train_kwargs = {}
+
+ if val_kwargs is None:
+ val_kwargs = {}
+
+ if test_kwargs is None:
+ test_kwargs = {}
+
+ if datamodule_kwargs is None:
+ datamodule_kwargs = {}
+
+ if num_workers is None:
+ num_workers = os.cpu_count()
+
+ if num_workers is None:
+ num_workers = 32
+
+ num_workers = min(num_workers, 64)
+
+ if use_speech_reverb:
+ train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb
+ else:
+ train_cls = DivideAndRemasterRandomChunkDataset
+
+ train_dataset = train_cls(
+ data_root, "train", **train_kwargs
+ )
+
+ # if augmentor is not None:
+ # train_dataset = AugmentedDataset(train_dataset, augmentor)
+
+ datamodule = pl.LightningDataModule.from_datasets(
+ train_dataset=train_dataset,
+ val_dataset=DivideAndRemasterDeterministicChunkDataset(
+ data_root, "val", **val_kwargs
+ ),
+ test_dataset=DivideAndRemasterDataset(
+ data_root,
+ "test",
+ **test_kwargs
+ ),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **datamodule_kwargs
+ )
+
+ datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign]
+
+ return datamodule
diff --git a/models/bandit/core/data/dnr/dataset.py b/models/bandit/core/data/dnr/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b241cf0dd474eafbfc9db3ec2f4987d12596de4
--- /dev/null
+++ b/models/bandit/core/data/dnr/dataset.py
@@ -0,0 +1,392 @@
+import os
+from abc import ABC
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import pedalboard as pb
+import torch
+import torchaudio as ta
+from torch.utils import data
+
+from models.bandit.core.data._types import AudioDict, DataDict
+from models.bandit.core.data.base import BaseSourceSeparationDataset
+
+
+class DivideAndRemasterBaseDataset(BaseSourceSeparationDataset, ABC):
+ ALLOWED_STEMS = ["mixture", "speech", "music", "effects", "mne"]
+ STEM_NAME_MAP = {
+ "mixture": "mix",
+ "speech": "speech",
+ "music": "music",
+ "effects": "sfx",
+ }
+ SPLIT_NAME_MAP = {"train": "tr", "val": "cv", "test": "tt"}
+
+ FULL_TRACK_LENGTH_SECOND = 60
+ FULL_TRACK_LENGTH_SAMPLES = FULL_TRACK_LENGTH_SECOND * 44100
+
+ def __init__(
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ recompute_mixture: bool = False,
+ ) -> None:
+ super().__init__(
+ split=split,
+ stems=stems,
+ files=files,
+ data_path=data_path,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ recompute_mixture=recompute_mixture
+ )
+
+ def get_stem(
+ self,
+ *,
+ stem: str,
+ identifier: Dict[str, Any]
+ ) -> torch.Tensor:
+
+ if stem == "mne":
+ return self.get_stem(
+ stem="music",
+ identifier=identifier) + self.get_stem(
+ stem="effects",
+ identifier=identifier)
+
+ track = identifier["track"]
+ path = os.path.join(self.data_path, track)
+
+ if self.npy_memmap:
+ audio = np.load(
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.npy"),
+ mmap_mode="r"
+ )
+ else:
+ # noinspection PyUnresolvedReferences
+ audio, _ = ta.load(
+ os.path.join(path, f"{self.STEM_NAME_MAP[stem]}.wav")
+ )
+
+ return audio
+
+ def get_identifier(self, index):
+ return dict(track=self.files[index])
+
+ def __getitem__(self, index: int) -> DataDict:
+ identifier = self.get_identifier(index)
+ audio = self.get_audio(identifier)
+
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
+
+
+class DivideAndRemasterDataset(DivideAndRemasterBaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
+
+ files = sorted(os.listdir(data_path))
+ files = [
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(
+ os.path.join(data_path, f)
+ )
+ ]
+ # pprint(list(enumerate(files)))
+ if split == "train":
+ assert len(files) == 3406, len(files)
+ elif split == "val":
+ assert len(files) == 487, len(files)
+ elif split == "test":
+ assert len(files) == 973, len(files)
+
+ self.n_tracks = len(files)
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ def __len__(self) -> int:
+ return self.n_tracks
+
+
+class DivideAndRemasterRandomChunkDataset(DivideAndRemasterBaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_length: int,
+ chunk_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
+
+ files = sorted(os.listdir(data_path))
+ files = [
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(
+ os.path.join(data_path, f)
+ )
+ ]
+
+ if split == "train":
+ assert len(files) == 3406, len(files)
+ elif split == "val":
+ assert len(files) == 487, len(files)
+ elif split == "test":
+ assert len(files) == 973, len(files)
+
+ self.n_tracks = len(files)
+
+ self.target_length = target_length
+ self.chunk_size = int(chunk_size_second * fs)
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ def __len__(self) -> int:
+ return self.target_length
+
+ def get_identifier(self, index):
+ return super().get_identifier(index % self.n_tracks)
+
+ def get_stem(
+ self,
+ *,
+ stem: str,
+ identifier: Dict[str, Any],
+ chunk_here: bool = False,
+ ) -> torch.Tensor:
+
+ stem = super().get_stem(
+ stem=stem,
+ identifier=identifier
+ )
+
+ if chunk_here:
+ start = np.random.randint(
+ 0,
+ self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
+ )
+ end = start + self.chunk_size
+
+ stem = stem[:, start:end]
+
+ return stem
+
+ def __getitem__(self, index: int) -> DataDict:
+ identifier = self.get_identifier(index)
+ # self.index_lock = index
+ audio = self.get_audio(identifier)
+ # self.index_lock = None
+
+ start = np.random.randint(
+ 0,
+ self.FULL_TRACK_LENGTH_SAMPLES - self.chunk_size
+ )
+ end = start + self.chunk_size
+
+ audio = {
+ k: v[:, start:end] for k, v in audio.items()
+ }
+
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
+
+
+class DivideAndRemasterDeterministicChunkDataset(DivideAndRemasterBaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ chunk_size_second: float,
+ hop_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ data_path = os.path.join(data_root, self.SPLIT_NAME_MAP[split])
+
+ files = sorted(os.listdir(data_path))
+ files = [
+ f
+ for f in files
+ if (not f.startswith(".")) and os.path.isdir(
+ os.path.join(data_path, f)
+ )
+ ]
+ # pprint(list(enumerate(files)))
+ if split == "train":
+ assert len(files) == 3406, len(files)
+ elif split == "val":
+ assert len(files) == 487, len(files)
+ elif split == "test":
+ assert len(files) == 973, len(files)
+
+ self.n_tracks = len(files)
+
+ self.chunk_size = int(chunk_size_second * fs)
+ self.hop_size = int(hop_size_second * fs)
+ self.n_chunks_per_track = int(
+ (
+ self.FULL_TRACK_LENGTH_SECOND - chunk_size_second) / hop_size_second
+ )
+
+ self.length = self.n_tracks * self.n_chunks_per_track
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ def get_identifier(self, index):
+ return super().get_identifier(index % self.n_tracks)
+
+ def __len__(self) -> int:
+ return self.length
+
+ def __getitem__(self, item: int) -> DataDict:
+
+ index = item % self.n_tracks
+ chunk = item // self.n_tracks
+
+ data_ = super().__getitem__(index)
+
+ audio = data_["audio"]
+
+ start = chunk * self.hop_size
+ end = start + self.chunk_size
+
+ for stem in self.stems:
+ data_["audio"][stem] = audio[stem][:, start:end]
+
+ return data_
+
+
+class DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
+ DivideAndRemasterRandomChunkDataset
+):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_length: int,
+ chunk_size_second: float,
+ stems: Optional[List[str]] = None,
+ fs: int = 44100,
+ npy_memmap: bool = True,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+
+ stems_no_mixture = [s for s in stems if s != "mixture"]
+
+ super().__init__(
+ data_root=data_root,
+ split=split,
+ target_length=target_length,
+ chunk_size_second=chunk_size_second,
+ stems=stems_no_mixture,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ )
+
+ self.stems = stems
+ self.stems_no_mixture = stems_no_mixture
+
+ def __getitem__(self, index: int) -> DataDict:
+
+ data_ = super().__getitem__(index)
+
+ dry = data_["audio"]["speech"][:]
+ n_samples = dry.shape[-1]
+
+ wet_level = np.random.rand()
+
+ speech = pb.Reverb(
+ room_size=np.random.rand(),
+ damping=np.random.rand(),
+ wet_level=wet_level,
+ dry_level=(1 - wet_level),
+ width=np.random.rand()
+ ).process(dry, self.fs, buffer_size=8192 * 4)[..., :n_samples]
+
+ data_["audio"]["speech"] = speech
+
+ data_["audio"]["mixture"] = sum(
+ [data_["audio"][s] for s in self.stems_no_mixture]
+ )
+
+ return data_
+
+ def __len__(self) -> int:
+ return super().__len__()
+
+
+if __name__ == "__main__":
+
+ from pprint import pprint
+ from tqdm.auto import tqdm
+
+ for split_ in ["train", "val", "test"]:
+ ds = DivideAndRemasterRandomChunkDatasetWithSpeechReverb(
+ data_root="$DATA_ROOT/DnR/v2np",
+ split=split_,
+ target_length=100,
+ chunk_size_second=6.0
+ )
+
+ print(split_, len(ds))
+
+ for track_ in tqdm(ds): # type: ignore
+ pprint(track_)
+ track_["audio"] = {k: v.shape for k, v in track_["audio"].items()}
+ pprint(track_)
+ # break
+
+ break
diff --git a/models/bandit/core/data/dnr/preprocess.py b/models/bandit/core/data/dnr/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d0b58690f3bae726b0655dbade6393c89bf8c9e
--- /dev/null
+++ b/models/bandit/core/data/dnr/preprocess.py
@@ -0,0 +1,54 @@
+import glob
+import os
+from typing import Tuple
+
+import numpy as np
+import torchaudio as ta
+from tqdm.contrib.concurrent import process_map
+
+
+def process_one(inputs: Tuple[str, str, int]) -> None:
+ infile, outfile, target_fs = inputs
+
+ dir = os.path.dirname(outfile)
+ os.makedirs(dir, exist_ok=True)
+
+ data, fs = ta.load(infile)
+
+ if fs != target_fs:
+ data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser")
+ fs = target_fs
+
+ data = data.numpy()
+ data = data.astype(np.float32)
+
+ if os.path.exists(outfile):
+ data_ = np.load(outfile)
+ if np.allclose(data, data_):
+ return
+
+ np.save(outfile, data)
+
+
+def preprocess(
+ data_path: str,
+ output_path: str,
+ fs: int
+) -> None:
+ files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True)
+ print(files)
+ outfiles = [
+ f.replace(data_path, output_path).replace(".wav", ".npy") for f in
+ files
+ ]
+
+ os.makedirs(output_path, exist_ok=True)
+ inputs = list(zip(files, outfiles, [fs] * len(files)))
+
+ process_map(process_one, inputs, chunksize=32)
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire()
diff --git a/models/bandit/core/data/musdb/__init__.py b/models/bandit/core/data/musdb/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/bandit/core/data/musdb/datamodule.py b/models/bandit/core/data/musdb/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8984daebd535b25f0551d348c91dbd1702fb9da
--- /dev/null
+++ b/models/bandit/core/data/musdb/datamodule.py
@@ -0,0 +1,77 @@
+import os.path
+from typing import Mapping, Optional
+
+import pytorch_lightning as pl
+
+from models.bandit.core.data.musdb.dataset import (
+ MUSDB18BaseDataset,
+ MUSDB18FullTrackDataset,
+ MUSDB18SadDataset,
+ MUSDB18SadOnTheFlyAugmentedDataset
+)
+
+
+def MUSDB18DataModule(
+ data_root: str = "$DATA_ROOT/MUSDB18/HQ",
+ target_stem: str = "vocals",
+ batch_size: int = 2,
+ num_workers: int = 8,
+ train_kwargs: Optional[Mapping] = None,
+ val_kwargs: Optional[Mapping] = None,
+ test_kwargs: Optional[Mapping] = None,
+ datamodule_kwargs: Optional[Mapping] = None,
+ use_on_the_fly: bool = True,
+ npy_memmap: bool = True
+) -> pl.LightningDataModule:
+ if train_kwargs is None:
+ train_kwargs = {}
+
+ if val_kwargs is None:
+ val_kwargs = {}
+
+ if test_kwargs is None:
+ test_kwargs = {}
+
+ if datamodule_kwargs is None:
+ datamodule_kwargs = {}
+
+ train_dataset: MUSDB18BaseDataset
+
+ if use_on_the_fly:
+ train_dataset = MUSDB18SadOnTheFlyAugmentedDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="train",
+ target_stem=target_stem,
+ **train_kwargs
+ )
+ else:
+ train_dataset = MUSDB18SadDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="train",
+ target_stem=target_stem,
+ **train_kwargs
+ )
+
+ datamodule = pl.LightningDataModule.from_datasets(
+ train_dataset=train_dataset,
+ val_dataset=MUSDB18SadDataset(
+ data_root=os.path.join(data_root, "saded-np"),
+ split="val",
+ target_stem=target_stem,
+ **val_kwargs
+ ),
+ test_dataset=MUSDB18FullTrackDataset(
+ data_root=os.path.join(data_root, "canonical"),
+ split="test",
+ **test_kwargs
+ ),
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **datamodule_kwargs
+ )
+
+ datamodule.predict_dataloader = ( # type: ignore[method-assign]
+ datamodule.test_dataloader
+ )
+
+ return datamodule
diff --git a/models/bandit/core/data/musdb/dataset.py b/models/bandit/core/data/musdb/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..53f5d9afdfe383600b5f89767c4ef1f4b54f4a47
--- /dev/null
+++ b/models/bandit/core/data/musdb/dataset.py
@@ -0,0 +1,280 @@
+import os
+from abc import ABC
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torchaudio as ta
+from torch.utils import data
+
+from models.bandit.core.data._types import AudioDict, DataDict
+from models.bandit.core.data.base import BaseSourceSeparationDataset
+
+
+class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC):
+
+ ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"]
+
+ def __init__(
+ self,
+ split: str,
+ stems: List[str],
+ files: List[str],
+ data_path: str,
+ fs: int = 44100,
+ npy_memmap=False,
+ ) -> None:
+ super().__init__(
+ split=split,
+ stems=stems,
+ files=files,
+ data_path=data_path,
+ fs=fs,
+ npy_memmap=npy_memmap,
+ recompute_mixture=False
+ )
+
+ def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
+ track = identifier["track"]
+ path = os.path.join(self.data_path, track)
+ # noinspection PyUnresolvedReferences
+
+ if self.npy_memmap:
+ audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r")
+ else:
+ audio, _ = ta.load(os.path.join(path, f"{stem}.wav"))
+
+ return audio
+
+ def get_identifier(self, index):
+ return dict(track=self.files[index])
+
+ def __getitem__(self, index: int) -> DataDict:
+ identifier = self.get_identifier(index)
+ audio = self.get_audio(identifier)
+
+ return {"audio": audio, "track": f"{self.split}/{identifier['track']}"}
+
+
+class MUSDB18FullTrackDataset(MUSDB18BaseDataset):
+
+ N_TRAIN_TRACKS = 100
+ N_TEST_TRACKS = 50
+ VALIDATION_FILES = [
+ "Actions - One Minute Smile",
+ "Clara Berry And Wooldog - Waltz For My Victims",
+ "Johnny Lokke - Promises & Lies",
+ "Patrick Talbot - A Reason To Leave",
+ "Triviul - Angelsaint",
+ "Alexander Ross - Goodbye Bolero",
+ "Fergessen - Nos Palpitants",
+ "Leaf - Summerghost",
+ "Skelpolu - Human Mistakes",
+ "Young Griffo - Pennies",
+ "ANiMAL - Rockshow",
+ "James May - On The Line",
+ "Meaxic - Take A Step",
+ "Traffic Experiment - Sirens",
+ ]
+
+ def __init__(
+ self, data_root: str, split: str, stems: Optional[List[
+ str]] = None
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+ self.stems = stems
+
+ if split == "test":
+ subset = "test"
+ elif split in ["train", "val"]:
+ subset = "train"
+ else:
+ raise NameError
+
+ data_path = os.path.join(data_root, subset)
+
+ files = sorted(os.listdir(data_path))
+ files = [f for f in files if not f.startswith(".")]
+ # pprint(list(enumerate(files)))
+ if subset == "train":
+ assert len(files) == 100, len(files)
+ if split == "train":
+ files = [f for f in files if f not in self.VALIDATION_FILES]
+ assert len(files) == 100 - len(self.VALIDATION_FILES)
+ else:
+ files = [f for f in files if f in self.VALIDATION_FILES]
+ assert len(files) == len(self.VALIDATION_FILES)
+ else:
+ split = "test"
+ assert len(files) == 50
+
+ self.n_tracks = len(files)
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files
+ )
+
+ def __len__(self) -> int:
+ return self.n_tracks
+
+class MUSDB18SadDataset(MUSDB18BaseDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_stem: str,
+ stems: Optional[List[str]] = None,
+ target_length: Optional[int] = None,
+ npy_memmap=False,
+ ) -> None:
+
+ if stems is None:
+ stems = self.ALLOWED_STEMS
+
+ data_path = os.path.join(data_root, target_stem, split)
+
+ files = sorted(os.listdir(data_path))
+ files = [f for f in files if not f.startswith(".")]
+
+ super().__init__(
+ data_path=data_path,
+ split=split,
+ stems=stems,
+ files=files,
+ npy_memmap=npy_memmap
+ )
+ self.n_segments = len(files)
+ self.target_stem = target_stem
+ self.target_length = (
+ target_length if target_length is not None else self.n_segments
+ )
+
+ def __len__(self) -> int:
+ return self.target_length
+
+ def __getitem__(self, index: int) -> DataDict:
+
+ index = index % self.n_segments
+
+ return super().__getitem__(index)
+
+ def get_identifier(self, index):
+ return super().get_identifier(index % self.n_segments)
+
+
+class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset):
+ def __init__(
+ self,
+ data_root: str,
+ split: str,
+ target_stem: str,
+ stems: Optional[List[str]] = None,
+ target_length: int = 20000,
+ apply_probability: Optional[float] = None,
+ chunk_size_second: float = 3.0,
+ random_scale_range_db: Tuple[float, float] = (-10, 10),
+ drop_probability: float = 0.1,
+ rescale: bool = True,
+ ) -> None:
+ super().__init__(data_root, split, target_stem, stems)
+
+ if apply_probability is None:
+ apply_probability = (
+ target_length - self.n_segments) / target_length
+
+ self.apply_probability = apply_probability
+ self.drop_probability = drop_probability
+ self.chunk_size_second = chunk_size_second
+ self.random_scale_range_db = random_scale_range_db
+ self.rescale = rescale
+
+ self.chunk_size_sample = int(self.chunk_size_second * self.fs)
+ self.target_length = target_length
+
+ def __len__(self) -> int:
+ return self.target_length
+
+ def __getitem__(self, index: int) -> DataDict:
+
+ index = index % self.n_segments
+
+ # if np.random.rand() > self.apply_probability:
+ # return super().__getitem__(index)
+
+ audio = {}
+ identifier = self.get_identifier(index)
+
+ # assert self.target_stem in self.stems_no_mixture
+ for stem in self.stems_no_mixture:
+ if stem == self.target_stem:
+ identifier_ = identifier
+ else:
+ if np.random.rand() < self.apply_probability:
+ index_ = np.random.randint(self.n_segments)
+ identifier_ = self.get_identifier(index_)
+ else:
+ identifier_ = identifier
+
+ audio[stem] = self.get_stem(stem=stem, identifier=identifier_)
+
+ # if stem == self.target_stem:
+
+ if self.chunk_size_sample < audio[stem].shape[-1]:
+ chunk_start = np.random.randint(
+ audio[stem].shape[-1] - self.chunk_size_sample
+ )
+ else:
+ chunk_start = 0
+
+ if np.random.rand() < self.drop_probability:
+ # db_scale = "-inf"
+ linear_scale = 0.0
+ else:
+ db_scale = np.random.uniform(*self.random_scale_range_db)
+ linear_scale = np.power(10, db_scale / 20)
+ # db_scale = f"{db_scale:+2.1f}"
+ # print(linear_scale)
+ audio[stem][...,
+ chunk_start: chunk_start + self.chunk_size_sample] = (
+ linear_scale
+ * audio[stem][...,
+ chunk_start: chunk_start + self.chunk_size_sample]
+ )
+
+ audio["mixture"] = self.compute_mixture(audio)
+
+ if self.rescale:
+ max_abs_val = max(
+ [torch.max(torch.abs(audio[stem])) for stem in self.stems]
+ ) # type: ignore[type-var]
+ if max_abs_val > 1:
+ audio = {k: v / max_abs_val for k, v in audio.items()}
+
+ track = identifier["track"]
+
+ return {"audio": audio, "track": f"{self.split}/{track}"}
+
+# if __name__ == "__main__":
+#
+# from pprint import pprint
+# from tqdm.auto import tqdm
+#
+# for split_ in ["train", "val", "test"]:
+# ds = MUSDB18SadOnTheFlyAugmentedDataset(
+# data_root="$DATA_ROOT/MUSDB18/HQ/saded",
+# split=split_,
+# target_stem="vocals"
+# )
+#
+# print(split_, len(ds))
+#
+# for track_ in tqdm(ds):
+# track_["audio"] = {
+# k: v.shape for k, v in track_["audio"].items()
+# }
+# pprint(track_)
diff --git a/models/bandit/core/data/musdb/preprocess.py b/models/bandit/core/data/musdb/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d5892e5c3f4acef5bbb183a746d76475c461810
--- /dev/null
+++ b/models/bandit/core/data/musdb/preprocess.py
@@ -0,0 +1,238 @@
+import glob
+import os
+
+import numpy as np
+import torch
+import torchaudio as ta
+from torch import nn
+from torch.nn import functional as F
+from tqdm.contrib.concurrent import process_map
+
+from core.data._types import DataDict
+from core.data.musdb.dataset import MUSDB18FullTrackDataset
+import pyloudnorm as pyln
+
+class SourceActivityDetector(nn.Module):
+ def __init__(
+ self,
+ analysis_stem: str,
+ output_path: str,
+ fs: int = 44100,
+ segment_length_second: float = 6.0,
+ hop_length_second: float = 3.0,
+ n_chunks: int = 10,
+ chunk_epsilon: float = 1e-5,
+ energy_threshold_quantile: float = 0.15,
+ segment_epsilon: float = 1e-3,
+ salient_proportion_threshold: float = 0.5,
+ target_lufs: float = -24
+ ) -> None:
+ super().__init__()
+
+ self.fs = fs
+ self.segment_length = int(segment_length_second * self.fs)
+ self.hop_length = int(hop_length_second * self.fs)
+ self.n_chunks = n_chunks
+ assert self.segment_length % self.n_chunks == 0
+ self.chunk_size = self.segment_length // self.n_chunks
+ self.chunk_epsilon = chunk_epsilon
+ self.energy_threshold_quantile = energy_threshold_quantile
+ self.segment_epsilon = segment_epsilon
+ self.salient_proportion_threshold = salient_proportion_threshold
+ self.analysis_stem = analysis_stem
+
+ self.meter = pyln.Meter(self.fs)
+ self.target_lufs = target_lufs
+
+ self.output_path = output_path
+
+ def forward(self, data: DataDict) -> None:
+
+ stem_ = self.analysis_stem if (
+ self.analysis_stem != "none") else "mixture"
+
+ x = data["audio"][stem_]
+
+ xnp = x.numpy()
+ loudness = self.meter.integrated_loudness(xnp.T)
+
+ for stem in data["audio"]:
+ s = data["audio"][stem]
+ s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T
+ s = torch.as_tensor(s)
+ data["audio"][stem] = s
+
+ if x.ndim == 3:
+ assert x.shape[0] == 1
+ x = x[0]
+
+ n_chan, n_samples = x.shape
+
+ n_segments = (
+ int(
+ np.ceil((n_samples - self.segment_length) / self.hop_length)
+ ) + 1
+ )
+
+ segments = torch.zeros((n_segments, n_chan, self.segment_length))
+ for i in range(n_segments):
+ start = i * self.hop_length
+ end = start + self.segment_length
+ end = min(end, n_samples)
+
+ xseg = x[:, start:end]
+
+ if end - start < self.segment_length:
+ xseg = F.pad(
+ xseg,
+ pad=(0, self.segment_length - (end - start)),
+ value=torch.nan
+ )
+
+ segments[i, :, :] = xseg
+
+ chunks = segments.reshape(
+ (n_segments, n_chan, self.n_chunks, self.chunk_size)
+ )
+
+ if self.analysis_stem != "none":
+ chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3))
+ chunk_energies = torch.nan_to_num(chunk_energies, nan=0)
+ chunk_energies[chunk_energies == 0] = self.chunk_epsilon
+
+ energy_threshold = torch.nanquantile(
+ chunk_energies, q=self.energy_threshold_quantile
+ )
+
+ if energy_threshold < self.segment_epsilon:
+ energy_threshold = self.segment_epsilon # type: ignore[assignment]
+
+ chunks_above_threshold = chunk_energies > energy_threshold
+ n_chunks_above_threshold = torch.mean(
+ chunks_above_threshold.to(torch.float), dim=-1
+ )
+
+ segment_above_threshold = (
+ n_chunks_above_threshold > self.salient_proportion_threshold
+ )
+
+ if torch.sum(segment_above_threshold) == 0:
+ return
+
+ else:
+ segment_above_threshold = torch.ones((n_segments,))
+
+ for i in range(n_segments):
+ if not segment_above_threshold[i]:
+ continue
+
+ outpath = os.path.join(
+ self.output_path,
+ self.analysis_stem,
+ f"{data['track']} - {self.analysis_stem}{i:03d}",
+ )
+ os.makedirs(outpath, exist_ok=True)
+
+ for stem in data["audio"]:
+ if stem == self.analysis_stem:
+ segment = torch.nan_to_num(segments[i, :, :], nan=0)
+ else:
+ start = i * self.hop_length
+ end = start + self.segment_length
+ end = min(n_samples, end)
+
+ segment = data["audio"][stem][:, start:end]
+
+ if end - start < self.segment_length:
+ segment = F.pad(
+ segment,
+ (0, self.segment_length - (end - start))
+ )
+
+ assert segment.shape[-1] == self.segment_length, segment.shape
+
+ # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs)
+
+ np.save(os.path.join(outpath, f"{stem}.wav"), segment)
+
+
+def preprocess(
+ analysis_stem: str,
+ output_path: str = "/data/MUSDB18/HQ/saded-np",
+ fs: int = 44100,
+ segment_length_second: float = 6.0,
+ hop_length_second: float = 3.0,
+ n_chunks: int = 10,
+ chunk_epsilon: float = 1e-5,
+ energy_threshold_quantile: float = 0.15,
+ segment_epsilon: float = 1e-3,
+ salient_proportion_threshold: float = 0.5,
+) -> None:
+
+ sad = SourceActivityDetector(
+ analysis_stem=analysis_stem,
+ output_path=output_path,
+ fs=fs,
+ segment_length_second=segment_length_second,
+ hop_length_second=hop_length_second,
+ n_chunks=n_chunks,
+ chunk_epsilon=chunk_epsilon,
+ energy_threshold_quantile=energy_threshold_quantile,
+ segment_epsilon=segment_epsilon,
+ salient_proportion_threshold=salient_proportion_threshold,
+ )
+
+ for split in ["train", "val", "test"]:
+ ds = MUSDB18FullTrackDataset(
+ data_root="/data/MUSDB18/HQ/canonical",
+ split=split,
+ )
+
+ tracks = []
+ for i, track in enumerate(tqdm(ds, total=len(ds))):
+ if i % 32 == 0 and tracks:
+ process_map(sad, tracks, max_workers=8)
+ tracks = []
+ tracks.append(track)
+ process_map(sad, tracks, max_workers=8)
+
+def loudness_norm_one(
+ inputs
+):
+ infile, outfile, target_lufs = inputs
+
+ audio, fs = ta.load(infile)
+ audio = audio.mean(dim=0, keepdim=True).numpy().T
+
+ meter = pyln.Meter(fs)
+ loudness = meter.integrated_loudness(audio)
+ audio = pyln.normalize.loudness(audio, loudness, target_lufs)
+
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
+ np.save(outfile, audio.T)
+
+def loudness_norm(
+ data_path: str,
+ # output_path: str,
+ target_lufs = -17.0,
+):
+ files = glob.glob(
+ os.path.join(data_path, "**", "*.wav"), recursive=True
+ )
+
+ outfiles = [
+ f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files
+ ]
+
+ files = [(f, o, target_lufs) for f, o in zip(files, outfiles)]
+
+ process_map(loudness_norm_one, files, chunksize=2)
+
+
+
+if __name__ == "__main__":
+
+ from tqdm.auto import tqdm
+ import fire
+
+ fire.Fire()
diff --git a/models/bandit/core/data/musdb/validation.yaml b/models/bandit/core/data/musdb/validation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f8752478d285d1d13d5e842225af1de95cae57a
--- /dev/null
+++ b/models/bandit/core/data/musdb/validation.yaml
@@ -0,0 +1,15 @@
+validation:
+ - 'Actions - One Minute Smile'
+ - 'Clara Berry And Wooldog - Waltz For My Victims'
+ - 'Johnny Lokke - Promises & Lies'
+ - 'Patrick Talbot - A Reason To Leave'
+ - 'Triviul - Angelsaint'
+ - 'Alexander Ross - Goodbye Bolero'
+ - 'Fergessen - Nos Palpitants'
+ - 'Leaf - Summerghost'
+ - 'Skelpolu - Human Mistakes'
+ - 'Young Griffo - Pennies'
+ - 'ANiMAL - Rockshow'
+ - 'James May - On The Line'
+ - 'Meaxic - Take A Step'
+ - 'Traffic Experiment - Sirens'
\ No newline at end of file
diff --git a/models/bandit/core/loss/__init__.py b/models/bandit/core/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ab803aecde4f686e34d93f3f2d585e0a9867525
--- /dev/null
+++ b/models/bandit/core/loss/__init__.py
@@ -0,0 +1,2 @@
+from ._multistem import MultiStemWrapperFromConfig
+from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss
diff --git a/models/bandit/core/loss/_complex.py b/models/bandit/core/loss/_complex.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d97e5d8bab3fdb095c2ba7c77faebef26e8f1ce
--- /dev/null
+++ b/models/bandit/core/loss/_complex.py
@@ -0,0 +1,34 @@
+from typing import Any
+
+import torch
+from torch import nn
+from torch.nn.modules import loss as _loss
+from torch.nn.modules.loss import _Loss
+
+
+class ReImLossWrapper(_Loss):
+ def __init__(self, module: _Loss) -> None:
+ super().__init__()
+ self.module = module
+
+ def forward(
+ self,
+ preds: torch.Tensor,
+ target: torch.Tensor
+ ) -> torch.Tensor:
+ return self.module(
+ torch.view_as_real(preds),
+ torch.view_as_real(target)
+ )
+
+
+class ReImL1Loss(ReImLossWrapper):
+ def __init__(self, **kwargs: Any) -> None:
+ l1_loss = _loss.L1Loss(**kwargs)
+ super().__init__(module=(l1_loss))
+
+
+class ReImL2Loss(ReImLossWrapper):
+ def __init__(self, **kwargs: Any) -> None:
+ l2_loss = _loss.MSELoss(**kwargs)
+ super().__init__(module=(l2_loss))
diff --git a/models/bandit/core/loss/_multistem.py b/models/bandit/core/loss/_multistem.py
new file mode 100644
index 0000000000000000000000000000000000000000..675e0ffbecf1f9f5efb0369bcb9d5c590efcfc31
--- /dev/null
+++ b/models/bandit/core/loss/_multistem.py
@@ -0,0 +1,45 @@
+from typing import Any, Dict
+
+import torch
+from asteroid import losses as asteroid_losses
+from torch import nn
+from torch.nn.modules.loss import _Loss
+
+from . import snr
+
+
+def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss:
+
+ for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]:
+ if name in module.__dict__:
+ return module.__dict__[name](**kwargs)
+
+ raise NameError
+
+
+class MultiStemWrapper(_Loss):
+ def __init__(self, module: _Loss, modality: str = "audio") -> None:
+ super().__init__()
+ self.loss = module
+ self.modality = modality
+
+ def forward(
+ self,
+ preds: Dict[str, Dict[str, torch.Tensor]],
+ target: Dict[str, Dict[str, torch.Tensor]],
+ ) -> torch.Tensor:
+ loss = {
+ stem: self.loss(
+ preds[self.modality][stem],
+ target[self.modality][stem]
+ )
+ for stem in preds[self.modality] if stem in target[self.modality]
+ }
+
+ return sum(list(loss.values()))
+
+
+class MultiStemWrapperFromConfig(MultiStemWrapper):
+ def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None:
+ loss = parse_loss(name, kwargs)
+ super().__init__(module=loss, modality=modality)
diff --git a/models/bandit/core/loss/_timefreq.py b/models/bandit/core/loss/_timefreq.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ea9d5994ca645546b5ccb7e6eafaa3d2fbcf959
--- /dev/null
+++ b/models/bandit/core/loss/_timefreq.py
@@ -0,0 +1,113 @@
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+from torch.nn.modules.loss import _Loss
+
+from models.bandit.core.loss._multistem import MultiStemWrapper
+from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper
+from models.bandit.core.loss.snr import SignalNoisePNormRatio
+
+class TimeFreqWrapper(_Loss):
+ def __init__(
+ self,
+ time_module: _Loss,
+ freq_module: Optional[_Loss] = None,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ multistem: bool = True,
+ ) -> None:
+ super().__init__()
+
+ if freq_module is None:
+ freq_module = time_module
+
+ if multistem:
+ time_module = MultiStemWrapper(time_module, modality="audio")
+ freq_module = MultiStemWrapper(freq_module, modality="spectrogram")
+
+ self.time_module = time_module
+ self.freq_module = freq_module
+
+ self.time_weight = time_weight
+ self.freq_weight = freq_weight
+
+ # TODO: add better type hints
+ def forward(self, preds: Any, target: Any) -> torch.Tensor:
+
+ return self.time_weight * self.time_module(
+ preds, target
+ ) + self.freq_weight * self.freq_module(preds, target)
+
+
+class TimeFreqL1Loss(TimeFreqWrapper):
+ def __init__(
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
+ ) -> None:
+ if tkwargs is None:
+ tkwargs = {}
+ if fkwargs is None:
+ fkwargs = {}
+ time_module = (nn.L1Loss(**tkwargs))
+ freq_module = ReImL1Loss(**fkwargs)
+ super().__init__(
+ time_module,
+ freq_module,
+ time_weight,
+ freq_weight,
+ multistem
+ )
+
+
+class TimeFreqL2Loss(TimeFreqWrapper):
+ def __init__(
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
+ ) -> None:
+ if tkwargs is None:
+ tkwargs = {}
+ if fkwargs is None:
+ fkwargs = {}
+ time_module = nn.MSELoss(**tkwargs)
+ freq_module = ReImL2Loss(**fkwargs)
+ super().__init__(
+ time_module,
+ freq_module,
+ time_weight,
+ freq_weight,
+ multistem
+ )
+
+
+
+class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper):
+ def __init__(
+ self,
+ time_weight: float = 1.0,
+ freq_weight: float = 1.0,
+ tkwargs: Optional[Dict[str, Any]] = None,
+ fkwargs: Optional[Dict[str, Any]] = None,
+ multistem: bool = True,
+ ) -> None:
+ if tkwargs is None:
+ tkwargs = {}
+ if fkwargs is None:
+ fkwargs = {}
+ time_module = SignalNoisePNormRatio(**tkwargs)
+ freq_module = SignalNoisePNormRatio(**fkwargs)
+ super().__init__(
+ time_module,
+ freq_module,
+ time_weight,
+ freq_weight,
+ multistem
+ )
diff --git a/models/bandit/core/loss/snr.py b/models/bandit/core/loss/snr.py
new file mode 100644
index 0000000000000000000000000000000000000000..2996dd57080db687599c1fd673d6807041a04b52
--- /dev/null
+++ b/models/bandit/core/loss/snr.py
@@ -0,0 +1,146 @@
+import torch
+from torch.nn.modules.loss import _Loss
+from torch.nn import functional as F
+
+class SignalNoisePNormRatio(_Loss):
+ def __init__(
+ self,
+ p: float = 1.0,
+ scale_invariant: bool = False,
+ zero_mean: bool = False,
+ take_log: bool = True,
+ reduction: str = "mean",
+ EPS: float = 1e-3,
+ ) -> None:
+ assert reduction != "sum", NotImplementedError
+ super().__init__(reduction=reduction)
+ assert not zero_mean
+
+ self.p = p
+
+ self.EPS = EPS
+ self.take_log = take_log
+
+ self.scale_invariant = scale_invariant
+
+ def forward(
+ self,
+ est_target: torch.Tensor,
+ target: torch.Tensor
+ ) -> torch.Tensor:
+
+ target_ = target
+ if self.scale_invariant:
+ ndim = target.ndim
+ dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
+ s_target_energy = (
+ torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
+ )
+
+ if ndim > 2:
+ dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
+ s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
+
+ target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
+ target = target_ * target_scaler
+
+ if torch.is_complex(est_target):
+ est_target = torch.view_as_real(est_target)
+ target = torch.view_as_real(target)
+
+
+ batch_size = est_target.shape[0]
+ est_target = est_target.reshape(batch_size, -1)
+ target = target.reshape(batch_size, -1)
+ # target_ = target_.reshape(batch_size, -1)
+
+ if self.p == 1:
+ e_error = torch.abs(est_target-target).mean(dim=-1)
+ e_target = torch.abs(target).mean(dim=-1)
+ elif self.p == 2:
+ e_error = torch.square(est_target-target).mean(dim=-1)
+ e_target = torch.square(target).mean(dim=-1)
+ else:
+ raise NotImplementedError
+
+ if self.take_log:
+ loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
+ else:
+ loss = (e_error + self.EPS)/(e_target + self.EPS)
+
+ if self.reduction == "mean":
+ loss = loss.mean()
+ elif self.reduction == "sum":
+ loss = loss.sum()
+
+ return loss
+
+
+
+class MultichannelSingleSrcNegSDR(_Loss):
+ def __init__(
+ self,
+ sdr_type: str,
+ p: float = 2.0,
+ zero_mean: bool = True,
+ take_log: bool = True,
+ reduction: str = "mean",
+ EPS: float = 1e-8,
+ ) -> None:
+ assert reduction != "sum", NotImplementedError
+ super().__init__(reduction=reduction)
+
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
+ self.sdr_type = sdr_type
+ self.zero_mean = zero_mean
+ self.take_log = take_log
+ self.EPS = 1e-8
+
+ self.p = p
+
+ def forward(
+ self,
+ est_target: torch.Tensor,
+ target: torch.Tensor
+ ) -> torch.Tensor:
+ if target.size() != est_target.size() or target.ndim != 3:
+ raise TypeError(
+ f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
+ )
+ # Step 1. Zero-mean norm
+ if self.zero_mean:
+ mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
+ mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
+ target = target - mean_source
+ est_target = est_target - mean_estimate
+ # Step 2. Pair-wise SI-SDR.
+ if self.sdr_type in ["sisdr", "sdsdr"]:
+ # [batch, 1]
+ dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
+ # [batch, 1]
+ s_target_energy = (
+ torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
+ )
+ # [batch, time]
+ scaled_target = dot * target / s_target_energy
+ else:
+ # [batch, time]
+ scaled_target = target
+ if self.sdr_type in ["sdsdr", "snr"]:
+ e_noise = est_target - target
+ else:
+ e_noise = est_target - scaled_target
+ # [batch]
+
+ if self.p == 2.0:
+ losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
+ torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
+ )
+ else:
+ losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
+ torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
+ )
+ if self.take_log:
+ losses = 10 * torch.log10(losses + self.EPS)
+ losses = losses.mean() if self.reduction == "mean" else losses
+ return -losses
diff --git a/models/bandit/core/metrics/__init__.py b/models/bandit/core/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c638b4df585ad6c3c6490d9e67b7fc197f0d06f4
--- /dev/null
+++ b/models/bandit/core/metrics/__init__.py
@@ -0,0 +1,9 @@
+from .snr import (
+ ChunkMedianScaleInvariantSignalDistortionRatio,
+ ChunkMedianScaleInvariantSignalNoiseRatio,
+ ChunkMedianSignalDistortionRatio,
+ ChunkMedianSignalNoiseRatio,
+ SafeSignalDistortionRatio,
+)
+
+# from .mushra import EstimatedMushraScore
diff --git a/models/bandit/core/metrics/_squim.py b/models/bandit/core/metrics/_squim.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec76b5fb5e27d0f6a6aaa5ececc5161482150bfc
--- /dev/null
+++ b/models/bandit/core/metrics/_squim.py
@@ -0,0 +1,383 @@
+from dataclasses import dataclass
+
+from torchaudio._internal import load_state_dict_from_url
+
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def transform_wb_pesq_range(x: float) -> float:
+ """The metric defined by ITU-T P.862 is often called 'PESQ score', which is defined
+ for narrow-band signals and has a value range of [-0.5, 4.5] exactly. Here, we use the metric
+ defined by ITU-T P.862.2, commonly known as 'wide-band PESQ' and will be referred to as "PESQ score".
+
+ Args:
+ x (float): Narrow-band PESQ score.
+
+ Returns:
+ (float): Wide-band PESQ score.
+ """
+ return 0.999 + (4.999 - 0.999) / (1 + math.exp(-1.3669 * x + 3.8224))
+
+
+PESQRange: Tuple[float, float] = (
+ 1.0, # P.862.2 uses a different input filter than P.862, and the lower bound of
+ # the raw score is not -0.5 anymore. It's hard to figure out the true lower bound.
+ # We are using 1.0 as a reasonable approximation.
+ transform_wb_pesq_range(4.5),
+)
+
+
+class RangeSigmoid(nn.Module):
+ def __init__(self, val_range: Tuple[float, float] = (0.0, 1.0)) -> None:
+ super(RangeSigmoid, self).__init__()
+ assert isinstance(val_range, tuple) and len(val_range) == 2
+ self.val_range: Tuple[float, float] = val_range
+ self.sigmoid: nn.modules.Module = nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = self.sigmoid(x) * (self.val_range[1] - self.val_range[0]) + self.val_range[0]
+ return out
+
+
+class Encoder(nn.Module):
+ """Encoder module that transform 1D waveform to 2D representations.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 512)
+ win_len (int, optional): kernel size in the Conv1D layer. (Default: 32)
+ """
+
+ def __init__(self, feat_dim: int = 512, win_len: int = 32) -> None:
+ super(Encoder, self).__init__()
+
+ self.conv1d = nn.Conv1d(1, feat_dim, win_len, stride=win_len // 2, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Apply waveforms to convolutional layer and ReLU layer.
+
+ Args:
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
+
+ Returns:
+ (torch,Tensor): Feature Tensor with dimensions `(batch, channel, frame)`.
+ """
+ out = x.unsqueeze(dim=1)
+ out = F.relu(self.conv1d(out))
+ return out
+
+
+class SingleRNN(nn.Module):
+ def __init__(self, rnn_type: str, input_size: int, hidden_size: int, dropout: float = 0.0) -> None:
+ super(SingleRNN, self).__init__()
+
+ self.rnn_type = rnn_type
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+
+ self.rnn: nn.modules.Module = getattr(nn, rnn_type)(
+ input_size,
+ hidden_size,
+ 1,
+ dropout=dropout,
+ batch_first=True,
+ bidirectional=True,
+ )
+
+ self.proj = nn.Linear(hidden_size * 2, input_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # input shape: batch, seq, dim
+ out, _ = self.rnn(x)
+ out = self.proj(out)
+ return out
+
+
+class DPRNN(nn.Module):
+ """*Dual-path recurrent neural networks (DPRNN)* :cite:`luo2020dual`.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module. (Default: 64)
+ hidden_dim (int, optional): Hidden dimension in the RNN layer of DPRNN. (Default: 128)
+ num_blocks (int, optional): Number of DPRNN layers. (Default: 6)
+ rnn_type (str, optional): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"]. (Default: "LSTM")
+ d_model (int, optional): The number of expected features in the input. (Default: 256)
+ chunk_size (int, optional): Chunk size of input for DPRNN. (Default: 100)
+ chunk_stride (int, optional): Stride of chunk input for DPRNN. (Default: 50)
+ """
+
+ def __init__(
+ self,
+ feat_dim: int = 64,
+ hidden_dim: int = 128,
+ num_blocks: int = 6,
+ rnn_type: str = "LSTM",
+ d_model: int = 256,
+ chunk_size: int = 100,
+ chunk_stride: int = 50,
+ ) -> None:
+ super(DPRNN, self).__init__()
+
+ self.num_blocks = num_blocks
+
+ self.row_rnn = nn.ModuleList([])
+ self.col_rnn = nn.ModuleList([])
+ self.row_norm = nn.ModuleList([])
+ self.col_norm = nn.ModuleList([])
+ for _ in range(num_blocks):
+ self.row_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
+ self.col_rnn.append(SingleRNN(rnn_type, feat_dim, hidden_dim))
+ self.row_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
+ self.col_norm.append(nn.GroupNorm(1, feat_dim, eps=1e-8))
+ self.conv = nn.Sequential(
+ nn.Conv2d(feat_dim, d_model, 1),
+ nn.PReLU(),
+ )
+ self.chunk_size = chunk_size
+ self.chunk_stride = chunk_stride
+
+ def pad_chunk(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ # input shape: (B, N, T)
+ seq_len = x.shape[-1]
+
+ rest = self.chunk_size - (self.chunk_stride + seq_len % self.chunk_size) % self.chunk_size
+ out = F.pad(x, [self.chunk_stride, rest + self.chunk_stride])
+
+ return out, rest
+
+ def chunking(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
+ out, rest = self.pad_chunk(x)
+ batch_size, feat_dim, seq_len = out.shape
+
+ segments1 = out[:, :, : -self.chunk_stride].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
+ segments2 = out[:, :, self.chunk_stride :].contiguous().view(batch_size, feat_dim, -1, self.chunk_size)
+ out = torch.cat([segments1, segments2], dim=3)
+ out = out.view(batch_size, feat_dim, -1, self.chunk_size).transpose(2, 3).contiguous()
+
+ return out, rest
+
+ def merging(self, x: torch.Tensor, rest: int) -> torch.Tensor:
+ batch_size, dim, _, _ = x.shape
+ out = x.transpose(2, 3).contiguous().view(batch_size, dim, -1, self.chunk_size * 2)
+ out1 = out[:, :, :, : self.chunk_size].contiguous().view(batch_size, dim, -1)[:, :, self.chunk_stride :]
+ out2 = out[:, :, :, self.chunk_size :].contiguous().view(batch_size, dim, -1)[:, :, : -self.chunk_stride]
+ out = out1 + out2
+ if rest > 0:
+ out = out[:, :, :-rest]
+ out = out.contiguous()
+ return out
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, rest = self.chunking(x)
+ batch_size, _, dim1, dim2 = x.shape
+ out = x
+ for row_rnn, row_norm, col_rnn, col_norm in zip(self.row_rnn, self.row_norm, self.col_rnn, self.col_norm):
+ row_in = out.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1).contiguous()
+ row_out = row_rnn(row_in)
+ row_out = row_out.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
+ row_out = row_norm(row_out)
+ out = out + row_out
+
+ col_in = out.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1).contiguous()
+ col_out = col_rnn(col_in)
+ col_out = col_out.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
+ col_out = col_norm(col_out)
+ out = out + col_out
+ out = self.conv(out)
+ out = self.merging(out, rest)
+ out = out.transpose(1, 2).contiguous()
+ return out
+
+
+class AutoPool(nn.Module):
+ def __init__(self, pool_dim: int = 1) -> None:
+ super(AutoPool, self).__init__()
+ self.pool_dim: int = pool_dim
+ self.softmax: nn.modules.Module = nn.Softmax(dim=pool_dim)
+ self.register_parameter("alpha", nn.Parameter(torch.ones(1)))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ weight = self.softmax(torch.mul(x, self.alpha))
+ out = torch.sum(torch.mul(x, weight), dim=self.pool_dim)
+ return out
+
+
+class SquimObjective(nn.Module):
+ """Speech Quality and Intelligibility Measures (SQUIM) model that predicts **objective** metric scores
+ for speech enhancement (e.g., STOI, PESQ, and SI-SDR).
+
+ Args:
+ encoder (torch.nn.Module): Encoder module to transform 1D waveform to 2D feature representation.
+ dprnn (torch.nn.Module): DPRNN module to model sequential feature.
+ branches (torch.nn.ModuleList): Transformer branches in which each branch estimate one objective metirc score.
+ """
+
+ def __init__(
+ self,
+ encoder: nn.Module,
+ dprnn: nn.Module,
+ branches: nn.ModuleList,
+ ):
+ super(SquimObjective, self).__init__()
+ self.encoder = encoder
+ self.dprnn = dprnn
+ self.branches = branches
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """
+ Args:
+ x (torch.Tensor): Input waveforms. Tensor with dimensions `(batch, time)`.
+
+ Returns:
+ List(torch.Tensor): List of score Tenosrs. Each Tensor is with dimension `(batch,)`.
+ """
+ if x.ndim != 2:
+ raise ValueError(f"The input must be a 2D Tensor. Found dimension {x.ndim}.")
+ x = x / (torch.mean(x**2, dim=1, keepdim=True) ** 0.5 * 20)
+ out = self.encoder(x)
+ out = self.dprnn(out)
+ scores = []
+ for branch in self.branches:
+ scores.append(branch(out).squeeze(dim=1))
+ return scores
+
+
+def _create_branch(d_model: int, nhead: int, metric: str) -> nn.modules.Module:
+ """Create branch module after DPRNN model for predicting metric score.
+
+ Args:
+ d_model (int): The number of expected features in the input.
+ nhead (int): Number of heads in the multi-head attention model.
+ metric (str): The metric name to predict.
+
+ Returns:
+ (nn.Module): Returned module to predict corresponding metric score.
+ """
+ layer1 = nn.TransformerEncoderLayer(d_model, nhead, d_model * 4, dropout=0.0, batch_first=True)
+ layer2 = AutoPool()
+ if metric == "stoi":
+ layer3 = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.PReLU(),
+ nn.Linear(d_model, 1),
+ RangeSigmoid(),
+ )
+ elif metric == "pesq":
+ layer3 = nn.Sequential(
+ nn.Linear(d_model, d_model),
+ nn.PReLU(),
+ nn.Linear(d_model, 1),
+ RangeSigmoid(val_range=PESQRange),
+ )
+ else:
+ layer3: nn.modules.Module = nn.Sequential(nn.Linear(d_model, d_model), nn.PReLU(), nn.Linear(d_model, 1))
+ return nn.Sequential(layer1, layer2, layer3)
+
+
+def squim_objective_model(
+ feat_dim: int,
+ win_len: int,
+ d_model: int,
+ nhead: int,
+ hidden_dim: int,
+ num_blocks: int,
+ rnn_type: str,
+ chunk_size: int,
+ chunk_stride: Optional[int] = None,
+) -> SquimObjective:
+ """Build a custome :class:`torchaudio.prototype.models.SquimObjective` model.
+
+ Args:
+ feat_dim (int, optional): The feature dimension after Encoder module.
+ win_len (int): Kernel size in the Encoder module.
+ d_model (int): The number of expected features in the input.
+ nhead (int): Number of heads in the multi-head attention model.
+ hidden_dim (int): Hidden dimension in the RNN layer of DPRNN.
+ num_blocks (int): Number of DPRNN layers.
+ rnn_type (str): Type of RNN in DPRNN. Valid options are ["RNN", "LSTM", "GRU"].
+ chunk_size (int): Chunk size of input for DPRNN.
+ chunk_stride (int or None, optional): Stride of chunk input for DPRNN.
+ """
+ if chunk_stride is None:
+ chunk_stride = chunk_size // 2
+ encoder = Encoder(feat_dim, win_len)
+ dprnn = DPRNN(feat_dim, hidden_dim, num_blocks, rnn_type, d_model, chunk_size, chunk_stride)
+ branches = nn.ModuleList(
+ [
+ _create_branch(d_model, nhead, "stoi"),
+ _create_branch(d_model, nhead, "pesq"),
+ _create_branch(d_model, nhead, "sisdr"),
+ ]
+ )
+ return SquimObjective(encoder, dprnn, branches)
+
+
+def squim_objective_base() -> SquimObjective:
+ """Build :class:`torchaudio.prototype.models.SquimObjective` model with default arguments."""
+ return squim_objective_model(
+ feat_dim=256,
+ win_len=64,
+ d_model=256,
+ nhead=4,
+ hidden_dim=256,
+ num_blocks=2,
+ rnn_type="LSTM",
+ chunk_size=71,
+ )
+
+@dataclass
+class SquimObjectiveBundle:
+
+ _path: str
+ _sample_rate: float
+
+ def _get_state_dict(self, dl_kwargs):
+ url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(url, **dl_kwargs)
+ return state_dict
+
+ def get_model(self, *, dl_kwargs=None) -> SquimObjective:
+ """Construct the SquimObjective model, and load the pretrained weight.
+
+ The weight file is downloaded from the internet and cached with
+ :func:`torch.hub.load_state_dict_from_url`
+
+ Args:
+ dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
+
+ Returns:
+ Variation of :py:class:`~torchaudio.models.SquimObjective`.
+ """
+ model = squim_objective_base()
+ model.load_state_dict(self._get_state_dict(dl_kwargs))
+ model.eval()
+ return model
+
+ @property
+ def sample_rate(self):
+ """Sample rate of the audio that the model is trained on.
+
+ :type: float
+ """
+ return self._sample_rate
+
+
+SQUIM_OBJECTIVE = SquimObjectiveBundle(
+ "squim_objective_dns2020.pth",
+ _sample_rate=16000,
+)
+SQUIM_OBJECTIVE.__doc__ = """SquimObjective pipeline trained using approach described in
+ :cite:`kumar2023torchaudio` on the *DNS 2020 Dataset* :cite:`reddy2020interspeech`.
+
+ The underlying model is constructed by :py:func:`torchaudio.models.squim_objective_base`.
+ The weights are under `Creative Commons Attribution 4.0 International License
+ `__.
+
+ Please refer to :py:class:`SquimObjectiveBundle` for usage instructions.
+ """
+
diff --git a/models/bandit/core/metrics/snr.py b/models/bandit/core/metrics/snr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2830b2cbecfa681c449d09e2d4c35a20fc98128
--- /dev/null
+++ b/models/bandit/core/metrics/snr.py
@@ -0,0 +1,150 @@
+from typing import Any, Callable
+
+import numpy as np
+import torch
+import torchmetrics as tm
+from torch._C import _LinAlgError
+from torchmetrics import functional as tmF
+
+
+class SafeSignalDistortionRatio(tm.SignalDistortionRatio):
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ def update(self, *args, **kwargs) -> Any:
+ try:
+ super().update(*args, **kwargs)
+ except:
+ pass
+
+ def compute(self) -> Any:
+ if self.total == 0:
+ return torch.tensor(torch.nan)
+ return super().compute()
+
+
+class BaseChunkMedianSignalRatio(tm.Metric):
+ def __init__(
+ self,
+ func: Callable,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # self.zero_mean = zero_mean
+ self.func = func
+ self.window_size = window_size
+ if hop_size is None:
+ hop_size = window_size
+ self.hop_size = hop_size
+
+ self.add_state(
+ "sum_snr",
+ default=torch.tensor(0.0),
+ dist_reduce_fx="sum"
+ )
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
+
+ n_samples = target.shape[-1]
+
+ n_chunks = int(
+ np.ceil((n_samples - self.window_size) / self.hop_size) + 1
+ )
+
+ snr_chunk = []
+
+ for i in range(n_chunks):
+ start = i * self.hop_size
+
+ if n_samples - start < self.window_size:
+ continue
+
+ end = start + self.window_size
+
+ try:
+ chunk_snr = self.func(
+ preds[..., start:end],
+ target[..., start:end]
+ )
+
+ # print(preds.shape, chunk_snr.shape)
+
+ if torch.all(torch.isfinite(chunk_snr)):
+ snr_chunk.append(chunk_snr)
+ except _LinAlgError:
+ pass
+
+ snr_chunk = torch.stack(snr_chunk, dim=-1)
+ snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1)
+
+ self.sum_snr += snr_batch.sum()
+ self.total += snr_batch.numel()
+
+ def compute(self) -> Any:
+ return self.sum_snr / self.total
+
+
+class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.signal_noise_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
+
+
+class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.scale_invariant_signal_noise_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
+
+
+class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio):
+ def __init__(
+ self,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.signal_distortion_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
+
+
+class ChunkMedianScaleInvariantSignalDistortionRatio(
+ BaseChunkMedianSignalRatio
+ ):
+ def __init__(
+ self,
+ window_size: int,
+ hop_size: int = None,
+ zero_mean: bool = False
+ ) -> None:
+ super().__init__(
+ func=tmF.scale_invariant_signal_distortion_ratio,
+ window_size=window_size,
+ hop_size=hop_size,
+ zero_mean=zero_mean,
+ )
diff --git a/models/bandit/core/model/__init__.py b/models/bandit/core/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ac48eb69d6f844ba5b73b213eae4cfab157cac
--- /dev/null
+++ b/models/bandit/core/model/__init__.py
@@ -0,0 +1,3 @@
+from .bsrnn.wrapper import (
+ MultiMaskMultiSourceBandSplitRNNSimple,
+)
diff --git a/models/bandit/core/model/_spectral.py b/models/bandit/core/model/_spectral.py
new file mode 100644
index 0000000000000000000000000000000000000000..564cd28600719579227a6085eed5e9d6ee521312
--- /dev/null
+++ b/models/bandit/core/model/_spectral.py
@@ -0,0 +1,58 @@
+from typing import Dict, Optional
+
+import torch
+import torchaudio as ta
+from torch import nn
+
+
+class _SpectralComponent(nn.Module):
+ def __init__(
+ self,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ assert power is None
+
+ window_fn = torch.__dict__[window_fn]
+
+ self.stft = (
+ ta.transforms.Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+ )
+
+ self.istft = (
+ ta.transforms.InverseSpectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+ )
diff --git a/models/bandit/core/model/bsrnn/__init__.py b/models/bandit/core/model/bsrnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c27826197fc8f4eb7a7036d8037966a58d8b38d4
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/__init__.py
@@ -0,0 +1,23 @@
+from abc import ABC
+from typing import Iterable, Mapping, Union
+
+from torch import nn
+
+from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
+from models.bandit.core.model.bsrnn.tfmodel import (
+ SeqBandModellingModule,
+ TransformerTimeFreqModule,
+)
+
+
+class BandsplitCoreBase(nn.Module, ABC):
+ band_split: nn.Module
+ tf_model: nn.Module
+ mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]]
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ @staticmethod
+ def mask(x, m):
+ return x * m
diff --git a/models/bandit/core/model/bsrnn/bandsplit.py b/models/bandit/core/model/bsrnn/bandsplit.py
new file mode 100644
index 0000000000000000000000000000000000000000..63e6255857fb2d538634be317332afb2f93e145d
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/bandsplit.py
@@ -0,0 +1,139 @@
+from typing import List, Tuple
+
+import torch
+from torch import nn
+
+from models.bandit.core.model.bsrnn.utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class NormFC(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ bandwidth: int,
+ in_channel: int,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.treat_channel_as_feature = treat_channel_as_feature
+
+ if normalize_channel_independently:
+ raise NotImplementedError
+
+ reim = 2
+
+ self.norm = nn.LayerNorm(in_channel * bandwidth * reim)
+
+ fc_in = bandwidth * reim
+
+ if treat_channel_as_feature:
+ fc_in *= in_channel
+ else:
+ assert emb_dim % in_channel == 0
+ emb_dim = emb_dim // in_channel
+
+ self.fc = nn.Linear(fc_in, emb_dim)
+
+ def forward(self, xb):
+ # xb = (batch, n_time, in_chan, reim * band_width)
+
+ batch, n_time, in_chan, ribw = xb.shape
+ xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
+ # (batch, n_time, in_chan * reim * band_width)
+
+ if not self.treat_channel_as_feature:
+ xb = xb.reshape(batch, n_time, in_chan, ribw)
+ # (batch, n_time, in_chan, reim * band_width)
+
+ zb = self.fc(xb)
+ # (batch, n_time, emb_dim)
+ # OR
+ # (batch, n_time, in_chan, emb_dim_per_chan)
+
+ if not self.treat_channel_as_feature:
+ batch, n_time, in_chan, emb_dim_per_chan = zb.shape
+ # (batch, n_time, in_chan, emb_dim_per_chan)
+ zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))
+
+ return zb # (batch, n_time, emb_dim)
+
+
+class BandSplitModule(nn.Module):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ in_channel: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ check_nonzero_bandwidth(band_specs)
+
+ if require_no_gap:
+ check_no_gap(band_specs)
+
+ if require_no_overlap:
+ check_no_overlap(band_specs)
+
+ self.band_specs = band_specs
+ # list of [fstart, fend) in index.
+ # Note that fend is exclusive.
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+ self.emb_dim = emb_dim
+
+ self.norm_fc_modules = nn.ModuleList(
+ [ # type: ignore
+ (
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channel=in_channel,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ )
+ )
+ for bw in self.band_widths
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+
+ batch, in_chan, _, n_time = x.shape
+
+ z = torch.zeros(
+ size=(batch, self.n_bands, n_time, self.emb_dim),
+ device=x.device
+ )
+
+ xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2
+ xr = torch.permute(
+ xr,
+ (0, 3, 1, 4, 2)
+ ) # batch, n_time, in_chan, 2, n_freq
+ batch, n_time, in_chan, reim, band_width = xr.shape
+ for i, nfm in enumerate(self.norm_fc_modules):
+ # print(f"bandsplit/band{i:02d}")
+ fstart, fend = self.band_specs[i]
+ xb = xr[..., fstart:fend]
+ # (batch, n_time, in_chan, reim, band_width)
+ xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
+ # (batch, n_time, in_chan, reim * band_width)
+ # z.append(nfm(xb)) # (batch, n_time, emb_dim)
+ z[:, i, :, :] = nfm(xb.contiguous())
+
+ # z = torch.stack(z, dim=1)
+
+ return z
diff --git a/models/bandit/core/model/bsrnn/core.py b/models/bandit/core/model/bsrnn/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fd36259002a395e7b7864f605fcab5b4422e422
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/core.py
@@ -0,0 +1,661 @@
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from models.bandit.core.model.bsrnn import BandsplitCoreBase
+from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule
+from models.bandit.core.model.bsrnn.maskestim import (
+ MaskEstimationModule,
+ OverlappingMaskEstimationModule
+)
+from models.bandit.core.model.bsrnn.tfmodel import (
+ ConvolutionalTimeFreqModule,
+ SeqBandModellingModule,
+ TransformerTimeFreqModule
+)
+
+
+class MultiMaskBandSplitCoreBase(BandsplitCoreBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x, cond=None, compute_residual: bool = True):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+ # print(x.shape)
+ batch, in_chan, n_freq, n_time = x.shape
+ x = torch.reshape(x, (-1, 1, n_freq, n_time))
+
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
+
+ # if torch.any(torch.isnan(z)):
+ # raise ValueError("z nan")
+
+ # print(z)
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
+ # print(q)
+
+
+ # if torch.any(torch.isnan(q)):
+ # raise ValueError("q nan")
+
+ out = {}
+
+ for stem, mem in self.mask_estim.items():
+ m = mem(q, cond=cond)
+
+ # if torch.any(torch.isnan(m)):
+ # raise ValueError("m nan", stem)
+
+ s = self.mask(x, m)
+ s = torch.reshape(s, (batch, in_chan, n_freq, n_time))
+ out[stem] = s
+
+ return {"spectrogram": out}
+
+
+
+ def instantiate_mask_estim(self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int,
+ hidden_activation: str,
+
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ mult_add_mask: bool = False
+ ):
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if "mne:+" in stems:
+ stems = [s for s in stems if s != "mne:+"]
+
+ if overlapping_band:
+ assert freq_weights is not None
+ assert n_freq is not None
+
+ if mult_add_mask:
+
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: MultAddMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
+ )
+ else:
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: OverlappingMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
+ )
+ else:
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: MaskEstimationModule(
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for stem in stems
+ }
+ )
+
+ def instantiate_bandsplit(self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ emb_dim: int = 128
+ ):
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+class SingleMaskBandsplitCoreBase(BandsplitCoreBase):
+ def __init__(self, **kwargs) -> None:
+ super().__init__()
+
+ def forward(self, x):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
+ m = self.mask_estim(q) # (batch, in_chan, n_freq, n_time)
+
+ s = self.mask(x, m)
+
+ return s
+
+
+class SingleMaskBandsplitCoreRNN(
+ SingleMaskBandsplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__()
+ self.band_split = (BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ ))
+ self.tf_model = (SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ))
+ self.mask_estim = (MaskEstimationModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ ))
+
+
+class SingleMaskBandsplitCoreTransformer(
+ SingleMaskBandsplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__()
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+ self.tf_model = TransformerTimeFreqModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
+ )
+ self.mask_estim = MaskEstimationModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+
+class MultiSourceMultiMaskBandSplitCoreRNN(MultiMaskBandSplitCoreBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = True,
+ mult_add_mask: bool = False
+ ) -> None:
+
+ super().__init__()
+ self.instantiate_bandsplit(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim
+ )
+
+
+ self.tf_model = (
+ SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+ )
+
+ self.mult_add_mask = mult_add_mask
+
+ self.instantiate_mask_estim(
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+
+ @staticmethod
+ def _mult_add_mask(x, m):
+
+ assert m.ndim == 5
+
+ mm = m[..., 0]
+ am = m[..., 1]
+
+ # print(mm.shape, am.shape, x.shape, m.shape)
+
+ return x * mm + am
+
+ def mask(self, x, m):
+ if self.mult_add_mask:
+
+ return self._mult_add_mask(x, m)
+ else:
+ return super().mask(x, m)
+
+
+class MultiSourceMultiMaskBandSplitCoreTransformer(
+ MultiMaskBandSplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights:bool=True,
+ rnn_type: str = "LSTM",
+ cond_dim: int = 0,
+ mult_add_mask: bool = False
+ ) -> None:
+ super().__init__()
+ self.instantiate_bandsplit(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim
+ )
+ self.tf_model = TransformerTimeFreqModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
+ )
+
+ self.instantiate_mask_estim(
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+
+
+
+class MultiSourceMultiMaskBandSplitCoreConv(
+ MultiMaskBandSplitCoreBase,
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ use_freq_weights:bool=True,
+ rnn_type: str = "LSTM",
+ cond_dim: int = 0,
+ mult_add_mask: bool = False
+ ) -> None:
+ super().__init__()
+ self.instantiate_bandsplit(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim
+ )
+ self.tf_model = ConvolutionalTimeFreqModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=tf_dropout,
+ )
+
+ self.instantiate_mask_estim(
+ in_channel=in_channel,
+ stems=stems,
+ band_specs=band_specs,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=overlapping_band,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+
+
+class PatchingMaskBandsplitCoreBase(MultiMaskBandSplitCoreBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def mask(self, x, m):
+ # x.shape = (batch, n_channel, n_freq, n_time)
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
+
+ _, n_channel, kernel_freq, kernel_time, n_freq, n_time = m.shape
+ padding = ((kernel_freq - 1) // 2, (kernel_time - 1) // 2)
+
+ xf = F.unfold(
+ x,
+ kernel_size=(kernel_freq, kernel_time),
+ padding=padding,
+ stride=(1, 1),
+ )
+
+ xf = xf.view(
+ -1,
+ n_channel,
+ kernel_freq,
+ kernel_time,
+ n_freq,
+ n_time,
+ )
+
+ sf = xf * m
+
+ sf = sf.view(
+ -1,
+ n_channel * kernel_freq * kernel_time,
+ n_freq * n_time,
+ )
+
+ s = F.fold(
+ sf,
+ output_size=(n_freq, n_time),
+ kernel_size=(kernel_freq, kernel_time),
+ padding=padding,
+ stride=(1, 1),
+ ).view(
+ -1,
+ n_channel,
+ n_freq,
+ n_time,
+ )
+
+ return s
+
+ def old_mask(self, x, m):
+ # x.shape = (batch, n_channel, n_freq, n_time)
+ # m.shape = (kernel_freq, kernel_time, batch, n_channel, n_freq, n_time)
+
+ s = torch.zeros_like(x)
+
+ _, n_channel, n_freq, n_time = x.shape
+ kernel_freq, kernel_time, _, _, _, _ = m.shape
+
+ # print(x.shape, m.shape)
+
+ kernel_freq_half = (kernel_freq - 1) // 2
+ kernel_time_half = (kernel_time - 1) // 2
+
+ for ifreq in range(kernel_freq):
+ for itime in range(kernel_time):
+ df, dt = kernel_freq_half - ifreq, kernel_time_half - itime
+ x = x.roll(shifts=(df, dt), dims=(2, 3))
+
+ # if `df` > 0:
+ # x[:, :, :df, :] = 0
+ # elif `df` < 0:
+ # x[:, :, df:, :] = 0
+
+ # if `dt` > 0:
+ # x[:, :, :, :dt] = 0
+ # elif `dt` < 0:
+ # x[:, :, :, dt:] = 0
+
+ fslice = slice(max(0, df), min(n_freq, n_freq + df))
+ tslice = slice(max(0, dt), min(n_time, n_time + dt))
+
+ s[:, :, fslice, tslice] += x[:, :, fslice, tslice] * m[ifreq,
+ itime, :,
+ :, fslice,
+ tslice]
+
+ return s
+
+
+class MultiSourceMultiPatchingMaskBandSplitCoreRNN(
+ PatchingMaskBandsplitCoreBase
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: List[Tuple[float, float]],
+ mask_kernel_freq: int,
+ mask_kernel_time: int,
+ conv_kernel_freq: int,
+ conv_kernel_time: int,
+ kernel_norm_mlp_version: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ overlapping_band: bool = False,
+ freq_weights: Optional[List[torch.Tensor]] = None,
+ n_freq: Optional[int] = None,
+ ) -> None:
+
+ super().__init__()
+ self.band_split = BandSplitModule(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+ self.tf_model = (
+ SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+ )
+
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if overlapping_band:
+ assert freq_weights is not None
+ assert n_freq is not None
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: PatchingMaskEstimationModule(
+ band_specs=band_specs,
+ freq_weights=freq_weights,
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ mask_kernel_freq=mask_kernel_freq,
+ mask_kernel_time=mask_kernel_time,
+ conv_kernel_freq=conv_kernel_freq,
+ conv_kernel_time=conv_kernel_time,
+ kernel_norm_mlp_version=kernel_norm_mlp_version
+ )
+ for stem in stems
+ }
+ )
+ else:
+ raise NotImplementedError
diff --git a/models/bandit/core/model/bsrnn/maskestim.py b/models/bandit/core/model/bsrnn/maskestim.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b9289dfa702e02ff4d4f0dc76196fd39bb68e34
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/maskestim.py
@@ -0,0 +1,347 @@
+import warnings
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch.nn.modules import activation
+
+from models.bandit.core.model.bsrnn.utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class BaseNormMLP(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True, ):
+
+ super().__init__()
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+ self.hidden_activation_kwargs = hidden_activation_kwargs
+ self.norm = nn.LayerNorm(emb_dim)
+ self.hidden = torch.jit.script(nn.Sequential(
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
+ activation.__dict__[hidden_activation](
+ **self.hidden_activation_kwargs
+ ),
+ ))
+
+ self.bandwidth = bandwidth
+ self.in_channel = in_channel
+
+ self.complex_mask = complex_mask
+ self.reim = 2 if complex_mask else 1
+ self.glu_mult = 2
+
+
+class NormMLP(BaseNormMLP):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__(
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ bandwidth=bandwidth,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ self.output = torch.jit.script(
+ nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channel * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
+ )
+
+ def reshape_output(self, mb):
+ # print(mb.shape)
+ batch, n_time, _ = mb.shape
+ if self.complex_mask:
+ mb = mb.reshape(
+ batch,
+ n_time,
+ self.in_channel,
+ self.bandwidth,
+ self.reim
+ ).contiguous()
+ # print(mb.shape)
+ mb = torch.view_as_complex(
+ mb
+ ) # (batch, n_time, in_channel, bandwidth)
+ else:
+ mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)
+
+ mb = torch.permute(
+ mb,
+ (0, 2, 3, 1)
+ ) # (batch, in_channel, bandwidth, n_time)
+
+ return mb
+
+ def forward(self, qb):
+ # qb = (batch, n_time, emb_dim)
+
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("qb0")
+
+
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
+
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("qb1")
+
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("qb2")
+ mb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
+ # if torch.any(torch.isnan(qb)):
+ # raise ValueError("mb")
+ mb = self.reshape_output(mb) # (batch, in_channel, bandwidth, n_time)
+
+ return mb
+
+
+class MultAddNormMLP(NormMLP):
+ def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: "int | None", hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
+ super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)
+
+ self.output2 = torch.jit.script(
+ nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channel * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
+ )
+
+ def forward(self, qb):
+
+ qb = self.norm(qb) # (batch, n_time, emb_dim)
+ qb = self.hidden(qb) # (batch, n_time, mlp_dim)
+ mmb = self.output(qb) # (batch, n_time, bandwidth * in_channel * reim)
+ mmb = self.reshape_output(mmb) # (batch, in_channel, bandwidth, n_time)
+ amb = self.output2(qb) # (batch, n_time, bandwidth * in_channel * reim)
+ amb = self.reshape_output(amb) # (batch, in_channel, bandwidth, n_time)
+
+ return mmb, amb
+
+
+class MaskEstimationModuleSuperBase(nn.Module):
+ pass
+
+
+class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ ) -> None:
+ super().__init__()
+
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if norm_mlp_kwargs is None:
+ norm_mlp_kwargs = {}
+
+ self.norm_mlp = nn.ModuleList(
+ [
+ (
+ norm_mlp_cls(
+ bandwidth=self.band_widths[b],
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ **norm_mlp_kwargs,
+ )
+ )
+ for b in range(self.n_bands)
+ ]
+ )
+
+ def compute_masks(self, q):
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ masks = []
+
+ for b, nmlp in enumerate(self.norm_mlp):
+ # print(f"maskestim/{b:02d}")
+ qb = q[:, b, :, :]
+ mb = nmlp(qb)
+ masks.append(mb)
+
+ return masks
+
+
+
+class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs: List[Tuple[float, float]],
+ freq_weights: List[torch.Tensor],
+ n_freq: int,
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ use_freq_weights: bool = True,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+
+ # if cond_dim > 0:
+ # raise NotImplementedError
+
+ super().__init__(
+ band_specs=band_specs,
+ emb_dim=emb_dim + cond_dim,
+ mlp_dim=mlp_dim,
+ in_channel=in_channel,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ norm_mlp_cls=norm_mlp_cls,
+ norm_mlp_kwargs=norm_mlp_kwargs,
+ )
+
+ self.n_freq = n_freq
+ self.band_specs = band_specs
+ self.in_channel = in_channel
+
+ if freq_weights is not None:
+ for i, fw in enumerate(freq_weights):
+ self.register_buffer(f"freq_weights/{i}", fw)
+
+ self.use_freq_weights = use_freq_weights
+ else:
+ self.use_freq_weights = False
+
+ self.cond_dim = cond_dim
+
+ def forward(self, q, cond=None):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ if cond is not None:
+ print(cond)
+ if cond.ndim == 2:
+ cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
+ elif cond.ndim == 3:
+ assert cond.shape[1] == n_time
+ else:
+ raise ValueError(f"Invalid cond shape: {cond.shape}")
+
+ q = torch.cat([q, cond], dim=-1)
+ elif self.cond_dim > 0:
+ cond = torch.ones(
+ (batch, n_bands, n_time, self.cond_dim),
+ device=q.device,
+ dtype=q.dtype,
+ )
+ q = torch.cat([q, cond], dim=-1)
+ else:
+ pass
+
+ mask_list = self.compute_masks(
+ q
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
+
+ masks = torch.zeros(
+ (batch, self.in_channel, self.n_freq, n_time),
+ device=q.device,
+ dtype=mask_list[0].dtype,
+ )
+
+ for im, mask in enumerate(mask_list):
+ fstart, fend = self.band_specs[im]
+ if self.use_freq_weights:
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
+ mask = mask * fw
+ masks[:, :, fstart:fend, :] += mask
+
+ return masks
+
+
+class MaskEstimationModule(OverlappingMaskEstimationModule):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channel: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ **kwargs,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+ check_no_overlap(band_specs)
+ super().__init__(
+ in_channel=in_channel,
+ band_specs=band_specs,
+ freq_weights=None,
+ n_freq=None,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ def forward(self, q, cond=None):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ masks = self.compute_masks(
+ q
+ ) # [n_bands * (batch, in_channel, bandwidth, n_time)]
+
+ # TODO: currently this requires band specs to have no gap and no overlap
+ masks = torch.concat(
+ masks,
+ dim=2
+ ) # (batch, in_channel, n_freq, n_time)
+
+ return masks
diff --git a/models/bandit/core/model/bsrnn/tfmodel.py b/models/bandit/core/model/bsrnn/tfmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba710798c5ab49936bd63c914f20da516cbc6af9
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/tfmodel.py
@@ -0,0 +1,317 @@
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.modules import rnn
+
+import torch.backends.cuda
+
+
+class TimeFrequencyModellingModule(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+
+class ResidualRNN(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ rnn_dim: int,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ use_batch_trick: bool = True,
+ use_layer_norm: bool = True,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+
+ self.use_layer_norm = use_layer_norm
+ if use_layer_norm:
+ self.norm = nn.LayerNorm(emb_dim)
+ else:
+ self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)
+
+ self.rnn = rnn.__dict__[rnn_type](
+ input_size=emb_dim,
+ hidden_size=rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional,
+ )
+
+ self.fc = nn.Linear(
+ in_features=rnn_dim * (2 if bidirectional else 1),
+ out_features=emb_dim
+ )
+
+ self.use_batch_trick = use_batch_trick
+ if not self.use_batch_trick:
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
+
+ def forward(self, z):
+ # z = (batch, n_uncrossed, n_across, emb_dim)
+
+ z0 = torch.clone(z)
+
+ # print(z.device)
+
+ if self.use_layer_norm:
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
+ else:
+ z = torch.permute(
+ z, (0, 3, 1, 2)
+ ) # (batch, emb_dim, n_uncrossed, n_across)
+
+ z = self.norm(z) # (batch, emb_dim, n_uncrossed, n_across)
+
+ z = torch.permute(
+ z, (0, 2, 3, 1)
+ ) # (batch, n_uncrossed, n_across, emb_dim)
+
+ batch, n_uncrossed, n_across, emb_dim = z.shape
+
+ if self.use_batch_trick:
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
+
+ z = self.rnn(z.contiguous())[0] # (batch * n_uncrossed, n_across, dir_rnn_dim)
+
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
+ # (batch, n_uncrossed, n_across, dir_rnn_dim)
+ else:
+ # Note: this is EXTREMELY SLOW
+ zlist = []
+ for i in range(n_uncrossed):
+ zi = self.rnn(z[:, i, :, :])[0] # (batch, n_across, emb_dim)
+ zlist.append(zi)
+
+ z = torch.stack(
+ zlist,
+ dim=1
+ ) # (batch, n_uncrossed, n_across, dir_rnn_dim)
+
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
+
+ z = z + z0
+
+ return z
+
+
+class SeqBandModellingModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ parallel_mode=False,
+ ) -> None:
+ super().__init__()
+ self.seqband = nn.ModuleList([])
+
+ if parallel_mode:
+ for _ in range(n_modules):
+ self.seqband.append(
+ nn.ModuleList(
+ [ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )]
+ )
+ )
+ else:
+
+ for _ in range(2 * n_modules):
+ self.seqband.append(
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+ )
+
+ self.parallel_mode = parallel_mode
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+
+ if self.parallel_mode:
+ for sbm_pair in self.seqband:
+ # z: (batch, n_bands, n_time, emb_dim)
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
+ z = zt + zf.transpose(1, 2)
+ else:
+ for sbm in self.seqband:
+ z = sbm(z)
+ z = z.transpose(1, 2)
+
+ # (batch, n_bands, n_time, emb_dim)
+ # --> (batch, n_time, n_bands, emb_dim)
+ # OR
+ # (batch, n_time, n_bands, emb_dim)
+ # --> (batch, n_bands, n_time, emb_dim)
+
+ q = z
+ return q # (batch, n_bands, n_time, emb_dim)
+
+
+class ResidualTransformer(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+
+ self.tf = nn.TransformerEncoderLayer(
+ d_model=emb_dim,
+ nhead=4,
+ dim_feedforward=rnn_dim,
+ batch_first=True
+ )
+
+ self.is_causal = not bidirectional
+ self.dropout = dropout
+
+ def forward(self, z):
+ batch, n_uncrossed, n_across, emb_dim = z.shape
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
+ z = self.tf(z, is_causal=self.is_causal) # (batch, n_uncrossed, n_across, emb_dim)
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, emb_dim))
+
+ return z
+
+
+class TransformerTimeFreqModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.norm = nn.LayerNorm(emb_dim)
+ self.seqband = nn.ModuleList([])
+
+ for _ in range(2 * n_modules):
+ self.seqband.append(
+ ResidualTransformer(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=dropout,
+ )
+ )
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+ z = self.norm(z) # (batch, n_bands, n_time, emb_dim)
+
+ for sbm in self.seqband:
+ z = sbm(z)
+ z = z.transpose(1, 2)
+
+ # (batch, n_bands, n_time, emb_dim)
+ # --> (batch, n_time, n_bands, emb_dim)
+ # OR
+ # (batch, n_time, n_bands, emb_dim)
+ # --> (batch, n_bands, n_time, emb_dim)
+
+ q = z
+ return q # (batch, n_bands, n_time, emb_dim)
+
+
+
+class ResidualConvolution(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+ self.norm = nn.InstanceNorm2d(emb_dim, affine=True)
+
+ self.conv = nn.Sequential(
+ nn.Conv2d(
+ in_channels=emb_dim,
+ out_channels=rnn_dim,
+ kernel_size=(3, 3),
+ padding="same",
+ stride=(1, 1),
+ ),
+ nn.Tanhshrink()
+ )
+
+ self.is_causal = not bidirectional
+ self.dropout = dropout
+
+ self.fc = nn.Conv2d(
+ in_channels=rnn_dim,
+ out_channels=emb_dim,
+ kernel_size=(1, 1),
+ padding="same",
+ stride=(1, 1),
+ )
+
+
+ def forward(self, z):
+ # z = (batch, n_uncrossed, n_across, emb_dim)
+
+ z0 = torch.clone(z)
+
+ z = self.norm(z) # (batch, n_uncrossed, n_across, emb_dim)
+ z = self.conv(z) # (batch, n_uncrossed, n_across, emb_dim)
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
+ z = z + z0
+
+ return z
+
+
+class ConvolutionalTimeFreqModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.seqband = torch.jit.script(nn.Sequential(
+ *[ResidualConvolution(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ dropout=dropout,
+ ) for _ in range(2 * n_modules) ]))
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+
+ z = torch.permute(z, (0, 3, 1, 2)) # (batch, emb_dim, n_bands, n_time)
+
+ z = self.seqband(z) # (batch, emb_dim, n_bands, n_time)
+
+ z = torch.permute(z, (0, 2, 3, 1)) # (batch, n_bands, n_time, emb_dim)
+
+ return z
diff --git a/models/bandit/core/model/bsrnn/utils.py b/models/bandit/core/model/bsrnn/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf8636e65fe9e7fdd13fa063760018df90a01cff
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/utils.py
@@ -0,0 +1,583 @@
+import os
+from abc import abstractmethod
+from typing import Any, Callable
+
+import numpy as np
+import torch
+from librosa import hz_to_midi, midi_to_hz
+from torch import Tensor
+from torchaudio import functional as taF
+from spafe.fbanks import bark_fbanks
+from spafe.utils.converters import erb2hz, hz2bark, hz2erb
+from torchaudio.functional.functional import _create_triangular_filterbank
+
+
+def band_widths_from_specs(band_specs):
+ return [e - i for i, e in band_specs]
+
+
+def check_nonzero_bandwidth(band_specs):
+ # pprint(band_specs)
+ for fstart, fend in band_specs:
+ if fend - fstart <= 0:
+ raise ValueError("Bands cannot be zero-width")
+
+
+def check_no_overlap(band_specs):
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr <= fend_prev:
+ raise ValueError("Bands cannot overlap")
+
+
+def check_no_gap(band_specs):
+ fstart, _ = band_specs[0]
+ assert fstart == 0
+
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr - fend_prev > 1:
+ raise ValueError("Bands cannot leave gap")
+ fend_prev = fend_curr
+
+
+class BandsplitSpecification:
+ def __init__(self, nfft: int, fs: int) -> None:
+ self.fs = fs
+ self.nfft = nfft
+ self.nyquist = fs / 2
+ self.max_index = nfft // 2 + 1
+
+ self.split500 = self.hertz_to_index(500)
+ self.split1k = self.hertz_to_index(1000)
+ self.split2k = self.hertz_to_index(2000)
+ self.split4k = self.hertz_to_index(4000)
+ self.split8k = self.hertz_to_index(8000)
+ self.split16k = self.hertz_to_index(16000)
+ self.split20k = self.hertz_to_index(20000)
+
+ self.above20k = [(self.split20k, self.max_index)]
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
+
+ def index_to_hertz(self, index: int):
+ return index * self.fs / self.nfft
+
+ def hertz_to_index(self, hz: float, round: bool = True):
+ index = hz * self.nfft / self.fs
+
+ if round:
+ index = int(np.round(index))
+
+ return index
+
+ def get_band_specs_with_bandwidth(
+ self,
+ start_index,
+ end_index,
+ bandwidth_hz
+ ):
+ band_specs = []
+ lower = start_index
+
+ while lower < end_index:
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
+ upper = min(upper, end_index)
+
+ band_specs.append((lower, upper))
+ lower = upper
+
+ return band_specs
+
+ @abstractmethod
+ def get_band_specs(self):
+ raise NotImplementedError
+
+
+class VocalBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ self.version = version
+
+ def get_band_specs(self):
+ return getattr(self, f"version{self.version}")()
+
+ @property
+ def version1(self):
+ return self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
+ )
+
+ def version2(self):
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k,
+ end_index=self.split20k,
+ bandwidth_hz=2000
+ )
+
+ return below16k + below20k + self.above20k
+
+ def version3(self):
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k,
+ end_index=self.split16k,
+ bandwidth_hz=2000
+ )
+
+ return below8k + below16k + self.above16k
+
+ def version4(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k,
+ end_index=self.split8k,
+ bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k,
+ end_index=self.split16k,
+ bandwidth_hz=2000
+ )
+
+ return below1k + below8k + below16k + self.above16k
+
+ def version5(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k,
+ end_index=self.split16k,
+ bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k,
+ end_index=self.split20k,
+ bandwidth_hz=2000
+ )
+ return below1k + below16k + below20k + self.above20k
+
+ def version6(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k,
+ end_index=self.split4k,
+ bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k,
+ end_index=self.split8k,
+ bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k,
+ end_index=self.split16k,
+ bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + self.above16k
+
+ def version7(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k,
+ end_index=self.split4k,
+ bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k,
+ end_index=self.split8k,
+ bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k,
+ end_index=self.split16k,
+ bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k,
+ end_index=self.split20k,
+ bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
+
+
+class OtherBandsplitSpecification(VocalBandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs, version="7")
+
+
+class BassBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below500 = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split500, bandwidth_hz=50
+ )
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=self.split500,
+ end_index=self.split1k,
+ bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k,
+ end_index=self.split4k,
+ bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k,
+ end_index=self.split8k,
+ bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k,
+ end_index=self.split16k,
+ bandwidth_hz=2000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below500 + below1k + below4k + below8k + below16k + above16k
+
+
+class DrumBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
+ )
+ below2k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k,
+ end_index=self.split2k,
+ bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split2k,
+ end_index=self.split4k,
+ bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k,
+ end_index=self.split8k,
+ bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k,
+ end_index=self.split16k,
+ bandwidth_hz=1000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below1k + below2k + below4k + below8k + below16k + above16k
+
+
+
+
+class PerceptualBandsplitSpecification(BandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+ self.n_bands = n_bands
+ if f_max is None:
+ f_max = fs / 2
+
+ self.filterbank = fbank_fn(
+ n_bands, fs, f_min, f_max, self.max_index
+ )
+
+ weight_per_bin = torch.sum(
+ self.filterbank,
+ dim=0,
+ keepdim=True
+ ) # (1, n_freqs)
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
+
+ freq_weights = []
+ band_specs = []
+ for i in range(self.n_bands):
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
+ if isinstance(active_bins, int):
+ active_bins = (active_bins, active_bins)
+ if len(active_bins) == 0:
+ continue
+ start_index = active_bins[0]
+ end_index = active_bins[-1] + 1
+ band_specs.append((start_index, end_index))
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
+
+ self.freq_weights = freq_weights
+ self.band_specs = band_specs
+
+ def get_band_specs(self):
+ return self.band_specs
+
+ def get_freq_weights(self):
+ return self.freq_weights
+
+ def save_to_file(self, dir_path: str) -> None:
+
+ os.makedirs(dir_path, exist_ok=True)
+
+ import pickle
+
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
+ pickle.dump(
+ {
+ "band_specs": self.band_specs,
+ "freq_weights": self.freq_weights,
+ "filterbank": self.filterbank,
+ },
+ f,
+ )
+
+def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ fb = taF.melscale_fbanks(
+ n_mels=n_bands,
+ sample_rate=fs,
+ f_min=f_min,
+ f_max=f_max,
+ n_freqs=n_freqs,
+ ).T
+
+ fb[0, 0] = 1.0
+
+ return fb
+
+
+class MelBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
+ scale="constant"):
+
+ nfft = 2 * (n_freqs - 1)
+ df = fs / nfft
+ # init freqs
+ f_max = f_max or fs / 2
+ f_min = f_min or 0
+ f_min = fs / nfft
+
+ n_octaves = np.log2(f_max / f_min)
+ n_octaves_per_band = n_octaves / n_bands
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
+
+ low_midi = max(0, hz_to_midi(f_min))
+ high_midi = hz_to_midi(f_max)
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
+ hz_pts = midi_to_hz(midi_points)
+
+ low_pts = hz_pts / bandwidth_mult
+ high_pts = hz_pts * bandwidth_mult
+
+ low_bins = np.floor(low_pts / df).astype(int)
+ high_bins = np.ceil(high_pts / df).astype(int)
+
+ fb = np.zeros((n_bands, n_freqs))
+
+ for i in range(n_bands):
+ fb[i, low_bins[i]:high_bins[i]+1] = 1.0
+
+ fb[0, :low_bins[0]] = 1.0
+ fb[-1, high_bins[-1]+1:] = 1.0
+
+ return torch.as_tensor(fb)
+
+class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+def bark_filterbank(
+ n_bands, fs, f_min, f_max, n_freqs
+):
+ nfft = 2 * (n_freqs -1)
+ fb, _ = bark_fbanks.bark_filter_banks(
+ nfilts=n_bands,
+ nfft=nfft,
+ fs=fs,
+ low_freq=f_min,
+ high_freq=f_max,
+ scale="constant"
+ )
+
+ return torch.as_tensor(fb)
+
+class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+def triangular_bark_filterbank(
+ n_bands, fs, f_min, f_max, n_freqs
+):
+
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+ # calculate mel freq bins
+ m_min = hz2bark(f_min)
+ m_max = hz2bark(f_max)
+
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+ f_pts = 600 * torch.sinh(m_pts / 6)
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ fb = fb.T
+
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+ fb[first_active_band, :first_active_bin] = 1.0
+
+ return fb
+
+class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+
+def minibark_filterbank(
+ n_bands, fs, f_min, f_max, n_freqs
+):
+ fb = bark_filterbank(
+ n_bands,
+ fs,
+ f_min,
+ f_max,
+ n_freqs
+ )
+
+ fb[fb < np.sqrt(0.5)] = 0.0
+
+ return fb
+
+class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+
+
+
+def erb_filterbank(
+ n_bands: int,
+ fs: int,
+ f_min: float,
+ f_max: float,
+ n_freqs: int,
+) -> Tensor:
+ # freq bins
+ A = (1000 * np.log(10)) / (24.7 * 4.37)
+ all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+ # calculate mel freq bins
+ m_min = hz2erb(f_min)
+ m_max = hz2erb(f_max)
+
+ m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+ f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
+
+ # create filterbank
+ fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+ fb = fb.T
+
+
+ first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+ first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+ fb[first_active_band, :first_active_bin] = 1.0
+
+ return fb
+
+
+
+class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None
+ ) -> None:
+ super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+if __name__ == "__main__":
+ import pandas as pd
+
+ band_defs = []
+
+ for bands in [VocalBandsplitSpecification]:
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
+
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
+
+ for i, (f_min, f_max) in enumerate(mbs):
+ band_defs.append({
+ "band": band_name,
+ "band_index": i,
+ "f_min": f_min,
+ "f_max": f_max
+ })
+
+ df = pd.DataFrame(band_defs)
+ df.to_csv("vox7bands.csv", index=False)
\ No newline at end of file
diff --git a/models/bandit/core/model/bsrnn/wrapper.py b/models/bandit/core/model/bsrnn/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31c087db33eb215effa3c3fc492999c5672c55e
--- /dev/null
+++ b/models/bandit/core/model/bsrnn/wrapper.py
@@ -0,0 +1,882 @@
+from pprint import pprint
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from models.bandit.core.model._spectral import _SpectralComponent
+from models.bandit.core.model.bsrnn.utils import (
+ BarkBandsplitSpecification, BassBandsplitSpecification,
+ DrumBandsplitSpecification,
+ EquivalentRectangularBandsplitSpecification, MelBandsplitSpecification,
+ MusicalBandsplitSpecification, OtherBandsplitSpecification,
+ TriangularBarkBandsplitSpecification, VocalBandsplitSpecification,
+)
+from .core import (
+ MultiSourceMultiMaskBandSplitCoreConv,
+ MultiSourceMultiMaskBandSplitCoreRNN,
+ MultiSourceMultiMaskBandSplitCoreTransformer,
+ MultiSourceMultiPatchingMaskBandSplitCoreRNN, SingleMaskBandsplitCoreRNN,
+ SingleMaskBandsplitCoreTransformer,
+)
+
+import pytorch_lightning as pl
+
+def get_band_specs(band_specs, n_fft, fs, n_bands=None):
+ if band_specs in ["dnr:speech", "dnr:vox7", "musdb:vocals", "musdb:vox7"]:
+ bsm = VocalBandsplitSpecification(
+ nfft=n_fft, fs=fs
+ ).get_band_specs()
+ freq_weights = None
+ overlapping_band = False
+ elif "tribark" in band_specs:
+ assert n_bands is not None
+ specs = TriangularBarkBandsplitSpecification(
+ nfft=n_fft,
+ fs=fs,
+ n_bands=n_bands
+ )
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif "bark" in band_specs:
+ assert n_bands is not None
+ specs = BarkBandsplitSpecification(
+ nfft=n_fft,
+ fs=fs,
+ n_bands=n_bands
+ )
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif "erb" in band_specs:
+ assert n_bands is not None
+ specs = EquivalentRectangularBandsplitSpecification(
+ nfft=n_fft,
+ fs=fs,
+ n_bands=n_bands
+ )
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif "musical" in band_specs:
+ assert n_bands is not None
+ specs = MusicalBandsplitSpecification(
+ nfft=n_fft,
+ fs=fs,
+ n_bands=n_bands
+ )
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ elif band_specs == "dnr:mel" or "mel" in band_specs:
+ assert n_bands is not None
+ specs = MelBandsplitSpecification(
+ nfft=n_fft,
+ fs=fs,
+ n_bands=n_bands
+ )
+ bsm = specs.get_band_specs()
+ freq_weights = specs.get_freq_weights()
+ overlapping_band = True
+ else:
+ raise NameError
+
+ return bsm, freq_weights, overlapping_band
+
+
+def get_band_specs_map(band_specs_map, n_fft, fs, n_bands=None):
+ if band_specs_map == "musdb:all":
+ bsm = {
+ "vocals": VocalBandsplitSpecification(
+ nfft=n_fft, fs=fs
+ ).get_band_specs(),
+ "drums": DrumBandsplitSpecification(
+ nfft=n_fft, fs=fs
+ ).get_band_specs(),
+ "bass": BassBandsplitSpecification(
+ nfft=n_fft, fs=fs
+ ).get_band_specs(),
+ "other": OtherBandsplitSpecification(
+ nfft=n_fft, fs=fs
+ ).get_band_specs(),
+ }
+ freq_weights = None
+ overlapping_band = False
+ elif band_specs_map == "dnr:vox7":
+ bsm_, freq_weights, overlapping_band = get_band_specs(
+ "dnr:speech", n_fft, fs, n_bands
+ )
+ bsm = {
+ "speech": bsm_,
+ "music": bsm_,
+ "effects": bsm_
+ }
+ elif "dnr:vox7:" in band_specs_map:
+ stem = band_specs_map.split(":")[-1]
+ bsm_, freq_weights, overlapping_band = get_band_specs(
+ "dnr:speech", n_fft, fs, n_bands
+ )
+ bsm = {
+ stem: bsm_
+ }
+ else:
+ raise NameError
+
+ return bsm, freq_weights, overlapping_band
+
+
+class BandSplitWrapperBase(pl.LightningModule):
+ bsrnn: nn.Module
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+
+class SingleMaskMultiSourceBandSplitBase(
+ BandSplitWrapperBase,
+ _SpectralComponent
+):
+ def __init__(
+ self,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ if isinstance(band_specs_map, str):
+ self.band_specs_map, self.freq_weights, self.overlapping_band = get_band_specs_map(
+ band_specs_map,
+ n_fft,
+ fs,
+ n_bands=n_bands
+ )
+
+ self.stems = list(self.band_specs_map.keys())
+
+ def forward(self, batch):
+ audio = batch["audio"]
+
+ with torch.no_grad():
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
+ audio}
+
+ X = batch["spectrogram"]["mixture"]
+ length = batch["audio"]["mixture"].shape[-1]
+
+ output = {"spectrogram": {}, "audio": {}}
+
+ for stem, bsrnn in self.bsrnn.items():
+ S = bsrnn(X)
+ s = self.istft(S, length)
+ output["spectrogram"][stem] = S
+ output["audio"][stem] = s
+
+ return batch, output
+
+
+class MultiMaskMultiSourceBandSplitBase(
+ BandSplitWrapperBase,
+ _SpectralComponent
+):
+ def __init__(
+ self,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ if isinstance(band_specs, str):
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
+ band_specs,
+ n_fft,
+ fs,
+ n_bands
+ )
+
+ self.stems = stems
+
+ def forward(self, batch):
+ # with torch.no_grad():
+ audio = batch["audio"]
+ cond = batch.get("condition", None)
+ with torch.no_grad():
+ batch["spectrogram"] = {stem: self.stft(audio[stem]) for stem in
+ audio}
+
+ X = batch["spectrogram"]["mixture"]
+ length = batch["audio"]["mixture"].shape[-1]
+
+ output = self.bsrnn(X, cond=cond)
+ output["audio"] = {}
+
+ for stem, S in output["spectrogram"].items():
+ s = self.istft(S, length)
+ output["audio"][stem] = s
+
+ return batch, output
+
+
+class MultiMaskMultiSourceBandSplitBaseSimple(
+ BandSplitWrapperBase,
+ _SpectralComponent
+):
+ def __init__(
+ self,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ if isinstance(band_specs, str):
+ self.band_specs, self.freq_weights, self.overlapping_band = get_band_specs(
+ band_specs,
+ n_fft,
+ fs,
+ n_bands
+ )
+
+ self.stems = stems
+
+ def forward(self, batch):
+ with torch.no_grad():
+ X = self.stft(batch)
+ length = batch.shape[-1]
+ output = self.bsrnn(X, cond=None)
+ res = []
+ for stem, S in output["spectrogram"].items():
+ s = self.istft(S, length)
+ res.append(s)
+ res = torch.stack(res, dim=1)
+ return res
+
+
+class SingleMaskMultiSourceBandSplitRNN(SingleMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ) -> None:
+ super().__init__(
+ band_specs_map=band_specs_map,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ self.bsrnn = nn.ModuleDict(
+ {
+ src: SingleMaskBandsplitCoreRNN(
+ band_specs=specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for src, specs in self.band_specs_map.items()
+ }
+ )
+
+
+class SingleMaskMultiSourceBandSplitTransformer(
+ SingleMaskMultiSourceBandSplitBase
+):
+ def __init__(
+ self,
+ in_channel: int,
+ band_specs_map: Union[str, Dict[str, List[Tuple[float, float]]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ tf_dropout: float = 0.0,
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ) -> None:
+ super().__init__(
+ band_specs_map=band_specs_map,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ self.bsrnn = nn.ModuleDict(
+ {
+ src: SingleMaskBandsplitCoreTransformer(
+ band_specs=specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ tf_dropout=tf_dropout,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+ for src, specs in self.band_specs_map.items()
+ }
+ )
+
+
+class MultiMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ freeze_encoder: bool = False,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+
+ self.normalize_input = normalize_input
+ self.cond_dim = cond_dim
+
+ if freeze_encoder:
+ for param in self.bsrnn.band_split.parameters():
+ param.requires_grad = False
+
+ for param in self.bsrnn.tf_model.parameters():
+ param.requires_grad = False
+
+
+class MultiMaskMultiSourceBandSplitRNNSimple(MultiMaskMultiSourceBandSplitBaseSimple):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False,
+ freeze_encoder: bool = False,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreRNN(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+
+ self.normalize_input = normalize_input
+ self.cond_dim = cond_dim
+
+ if freeze_encoder:
+ for param in self.bsrnn.band_split.parameters():
+ param.requires_grad = False
+
+ for param in self.bsrnn.tf_model.parameters():
+ param.requires_grad = False
+
+
+class MultiMaskMultiSourceBandSplitTransformer(
+ MultiMaskMultiSourceBandSplitBase
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreTransformer(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+
+
+
+class MultiMaskMultiSourceBandSplitConv(
+ MultiMaskMultiSourceBandSplitBase
+):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ cond_dim: int = 0,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ use_freq_weights: bool = True,
+ normalize_input: bool = False,
+ mult_add_mask: bool = False
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiMaskBandSplitCoreConv(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ cond_dim=cond_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ mult_add_mask=mult_add_mask
+ )
+class PatchingMaskMultiSourceBandSplitRNN(MultiMaskMultiSourceBandSplitBase):
+ def __init__(
+ self,
+ in_channel: int,
+ stems: List[str],
+ band_specs: Union[str, List[Tuple[float, float]]],
+ kernel_norm_mlp_version: int = 1,
+ mask_kernel_freq: int = 3,
+ mask_kernel_time: int = 3,
+ conv_kernel_freq: int = 1,
+ conv_kernel_time: int = 1,
+ fs: int = 44100,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ n_bands: int = None,
+ ) -> None:
+ super().__init__(
+ stems=stems,
+ band_specs=band_specs,
+ fs=fs,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ n_bands=n_bands,
+ )
+
+ self.bsrnn = MultiSourceMultiPatchingMaskBandSplitCoreRNN(
+ stems=stems,
+ band_specs=self.band_specs,
+ in_channel=in_channel,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ overlapping_band=self.overlapping_band,
+ freq_weights=self.freq_weights,
+ n_freq=n_fft // 2 + 1,
+ mask_kernel_freq=mask_kernel_freq,
+ mask_kernel_time=mask_kernel_time,
+ conv_kernel_freq=conv_kernel_freq,
+ conv_kernel_time=conv_kernel_time,
+ kernel_norm_mlp_version=kernel_norm_mlp_version,
+ )
diff --git a/models/bandit/core/utils/__init__.py b/models/bandit/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/bandit/core/utils/audio.py b/models/bandit/core/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..adae756bf2b02a994a42fcc007da1e1ff7bb6cfb
--- /dev/null
+++ b/models/bandit/core/utils/audio.py
@@ -0,0 +1,463 @@
+from collections import defaultdict
+
+from tqdm.auto import tqdm
+from typing import Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+@torch.jit.script
+def merge(
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_chunks: int,
+ chunk_size: int, ):
+ combined = torch.reshape(
+ combined,
+ (original_batch_size, n_chunks, n_channel, chunk_size)
+ )
+ combined = torch.permute(combined, (0, 2, 3, 1)).reshape(
+ original_batch_size * n_channel,
+ chunk_size,
+ n_chunks
+ )
+
+ return combined
+
+
+@torch.jit.script
+def unfold(
+ padded_audio: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ chunk_size: int,
+ hop_size: int
+ ) -> torch.Tensor:
+
+ unfolded_input = F.unfold(
+ padded_audio[:, :, None, :],
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size)
+ )
+
+ _, _, n_chunks = unfolded_input.shape
+ unfolded_input = unfolded_input.view(
+ original_batch_size,
+ n_channel,
+ chunk_size,
+ n_chunks
+ )
+ unfolded_input = torch.permute(
+ unfolded_input,
+ (0, 3, 1, 2)
+ ).reshape(
+ original_batch_size * n_chunks,
+ n_channel,
+ chunk_size
+ )
+
+ return unfolded_input
+
+
+@torch.jit.script
+# @torch.compile
+def merge_chunks_all(
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_samples: int,
+ n_padded_samples: int,
+ n_chunks: int,
+ chunk_size: int,
+ hop_size: int,
+ edge_frame_pad_sizes: Tuple[int, int],
+ standard_window: torch.Tensor,
+ first_window: torch.Tensor,
+ last_window: torch.Tensor
+):
+ combined = merge(
+ combined,
+ original_batch_size,
+ n_channel,
+ n_chunks,
+ chunk_size
+ )
+
+ combined = combined * standard_window[:, None].to(combined.device)
+
+ combined = F.fold(
+ combined.to(torch.float32), output_size=(1, n_padded_samples),
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size)
+ )
+
+ combined = combined.view(
+ original_batch_size,
+ n_channel,
+ n_padded_samples
+ )
+
+ pad_front, pad_back = edge_frame_pad_sizes
+ combined = combined[..., pad_front:-pad_back]
+
+ combined = combined[..., :n_samples]
+
+ return combined
+
+ # @torch.jit.script
+
+
+def merge_chunks_edge(
+ combined: torch.Tensor,
+ original_batch_size: int,
+ n_channel: int,
+ n_samples: int,
+ n_padded_samples: int,
+ n_chunks: int,
+ chunk_size: int,
+ hop_size: int,
+ edge_frame_pad_sizes: Tuple[int, int],
+ standard_window: torch.Tensor,
+ first_window: torch.Tensor,
+ last_window: torch.Tensor
+):
+ combined = merge(
+ combined,
+ original_batch_size,
+ n_channel,
+ n_chunks,
+ chunk_size
+ )
+
+ combined[..., 0] = combined[..., 0] * first_window
+ combined[..., -1] = combined[..., -1] * last_window
+ combined[..., 1:-1] = combined[...,
+ 1:-1] * standard_window[:, None]
+
+ combined = F.fold(
+ combined, output_size=(1, n_padded_samples),
+ kernel_size=(1, chunk_size),
+ stride=(1, hop_size)
+ )
+
+ combined = combined.view(
+ original_batch_size,
+ n_channel,
+ n_padded_samples
+ )
+
+ combined = combined[..., :n_samples]
+
+ return combined
+
+
+class BaseFader(nn.Module):
+ def __init__(
+ self,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ fade_edge_frames: bool,
+ batch_size: int,
+ ) -> None:
+ super().__init__()
+
+ self.chunk_size = int(chunk_size_second * fs)
+ self.hop_size = int(hop_size_second * fs)
+ self.overlap_size = self.chunk_size - self.hop_size
+ self.fade_edge_frames = fade_edge_frames
+ self.batch_size = batch_size
+
+ # @torch.jit.script
+ def prepare(self, audio):
+
+ if self.fade_edge_frames:
+ audio = F.pad(audio, self.edge_frame_pad_sizes, mode="reflect")
+
+ n_samples = audio.shape[-1]
+ n_chunks = int(
+ np.ceil((n_samples - self.chunk_size) / self.hop_size) + 1
+ )
+
+ padded_size = (n_chunks - 1) * self.hop_size + self.chunk_size
+ pad_size = padded_size - n_samples
+
+ padded_audio = F.pad(audio, (0, pad_size))
+
+ return padded_audio, n_chunks
+
+ def forward(
+ self,
+ audio: torch.Tensor,
+ model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
+ ):
+
+ original_dtype = audio.dtype
+ original_device = audio.device
+
+ audio = audio.to("cpu")
+
+ original_batch_size, n_channel, n_samples = audio.shape
+ padded_audio, n_chunks = self.prepare(audio)
+ del audio
+ n_padded_samples = padded_audio.shape[-1]
+
+ if n_channel > 1:
+ padded_audio = padded_audio.view(
+ original_batch_size * n_channel, 1, n_padded_samples
+ )
+
+ unfolded_input = unfold(
+ padded_audio,
+ original_batch_size,
+ n_channel,
+ self.chunk_size, self.hop_size
+ )
+
+ n_total_chunks, n_channel, chunk_size = unfolded_input.shape
+
+ n_batch = np.ceil(n_total_chunks / self.batch_size).astype(int)
+
+ chunks_in = [
+ unfolded_input[
+ b * self.batch_size:(b + 1) * self.batch_size, ...].clone()
+ for b in range(n_batch)
+ ]
+
+ all_chunks_out = defaultdict(
+ lambda: torch.zeros_like(
+ unfolded_input, device="cpu"
+ )
+ )
+
+ # for b, cin in enumerate(tqdm(chunks_in)):
+ for b, cin in enumerate(chunks_in):
+ if torch.allclose(cin, torch.tensor(0.0)):
+ del cin
+ continue
+
+ chunks_out = model_fn(cin.to(original_device))
+ del cin
+ for s, c in chunks_out.items():
+ all_chunks_out[s][b * self.batch_size:(b + 1) * self.batch_size,
+ ...] = c.cpu()
+ del chunks_out
+
+ del unfolded_input
+ del padded_audio
+
+ if self.fade_edge_frames:
+ fn = merge_chunks_all
+ else:
+ fn = merge_chunks_edge
+ outputs = {}
+
+ torch.cuda.empty_cache()
+
+ for s, c in all_chunks_out.items():
+ combined: torch.Tensor = fn(
+ c,
+ original_batch_size,
+ n_channel,
+ n_samples,
+ n_padded_samples,
+ n_chunks,
+ self.chunk_size,
+ self.hop_size,
+ self.edge_frame_pad_sizes,
+ self.standard_window,
+ self.__dict__.get("first_window", self.standard_window),
+ self.__dict__.get("last_window", self.standard_window)
+ )
+
+ outputs[s] = combined.to(
+ dtype=original_dtype,
+ device=original_device
+ )
+
+ return {
+ "audio": outputs
+ }
+ #
+ # def old_forward(
+ # self,
+ # audio: torch.Tensor,
+ # model_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]],
+ # ):
+ #
+ # n_samples = audio.shape[-1]
+ # original_batch_size = audio.shape[0]
+ #
+ # padded_audio, n_chunks = self.prepare(audio)
+ #
+ # ndim = padded_audio.ndim
+ # broadcaster = [1 for _ in range(ndim - 1)] + [self.chunk_size]
+ #
+ # outputs = defaultdict(
+ # lambda: torch.zeros_like(
+ # padded_audio, device=audio.device, dtype=torch.float64
+ # )
+ # )
+ #
+ # all_chunks_out = []
+ # len_chunks_in = []
+ #
+ # batch_size_ = int(self.batch_size // original_batch_size)
+ # for b in range(int(np.ceil(n_chunks / batch_size_))):
+ # chunks_in = []
+ # for j in range(batch_size_):
+ # i = b * batch_size_ + j
+ # if i == n_chunks:
+ # break
+ #
+ # start = i * hop_size
+ # end = start + self.chunk_size
+ # chunk_in = padded_audio[..., start:end]
+ # chunks_in.append(chunk_in)
+ #
+ # chunks_in = torch.concat(chunks_in, dim=0)
+ # chunks_out = model_fn(chunks_in)
+ # all_chunks_out.append(chunks_out)
+ # len_chunks_in.append(len(chunks_in))
+ #
+ # for b, (chunks_out, lci) in enumerate(
+ # zip(all_chunks_out, len_chunks_in)
+ # ):
+ # for stem in chunks_out:
+ # for j in range(lci // original_batch_size):
+ # i = b * batch_size_ + j
+ #
+ # if self.fade_edge_frames:
+ # window = self.standard_window
+ # else:
+ # if i == 0:
+ # window = self.first_window
+ # elif i == n_chunks - 1:
+ # window = self.last_window
+ # else:
+ # window = self.standard_window
+ #
+ # start = i * hop_size
+ # end = start + self.chunk_size
+ #
+ # chunk_out = chunks_out[stem][j * original_batch_size: (j + 1) * original_batch_size,
+ # ...]
+ # contrib = window.view(*broadcaster) * chunk_out
+ # outputs[stem][..., start:end] = (
+ # outputs[stem][..., start:end] + contrib
+ # )
+ #
+ # if self.fade_edge_frames:
+ # pad_front, pad_back = self.edge_frame_pad_sizes
+ # outputs = {k: v[..., pad_front:-pad_back] for k, v in
+ # outputs.items()}
+ #
+ # outputs = {k: v[..., :n_samples].to(audio.dtype) for k, v in
+ # outputs.items()}
+ #
+ # return {
+ # "audio": outputs
+ # }
+
+
+class LinearFader(BaseFader):
+ def __init__(
+ self,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ fade_edge_frames: bool = False,
+ batch_size: int = 1,
+ ) -> None:
+
+ assert hop_size_second >= chunk_size_second / 2
+
+ super().__init__(
+ chunk_size_second=chunk_size_second,
+ hop_size_second=hop_size_second,
+ fs=fs,
+ fade_edge_frames=fade_edge_frames,
+ batch_size=batch_size,
+ )
+
+ in_fade = torch.linspace(0.0, 1.0, self.overlap_size + 1)[:-1]
+ out_fade = torch.linspace(1.0, 0.0, self.overlap_size + 1)[1:]
+ center_ones = torch.ones(self.chunk_size - 2 * self.overlap_size)
+ inout_ones = torch.ones(self.overlap_size)
+
+ # using nn.Parameters allows lightning to take care of devices for us
+ self.register_buffer(
+ "standard_window",
+ torch.concat([in_fade, center_ones, out_fade])
+ )
+
+ self.fade_edge_frames = fade_edge_frames
+ self.edge_frame_pad_size = (self.overlap_size, self.overlap_size)
+
+ if not self.fade_edge_frames:
+ self.first_window = nn.Parameter(
+ torch.concat([inout_ones, center_ones, out_fade]),
+ requires_grad=False
+ )
+ self.last_window = nn.Parameter(
+ torch.concat([in_fade, center_ones, inout_ones]),
+ requires_grad=False
+ )
+
+
+class OverlapAddFader(BaseFader):
+ def __init__(
+ self,
+ window_type: str,
+ chunk_size_second: float,
+ hop_size_second: float,
+ fs: int,
+ batch_size: int = 1,
+ ) -> None:
+ assert (chunk_size_second / hop_size_second) % 2 == 0
+ assert int(chunk_size_second * fs) % 2 == 0
+
+ super().__init__(
+ chunk_size_second=chunk_size_second,
+ hop_size_second=hop_size_second,
+ fs=fs,
+ fade_edge_frames=True,
+ batch_size=batch_size,
+ )
+
+ self.hop_multiplier = self.chunk_size / (2 * self.hop_size)
+ # print(f"hop multiplier: {self.hop_multiplier}")
+
+ self.edge_frame_pad_sizes = (
+ 2 * self.overlap_size,
+ 2 * self.overlap_size
+ )
+
+ self.register_buffer(
+ "standard_window", torch.windows.__dict__[window_type](
+ self.chunk_size, sym=False, # dtype=torch.float64
+ ) / self.hop_multiplier
+ )
+
+
+if __name__ == "__main__":
+ import torchaudio as ta
+ fs = 44100
+ ola = OverlapAddFader(
+ "hann",
+ 6.0,
+ 1.0,
+ fs,
+ batch_size=16
+ )
+ audio_, _ = ta.load(
+ "$DATA_ROOT/MUSDB18/HQ/canonical/test/BKS - Too "
+ "Much/vocals.wav"
+ )
+ audio_ = audio_[None, ...]
+ out = ola(audio_, lambda x: {"stem": x})["audio"]["stem"]
+ print(torch.allclose(out, audio_))
diff --git a/models/bandit/model_from_config.py b/models/bandit/model_from_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ea586d7dfdbd6b89d6b7f2f400e6c8d04da5e4
--- /dev/null
+++ b/models/bandit/model_from_config.py
@@ -0,0 +1,31 @@
+import sys
+import os.path
+import torch
+
+code_path = os.path.dirname(os.path.abspath(__file__)) + '/'
+sys.path.append(code_path)
+
+import yaml
+from ml_collections import ConfigDict
+
+torch.set_float32_matmul_precision("medium")
+
+
+def get_model(
+ config_path,
+ weights_path,
+ device,
+):
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
+
+ f = open(config_path)
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
+ f.close()
+
+ model = MultiMaskMultiSourceBandSplitRNNSimple(
+ **config.model
+ )
+ d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt')
+ model.load_state_dict(d)
+ model.to(device)
+ return model, config
diff --git a/models/bandit_v2/bandit.py b/models/bandit_v2/bandit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4e13f479891065cf1f7dae0720721128347979
--- /dev/null
+++ b/models/bandit_v2/bandit.py
@@ -0,0 +1,367 @@
+from typing import Dict, List, Optional
+
+import torch
+import torchaudio as ta
+from torch import nn
+import pytorch_lightning as pl
+
+from .bandsplit import BandSplitModule
+from .maskestim import OverlappingMaskEstimationModule
+from .tfmodel import SeqBandModellingModule
+from .utils import MusicalBandsplitSpecification
+
+
+
+class BaseEndToEndModule(pl.LightningModule):
+ def __init__(
+ self,
+ ) -> None:
+ super().__init__()
+
+
+class BaseBandit(BaseEndToEndModule):
+ def __init__(
+ self,
+ in_channels: int,
+ fs: int,
+ band_type: str = "musical",
+ n_bands: int = 64,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+
+ self.instantitate_spectral(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ )
+
+ self.instantiate_bandsplit(
+ in_channels=in_channels,
+ band_type=band_type,
+ n_bands=n_bands,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ n_fft=n_fft,
+ fs=fs,
+ )
+
+ self.instantiate_tf_modelling(
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+
+ def instantitate_spectral(
+ self,
+ n_fft: int = 2048,
+ win_length: Optional[int] = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Optional[Dict] = None,
+ power: Optional[int] = None,
+ normalized: bool = True,
+ center: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ ):
+ assert power is None
+
+ window_fn = torch.__dict__[window_fn]
+
+ self.stft = ta.transforms.Spectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+
+ self.istft = ta.transforms.InverseSpectrogram(
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ pad_mode=pad_mode,
+ pad=0,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ normalized=normalized,
+ center=center,
+ onesided=onesided,
+ )
+
+ def instantiate_bandsplit(
+ self,
+ in_channels: int,
+ band_type: str = "musical",
+ n_bands: int = 64,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ emb_dim: int = 128,
+ n_fft: int = 2048,
+ fs: int = 44100,
+ ):
+ assert band_type == "musical"
+
+ self.band_specs = MusicalBandsplitSpecification(
+ nfft=n_fft, fs=fs, n_bands=n_bands
+ )
+
+ self.band_split = BandSplitModule(
+ in_channels=in_channels,
+ band_specs=self.band_specs.get_band_specs(),
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ emb_dim=emb_dim,
+ )
+
+ def instantiate_tf_modelling(
+ self,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ ):
+ try:
+ self.tf_model = torch.compile(
+ SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ disable=True,
+ )
+ except Exception as e:
+ self.tf_model = SeqBandModellingModule(
+ n_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ )
+
+ def mask(self, x, m):
+ return x * m
+
+ def forward(self, batch, mode="train"):
+ # Model takes mono as input we give stereo, so we do process of each channel independently
+ init_shape = batch.shape
+ if not isinstance(batch, dict):
+ mono = batch.view(-1, 1, batch.shape[-1])
+ batch = {
+ "mixture": {
+ "audio": mono
+ }
+ }
+
+ with torch.no_grad():
+ mixture = batch["mixture"]["audio"]
+
+ x = self.stft(mixture)
+ batch["mixture"]["spectrogram"] = x
+
+ if "sources" in batch.keys():
+ for stem in batch["sources"].keys():
+ s = batch["sources"][stem]["audio"]
+ s = self.stft(s)
+ batch["sources"][stem]["spectrogram"] = s
+
+ batch = self.separate(batch)
+
+ if 1:
+ b = []
+ for s in self.stems:
+ # We need to obtain stereo again
+ r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2])
+ b.append(r)
+ # And we need to return back tensor and not independent stems
+ batch = torch.stack(b, dim=1)
+ return batch
+
+ def encode(self, batch):
+ x = batch["mixture"]["spectrogram"]
+ length = batch["mixture"]["audio"].shape[-1]
+
+ z = self.band_split(x) # (batch, emb_dim, n_band, n_time)
+ q = self.tf_model(z) # (batch, emb_dim, n_band, n_time)
+
+ return x, q, length
+
+ def separate(self, batch):
+ raise NotImplementedError
+
+
+class Bandit(BaseBandit):
+ def __init__(
+ self,
+ in_channels: int,
+ stems: List[str],
+ band_type: str = "musical",
+ n_bands: int = 64,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ n_sqm_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ mlp_dim: int = 512,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict | None = None,
+ complex_mask: bool = True,
+ use_freq_weights: bool = True,
+ n_fft: int = 2048,
+ win_length: int | None = 2048,
+ hop_length: int = 512,
+ window_fn: str = "hann_window",
+ wkwargs: Dict | None = None,
+ power: int | None = None,
+ center: bool = True,
+ normalized: bool = True,
+ pad_mode: str = "constant",
+ onesided: bool = True,
+ fs: int = 44100,
+ stft_precisions="32",
+ bandsplit_precisions="bf16",
+ tf_model_precisions="bf16",
+ mask_estim_precisions="bf16",
+ ):
+ super().__init__(
+ in_channels=in_channels,
+ band_type=band_type,
+ n_bands=n_bands,
+ require_no_overlap=require_no_overlap,
+ require_no_gap=require_no_gap,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ n_sqm_modules=n_sqm_modules,
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ window_fn=window_fn,
+ wkwargs=wkwargs,
+ power=power,
+ center=center,
+ normalized=normalized,
+ pad_mode=pad_mode,
+ onesided=onesided,
+ fs=fs,
+ )
+
+ self.stems = stems
+
+ self.instantiate_mask_estim(
+ in_channels=in_channels,
+ stems=stems,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ n_freq=n_fft // 2 + 1,
+ use_freq_weights=use_freq_weights,
+ )
+
+ def instantiate_mask_estim(
+ self,
+ in_channels: int,
+ stems: List[str],
+ emb_dim: int,
+ mlp_dim: int,
+ hidden_activation: str,
+ hidden_activation_kwargs: Optional[Dict] = None,
+ complex_mask: bool = True,
+ n_freq: Optional[int] = None,
+ use_freq_weights: bool = False,
+ ):
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ assert n_freq is not None
+
+ self.mask_estim = nn.ModuleDict(
+ {
+ stem: OverlappingMaskEstimationModule(
+ band_specs=self.band_specs.get_band_specs(),
+ freq_weights=self.band_specs.get_freq_weights(),
+ n_freq=n_freq,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ use_freq_weights=use_freq_weights,
+ )
+ for stem in stems
+ }
+ )
+
+ def separate(self, batch):
+ batch["estimates"] = {}
+
+ x, q, length = self.encode(batch)
+
+ for stem, mem in self.mask_estim.items():
+ m = mem(q)
+
+ s = self.mask(x, m.to(x.dtype))
+ s = torch.reshape(s, x.shape)
+ batch["estimates"][stem] = {
+ "audio": self.istft(s, length),
+ "spectrogram": s,
+ }
+
+ return batch
+
diff --git a/models/bandit_v2/bandsplit.py b/models/bandit_v2/bandsplit.py
new file mode 100644
index 0000000000000000000000000000000000000000..a14ea52bfa318264d536c9f934d0e28db63e15dc
--- /dev/null
+++ b/models/bandit_v2/bandsplit.py
@@ -0,0 +1,130 @@
+from typing import List, Tuple
+
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint_sequential
+
+from .utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class NormFC(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ bandwidth: int,
+ in_channels: int,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ if not treat_channel_as_feature:
+ raise NotImplementedError
+
+ self.treat_channel_as_feature = treat_channel_as_feature
+
+ if normalize_channel_independently:
+ raise NotImplementedError
+
+ reim = 2
+
+ norm = nn.LayerNorm(in_channels * bandwidth * reim)
+
+ fc_in = bandwidth * reim
+
+ if treat_channel_as_feature:
+ fc_in *= in_channels
+ else:
+ assert emb_dim % in_channels == 0
+ emb_dim = emb_dim // in_channels
+
+ fc = nn.Linear(fc_in, emb_dim)
+
+ self.combined = nn.Sequential(norm, fc)
+
+ def forward(self, xb):
+ return checkpoint_sequential(self.combined, 1, xb, use_reentrant=False)
+
+
+class BandSplitModule(nn.Module):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ in_channels: int,
+ require_no_overlap: bool = False,
+ require_no_gap: bool = True,
+ normalize_channel_independently: bool = False,
+ treat_channel_as_feature: bool = True,
+ ) -> None:
+ super().__init__()
+
+ check_nonzero_bandwidth(band_specs)
+
+ if require_no_gap:
+ check_no_gap(band_specs)
+
+ if require_no_overlap:
+ check_no_overlap(band_specs)
+
+ self.band_specs = band_specs
+ # list of [fstart, fend) in index.
+ # Note that fend is exclusive.
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+ self.emb_dim = emb_dim
+
+ try:
+ self.norm_fc_modules = nn.ModuleList(
+ [ # type: ignore
+ torch.compile(
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channels=in_channels,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ ),
+ disable=True,
+ )
+ for bw in self.band_widths
+ ]
+ )
+ except Exception as e:
+ self.norm_fc_modules = nn.ModuleList(
+ [ # type: ignore
+ NormFC(
+ emb_dim=emb_dim,
+ bandwidth=bw,
+ in_channels=in_channels,
+ normalize_channel_independently=normalize_channel_independently,
+ treat_channel_as_feature=treat_channel_as_feature,
+ )
+ for bw in self.band_widths
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ # x = complex spectrogram (batch, in_chan, n_freq, n_time)
+
+ batch, in_chan, band_width, n_time = x.shape
+
+ z = torch.zeros(
+ size=(batch, self.n_bands, n_time, self.emb_dim), device=x.device
+ )
+
+ x = torch.permute(x, (0, 3, 1, 2)).contiguous()
+
+ for i, nfm in enumerate(self.norm_fc_modules):
+ fstart, fend = self.band_specs[i]
+ xb = x[:, :, :, fstart:fend]
+ xb = torch.view_as_real(xb)
+ xb = torch.reshape(xb, (batch, n_time, -1))
+ z[:, i, :, :] = nfm(xb)
+
+ return z
diff --git a/models/bandit_v2/film.py b/models/bandit_v2/film.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30795332ea0e06865ea3d883767db17bb02353c
--- /dev/null
+++ b/models/bandit_v2/film.py
@@ -0,0 +1,25 @@
+from torch import nn
+import torch
+
+class FiLM(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, gamma, beta):
+ return gamma * x + beta
+
+
+class BTFBroadcastedFiLM(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.film = FiLM()
+
+ def forward(self, x, gamma, beta):
+
+ gamma = gamma[None, None, None, :]
+ beta = beta[None, None, None, :]
+
+ return self.film(x, gamma, beta)
+
+
+
\ No newline at end of file
diff --git a/models/bandit_v2/maskestim.py b/models/bandit_v2/maskestim.py
new file mode 100644
index 0000000000000000000000000000000000000000..65215d86a5e94dafdb71744aafadf7aaab93330d
--- /dev/null
+++ b/models/bandit_v2/maskestim.py
@@ -0,0 +1,281 @@
+from typing import Dict, List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch.nn.modules import activation
+from torch.utils.checkpoint import checkpoint_sequential
+
+from .utils import (
+ band_widths_from_specs,
+ check_no_gap,
+ check_no_overlap,
+ check_nonzero_bandwidth,
+)
+
+
+class BaseNormMLP(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ):
+ super().__init__()
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+ self.hidden_activation_kwargs = hidden_activation_kwargs
+ self.norm = nn.LayerNorm(emb_dim)
+ self.hidden = nn.Sequential(
+ nn.Linear(in_features=emb_dim, out_features=mlp_dim),
+ activation.__dict__[hidden_activation](**self.hidden_activation_kwargs),
+ )
+
+ self.bandwidth = bandwidth
+ self.in_channels = in_channels
+
+ self.complex_mask = complex_mask
+ self.reim = 2 if complex_mask else 1
+ self.glu_mult = 2
+
+
+class NormMLP(BaseNormMLP):
+ def __init__(
+ self,
+ emb_dim: int,
+ mlp_dim: int,
+ bandwidth: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs=None,
+ complex_mask: bool = True,
+ ) -> None:
+ super().__init__(
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ bandwidth=bandwidth,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ self.output = nn.Sequential(
+ nn.Linear(
+ in_features=mlp_dim,
+ out_features=bandwidth * in_channels * self.reim * 2,
+ ),
+ nn.GLU(dim=-1),
+ )
+
+ try:
+ self.combined = torch.compile(
+ nn.Sequential(self.norm, self.hidden, self.output), disable=True
+ )
+ except Exception as e:
+ self.combined = nn.Sequential(self.norm, self.hidden, self.output)
+
+ def reshape_output(self, mb):
+ # print(mb.shape)
+ batch, n_time, _ = mb.shape
+ if self.complex_mask:
+ mb = mb.reshape(
+ batch, n_time, self.in_channels, self.bandwidth, self.reim
+ ).contiguous()
+ # print(mb.shape)
+ mb = torch.view_as_complex(mb) # (batch, n_time, in_channels, bandwidth)
+ else:
+ mb = mb.reshape(batch, n_time, self.in_channels, self.bandwidth)
+
+ mb = torch.permute(mb, (0, 2, 3, 1)) # (batch, in_channels, bandwidth, n_time)
+
+ return mb
+
+ def forward(self, qb):
+ # qb = (batch, n_time, emb_dim)
+ # qb = self.norm(qb) # (batch, n_time, emb_dim)
+ # qb = self.hidden(qb) # (batch, n_time, mlp_dim)
+ # mb = self.output(qb) # (batch, n_time, bandwidth * in_channels * reim)
+
+ mb = checkpoint_sequential(self.combined, 2, qb, use_reentrant=False)
+ mb = self.reshape_output(mb) # (batch, in_channels, bandwidth, n_time)
+
+ return mb
+
+
+class MaskEstimationModuleSuperBase(nn.Module):
+ pass
+
+
+class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ ) -> None:
+ super().__init__()
+
+ self.band_widths = band_widths_from_specs(band_specs)
+ self.n_bands = len(band_specs)
+
+ if hidden_activation_kwargs is None:
+ hidden_activation_kwargs = {}
+
+ if norm_mlp_kwargs is None:
+ norm_mlp_kwargs = {}
+
+ self.norm_mlp = nn.ModuleList(
+ [
+ norm_mlp_cls(
+ bandwidth=self.band_widths[b],
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ **norm_mlp_kwargs,
+ )
+ for b in range(self.n_bands)
+ ]
+ )
+
+ def compute_masks(self, q):
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ masks = []
+
+ for b, nmlp in enumerate(self.norm_mlp):
+ # print(f"maskestim/{b:02d}")
+ qb = q[:, b, :, :]
+ mb = nmlp(qb)
+ masks.append(mb)
+
+ return masks
+
+ def compute_mask(self, q, b):
+ batch, n_bands, n_time, emb_dim = q.shape
+ qb = q[:, b, :, :]
+ mb = self.norm_mlp[b](qb)
+ return mb
+
+
+class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
+ def __init__(
+ self,
+ in_channels: int,
+ band_specs: List[Tuple[float, float]],
+ freq_weights: List[torch.Tensor],
+ n_freq: int,
+ emb_dim: int,
+ mlp_dim: int,
+ cond_dim: int = 0,
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ norm_mlp_cls: Type[nn.Module] = NormMLP,
+ norm_mlp_kwargs: Dict = None,
+ use_freq_weights: bool = False,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+
+ if cond_dim > 0:
+ raise NotImplementedError
+
+ super().__init__(
+ band_specs=band_specs,
+ emb_dim=emb_dim + cond_dim,
+ mlp_dim=mlp_dim,
+ in_channels=in_channels,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ norm_mlp_cls=norm_mlp_cls,
+ norm_mlp_kwargs=norm_mlp_kwargs,
+ )
+
+ self.n_freq = n_freq
+ self.band_specs = band_specs
+ self.in_channels = in_channels
+
+ if freq_weights is not None and use_freq_weights:
+ for i, fw in enumerate(freq_weights):
+ self.register_buffer(f"freq_weights/{i}", fw)
+
+ self.use_freq_weights = use_freq_weights
+ else:
+ self.use_freq_weights = False
+
+ def forward(self, q):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ batch, n_bands, n_time, emb_dim = q.shape
+
+ masks = torch.zeros(
+ (batch, self.in_channels, self.n_freq, n_time),
+ device=q.device,
+ dtype=torch.complex64,
+ )
+
+ for im in range(n_bands):
+ fstart, fend = self.band_specs[im]
+
+ mask = self.compute_mask(q, im)
+
+ if self.use_freq_weights:
+ fw = self.get_buffer(f"freq_weights/{im}")[:, None]
+ mask = mask * fw
+ masks[:, :, fstart:fend, :] += mask
+
+ return masks
+
+
+class MaskEstimationModule(OverlappingMaskEstimationModule):
+ def __init__(
+ self,
+ band_specs: List[Tuple[float, float]],
+ emb_dim: int,
+ mlp_dim: int,
+ in_channels: Optional[int],
+ hidden_activation: str = "Tanh",
+ hidden_activation_kwargs: Dict = None,
+ complex_mask: bool = True,
+ **kwargs,
+ ) -> None:
+ check_nonzero_bandwidth(band_specs)
+ check_no_gap(band_specs)
+ check_no_overlap(band_specs)
+ super().__init__(
+ in_channels=in_channels,
+ band_specs=band_specs,
+ freq_weights=None,
+ n_freq=None,
+ emb_dim=emb_dim,
+ mlp_dim=mlp_dim,
+ hidden_activation=hidden_activation,
+ hidden_activation_kwargs=hidden_activation_kwargs,
+ complex_mask=complex_mask,
+ )
+
+ def forward(self, q, cond=None):
+ # q = (batch, n_bands, n_time, emb_dim)
+
+ masks = self.compute_masks(
+ q
+ ) # [n_bands * (batch, in_channels, bandwidth, n_time)]
+
+ # TODO: currently this requires band specs to have no gap and no overlap
+ masks = torch.concat(masks, dim=2) # (batch, in_channels, n_freq, n_time)
+
+ return masks
diff --git a/models/bandit_v2/tfmodel.py b/models/bandit_v2/tfmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..21aef03d1f0e814c20db05fe7d14f8019f07713b
--- /dev/null
+++ b/models/bandit_v2/tfmodel.py
@@ -0,0 +1,145 @@
+import warnings
+
+import torch
+import torch.backends.cuda
+from torch import nn
+from torch.nn.modules import rnn
+from torch.utils.checkpoint import checkpoint_sequential
+
+
+class TimeFrequencyModellingModule(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+
+class ResidualRNN(nn.Module):
+ def __init__(
+ self,
+ emb_dim: int,
+ rnn_dim: int,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ use_batch_trick: bool = True,
+ use_layer_norm: bool = True,
+ ) -> None:
+ # n_group is the size of the 2nd dim
+ super().__init__()
+
+ assert use_layer_norm
+ assert use_batch_trick
+
+ self.use_layer_norm = use_layer_norm
+ self.norm = nn.LayerNorm(emb_dim)
+ self.rnn = rnn.__dict__[rnn_type](
+ input_size=emb_dim,
+ hidden_size=rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=bidirectional,
+ )
+
+ self.fc = nn.Linear(
+ in_features=rnn_dim * (2 if bidirectional else 1), out_features=emb_dim
+ )
+
+ self.use_batch_trick = use_batch_trick
+ if not self.use_batch_trick:
+ warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")
+
+ def forward(self, z):
+ # z = (batch, n_uncrossed, n_across, emb_dim)
+
+ z0 = torch.clone(z)
+ z = self.norm(z)
+
+ batch, n_uncrossed, n_across, emb_dim = z.shape
+ z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
+ z = self.rnn(z)[0]
+ z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
+
+ z = self.fc(z) # (batch, n_uncrossed, n_across, emb_dim)
+
+ z = z + z0
+
+ return z
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0: int, dim1: int) -> None:
+ super().__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, z):
+ return z.transpose(self.dim0, self.dim1)
+
+
+class SeqBandModellingModule(TimeFrequencyModellingModule):
+ def __init__(
+ self,
+ n_modules: int = 12,
+ emb_dim: int = 128,
+ rnn_dim: int = 256,
+ bidirectional: bool = True,
+ rnn_type: str = "LSTM",
+ parallel_mode=False,
+ ) -> None:
+ super().__init__()
+
+ self.n_modules = n_modules
+
+ if parallel_mode:
+ self.seqband = nn.ModuleList([])
+ for _ in range(n_modules):
+ self.seqband.append(
+ nn.ModuleList(
+ [
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ ]
+ )
+ )
+ else:
+ seqband = []
+ for _ in range(2 * n_modules):
+ seqband += [
+ ResidualRNN(
+ emb_dim=emb_dim,
+ rnn_dim=rnn_dim,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ ),
+ Transpose(1, 2),
+ ]
+
+ self.seqband = nn.Sequential(*seqband)
+
+ self.parallel_mode = parallel_mode
+
+ def forward(self, z):
+ # z = (batch, n_bands, n_time, emb_dim)
+
+ if self.parallel_mode:
+ for sbm_pair in self.seqband:
+ # z: (batch, n_bands, n_time, emb_dim)
+ sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
+ zt = sbm_t(z) # (batch, n_bands, n_time, emb_dim)
+ zf = sbm_f(z.transpose(1, 2)) # (batch, n_time, n_bands, emb_dim)
+ z = zt + zf.transpose(1, 2)
+ else:
+ z = checkpoint_sequential(
+ self.seqband, self.n_modules, z, use_reentrant=False
+ )
+
+ q = z
+ return q # (batch, n_bands, n_time, emb_dim)
diff --git a/models/bandit_v2/utils.py b/models/bandit_v2/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4eab5d8c5b5396ed717f5b9c365a6900eddd2f
--- /dev/null
+++ b/models/bandit_v2/utils.py
@@ -0,0 +1,523 @@
+import os
+from abc import abstractmethod
+from typing import Callable
+
+import numpy as np
+import torch
+from librosa import hz_to_midi, midi_to_hz
+from torchaudio import functional as taF
+
+# from spafe.fbanks import bark_fbanks
+# from spafe.utils.converters import erb2hz, hz2bark, hz2erb
+
+
+def band_widths_from_specs(band_specs):
+ return [e - i for i, e in band_specs]
+
+
+def check_nonzero_bandwidth(band_specs):
+ # pprint(band_specs)
+ for fstart, fend in band_specs:
+ if fend - fstart <= 0:
+ raise ValueError("Bands cannot be zero-width")
+
+
+def check_no_overlap(band_specs):
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr <= fend_prev:
+ raise ValueError("Bands cannot overlap")
+
+
+def check_no_gap(band_specs):
+ fstart, _ = band_specs[0]
+ assert fstart == 0
+
+ fend_prev = -1
+ for fstart_curr, fend_curr in band_specs:
+ if fstart_curr - fend_prev > 1:
+ raise ValueError("Bands cannot leave gap")
+ fend_prev = fend_curr
+
+
+class BandsplitSpecification:
+ def __init__(self, nfft: int, fs: int) -> None:
+ self.fs = fs
+ self.nfft = nfft
+ self.nyquist = fs / 2
+ self.max_index = nfft // 2 + 1
+
+ self.split500 = self.hertz_to_index(500)
+ self.split1k = self.hertz_to_index(1000)
+ self.split2k = self.hertz_to_index(2000)
+ self.split4k = self.hertz_to_index(4000)
+ self.split8k = self.hertz_to_index(8000)
+ self.split16k = self.hertz_to_index(16000)
+ self.split20k = self.hertz_to_index(20000)
+
+ self.above20k = [(self.split20k, self.max_index)]
+ self.above16k = [(self.split16k, self.split20k)] + self.above20k
+
+ def index_to_hertz(self, index: int):
+ return index * self.fs / self.nfft
+
+ def hertz_to_index(self, hz: float, round: bool = True):
+ index = hz * self.nfft / self.fs
+
+ if round:
+ index = int(np.round(index))
+
+ return index
+
+ def get_band_specs_with_bandwidth(self, start_index, end_index, bandwidth_hz):
+ band_specs = []
+ lower = start_index
+
+ while lower < end_index:
+ upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
+ upper = min(upper, end_index)
+
+ band_specs.append((lower, upper))
+ lower = upper
+
+ return band_specs
+
+ @abstractmethod
+ def get_band_specs(self):
+ raise NotImplementedError
+
+
+class VocalBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ self.version = version
+
+ def get_band_specs(self):
+ return getattr(self, f"version{self.version}")()
+
+ @property
+ def version1(self):
+ return self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.max_index, bandwidth_hz=1000
+ )
+
+ def version2(self):
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+
+ return below16k + below20k + self.above20k
+
+ def version3(self):
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+
+ return below8k + below16k + self.above16k
+
+ def version4(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+
+ return below1k + below8k + below16k + self.above16k
+
+ def version5(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+ return below1k + below16k + below20k + self.above20k
+
+ def version6(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + self.above16k
+
+ def version7(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ below20k = self.get_band_specs_with_bandwidth(
+ start_index=self.split16k, end_index=self.split20k, bandwidth_hz=2000
+ )
+ return below1k + below4k + below8k + below16k + below20k + self.above20k
+
+
+class OtherBandsplitSpecification(VocalBandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs, version="7")
+
+
+class BassBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below500 = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split500, bandwidth_hz=50
+ )
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=self.split500, end_index=self.split1k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split4k, bandwidth_hz=500
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=1000
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=2000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below500 + below1k + below4k + below8k + below16k + above16k
+
+
+class DrumBandsplitSpecification(BandsplitSpecification):
+ def __init__(self, nfft: int, fs: int) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+
+ def get_band_specs(self):
+ below1k = self.get_band_specs_with_bandwidth(
+ start_index=0, end_index=self.split1k, bandwidth_hz=50
+ )
+ below2k = self.get_band_specs_with_bandwidth(
+ start_index=self.split1k, end_index=self.split2k, bandwidth_hz=100
+ )
+ below4k = self.get_band_specs_with_bandwidth(
+ start_index=self.split2k, end_index=self.split4k, bandwidth_hz=250
+ )
+ below8k = self.get_band_specs_with_bandwidth(
+ start_index=self.split4k, end_index=self.split8k, bandwidth_hz=500
+ )
+ below16k = self.get_band_specs_with_bandwidth(
+ start_index=self.split8k, end_index=self.split16k, bandwidth_hz=1000
+ )
+ above16k = [(self.split16k, self.max_index)]
+
+ return below1k + below2k + below4k + below8k + below16k + above16k
+
+
+class PerceptualBandsplitSpecification(BandsplitSpecification):
+ def __init__(
+ self,
+ nfft: int,
+ fs: int,
+ fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
+ n_bands: int,
+ f_min: float = 0.0,
+ f_max: float = None,
+ ) -> None:
+ super().__init__(nfft=nfft, fs=fs)
+ self.n_bands = n_bands
+ if f_max is None:
+ f_max = fs / 2
+
+ self.filterbank = fbank_fn(n_bands, fs, f_min, f_max, self.max_index)
+
+ weight_per_bin = torch.sum(self.filterbank, dim=0, keepdim=True) # (1, n_freqs)
+ normalized_mel_fb = self.filterbank / weight_per_bin # (n_mels, n_freqs)
+
+ freq_weights = []
+ band_specs = []
+ for i in range(self.n_bands):
+ active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
+ if isinstance(active_bins, int):
+ active_bins = (active_bins, active_bins)
+ if len(active_bins) == 0:
+ continue
+ start_index = active_bins[0]
+ end_index = active_bins[-1] + 1
+ band_specs.append((start_index, end_index))
+ freq_weights.append(normalized_mel_fb[i, start_index:end_index])
+
+ self.freq_weights = freq_weights
+ self.band_specs = band_specs
+
+ def get_band_specs(self):
+ return self.band_specs
+
+ def get_freq_weights(self):
+ return self.freq_weights
+
+ def save_to_file(self, dir_path: str) -> None:
+ os.makedirs(dir_path, exist_ok=True)
+
+ import pickle
+
+ with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
+ pickle.dump(
+ {
+ "band_specs": self.band_specs,
+ "freq_weights": self.freq_weights,
+ "filterbank": self.filterbank,
+ },
+ f,
+ )
+
+
+def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
+ fb = taF.melscale_fbanks(
+ n_mels=n_bands,
+ sample_rate=fs,
+ f_min=f_min,
+ f_max=f_max,
+ n_freqs=n_freqs,
+ ).T
+
+ fb[0, 0] = 1.0
+
+ return fb
+
+
+class MelBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=mel_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs, scale="constant"):
+ nfft = 2 * (n_freqs - 1)
+ df = fs / nfft
+ # init freqs
+ f_max = f_max or fs / 2
+ f_min = f_min or 0
+ f_min = fs / nfft
+
+ n_octaves = np.log2(f_max / f_min)
+ n_octaves_per_band = n_octaves / n_bands
+ bandwidth_mult = np.power(2.0, n_octaves_per_band)
+
+ low_midi = max(0, hz_to_midi(f_min))
+ high_midi = hz_to_midi(f_max)
+ midi_points = np.linspace(low_midi, high_midi, n_bands)
+ hz_pts = midi_to_hz(midi_points)
+
+ low_pts = hz_pts / bandwidth_mult
+ high_pts = hz_pts * bandwidth_mult
+
+ low_bins = np.floor(low_pts / df).astype(int)
+ high_bins = np.ceil(high_pts / df).astype(int)
+
+ fb = np.zeros((n_bands, n_freqs))
+
+ for i in range(n_bands):
+ fb[i, low_bins[i] : high_bins[i] + 1] = 1.0
+
+ fb[0, : low_bins[0]] = 1.0
+ fb[-1, high_bins[-1] + 1 :] = 1.0
+
+ return torch.as_tensor(fb)
+
+
+class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
+ def __init__(
+ self, nfft: int, fs: int, n_bands: int, f_min: float = 0.0, f_max: float = None
+ ) -> None:
+ super().__init__(
+ fbank_fn=musical_filterbank,
+ nfft=nfft,
+ fs=fs,
+ n_bands=n_bands,
+ f_min=f_min,
+ f_max=f_max,
+ )
+
+
+# def bark_filterbank(
+# n_bands, fs, f_min, f_max, n_freqs
+# ):
+# nfft = 2 * (n_freqs -1)
+# fb, _ = bark_fbanks.bark_filter_banks(
+# nfilts=n_bands,
+# nfft=nfft,
+# fs=fs,
+# low_freq=f_min,
+# high_freq=f_max,
+# scale="constant"
+# )
+
+# return torch.as_tensor(fb)
+
+# class BarkBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+# def triangular_bark_filterbank(
+# n_bands, fs, f_min, f_max, n_freqs
+# ):
+
+# all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+# # calculate mel freq bins
+# m_min = hz2bark(f_min)
+# m_max = hz2bark(f_max)
+
+# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+# f_pts = 600 * torch.sinh(m_pts / 6)
+
+# # create filterbank
+# fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+# fb = fb.T
+
+# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+# fb[first_active_band, :first_active_bin] = 1.0
+
+# return fb
+
+# class TriangularBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=triangular_bark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+# def minibark_filterbank(
+# n_bands, fs, f_min, f_max, n_freqs
+# ):
+# fb = bark_filterbank(
+# n_bands,
+# fs,
+# f_min,
+# f_max,
+# n_freqs
+# )
+
+# fb[fb < np.sqrt(0.5)] = 0.0
+
+# return fb
+
+# class MiniBarkBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=minibark_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+
+# def erb_filterbank(
+# n_bands: int,
+# fs: int,
+# f_min: float,
+# f_max: float,
+# n_freqs: int,
+# ) -> Tensor:
+# # freq bins
+# A = (1000 * np.log(10)) / (24.7 * 4.37)
+# all_freqs = torch.linspace(0, fs // 2, n_freqs)
+
+# # calculate mel freq bins
+# m_min = hz2erb(f_min)
+# m_max = hz2erb(f_max)
+
+# m_pts = torch.linspace(m_min, m_max, n_bands + 2)
+# f_pts = (torch.pow(10, (m_pts / A)) - 1)/ 0.00437
+
+# # create filterbank
+# fb = _create_triangular_filterbank(all_freqs, f_pts)
+
+# fb = fb.T
+
+
+# first_active_band = torch.nonzero(torch.sum(fb, dim=-1))[0, 0]
+# first_active_bin = torch.nonzero(fb[first_active_band, :])[0, 0]
+
+# fb[first_active_band, :first_active_bin] = 1.0
+
+# return fb
+
+
+# class EquivalentRectangularBandsplitSpecification(PerceptualBandsplitSpecification):
+# def __init__(
+# self,
+# nfft: int,
+# fs: int,
+# n_bands: int,
+# f_min: float = 0.0,
+# f_max: float = None
+# ) -> None:
+# super().__init__(fbank_fn=erb_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)
+
+if __name__ == "__main__":
+ import pandas as pd
+
+ band_defs = []
+
+ for bands in [VocalBandsplitSpecification]:
+ band_name = bands.__name__.replace("BandsplitSpecification", "")
+
+ mbs = bands(nfft=2048, fs=44100).get_band_specs()
+
+ for i, (f_min, f_max) in enumerate(mbs):
+ band_defs.append(
+ {"band": band_name, "band_index": i, "f_min": f_min, "f_max": f_max}
+ )
+
+ df = pd.DataFrame(band_defs)
+ df.to_csv("vox7bands.csv", index=False)
diff --git a/models/bs_roformer/__init__.py b/models/bs_roformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..980e0afa5b7b4fd66168bce6905a94e7c91c380e
--- /dev/null
+++ b/models/bs_roformer/__init__.py
@@ -0,0 +1,2 @@
+from models.bs_roformer.bs_roformer import BSRoformer
+from models.bs_roformer.mel_band_roformer import MelBandRoformer
diff --git a/models/bs_roformer/attend.py b/models/bs_roformer/attend.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6dc4b3079cff5b3c8c90cea8df2301afd18918b
--- /dev/null
+++ b/models/bs_roformer/attend.py
@@ -0,0 +1,126 @@
+from functools import wraps
+from packaging import version
+from collections import namedtuple
+
+import os
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+from einops import rearrange, reduce
+
+# constants
+
+FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(v, d):
+ return v if exists(v) else d
+
+def once(fn):
+ called = False
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+ return inner
+
+print_once = once(print)
+
+# main class
+
+class Attend(nn.Module):
+ def __init__(
+ self,
+ dropout = 0.,
+ flash = False,
+ scale = None
+ ):
+ super().__init__()
+ self.scale = scale
+ self.dropout = dropout
+ self.attn_dropout = nn.Dropout(dropout)
+
+ self.flash = flash
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
+
+ # determine efficient attention configs for cuda and cpu
+
+ self.cpu_config = FlashAttentionConfig(True, True, True)
+ self.cuda_config = None
+
+ if not torch.cuda.is_available() or not flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+ device_version = version.parse(f'{device_properties.major}.{device_properties.minor}')
+
+ if device_version >= version.parse('8.0'):
+ if os.name == 'nt':
+ print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+ else:
+ print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(True, False, False)
+ else:
+ print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda')
+ self.cuda_config = FlashAttentionConfig(False, True, True)
+
+ def flash_attn(self, q, k, v):
+ _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
+
+ if exists(self.scale):
+ default_scale = q.shape[-1] ** -0.5
+ q = q * (self.scale / default_scale)
+
+ # Check if there is a compatible device for flash attention
+
+ config = self.cuda_config if is_cuda else self.cpu_config
+
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
+
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
+ out = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p = self.dropout if self.training else 0.
+ )
+
+ return out
+
+ def forward(self, q, k, v):
+ """
+ einstein notation
+ b - batch
+ h - heads
+ n, i, j - sequence length (base sequence length, source, target)
+ d - feature dimension
+ """
+
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
+
+ scale = default(self.scale, q.shape[-1] ** -0.5)
+
+ if self.flash:
+ return self.flash_attn(q, k, v)
+
+ # similarity
+
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
+
+ # attention
+
+ attn = sim.softmax(dim=-1)
+ attn = self.attn_dropout(attn)
+
+ # aggregate values
+
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
+
+ return out
diff --git a/models/bs_roformer/bs_roformer.py b/models/bs_roformer/bs_roformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..195593ed4794f808034ca422993bc63d68ae9643
--- /dev/null
+++ b/models/bs_roformer/bs_roformer.py
@@ -0,0 +1,622 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from models.bs_roformer.attend import Attend
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack
+from einops.layers.torch import Rearrange
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+# norm
+
+def l2norm(t):
+ return F.normalize(t, dim = -1, p = 2)
+
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+class FeedForward(Module):
+ def __init__(
+ self,
+ dim,
+ mult=4,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ rotary_embed=None,
+ flash=True
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ self.attend = Attend(flash=flash, dropout=dropout)
+
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head=32,
+ heads=8,
+ scale=8,
+ flash=False,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ self.attend = Attend(
+ scale=scale,
+ dropout=dropout,
+ flash=flash
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange('b h d n -> b n (h d)'),
+ nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(
+ self,
+ x
+ ):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
+ else:
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
+ rotary_embed=rotary_embed, flash=flash_attn)
+
+ self.layers.append(ModuleList([
+ attn,
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
+ ]))
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+class BandSplit(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...]
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(
+ RMSNorm(dim_in),
+ nn.Linear(dim_in, dim)
+ )
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(
+ dim_in,
+ dim_out,
+ dim_hidden=None,
+ depth=1,
+ activation=nn.Tanh
+):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...],
+ depth,
+ mlp_expansion_factor=4
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
+ nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+DEFAULT_FREQS_PER_BANDS = (
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
+ 12, 12, 12, 12, 12, 12, 12, 12,
+ 24, 24, 24, 24, 24, 24, 24, 24,
+ 48, 48, 48, 48, 48, 48, 48, 48,
+ 128, 129,
+)
+
+
+class BSRoformer(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
+ # in the paper, they divide into ~60 bands, test with 1 for starters
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=2,
+ multi_stft_resolution_loss_weight=1.,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn,
+ norm_output=False
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
+ tran_modules.append(
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
+ )
+ tran_modules.append(
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.final_norm = RMSNorm(dim)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized
+ )
+
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
+
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
+
+ assert len(freqs_per_bands) > 1
+ assert sum(
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
+
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
+
+ self.band_split = BandSplit(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex
+ )
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size,
+ normalized=multi_stft_normalized
+ )
+
+ def forward(
+ self,
+ raw_audio,
+ target=None,
+ return_loss_breakdown=False
+ ):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ # defining whether model is loaded on MPS (MacOS GPU accelerator)
+ x_is_mps = True if device.type == "mps" else False
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+
+ channels = raw_audio.shape[1]
+ assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+
+ stft_window = self.stft_window_fn(device=device)
+
+ # RuntimeError: FFT operations are only supported on MacOS 14+
+ # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used
+ try:
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ except:
+ stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs,
+ window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to(
+ device)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
+
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
+
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = transformer_block
+
+ x, ft_ps = pack([x], 'b * d')
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ x, = unpack(x, ft_ps, 'b * d')
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, 'b t f d -> b f t d')
+ x, ps = pack([x], '* t d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ x, = unpack(x, ps, '* t d')
+ x = rearrange(x, 'b f t d -> b t f d')
+ x, ps = pack([x], '* f d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ x, = unpack(x, ps, '* f d')
+
+ if self.skip_connection:
+ store[i] = x
+
+ x = self.final_norm(x)
+
+ num_stems = len(self.mask_estimators)
+
+ if self.use_torch_checkpoint:
+ mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
+ else:
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ mask = torch.view_as_complex(mask)
+
+ stft_repr = stft_repr * mask
+
+ # istft
+
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+
+ # same as torch.stft() fix for MacOS MPS above
+ try:
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
+ except:
+ recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device)
+
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems)
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, '... t -> ... 1 t')
+
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
+
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
\ No newline at end of file
diff --git a/models/bs_roformer/mel_band_roformer.py b/models/bs_roformer/mel_band_roformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0d2c40f2e00eb0b99521e6506cf0c0027561541
--- /dev/null
+++ b/models/bs_roformer/mel_band_roformer.py
@@ -0,0 +1,668 @@
+from functools import partial
+
+import torch
+from torch import nn, einsum, Tensor
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+
+from models.bs_roformer.attend import Attend
+from torch.utils.checkpoint import checkpoint
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+from rotary_embedding_torch import RotaryEmbedding
+
+from einops import rearrange, pack, unpack, reduce, repeat
+from einops.layers.torch import Rearrange
+
+from librosa import filters
+
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+def pad_at_dim(t, pad, dim=-1, value=0.):
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
+ zeros = ((0, 0) * dims_from_right)
+ return F.pad(t, (*zeros, *pad), value=value)
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1, p=2)
+
+
+# norm
+
+class RMSNorm(Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+# attention
+
+class FeedForward(Module):
+ def __init__(
+ self,
+ dim,
+ mult=4,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = int(dim * mult)
+ self.net = nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim_inner),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_inner, dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(Module):
+ def __init__(
+ self,
+ dim,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ rotary_embed=None,
+ flash=True
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ dim_inner = heads * dim_head
+
+ self.rotary_embed = rotary_embed
+
+ self.attend = Attend(flash=flash, dropout=dropout)
+
+ self.norm = RMSNorm(dim)
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
+
+ self.to_gates = nn.Linear(dim, heads)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(dim_inner, dim, bias=False),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
+
+ if exists(self.rotary_embed):
+ q = self.rotary_embed.rotate_queries_or_keys(q)
+ k = self.rotary_embed.rotate_queries_or_keys(k)
+
+ out = self.attend(q, k, v)
+
+ gates = self.to_gates(x)
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+
+class LinearAttention(Module):
+ """
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head=32,
+ heads=8,
+ scale=8,
+ flash=False,
+ dropout=0.
+ ):
+ super().__init__()
+ dim_inner = dim_head * heads
+ self.norm = RMSNorm(dim)
+
+ self.to_qkv = nn.Sequential(
+ nn.Linear(dim, dim_inner * 3, bias=False),
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
+ )
+
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
+
+ self.attend = Attend(
+ scale=scale,
+ dropout=dropout,
+ flash=flash
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange('b h d n -> b n (h d)'),
+ nn.Linear(dim_inner, dim, bias=False)
+ )
+
+ def forward(
+ self,
+ x
+ ):
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x)
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.temperature.exp()
+
+ out = self.attend(q, k, v)
+
+ return self.to_out(out)
+
+
+class Transformer(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.,
+ ff_dropout=0.,
+ ff_mult=4,
+ norm_output=True,
+ rotary_embed=None,
+ flash_attn=True,
+ linear_attn=False
+ ):
+ super().__init__()
+ self.layers = ModuleList([])
+
+ for _ in range(depth):
+ if linear_attn:
+ attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn)
+ else:
+ attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout,
+ rotary_embed=rotary_embed, flash=flash_attn)
+
+ self.layers.append(ModuleList([
+ attn,
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
+ ]))
+
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
+
+ def forward(self, x):
+
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+
+# bandsplit module
+
+class BandSplit(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...]
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(
+ RMSNorm(dim_in),
+ nn.Linear(dim_in, dim)
+ )
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+def MLP(
+ dim_in,
+ dim_out,
+ dim_hidden=None,
+ depth=1,
+ activation=nn.Tanh
+):
+ dim_hidden = default(dim_hidden, dim_in)
+
+ net = []
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
+
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
+ is_last = ind == (len(dims) - 2)
+
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
+
+ if is_last:
+ continue
+
+ net.append(activation())
+
+ return nn.Sequential(*net)
+
+
+class MaskEstimator(Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...],
+ depth,
+ mlp_expansion_factor=4
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_freqs = ModuleList([])
+ dim_hidden = dim * mlp_expansion_factor
+
+ for dim_in in dim_inputs:
+ net = []
+
+ mlp = nn.Sequential(
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
+ nn.GLU(dim=-1)
+ )
+
+ self.to_freqs.append(mlp)
+
+ def forward(self, x):
+ x = x.unbind(dim=-2)
+
+ outs = []
+
+ for band_features, mlp in zip(x, self.to_freqs):
+ freq_out = mlp(band_features)
+ outs.append(freq_out)
+
+ return torch.cat(outs, dim=-1)
+
+
+# main class
+
+class MelBandRoformer(Module):
+
+ @beartype
+ def __init__(
+ self,
+ dim,
+ *,
+ depth,
+ stereo=False,
+ num_stems=1,
+ time_transformer_depth=2,
+ freq_transformer_depth=2,
+ linear_transformer_depth=0,
+ num_bands=60,
+ dim_head=64,
+ heads=8,
+ attn_dropout=0.1,
+ ff_dropout=0.1,
+ flash_attn=True,
+ dim_freqs_in=1025,
+ sample_rate=44100, # needed for mel filter bank from librosa
+ stft_n_fft=2048,
+ stft_hop_length=512,
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
+ stft_win_length=2048,
+ stft_normalized=False,
+ stft_window_fn: Optional[Callable] = None,
+ mask_estimator_depth=1,
+ multi_stft_resolution_loss_weight=1.,
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
+ multi_stft_hop_size=147,
+ multi_stft_normalized=False,
+ multi_stft_window_fn: Callable = torch.hann_window,
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
+ mlp_expansion_factor=4,
+ use_torch_checkpoint=False,
+ skip_connection=False,
+ ):
+ super().__init__()
+
+ self.stereo = stereo
+ self.audio_channels = 2 if stereo else 1
+ self.num_stems = num_stems
+ self.use_torch_checkpoint = use_torch_checkpoint
+ self.skip_connection = skip_connection
+
+ self.layers = ModuleList([])
+
+ transformer_kwargs = dict(
+ dim=dim,
+ heads=heads,
+ dim_head=dim_head,
+ attn_dropout=attn_dropout,
+ ff_dropout=ff_dropout,
+ flash_attn=flash_attn
+ )
+
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
+
+ for _ in range(depth):
+ tran_modules = []
+ if linear_transformer_depth > 0:
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
+ tran_modules.append(
+ Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs)
+ )
+ tran_modules.append(
+ Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs)
+ )
+ self.layers.append(nn.ModuleList(tran_modules))
+
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
+
+ self.stft_kwargs = dict(
+ n_fft=stft_n_fft,
+ hop_length=stft_hop_length,
+ win_length=stft_win_length,
+ normalized=stft_normalized
+ )
+
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
+
+ # create mel filter bank
+ # with librosa.filters.mel as in section 2 of paper
+
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
+
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
+
+ # for some reason, it doesn't include the first freq? just force a value for now
+
+ mel_filter_bank[0][0] = 1.
+
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
+ # so let's force a positive value
+
+ mel_filter_bank[-1, -1] = 1.
+
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
+
+ freqs_per_band = mel_filter_bank > 0
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
+
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
+ freq_indices = repeated_freq_indices[freqs_per_band]
+
+ if stereo:
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
+ freq_indices = freq_indices * 2 + torch.arange(2)
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
+
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
+
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
+
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
+
+ # band split and mask estimator
+
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
+
+ self.band_split = BandSplit(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex
+ )
+
+ self.mask_estimators = nn.ModuleList([])
+
+ for _ in range(num_stems):
+ mask_estimator = MaskEstimator(
+ dim=dim,
+ dim_inputs=freqs_per_bands_with_complex,
+ depth=mask_estimator_depth,
+ mlp_expansion_factor=mlp_expansion_factor,
+ )
+
+ self.mask_estimators.append(mask_estimator)
+
+ # for the multi-resolution stft loss
+
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
+ self.multi_stft_n_fft = stft_n_fft
+ self.multi_stft_window_fn = multi_stft_window_fn
+
+ self.multi_stft_kwargs = dict(
+ hop_length=multi_stft_hop_size,
+ normalized=multi_stft_normalized
+ )
+
+ self.match_input_audio_length = match_input_audio_length
+
+ def forward(
+ self,
+ raw_audio,
+ target=None,
+ return_loss_breakdown=False
+ ):
+ """
+ einops
+
+ b - batch
+ f - freq
+ t - time
+ s - audio channel (1 for mono, 2 for stereo)
+ n - number of 'stems'
+ c - complex (2)
+ d - feature dimension
+ """
+
+ device = raw_audio.device
+
+ if raw_audio.ndim == 2:
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
+
+ batch, channels, raw_audio_length = raw_audio.shape
+
+ istft_length = raw_audio_length if self.match_input_audio_length else None
+
+ assert (not self.stereo and channels == 1) or (
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
+
+ # to stft
+
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
+
+ stft_window = self.stft_window_fn(device=device)
+
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
+ stft_repr = torch.view_as_real(stft_repr)
+
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
+
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
+
+ # index out all frequencies for all frequency ranges across bands ascending in one go
+
+ batch_arange = torch.arange(batch, device=device)[..., None]
+
+ # account for stereo
+
+ x = stft_repr[batch_arange, self.freq_indices]
+
+ # fold the complex (real and imag) into the frequencies dimension
+
+ x = rearrange(x, 'b f t c -> b t (f c)')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(self.band_split, x, use_reentrant=False)
+ else:
+ x = self.band_split(x)
+
+ # axial / hierarchical attention
+
+ store = [None] * len(self.layers)
+ for i, transformer_block in enumerate(self.layers):
+
+ if len(transformer_block) == 3:
+ linear_transformer, time_transformer, freq_transformer = transformer_block
+
+ x, ft_ps = pack([x], 'b * d')
+ if self.use_torch_checkpoint:
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
+ else:
+ x = linear_transformer(x)
+ x, = unpack(x, ft_ps, 'b * d')
+ else:
+ time_transformer, freq_transformer = transformer_block
+
+ if self.skip_connection:
+ # Sum all previous
+ for j in range(i):
+ x = x + store[j]
+
+ x = rearrange(x, 'b t f d -> b f t d')
+ x, ps = pack([x], '* t d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(time_transformer, x, use_reentrant=False)
+ else:
+ x = time_transformer(x)
+
+ x, = unpack(x, ps, '* t d')
+ x = rearrange(x, 'b f t d -> b t f d')
+ x, ps = pack([x], '* f d')
+
+ if self.use_torch_checkpoint:
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
+ else:
+ x = freq_transformer(x)
+
+ x, = unpack(x, ps, '* f d')
+
+ if self.skip_connection:
+ store[i] = x
+
+ num_stems = len(self.mask_estimators)
+ if self.use_torch_checkpoint:
+ masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1)
+ else:
+ masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
+
+ # modulate frequency representation
+
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
+
+ # complex number multiplication
+
+ stft_repr = torch.view_as_complex(stft_repr)
+ masks = torch.view_as_complex(masks)
+
+ masks = masks.type(stft_repr.dtype)
+
+ # need to average the estimated mask for the overlapped frequencies
+
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
+
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
+
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
+
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
+
+ # modulate stft repr with estimated mask
+
+ stft_repr = stft_repr * masks_averaged
+
+ # istft
+
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
+
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
+ length=istft_length)
+
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
+
+ if num_stems == 1:
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
+
+ # if a target is passed in, calculate loss for learning
+
+ if not exists(target):
+ return recon_audio
+
+ if self.num_stems > 1:
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
+
+ if target.ndim == 2:
+ target = rearrange(target, '... t -> ... 1 t')
+
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
+
+ loss = F.l1_loss(recon_audio, target)
+
+ multi_stft_resolution_loss = 0.
+
+ for window_size in self.multi_stft_resolutions_window_sizes:
+ res_stft_kwargs = dict(
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
+ win_length=window_size,
+ return_complex=True,
+ window=self.multi_stft_window_fn(window_size, device=device),
+ **self.multi_stft_kwargs,
+ )
+
+ recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
+ target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
+
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
+
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
+
+ total_loss = loss + weighted_multi_resolution_loss
+
+ if not return_loss_breakdown:
+ return total_loss
+
+ return total_loss, (loss, multi_stft_resolution_loss)
diff --git a/models/demucs4ht.py b/models/demucs4ht.py
new file mode 100644
index 0000000000000000000000000000000000000000..06c279c31a7ac7e12af4375a5715eb291ad5405c
--- /dev/null
+++ b/models/demucs4ht.py
@@ -0,0 +1,713 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+import numpy as np
+import torch
+import json
+from omegaconf import OmegaConf
+from demucs.demucs import Demucs
+from demucs.hdemucs import HDemucs
+
+import math
+from openunmix.filtering import wiener
+from torch import nn
+from torch.nn import functional as F
+from fractions import Fraction
+from einops import rearrange
+
+from demucs.transformer import CrossTransformerEncoder
+
+from demucs.demucs import rescale_module
+from demucs.states import capture_init
+from demucs.spec import spectro, ispectro
+from demucs.hdemucs import pad1d, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
+
+
+class HTDemucs(nn.Module):
+ """
+ Spectrogram and hybrid Demucs model.
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
+ Frequency layers can still access information across time steps thanks to the DConv residual.
+
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
+
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
+ Open Unmix implementation [Stoter et al. 2019].
+
+ The loss is always on the temporal domain, by backpropagating through the above
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
+ contribution, without changing the one from the waveform, which will lead to worse performance.
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
+ hybrid models.
+
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
+
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
+ """
+
+ @capture_init
+ def __init__(
+ self,
+ sources,
+ # Channels
+ audio_channels=2,
+ channels=48,
+ channels_time=None,
+ growth=2,
+ # STFT
+ nfft=4096,
+ num_subbands=1,
+ wiener_iters=0,
+ end_iters=0,
+ wiener_residual=False,
+ cac=True,
+ # Main structure
+ depth=4,
+ rewrite=True,
+ # Frequency branch
+ multi_freqs=None,
+ multi_freqs_depth=3,
+ freq_emb=0.2,
+ emb_scale=10,
+ emb_smooth=True,
+ # Convolutions
+ kernel_size=8,
+ time_stride=2,
+ stride=4,
+ context=1,
+ context_enc=0,
+ # Normalization
+ norm_starts=4,
+ norm_groups=4,
+ # DConv residual branch
+ dconv_mode=1,
+ dconv_depth=2,
+ dconv_comp=8,
+ dconv_init=1e-3,
+ # Before the Transformer
+ bottom_channels=0,
+ # Transformer
+ t_layers=5,
+ t_emb="sin",
+ t_hidden_scale=4.0,
+ t_heads=8,
+ t_dropout=0.0,
+ t_max_positions=10000,
+ t_norm_in=True,
+ t_norm_in_group=False,
+ t_group_norm=False,
+ t_norm_first=True,
+ t_norm_out=True,
+ t_max_period=10000.0,
+ t_weight_decay=0.0,
+ t_lr=None,
+ t_layer_scale=True,
+ t_gelu=True,
+ t_weight_pos_embed=1.0,
+ t_sin_random_shift=0,
+ t_cape_mean_normalize=True,
+ t_cape_augment=True,
+ t_cape_glob_loc_scale=[5000.0, 1.0, 1.4],
+ t_sparse_self_attn=False,
+ t_sparse_cross_attn=False,
+ t_mask_type="diag",
+ t_mask_random_seed=42,
+ t_sparse_attn_window=500,
+ t_global_window=100,
+ t_sparsity=0.95,
+ t_auto_sparsity=False,
+ # ------ Particuliar parameters
+ t_cross_first=False,
+ # Weight init
+ rescale=0.1,
+ # Metadata
+ samplerate=44100,
+ segment=10,
+ use_train_segment=False,
+ ):
+ """
+ Args:
+ sources (list[str]): list of source names.
+ audio_channels (int): input/output audio channels.
+ channels (int): initial number of hidden channels.
+ channels_time: if not None, use a different `channels` value for the time branch.
+ growth: increase the number of hidden channels by this factor at each layer.
+ nfft: number of fft bins. Note that changing this require careful computation of
+ various shape parameters and will not work out of the box for hybrid models.
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
+ wiener_residual: add residual source before wiener filtering.
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
+ in input and output. no further processing is done before ISTFT.
+ depth (int): number of layers in the encoder and in the decoder.
+ rewrite (bool): add 1x1 convolution to each layer.
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
+ layers will be wrapped.
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
+ the actual value controls the weight of the embedding.
+ emb_scale: equivalent to scaling the embedding learning rate
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
+ kernel_size: kernel_size for encoder and decoder layers.
+ stride: stride for encoder and decoder layers.
+ time_stride: stride for the final time layer, after the merge.
+ context: context for 1x1 conv in the decoder.
+ context_enc: context for 1x1 conv in the encoder.
+ norm_starts: layer at which group norm starts being used.
+ decoder layers are numbered in reverse order.
+ norm_groups: number of groups for group norm.
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
+ dconv_depth: depth of residual DConv branch.
+ dconv_comp: compression of DConv branch.
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
+ dconv_init: initial scale for the DConv branch LayerScale.
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
+ transformer in order to change the number of channels
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
+ t_emb: "sin", "cape" or "scaled"
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
+ for instance if C = 384 (the number of channels in the transformer) and
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
+ 384 * 4 = 1536
+ t_heads: number of heads for the transformer
+ t_dropout: dropout in the transformer
+ t_max_positions: max_positions for the "scaled" positional embedding, only
+ useful if t_emb="scaled"
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
+ transformer layers
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
+ timesteps (GroupNorm with group=1)
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
+ timesteps (GroupNorm with group=1)
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
+ t_max_period: (float) denominator in the sinusoidal embedding expression
+ t_weight_decay: (float) weight decay for the transformer
+ t_lr: (float) specific learning rate for the transformer
+ t_layer_scale: (bool) Layer Scale for the transformer
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
+ t_weight_pos_embed: (float) weighting of the positional embedding
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
+ see: https://arxiv.org/abs/2106.03143
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
+ during the inference, see: https://arxiv.org/abs/2106.03143
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
+ see: https://arxiv.org/abs/2106.03143
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
+ unless you designed really specific masks)
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
+ that generated the random part of the mask
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
+ and mask[:, :t_global_window] will be True
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
+ level of the random part of the mask.
+ t_cross_first: (bool) if True cross attention is the first layer of the
+ transformer (False seems to be better)
+ rescale: weight rescaling trick
+ use_train_segment: (bool) if True, the actual size that is used during the
+ training is used during inference.
+ """
+ super().__init__()
+ self.num_subbands = num_subbands
+ self.cac = cac
+ self.wiener_residual = wiener_residual
+ self.audio_channels = audio_channels
+ self.sources = sources
+ self.kernel_size = kernel_size
+ self.context = context
+ self.stride = stride
+ self.depth = depth
+ self.bottom_channels = bottom_channels
+ self.channels = channels
+ self.samplerate = samplerate
+ self.segment = segment
+ self.use_train_segment = use_train_segment
+ self.nfft = nfft
+ self.hop_length = nfft // 4
+ self.wiener_iters = wiener_iters
+ self.end_iters = end_iters
+ self.freq_emb = None
+ assert wiener_iters == end_iters
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ self.tencoder = nn.ModuleList()
+ self.tdecoder = nn.ModuleList()
+
+ chin = audio_channels
+ chin_z = chin # number of channels for the freq branch
+ if self.cac:
+ chin_z *= 2
+ if self.num_subbands > 1:
+ chin_z *= self.num_subbands
+ chout = channels_time or channels
+ chout_z = channels
+ freqs = nfft // 2
+
+ for index in range(depth):
+ norm = index >= norm_starts
+ freq = freqs > 1
+ stri = stride
+ ker = kernel_size
+ if not freq:
+ assert freqs == 1
+ ker = time_stride * 2
+ stri = time_stride
+
+ pad = True
+ last_freq = False
+ if freq and freqs <= kernel_size:
+ ker = freqs
+ pad = False
+ last_freq = True
+
+ kw = {
+ "kernel_size": ker,
+ "stride": stri,
+ "freq": freq,
+ "pad": pad,
+ "norm": norm,
+ "rewrite": rewrite,
+ "norm_groups": norm_groups,
+ "dconv_kw": {
+ "depth": dconv_depth,
+ "compress": dconv_comp,
+ "init": dconv_init,
+ "gelu": True,
+ },
+ }
+ kwt = dict(kw)
+ kwt["freq"] = 0
+ kwt["kernel_size"] = kernel_size
+ kwt["stride"] = stride
+ kwt["pad"] = True
+ kw_dec = dict(kw)
+ multi = False
+ if multi_freqs and index < multi_freqs_depth:
+ multi = True
+ kw_dec["context_freq"] = False
+
+ if last_freq:
+ chout_z = max(chout, chout_z)
+ chout = chout_z
+
+ enc = HEncLayer(
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
+ )
+ if freq:
+ tenc = HEncLayer(
+ chin,
+ chout,
+ dconv=dconv_mode & 1,
+ context=context_enc,
+ empty=last_freq,
+ **kwt
+ )
+ self.tencoder.append(tenc)
+
+ if multi:
+ enc = MultiWrap(enc, multi_freqs)
+ self.encoder.append(enc)
+ if index == 0:
+ chin = self.audio_channels * len(self.sources)
+ chin_z = chin
+ if self.cac:
+ chin_z *= 2
+ if self.num_subbands > 1:
+ chin_z *= self.num_subbands
+ dec = HDecLayer(
+ chout_z,
+ chin_z,
+ dconv=dconv_mode & 2,
+ last=index == 0,
+ context=context,
+ **kw_dec
+ )
+ if multi:
+ dec = MultiWrap(dec, multi_freqs)
+ if freq:
+ tdec = HDecLayer(
+ chout,
+ chin,
+ dconv=dconv_mode & 2,
+ empty=last_freq,
+ last=index == 0,
+ context=context,
+ **kwt
+ )
+ self.tdecoder.insert(0, tdec)
+ self.decoder.insert(0, dec)
+
+ chin = chout
+ chin_z = chout_z
+ chout = int(growth * chout)
+ chout_z = int(growth * chout_z)
+ if freq:
+ if freqs <= kernel_size:
+ freqs = 1
+ else:
+ freqs //= stride
+ if index == 0 and freq_emb:
+ self.freq_emb = ScaledEmbedding(
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
+ )
+ self.freq_emb_scale = freq_emb
+
+ if rescale:
+ rescale_module(self, reference=rescale)
+
+ transformer_channels = channels * growth ** (depth - 1)
+ if bottom_channels:
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
+ self.channel_downsampler = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+ self.channel_upsampler_t = nn.Conv1d(
+ transformer_channels, bottom_channels, 1
+ )
+ self.channel_downsampler_t = nn.Conv1d(
+ bottom_channels, transformer_channels, 1
+ )
+
+ transformer_channels = bottom_channels
+
+ if t_layers > 0:
+ self.crosstransformer = CrossTransformerEncoder(
+ dim=transformer_channels,
+ emb=t_emb,
+ hidden_scale=t_hidden_scale,
+ num_heads=t_heads,
+ num_layers=t_layers,
+ cross_first=t_cross_first,
+ dropout=t_dropout,
+ max_positions=t_max_positions,
+ norm_in=t_norm_in,
+ norm_in_group=t_norm_in_group,
+ group_norm=t_group_norm,
+ norm_first=t_norm_first,
+ norm_out=t_norm_out,
+ max_period=t_max_period,
+ weight_decay=t_weight_decay,
+ lr=t_lr,
+ layer_scale=t_layer_scale,
+ gelu=t_gelu,
+ sin_random_shift=t_sin_random_shift,
+ weight_pos_embed=t_weight_pos_embed,
+ cape_mean_normalize=t_cape_mean_normalize,
+ cape_augment=t_cape_augment,
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
+ sparse_self_attn=t_sparse_self_attn,
+ sparse_cross_attn=t_sparse_cross_attn,
+ mask_type=t_mask_type,
+ mask_random_seed=t_mask_random_seed,
+ sparse_attn_window=t_sparse_attn_window,
+ global_window=t_global_window,
+ sparsity=t_sparsity,
+ auto_sparsity=t_auto_sparsity,
+ )
+ else:
+ self.crosstransformer = None
+
+ def _spec(self, x):
+ hl = self.hop_length
+ nfft = self.nfft
+ x0 = x # noqa
+
+ # We re-pad the signal in order to keep the property
+ # that the size of the output is exactly the size of the input
+ # divided by the stride (here hop_length), when divisible.
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
+ # which is not supported by torch.stft.
+ # Having all convolution operations follow this convention allow to easily
+ # align the time and frequency branches later on.
+ assert hl == nfft // 4
+ le = int(math.ceil(x.shape[-1] / hl))
+ pad = hl // 2 * 3
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
+
+ z = spectro(x, nfft, hl)[..., :-1, :]
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
+ z = z[..., 2: 2 + le]
+ return z
+
+ def _ispec(self, z, length=None, scale=0):
+ hl = self.hop_length // (4**scale)
+ z = F.pad(z, (0, 0, 0, 1))
+ z = F.pad(z, (2, 2))
+ pad = hl // 2 * 3
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
+ x = ispectro(z, hl, length=le)
+ x = x[..., pad: pad + length]
+ return x
+
+ def _magnitude(self, z):
+ # return the magnitude of the spectrogram, except when cac is True,
+ # in which case we just move the complex dimension to the channel one.
+ if self.cac:
+ B, C, Fr, T = z.shape
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
+ m = m.reshape(B, C * 2, Fr, T)
+ else:
+ m = z.abs()
+ return m
+
+ def _mask(self, z, m):
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
+ niters = self.wiener_iters
+ if self.cac:
+ B, S, C, Fr, T = m.shape
+ out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
+ out = torch.view_as_complex(out.contiguous())
+ return out
+ if self.training:
+ niters = self.end_iters
+ if niters < 0:
+ z = z[:, None]
+ return z / (1e-8 + z.abs()) * m
+ else:
+ return self._wiener(m, z, niters)
+
+ def _wiener(self, mag_out, mix_stft, niters):
+ # apply wiener filtering from OpenUnmix.
+ init = mix_stft.dtype
+ wiener_win_len = 300
+ residual = self.wiener_residual
+
+ B, S, C, Fq, T = mag_out.shape
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
+
+ outs = []
+ for sample in range(B):
+ pos = 0
+ out = []
+ for pos in range(0, T, wiener_win_len):
+ frame = slice(pos, pos + wiener_win_len)
+ z_out = wiener(
+ mag_out[sample, frame],
+ mix_stft[sample, frame],
+ niters,
+ residual=residual,
+ )
+ out.append(z_out.transpose(-1, -2))
+ outs.append(torch.cat(out, dim=0))
+ out = torch.view_as_complex(torch.stack(outs, 0))
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
+ if residual:
+ out = out[:, :-1]
+ assert list(out.shape) == [B, S, C, Fq, T]
+ return out.to(init)
+
+ def valid_length(self, length: int):
+ """
+ Return a length that is appropriate for evaluation.
+ In our case, always return the training length, unless
+ it is smaller than the given length, in which case this
+ raises an error.
+ """
+ if not self.use_train_segment:
+ return length
+ training_length = int(self.segment * self.samplerate)
+ if training_length < length:
+ raise ValueError(
+ f"Given length {length} is longer than "
+ f"training length {training_length}")
+ return training_length
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, mix):
+ length = mix.shape[-1]
+ length_pre_pad = None
+ if self.use_train_segment:
+ if self.training:
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
+ else:
+ training_length = int(self.segment * self.samplerate)
+ # print('Training length: {} Segment: {} Sample rate: {}'.format(training_length, self.segment, self.samplerate))
+ if mix.shape[-1] < training_length:
+ length_pre_pad = mix.shape[-1]
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
+ # print("Mix: {}".format(mix.shape))
+ # print("Length: {}".format(length))
+ z = self._spec(mix)
+ # print("Z: {} Type: {}".format(z.shape, z.dtype))
+ mag = self._magnitude(z)
+ x = mag
+ # print("MAG: {} Type: {}".format(x.shape, x.dtype))
+
+ if self.num_subbands > 1:
+ x = self.cac2cws(x)
+ # print("After SUBBANDS: {} Type: {}".format(x.shape, x.dtype))
+
+ B, C, Fq, T = x.shape
+
+ # unlike previous Demucs, we always normalize because it is easier.
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
+ std = x.std(dim=(1, 2, 3), keepdim=True)
+ x = (x - mean) / (1e-5 + std)
+ # x will be the freq. branch input.
+
+ # Prepare the time branch input.
+ xt = mix
+ meant = xt.mean(dim=(1, 2), keepdim=True)
+ stdt = xt.std(dim=(1, 2), keepdim=True)
+ xt = (xt - meant) / (1e-5 + stdt)
+
+ # print("XT: {}".format(xt.shape))
+
+ # okay, this is a giant mess I know...
+ saved = [] # skip connections, freq.
+ saved_t = [] # skip connections, time.
+ lengths = [] # saved lengths to properly remove padding, freq branch.
+ lengths_t = [] # saved lengths for time branch.
+ for idx, encode in enumerate(self.encoder):
+ lengths.append(x.shape[-1])
+ inject = None
+ if idx < len(self.tencoder):
+ # we have not yet merged branches.
+ lengths_t.append(xt.shape[-1])
+ tenc = self.tencoder[idx]
+ xt = tenc(xt)
+ # print("Encode XT {}: {}".format(idx, xt.shape))
+ if not tenc.empty:
+ # save for skip connection
+ saved_t.append(xt)
+ else:
+ # tenc contains just the first conv., so that now time and freq.
+ # branches have the same shape and can be merged.
+ inject = xt
+ x = encode(x, inject)
+ # print("Encode X {}: {}".format(idx, x.shape))
+ if idx == 0 and self.freq_emb is not None:
+ # add frequency embedding to allow for non equivariant convolutions
+ # over the frequency axis.
+ frs = torch.arange(x.shape[-2], device=x.device)
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
+ x = x + self.freq_emb_scale * emb
+
+ saved.append(x)
+ if self.crosstransformer:
+ if self.bottom_channels:
+ b, c, f, t = x.shape
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_upsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_upsampler_t(xt)
+
+ x, xt = self.crosstransformer(x, xt)
+ # print("Cross Tran X {}, XT: {}".format(x.shape, xt.shape))
+
+ if self.bottom_channels:
+ x = rearrange(x, "b c f t-> b c (f t)")
+ x = self.channel_downsampler(x)
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
+ xt = self.channel_downsampler_t(xt)
+
+ for idx, decode in enumerate(self.decoder):
+ skip = saved.pop(-1)
+ x, pre = decode(x, skip, lengths.pop(-1))
+ # print('Decode {} X: {}'.format(idx, x.shape))
+ # `pre` contains the output just before final transposed convolution,
+ # which is used when the freq. and time branch separate.
+
+ offset = self.depth - len(self.tdecoder)
+ if idx >= offset:
+ tdec = self.tdecoder[idx - offset]
+ length_t = lengths_t.pop(-1)
+ if tdec.empty:
+ assert pre.shape[2] == 1, pre.shape
+ pre = pre[:, :, 0]
+ xt, _ = tdec(pre, None, length_t)
+ else:
+ skip = saved_t.pop(-1)
+ xt, _ = tdec(xt, skip, length_t)
+ # print('Decode {} XT: {}'.format(idx, xt.shape))
+
+ # Let's make sure we used all stored skip connections.
+ assert len(saved) == 0
+ assert len(lengths_t) == 0
+ assert len(saved_t) == 0
+
+ S = len(self.sources)
+
+ if self.num_subbands > 1:
+ x = x.view(B, -1, Fq, T)
+ # print("X view 1: {}".format(x.shape))
+ x = self.cws2cac(x)
+ # print("X view 2: {}".format(x.shape))
+
+ x = x.view(B, S, -1, Fq * self.num_subbands, T)
+ x = x * std[:, None] + mean[:, None]
+ # print("X returned: {}".format(x.shape))
+
+ zout = self._mask(z, x)
+ if self.use_train_segment:
+ if self.training:
+ x = self._ispec(zout, length)
+ else:
+ x = self._ispec(zout, training_length)
+ else:
+ x = self._ispec(zout, length)
+
+ if self.use_train_segment:
+ if self.training:
+ xt = xt.view(B, S, -1, length)
+ else:
+ xt = xt.view(B, S, -1, training_length)
+ else:
+ xt = xt.view(B, S, -1, length)
+ xt = xt * stdt[:, None] + meant[:, None]
+ x = xt + x
+ if length_pre_pad:
+ x = x[..., :length_pre_pad]
+ return x
+
+
+def get_model(args):
+ extra = {
+ 'sources': list(args.training.instruments),
+ 'audio_channels': args.training.channels,
+ 'samplerate': args.training.samplerate,
+ # 'segment': args.model_segment or 4 * args.dset.segment,
+ 'segment': args.training.segment,
+ }
+ klass = {
+ 'demucs': Demucs,
+ 'hdemucs': HDemucs,
+ 'htdemucs': HTDemucs,
+ }[args.model]
+ kw = OmegaConf.to_container(getattr(args, args.model), resolve=True)
+ model = klass(**extra, **kw)
+ return model
+
+
diff --git a/models/ex_bi_mamba2.py b/models/ex_bi_mamba2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bcf797cce8b28e88c7aab9adfa2037070504763
--- /dev/null
+++ b/models/ex_bi_mamba2.py
@@ -0,0 +1,303 @@
+# https://github.com/Human9000/nd-Mamba2-torch
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+from abc import abstractmethod
+
+
+def silu(x):
+ return x * F.sigmoid(x)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, d: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(d))
+
+ def forward(self, x, z):
+ x = x * silu(z)
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
+
+
+class Mamba2(nn.Module):
+ def __init__(self, d_model: int, # model dimension (D)
+ n_layer: int = 24, # number of Mamba-2 layers in the language model
+ d_state: int = 128, # state dimension (N)
+ d_conv: int = 4, # convolution kernel size
+ expand: int = 2, # expansion factor (E)
+ headdim: int = 64, # head dimension (P)
+ chunk_size: int = 64, # matrix partition size (Q)
+ ):
+ super().__init__()
+ self.n_layer = n_layer
+ self.d_state = d_state
+ self.headdim = headdim
+ # self.chunk_size = torch.tensor(chunk_size, dtype=torch.int32)
+ self.chunk_size = chunk_size
+
+ self.d_inner = expand * d_model
+ assert self.d_inner % self.headdim == 0, "self.d_inner must be divisible by self.headdim"
+ self.nheads = self.d_inner // self.headdim
+
+ d_in_proj = 2 * self.d_inner + 2 * self.d_state + self.nheads
+ self.in_proj = nn.Linear(d_model, d_in_proj, bias=False)
+
+ conv_dim = self.d_inner + 2 * d_state
+ self.conv1d = nn.Conv1d(conv_dim, conv_dim, d_conv, groups=conv_dim, padding=d_conv - 1, )
+ self.dt_bias = nn.Parameter(torch.empty(self.nheads, ))
+ self.A_log = nn.Parameter(torch.empty(self.nheads, ))
+ self.D = nn.Parameter(torch.empty(self.nheads, ))
+ self.norm = RMSNorm(self.d_inner, )
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False, )
+
+ def forward(self, u: Tensor):
+ A = -torch.exp(self.A_log) # (nheads,)
+ zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
+ z, xBC, dt = torch.split(
+ zxbcdt,
+ [
+ self.d_inner,
+ self.d_inner + 2 * self.d_state,
+ self.nheads,
+ ],
+ dim=-1,
+ )
+ dt = F.softplus(dt + self.dt_bias) # (batch, seqlen, nheads)
+
+ # Pad or truncate xBC seqlen to d_conv
+ xBC = silu(
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, : u.shape[1], :]
+ ) # (batch, seqlen, d_inner + 2 * d_state))
+ x, B, C = torch.split(
+ xBC, [self.d_inner, self.d_state, self.d_state], dim=-1
+ )
+
+ _b, _l, _hp = x.shape
+ _h = _hp // self.headdim
+ _p = self.headdim
+ x = x.reshape(_b, _l, _h, _p)
+
+ y = self.ssd(x * dt.unsqueeze(-1),
+ A * dt,
+ B.unsqueeze(2),
+ C.unsqueeze(2), )
+
+ y = y + x * self.D.unsqueeze(-1)
+
+ _b, _l, _h, _p = y.shape
+ y = y.reshape(_b, _l, _h * _p)
+
+ y = self.norm(y, z)
+ y = self.out_proj(y)
+
+ return y
+
+ def segsum(self, x: Tensor) -> Tensor:
+ T = x.size(-1)
+ device = x.device
+ x = x[..., None].repeat(1, 1, 1, 1, T)
+ mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
+ x = x.masked_fill(~mask, 0)
+ x_segsum = torch.cumsum(x, dim=-2)
+ mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
+ return x_segsum
+
+ def ssd(self, x, A, B, C):
+ chunk_size = self.chunk_size
+ # if x.shape[1] % chunk_size == 0:
+ #
+ x = x.reshape(x.shape[0], x.shape[1] // chunk_size, chunk_size, x.shape[2], x.shape[3], )
+ B = B.reshape(B.shape[0], B.shape[1] // chunk_size, chunk_size, B.shape[2], B.shape[3], )
+ C = C.reshape(C.shape[0], C.shape[1] // chunk_size, chunk_size, C.shape[2], C.shape[3], )
+ A = A.reshape(A.shape[0], A.shape[1] // chunk_size, chunk_size, A.shape[2])
+ A = A.permute(0, 3, 1, 2)
+ A_cumsum = torch.cumsum(A, dim=-1)
+
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
+ L = torch.exp(self.segsum(A))
+ Y_diag = torch.einsum("bclhn, bcshn, bhcls, bcshp -> bclhp", C, B, L, x)
+
+ # 2. Compute the state for each intra-chunk
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
+ decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
+ states = torch.einsum("bclhn, bhcl, bclhp -> bchpn", B, decay_states, x)
+
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
+ # (middle term of factorization of off-diag blocks; A terms)
+
+ initial_states = torch.zeros_like(states[:, :1])
+ states = torch.cat([initial_states, states], dim=1)
+
+ decay_chunk = torch.exp(self.segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))[0]
+ new_states = torch.einsum("bhzc, bchpn -> bzhpn", decay_chunk, states)
+ states = new_states[:, :-1]
+
+ # 4. Compute state -> output conversion per chunk
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
+ state_decay_out = torch.exp(A_cumsum)
+ Y_off = torch.einsum("bclhn, bchpn, bhcl -> bclhp", C, states, state_decay_out)
+
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+ # Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
+ Y = Y_diag + Y_off
+ Y = Y.reshape(Y.shape[0], Y.shape[1] * Y.shape[2], Y.shape[3], Y.shape[4], )
+
+ return Y
+
+
+class _BiMamba2(nn.Module):
+ def __init__(self,
+ cin: int,
+ cout: int,
+ d_model: int, # model dimension (D)
+ n_layer: int = 24, # number of Mamba-2 layers in the language model
+ d_state: int = 128, # state dimension (N)
+ d_conv: int = 4, # convolution kernel size
+ expand: int = 2, # expansion factor (E)
+ headdim: int = 64, # head dimension (P)
+ chunk_size: int = 64, # matrix partition size (Q)
+ ):
+ super().__init__()
+ self.fc_in = nn.Linear(cin, d_model, bias=False) # 调整通道数到cmid
+ self.mamba2_for = Mamba2(d_model, n_layer, d_state, d_conv, expand, headdim, chunk_size, ) # 正向
+ self.mamba2_back = Mamba2(d_model, n_layer, d_state, d_conv, expand, headdim, chunk_size, ) # 负向
+ self.fc_out = nn.Linear(d_model, cout, bias=False) # 调整通道数到cout
+ self.chunk_size = chunk_size
+
+ @abstractmethod
+ def forward(self, x):
+ pass
+
+
+class BiMamba2_1D(_BiMamba2):
+ def __init__(self, cin, cout, d_model, **mamba2_args):
+ super().__init__(cin, cout, d_model, **mamba2_args)
+
+ def forward(self, x):
+ l = x.shape[2]
+ x = F.pad(x, (0, (64 - x.shape[2] % 64) % 64)) # 将 l , pad到4的倍数, [b, c64,l4]
+ x = x.transpose(1, 2) # 转成 1d 信号 [b, c64, d4*w4*h4]
+ x = self.fc_in(x) # 调整通道数为目标通道数
+ x1 = self.mamba2_for(x)
+ x2 = self.mamba2_back(x.flip(1)).flip(1)
+ x = x1 + x2
+ x = self.fc_out(x) # 调整通道数为目标通道数
+ x = x.transpose(1, 2) # 转成 1d 信号 [b, c64, d4*w4*h4] ]
+ x = x[:, :, :l] # 截取原图大小
+ return x
+
+
+class BiMamba2_2D(_BiMamba2):
+ def __init__(self, cin, cout, d_model, **mamba2_args):
+ super().__init__(cin, cout, d_model, **mamba2_args)
+
+ def forward(self, x):
+ h, w = x.shape[2:]
+ x = F.pad(x, (0, (8 - x.shape[3] % 8) % 8,
+ 0, (8 - x.shape[2] % 8) % 8)
+ ) # 将 h , w pad到8的倍数, [b, c64, h8, w8]
+ _b, _c, _h, _w = x.shape
+ x = x.permute(0, 2, 3, 1).reshape(_b, _h * _w, _c)
+ x = self.fc_in(x) # 调整通道数为目标通道数
+ x1 = self.mamba2_for(x)
+ x2 = self.mamba2_back(x.flip(1)).flip(1)
+ x = x1 + x2
+ x = self.fc_out(x) # 调整通道数为目标通道数
+ x = x.reshape(_b, _h, _w, -1, )
+ x = x.permute(0, 3, 1, 2)
+ x = x.reshape(_b, -1, _h, _w, )
+ x = x[:, :, :h, :w] # 截取原图大小
+ return x
+
+
+class BiMamba2_3D(_BiMamba2):
+ def __init__(self, cin, cout, d_model, **mamba2_args):
+ super().__init__(cin, cout, d_model, **mamba2_args)
+
+ def forward(self, x):
+ d, h, w = x.shape[2:]
+ x = F.pad(x, (0, (4 - x.shape[4] % 4) % 4,
+ 0, (4 - x.shape[3] % 4) % 4,
+ 0, (4 - x.shape[2] % 4) % 4)
+ ) # 将 d, h, w , pad到4的倍数, [b, c64,d4, h4, w4]
+ _b, _c, _d, _h, _w = x.shape
+ x = x.permute(0, 2, 3, 4, 1).reshape(_b, _d * _h * _w, _c)
+ x = self.fc_in(x) # 调整通道数为目标通道数
+ x1 = self.mamba2_for(x)
+ x2 = self.mamba2_back(x.flip(1)).flip(1)
+ x = x1 + x2
+ x = self.fc_out(x) # 调整通道数为目标通道数
+ x = x.reshape(_b, _d, _h, _w, -1)
+ x = x.permute(0, 4, 1, 2, 3)
+ x=x.reshape(_b, -1, _d, _h, _w, )
+ x = x[:, :, :d, :h, :w] # 截取原图大小
+ return x
+
+
+class BiMamba2(_BiMamba2):
+ def __init__(self, cin, cout, d_model, **mamba2_args):
+ super().__init__(cin, cout, d_model, **mamba2_args)
+
+ def forward(self, x):
+ size = x.shape[2:]
+ out_size = list(x.shape)
+ out_size[1] = -1
+
+ x = torch.flatten(x, 2) # b c size
+ l = x.shape[2]
+ _s = self.chunk_size
+ x = F.pad(x, [0, (_s - x.shape[2] % _s) % _s]) # 将 l, pad到chunk_size的倍数, [b, c64,l4]
+ x = x.transpose(1, 2) # 转成 1d 信号
+ x = self.fc_in(x) # 调整通道数为目标通道数
+ x1 = self.mamba2_for(x)
+ x2 = self.mamba2_back(x.flip(1)).flip(1)
+ x = x1 + x2
+ x = self.fc_out(x) # 调整通道数为目标通道数
+ x = x.transpose(1, 2) # 转成 1d 信号
+ x = x[:, :, :l] # 截取原图大小
+ x = x.reshape(out_size)
+
+ return x
+
+
+def test_export_jit_script(net, x):
+ y = net(x)
+ net_script = torch.jit.script(net)
+ torch.jit.save(net_script, 'net.jit.script')
+ net2 = torch.jit.load('net.jit.script')
+ y = net2(x)
+ print(y.shape)
+
+
+def test_export_onnx(net, x):
+ torch.onnx.export(net,
+ x,
+ "net.onnx", # 输出的 ONNX 文件名
+ export_params=True, # 存储训练参数
+ opset_version=14, # 指定 ONNX 操作集版本
+ do_constant_folding=False, # 是否执行常量折叠优化
+ input_names=['input'], # 输入张量的名称
+ output_names=['output'], # 输出张量的名称
+ dynamic_axes={'input': {0: 'batch_size'}, # 可变维度的字典
+ 'output': {0: 'batch_size'}})
+
+
+if __name__ == '__main__':
+ # 通用的多维度双向mamba2
+ from torchnssd import (
+ export_jit_script,
+ export_onnx,
+ statistics,
+ test_run,
+ )
+
+ net_n = BiMamba2_1D(61, 128, 32).cuda()
+ net_n.eval()
+ x = torch.randn(1, 61, 63).cuda()
+ export_jit_script(net_n)
+ export_onnx(net_n, x)
+ test_run(net_n, x)
+ statistics(net_n, (61, 63))
diff --git a/models/look2hear/models/__init__.py b/models/look2hear/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b72fb1f70bbc43f10bd74583d582ddb45848e45e
--- /dev/null
+++ b/models/look2hear/models/__init__.py
@@ -0,0 +1,49 @@
+###
+# Author: Kai Li
+# Date: 2022-02-12 15:16:35
+# Email: lk21@mails.tsinghua.edu.cn
+# LastEditTime: 2022-10-04 16:24:53
+###
+from .base_model import BaseModel
+from .apollo import Apollo
+
+__all__ = [
+ "BaseModel",
+ "GullFullband",
+ "Apollo"
+]
+
+
+def register_model(custom_model):
+ """Register a custom model, gettable with `models.get`.
+
+ Args:
+ custom_model: Custom model to register.
+
+ """
+ if (
+ custom_model.__name__ in globals().keys()
+ or custom_model.__name__.lower() in globals().keys()
+ ):
+ raise ValueError(
+ f"Model {custom_model.__name__} already exists. Choose another name."
+ )
+ globals().update({custom_model.__name__: custom_model})
+
+
+def get(identifier):
+ """Returns an model class from a string (case-insensitive).
+
+ Args:
+ identifier (str): the model name.
+
+ Returns:
+ :class:`torch.nn.Module`
+ """
+ if isinstance(identifier, str):
+ to_get = {k.lower(): v for k, v in globals().items()}
+ cls = to_get.get(identifier.lower())
+ if cls is None:
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
+ return cls
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
diff --git a/models/look2hear/models/apollo.py b/models/look2hear/models/apollo.py
new file mode 100644
index 0000000000000000000000000000000000000000..5de9afd468e4635e581c0ff41dce7acc4eb249be
--- /dev/null
+++ b/models/look2hear/models/apollo.py
@@ -0,0 +1,324 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from .base_model import BaseModel
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dimension, groups=1):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.ones(dimension))
+ self.groups = groups
+ self.eps = 1e-5
+
+ def forward(self, input):
+ # input size: (B, N, T)
+ B, N, T = input.shape
+ assert N % self.groups == 0
+
+ input_float = input.reshape(B, self.groups, -1, T).float()
+ input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps)
+
+ return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1)
+
+
+class RMVN(nn.Module):
+ """
+ Rescaled MVN.
+ """
+
+ def __init__(self, dimension, groups=1):
+ super(RMVN, self).__init__()
+
+ self.mean = nn.Parameter(torch.zeros(dimension))
+ self.std = nn.Parameter(torch.ones(dimension))
+ self.groups = groups
+ self.eps = 1e-5
+
+ def forward(self, input):
+ # input size: (B, N, *)
+ B, N = input.shape[:2]
+ assert N % self.groups == 0
+ input_reshape = input.reshape(B, self.groups, N // self.groups, -1)
+ T = input_reshape.shape[-1]
+
+ input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / (
+ input_reshape.var(2).unsqueeze(2) + self.eps).sqrt()
+ input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1)
+
+ return input_norm.reshape(input.shape)
+
+
+class Roformer(nn.Module):
+ """
+ Transformer with rotary positional embedding.
+ """
+
+ def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000,
+ input_drop=0., attention_drop=0., causal=True):
+ super().__init__()
+
+ self.input_size = input_size
+ self.hidden_size = hidden_size // num_head
+ self.num_head = num_head
+ self.theta = theta # base frequency for RoPE
+ self.window = window
+ # pre-calculate rotary embeddings
+ cos_freq, sin_freq = self._calc_rotary_emb()
+ self.register_buffer("cos_freq", cos_freq) # win, N
+ self.register_buffer("sin_freq", sin_freq) # win, N
+
+ self.attention_drop = attention_drop
+ self.causal = causal
+ self.eps = 1e-5
+
+ self.input_norm = RMSNorm(self.input_size)
+ self.input_drop = nn.Dropout(p=input_drop)
+ self.weight = nn.Conv1d(self.input_size, self.hidden_size * self.num_head * 3, 1, bias=False)
+ self.output = nn.Conv1d(self.hidden_size * self.num_head, self.input_size, 1, bias=False)
+
+ self.MLP = nn.Sequential(RMSNorm(self.input_size),
+ nn.Conv1d(self.input_size, self.input_size * 8, 1, bias=False),
+ nn.SiLU()
+ )
+ self.MLP_output = nn.Conv1d(self.input_size * 4, self.input_size, 1, bias=False)
+
+ def _calc_rotary_emb(self):
+ freq = 1. / (self.theta ** (
+ torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) # theta_i
+ freq = freq.reshape(1, -1) # 1, N//2
+ pos = torch.arange(0, self.window).reshape(-1, 1) # win, 1
+ cos_freq = torch.cos(pos * freq) # win, N//2
+ sin_freq = torch.sin(pos * freq) # win, N//2
+ cos_freq = torch.stack([cos_freq] * 2, -1).reshape(self.window, self.hidden_size) # win, N
+ sin_freq = torch.stack([sin_freq] * 2, -1).reshape(self.window, self.hidden_size) # win, N
+
+ return cos_freq, sin_freq
+
+ def _add_rotary_emb(self, feature, pos):
+ # feature shape: ..., N
+ N = feature.shape[-1]
+
+ feature_reshape = feature.reshape(-1, N)
+ pos = min(pos, self.window - 1)
+ cos_freq = self.cos_freq[pos]
+ sin_freq = self.sin_freq[pos]
+ reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
+ feature_reshape_neg = (
+ torch.flip(feature_reshape.reshape(-1, N // 2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(
+ -1, N)
+ feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
+
+ return feature_rope.reshape(feature.shape)
+
+ def _add_rotary_sequence(self, feature):
+ # feature shape: ..., T, N
+ T, N = feature.shape[-2:]
+ feature_reshape = feature.reshape(-1, T, N)
+
+ cos_freq = self.cos_freq[:T]
+ sin_freq = self.sin_freq[:T]
+ reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype)
+ feature_reshape_neg = (
+ torch.flip(feature_reshape.reshape(-1, N // 2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(
+ -1, T, N)
+ feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0)
+
+ return feature_rope.reshape(feature.shape)
+
+ def forward(self, input):
+ # input shape: B, N, T
+
+ B, _, T = input.shape
+
+ weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size * 3,
+ T).mT
+ Q, K, V = torch.split(weight, self.hidden_size, dim=-1) # B, num_head, T, N
+
+ # rotary positional embedding
+ Q_rot = self._add_rotary_sequence(Q)
+ K_rot = self._add_rotary_sequence(K)
+
+ attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(),
+ dropout_p=self.attention_drop,
+ is_causal=self.causal) # B, num_head, T, N
+ attention_output = attention_output.mT.reshape(B, -1, T)
+ output = self.output(attention_output) + input
+
+ gate, z = self.MLP(output).chunk(2, dim=1)
+ output = output + self.MLP_output(F.silu(gate) * z)
+
+ return output, (K_rot, V)
+
+
+class ConvActNorm1d(nn.Module):
+ def __init__(self, in_channel, hidden_channel, kernel=7, causal=False):
+ super(ConvActNorm1d, self).__init__()
+
+ self.in_channel = in_channel
+ self.kernel = kernel
+ self.causal = causal
+ if not causal:
+ self.conv = nn.Sequential(
+ nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel - 1) // 2, groups=in_channel),
+ RMSNorm(in_channel),
+ nn.Conv1d(in_channel, hidden_channel, 1),
+ nn.SiLU(),
+ nn.Conv1d(hidden_channel, in_channel, 1)
+ )
+ else:
+ self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel - 1, groups=in_channel),
+ RMSNorm(in_channel),
+ nn.Conv1d(in_channel, hidden_channel, 1),
+ nn.SiLU(),
+ nn.Conv1d(hidden_channel, in_channel, 1)
+ )
+
+ def forward(self, input):
+
+ output = self.conv(input)
+ if self.causal:
+ output = output[..., :-self.kernel + 1]
+ return input + output
+
+
+class ICB(nn.Module):
+ def __init__(self, in_channel, kernel=7, causal=False):
+ super(ICB, self).__init__()
+
+ self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel * 4, kernel, causal=causal),
+ ConvActNorm1d(in_channel, in_channel * 4, kernel, causal=causal),
+ ConvActNorm1d(in_channel, in_channel * 4, kernel, causal=causal)
+ )
+
+ def forward(self, input):
+ return self.blocks(input)
+
+
+class BSNet(nn.Module):
+ def __init__(self, feature_dim, kernel=7):
+ super(BSNet, self).__init__()
+
+ self.feature_dim = feature_dim
+
+ self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False)
+ self.seq_net = ICB(self.feature_dim, kernel=kernel)
+
+ def forward(self, input):
+ # input shape: B, nband, N, T
+
+ B, nband, N, T = input.shape
+
+ # band comm
+ band_input = input.permute(0, 3, 2, 1).reshape(B * T, -1, nband)
+ band_output, _ = self.band_net(band_input)
+ band_output = band_output.reshape(B, T, -1, nband).permute(0, 3, 2, 1)
+
+ # sequence modeling
+ output = self.seq_net(band_output.reshape(B * nband, -1, T)).reshape(B, nband, -1, T) # B, nband, N, T
+
+ return output
+
+
+class Apollo(BaseModel):
+ def __init__(
+ self,
+ sr: int,
+ win: int,
+ feature_dim: int,
+ layer: int
+ ):
+ super().__init__(sample_rate=sr)
+
+ self.sr = sr
+ self.win = int(sr * win // 1000)
+ self.stride = self.win // 2
+ self.enc_dim = self.win // 2 + 1
+ self.feature_dim = feature_dim
+ self.eps = torch.finfo(torch.float32).eps
+
+ # 80 bands
+ bandwidth = int(self.win / 160)
+ self.band_width = [bandwidth] * 79
+ self.band_width.append(self.enc_dim - np.sum(self.band_width))
+ self.nband = len(self.band_width)
+ print(self.band_width, self.nband)
+
+ self.BN = nn.ModuleList([])
+ for i in range(self.nband):
+ self.BN.append(nn.Sequential(RMSNorm(self.band_width[i] * 2 + 1),
+ nn.Conv1d(self.band_width[i] * 2 + 1, self.feature_dim, 1))
+ )
+
+ self.net = []
+ for _ in range(layer):
+ self.net.append(BSNet(self.feature_dim))
+ self.net = nn.Sequential(*self.net)
+
+ self.output = nn.ModuleList([])
+ for i in range(self.nband):
+ self.output.append(nn.Sequential(RMSNorm(self.feature_dim),
+ nn.Conv1d(self.feature_dim, self.band_width[i] * 4, 1),
+ nn.GLU(dim=1)
+ )
+ )
+
+ def spec_band_split(self, input):
+
+ B, nch, nsample = input.shape
+
+ spec = torch.stft(input.view(B * nch, nsample), n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device), return_complex=True)
+
+ subband_spec = []
+ subband_spec_norm = []
+ subband_power = []
+ band_idx = 0
+ for i in range(self.nband):
+ this_spec = spec[:, band_idx:band_idx + self.band_width[i]]
+ subband_spec.append(this_spec) # B, BW, T
+ subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T
+ subband_spec_norm.append(
+ torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) # B, BW, T
+ band_idx += self.band_width[i]
+ subband_power = torch.cat(subband_power, 1) # B, nband, T
+
+ return subband_spec_norm, subband_power
+
+ def feature_extractor(self, input):
+
+ subband_spec_norm, subband_power = self.spec_band_split(input)
+
+ # normalization and bottleneck
+ subband_feature = []
+ for i in range(self.nband):
+ concat_spec = torch.cat(
+ [subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:, i].unsqueeze(1))], 1)
+ subband_feature.append(self.BN[i](concat_spec))
+ subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
+
+ return subband_feature
+
+ def forward(self, input):
+
+ B, nch, nsample = input.shape
+
+ subband_feature = self.feature_extractor(input)
+ feature = self.net(subband_feature)
+
+ est_spec = []
+ for i in range(self.nband):
+ this_RI = self.output[i](feature[:, i]).view(B * nch, 2, self.band_width[i], -1)
+ est_spec.append(torch.complex(this_RI[:, 0], this_RI[:, 1]))
+ est_spec = torch.cat(est_spec, 1)
+ est_spec = est_spec.to(dtype=torch.complex64)
+ output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1)
+
+ return output
+
+ def get_model_args(self):
+ model_args = {"n_sample_rate": 2}
+ return model_args
\ No newline at end of file
diff --git a/models/look2hear/models/base_model.py b/models/look2hear/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e24b23192141634c0496be04899b6ecdb683b6c5
--- /dev/null
+++ b/models/look2hear/models/base_model.py
@@ -0,0 +1,100 @@
+###
+# Author: Kai Li
+# Date: 2021-06-17 23:08:32
+# LastEditors: Please set LastEditors
+# LastEditTime: 2022-05-26 18:06:22
+###
+import torch
+import torch.nn as nn
+
+
+def _unsqueeze_to_3d(x):
+ """Normalize shape of `x` to [batch, n_chan, time]."""
+ if x.ndim == 1:
+ return x.reshape(1, 1, -1)
+ elif x.ndim == 2:
+ return x.unsqueeze(1)
+ else:
+ return x
+
+
+def pad_to_appropriate_length(x, lcm):
+ values_to_pad = int(x.shape[-1]) % lcm
+ if values_to_pad:
+ appropriate_shape = x.shape
+ padded_x = torch.zeros(
+ list(appropriate_shape[:-1])
+ + [appropriate_shape[-1] + lcm - values_to_pad],
+ dtype=torch.float32,
+ ).to(x.device)
+ padded_x[..., : x.shape[-1]] = x
+ return padded_x
+ return x
+
+
+class BaseModel(nn.Module):
+ def __init__(self, sample_rate, in_chan=1):
+ super().__init__()
+ self._sample_rate = sample_rate
+ self._in_chan = in_chan
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def sample_rate(self,):
+ return self._sample_rate
+
+ @staticmethod
+ def load_state_dict_in_audio(model, pretrained_dict):
+ model_dict = model.state_dict()
+ update_dict = {}
+ for k, v in pretrained_dict.items():
+ if "audio_model" in k:
+ update_dict[k[12:]] = v
+ model_dict.update(update_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+ @staticmethod
+ def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs):
+ from . import get
+
+ conf = torch.load(
+ pretrained_model_conf_or_path, map_location="cpu"
+ ) # Attempt to find the model and instantiate it.
+
+ model_class = get(conf["model_name"])
+ # model_class = get("Conv_TasNet")
+ model = model_class(*args, **kwargs)
+ model.load_state_dict(conf["state_dict"])
+ return model
+
+ def apollo(*args, **kwargs):
+ from . import get
+ model_class = get('Apollo')
+ model = model_class(*args, **kwargs)
+ return model
+
+ def serialize(self):
+ import pytorch_lightning as pl # Not used in torch.hub
+
+ model_conf = dict(
+ model_name=self.__class__.__name__,
+ state_dict=self.get_state_dict(),
+ model_args=self.get_model_args(),
+ )
+ # Additional infos
+ infos = dict()
+ infos["software_versions"] = dict(
+ torch_version=torch.__version__, pytorch_lightning_version=pl.__version__,
+ )
+ model_conf["infos"] = infos
+ return model_conf
+
+ def get_state_dict(self):
+ """In case the state dict needs to be modified before sharing the model."""
+ return self.state_dict()
+
+ def get_model_args(self):
+ """Should return args to re-instantiate the class."""
+ raise NotImplementedError
diff --git a/models/mdx23c_tfc_tdf_v3.py b/models/mdx23c_tfc_tdf_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c2d91550a67115d598bbce8558018bcca11acae
--- /dev/null
+++ b/models/mdx23c_tfc_tdf_v3.py
@@ -0,0 +1,242 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from utils import prefer_target_instrument
+
+class STFT:
+ def __init__(self, config):
+ self.n_fft = config.n_fft
+ self.hop_length = config.hop_length
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
+ self.dim_f = config.dim_f
+
+ def __call__(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-2]
+ c, t = x.shape[-2:]
+ x = x.reshape([-1, t])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ return_complex=True
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
+ return x[..., :self.dim_f, :]
+
+ def inverse(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-3]
+ c, f, t = x.shape[-3:]
+ n = self.n_fft // 2 + 1
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
+ x = torch.cat([x, f_pad], -2)
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
+ x = x.permute([0, 2, 3, 1])
+ x = x[..., 0] + x[..., 1] * 1.j
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True)
+ x = x.reshape([*batch_dims, 2, -1])
+ return x
+
+
+def get_norm(norm_type):
+ def norm(c, norm_type):
+ if norm_type == 'BatchNorm':
+ return nn.BatchNorm2d(c)
+ elif norm_type == 'InstanceNorm':
+ return nn.InstanceNorm2d(c, affine=True)
+ elif 'GroupNorm' in norm_type:
+ g = int(norm_type.replace('GroupNorm', ''))
+ return nn.GroupNorm(num_groups=g, num_channels=c)
+ else:
+ return nn.Identity()
+
+ return partial(norm, norm_type=norm_type)
+
+
+def get_act(act_type):
+ if act_type == 'gelu':
+ return nn.GELU()
+ elif act_type == 'relu':
+ return nn.ReLU()
+ elif act_type[:3] == 'elu':
+ alpha = float(act_type.replace('elu', ''))
+ return nn.ELU(alpha)
+ else:
+ raise Exception
+
+
+class Upscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Downscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class TFC_TDF(nn.Module):
+ def __init__(self, in_c, c, l, f, bn, norm, act):
+ super().__init__()
+
+ self.blocks = nn.ModuleList()
+ for i in range(l):
+ block = nn.Module()
+
+ block.tfc1 = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
+ )
+ block.tdf = nn.Sequential(
+ norm(c),
+ act,
+ nn.Linear(f, f // bn, bias=False),
+ norm(c),
+ act,
+ nn.Linear(f // bn, f, bias=False),
+ )
+ block.tfc2 = nn.Sequential(
+ norm(c),
+ act,
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
+ )
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
+
+ self.blocks.append(block)
+ in_c = c
+
+ def forward(self, x):
+ for block in self.blocks:
+ s = block.shortcut(x)
+ x = block.tfc1(x)
+ x = x + block.tdf(x)
+ x = block.tfc2(x)
+ x = x + s
+ return x
+
+
+class TFC_TDF_net(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ norm = get_norm(norm_type=config.model.norm)
+ act = get_act(act_type=config.model.act)
+
+ self.num_target_instruments = len(prefer_target_instrument(config))
+ self.num_subbands = config.model.num_subbands
+
+ dim_c = self.num_subbands * config.audio.num_channels * 2
+ n = config.model.num_scales
+ scale = config.model.scale
+ l = config.model.num_blocks_per_scale
+ c = config.model.num_channels
+ g = config.model.growth
+ bn = config.model.bottleneck_factor
+ f = config.audio.dim_f // self.num_subbands
+
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
+
+ self.encoder_blocks = nn.ModuleList()
+ for i in range(n):
+ block = nn.Module()
+ block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
+ block.downscale = Downscale(c, c + g, scale, norm, act)
+ f = f // scale[1]
+ c += g
+ self.encoder_blocks.append(block)
+
+ self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
+
+ self.decoder_blocks = nn.ModuleList()
+ for i in range(n):
+ block = nn.Module()
+ block.upscale = Upscale(c, c - g, scale, norm, act)
+ f = f * scale[1]
+ c -= g
+ block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
+ self.decoder_blocks.append(block)
+
+ self.final_conv = nn.Sequential(
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
+ act,
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ )
+
+ self.stft = STFT(config.audio)
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, x):
+
+ x = self.stft(x)
+
+ mix = x = self.cac2cws(x)
+
+ first_conv_out = x = self.first_conv(x)
+
+ x = x.transpose(-1, -2)
+
+ encoder_outputs = []
+ for block in self.encoder_blocks:
+ x = block.tfc_tdf(x)
+ encoder_outputs.append(x)
+ x = block.downscale(x)
+
+ x = self.bottleneck_block(x)
+
+ for block in self.decoder_blocks:
+ x = block.upscale(x)
+ x = torch.cat([x, encoder_outputs.pop()], 1)
+ x = block.tfc_tdf(x)
+
+ x = x.transpose(-1, -2)
+
+ x = x * first_conv_out # reduce artifacts
+
+ x = self.final_conv(torch.cat([mix, x], 1))
+
+ x = self.cws2cac(x)
+
+ if self.num_target_instruments > 1:
+ b, c, f, t = x.shape
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
+
+ x = self.stft.inverse(x)
+
+ return x
diff --git a/models/mdx23c_tfc_tdf_v3_with_STHT.py b/models/mdx23c_tfc_tdf_v3_with_STHT.py
new file mode 100644
index 0000000000000000000000000000000000000000..3261ec09bae3abdd29f5a66fe4ecf58074574833
--- /dev/null
+++ b/models/mdx23c_tfc_tdf_v3_with_STHT.py
@@ -0,0 +1,315 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from utils import prefer_target_instrument
+
+
+class ShortTimeHartleyTransform:
+ def __init__(self, *, n_fft: int, hop_length: int, center: bool = True,
+ pad_mode: str = "reflect") -> None:
+ self.n_fft = n_fft
+ self.hop_length = hop_length
+ self.center = center
+ self.pad_mode = pad_mode
+ self.window = torch.hamming_window(self.n_fft)
+
+ @staticmethod
+ def _hartley_transform(x: torch.Tensor) -> torch.Tensor:
+ fft = torch.fft.fft(x)
+ return fft.real - fft.imag
+
+ @staticmethod
+ def _inverse_hartley_transform(X: torch.Tensor) -> torch.Tensor:
+ N = X.size(-1)
+ return ShortTimeHartleyTransform._hartley_transform(X) / N
+
+ def transform(self, *, signal: torch.Tensor) -> torch.Tensor:
+ assert signal.dim() == 3, "Signal must be a 3D tensor (batch_size, channel, samples)"
+ self.window = self.window.to(signal.device)
+ batch_size, channels, samples = signal.shape
+
+ # Apply padding if center=True
+ if self.center:
+ pad_length = self.n_fft // 2
+ signal = F.pad(signal, (pad_length, pad_length), mode=self.pad_mode)
+ else:
+ pad_length = 0
+
+ # print(
+ # f"samples={samples}\n"
+ # f"self.hop_length={self.hop_length}\n"
+ # f"pad_length={pad_length}\n"
+ # f"signal_padded={signal.size(2)}"
+ # )
+
+ # Compute number of frames
+ num_frames = (signal.size(2) - self.n_fft) // self.hop_length + 1
+
+ # Apply window and compute Hartley transform
+ window = self.window.to(signal.device, signal.dtype).unsqueeze(0).unsqueeze(0)
+ stht_coeffs = []
+
+ for i in range(num_frames):
+ start = i * self.hop_length
+ end = start + self.n_fft
+ frame = signal[:, :, start:end] * window
+ stht_coeffs.append(self._hartley_transform(frame))
+
+ return torch.stack(stht_coeffs, dim=-1)
+
+ def inverse_transform(self, *, stht_coeffs: torch.Tensor, length: int) -> torch.Tensor:
+ self.window = self.window.to(stht_coeffs.device)
+ # print(stht_coeffs.shape)
+ batch_size, channels, n_fft, num_frames = stht_coeffs.shape
+ signal_length = length
+
+ # Initialize reconstruction
+ reconstructed_signal = torch.zeros((batch_size, channels, signal_length + (self.n_fft if self.center else 0)),
+ device=stht_coeffs.device, dtype=stht_coeffs.dtype)
+ normalization = torch.zeros(signal_length + (self.n_fft if self.center else 0),
+ device=stht_coeffs.device, dtype=stht_coeffs.dtype)
+
+ window = self.window.to(stht_coeffs.device, stht_coeffs.dtype).unsqueeze(0).unsqueeze(0)
+
+ for i in range(num_frames):
+ start = i * self.hop_length
+ end = start + self.n_fft
+
+ # Reconstruct frame and add to signal
+ frame = self._inverse_hartley_transform(stht_coeffs[:, :, :, i]) * window
+ reconstructed_signal[:, :, start:end] += frame
+ normalization[start:end] += (window ** 2).squeeze()
+
+ # Normalize the overlapping regions
+ eps = torch.finfo(normalization.dtype).eps
+ normalization = torch.clamp(normalization, min=eps)
+ reconstructed_signal /= normalization.unsqueeze(0).unsqueeze(0)
+
+ # Remove padding if center=True
+ if self.center:
+ pad_length = self.n_fft // 2
+ reconstructed_signal = reconstructed_signal[:, :, pad_length:-pad_length]
+
+ # Trim to the specified length
+ return reconstructed_signal[:, :, :signal_length]
+
+
+def get_norm(norm_type):
+ def norm(c, norm_type):
+ if norm_type == 'BatchNorm':
+ return nn.BatchNorm2d(c)
+ elif norm_type == 'InstanceNorm':
+ return nn.InstanceNorm2d(c, affine=True)
+ elif 'GroupNorm' in norm_type:
+ g = int(norm_type.replace('GroupNorm', ''))
+ return nn.GroupNorm(num_groups=g, num_channels=c)
+ else:
+ return nn.Identity()
+
+ return partial(norm, norm_type=norm_type)
+
+
+def get_act(act_type):
+ if act_type == 'gelu':
+ return nn.GELU()
+ elif act_type == 'relu':
+ return nn.ReLU()
+ elif act_type[:3] == 'elu':
+ alpha = float(act_type.replace('elu', ''))
+ return nn.ELU(alpha)
+ else:
+ raise Exception
+
+
+class Upscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Downscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class TFC_TDF(nn.Module):
+ def __init__(self, in_c, c, l, f, bn, norm, act):
+ super().__init__()
+
+ self.blocks = nn.ModuleList()
+ for i in range(l):
+ block = nn.Module()
+
+ block.tfc1 = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
+ )
+ block.tdf = nn.Sequential(
+ norm(c),
+ act,
+ nn.Linear(f, f // bn, bias=False),
+ norm(c),
+ act,
+ nn.Linear(f // bn, f, bias=False),
+ )
+ block.tfc2 = nn.Sequential(
+ norm(c),
+ act,
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
+ )
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
+
+ self.blocks.append(block)
+ in_c = c
+
+ def forward(self, x):
+ for block in self.blocks:
+ s = block.shortcut(x)
+ x = block.tfc1(x)
+ x = x + block.tdf(x)
+ x = block.tfc2(x)
+ x = x + s
+ return x
+
+
+class TFC_TDF_net(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ norm = get_norm(norm_type=config.model.norm)
+ act = get_act(act_type=config.model.act)
+
+ self.num_target_instruments = len(prefer_target_instrument(config))
+ self.num_subbands = config.model.num_subbands
+
+ # dim_c = self.num_subbands * config.audio.num_channels * 2
+ dim_c = self.num_subbands * config.audio.num_channels
+ n = config.model.num_scales
+ scale = config.model.scale
+ l = config.model.num_blocks_per_scale
+ c = config.model.num_channels
+ g = config.model.growth
+ bn = config.model.bottleneck_factor
+ f = config.audio.dim_f // (self.num_subbands // 2)
+
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
+
+ self.encoder_blocks = nn.ModuleList()
+ for i in range(n):
+ block = nn.Module()
+ block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act)
+ block.downscale = Downscale(c, c + g, scale, norm, act)
+ f = f // scale[1]
+ c += g
+ self.encoder_blocks.append(block)
+
+ self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act)
+
+ self.decoder_blocks = nn.ModuleList()
+ for i in range(n):
+ block = nn.Module()
+ block.upscale = Upscale(c, c - g, scale, norm, act)
+ f = f * scale[1]
+ c -= g
+ block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act)
+ self.decoder_blocks.append(block)
+
+ self.final_conv = nn.Sequential(
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
+ act,
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ )
+
+ self.stft = ShortTimeHartleyTransform(n_fft=config.audio.n_fft, hop_length=config.audio.hop_length)
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, x):
+ length = x.shape[-1]
+ # print(x.shape)
+ x = self.stft.transform(signal=x)
+ # print(x.shape)
+
+ mix = x = self.cac2cws(x)
+
+ # print(x.shape)
+
+ first_conv_out = x = self.first_conv(x)
+
+ # print(x.shape)
+
+ x = x.transpose(-1, -2)
+
+ # print(x.shape)
+
+ encoder_outputs = []
+ for block in self.encoder_blocks:
+ # print(x.shape)
+ x = block.tfc_tdf(x)
+ # print(x.shape)
+ encoder_outputs.append(x)
+ x = block.downscale(x)
+ # print(x.shape)
+
+ x = self.bottleneck_block(x)
+ # print(x.shape)
+
+ for block in self.decoder_blocks:
+ # print(x.shape)
+ x = block.upscale(x)
+ # print(x.shape)
+ x = torch.cat([x, encoder_outputs.pop()], 1)
+ # print(x.shape)
+ x = block.tfc_tdf(x)
+ # print(x.shape)
+
+ x = x.transpose(-1, -2)
+ # print(x.shape)
+
+ x = x * first_conv_out # reduce artifacts
+
+ # print(x.shape)
+
+ x = self.final_conv(torch.cat([mix, x], 1))
+
+ x = self.cws2cac(x)
+
+ if self.num_target_instruments > 1:
+ b, c, f, t = x.shape
+ x = x.reshape(b * self.num_target_instruments, -1, f, t)
+ x = self.stft.inverse_transform(stht_coeffs=x, length=length)
+ x = x.reshape(b, self.num_target_instruments, x.shape[-2], x.shape[-1])
+ else:
+ x = self.stft.inverse_transform(stht_coeffs=x, length=length)
+ # print("!!!", x.shape)
+ return x
diff --git a/models/scnet/__init__.py b/models/scnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6ecefede9345237623066dd21ebd8253af1c60
--- /dev/null
+++ b/models/scnet/__init__.py
@@ -0,0 +1 @@
+from .scnet import SCNet
diff --git a/models/scnet/scnet.py b/models/scnet/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b27704dc922eb593dc76f3b9905aa8c0ea02507f
--- /dev/null
+++ b/models/scnet/scnet.py
@@ -0,0 +1,373 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import deque
+from .separation import SeparationNet
+import typing as tp
+import math
+
+
+class Swish(nn.Module):
+ def forward(self, x):
+ return x * x.sigmoid()
+
+
+class ConvolutionModule(nn.Module):
+ """
+ Convolution Module in SD block.
+
+ Args:
+ channels (int): input/output channels.
+ depth (int): number of layers in the residual branch. Each layer has its own
+ compress (float): amount of channel compression.
+ kernel (int): kernel size for the convolutions.
+ """
+
+ def __init__(self, channels, depth=2, compress=4, kernel=3):
+ super().__init__()
+ assert kernel % 2 == 1
+ self.depth = abs(depth)
+ hidden_size = int(channels / compress)
+ norm = lambda d: nn.GroupNorm(1, d)
+ self.layers = nn.ModuleList([])
+ for _ in range(self.depth):
+ padding = (kernel // 2)
+ mods = [
+ norm(channels),
+ nn.Conv1d(channels, hidden_size * 2, kernel, padding=padding),
+ nn.GLU(1),
+ nn.Conv1d(hidden_size, hidden_size, kernel, padding=padding, groups=hidden_size),
+ norm(hidden_size),
+ Swish(),
+ nn.Conv1d(hidden_size, channels, 1),
+ ]
+ layer = nn.Sequential(*mods)
+ self.layers.append(layer)
+
+ def forward(self, x):
+ for layer in self.layers:
+ x = x + layer(x)
+ return x
+
+
+class FusionLayer(nn.Module):
+ """
+ A FusionLayer within the decoder.
+
+ Args:
+ - channels (int): Number of input channels.
+ - kernel_size (int, optional): Kernel size for the convolutional layer, defaults to 3.
+ - stride (int, optional): Stride for the convolutional layer, defaults to 1.
+ - padding (int, optional): Padding for the convolutional layer, defaults to 1.
+ """
+
+ def __init__(self, channels, kernel_size=3, stride=1, padding=1):
+ super(FusionLayer, self).__init__()
+ self.conv = nn.Conv2d(channels * 2, channels * 2, kernel_size, stride=stride, padding=padding)
+
+ def forward(self, x, skip=None):
+ if skip is not None:
+ x += skip
+ x = x.repeat(1, 2, 1, 1)
+ x = self.conv(x)
+ x = F.glu(x, dim=1)
+ return x
+
+
+class SDlayer(nn.Module):
+ """
+ Implements a Sparse Down-sample Layer for processing different frequency bands separately.
+
+ Args:
+ - channels_in (int): Input channel count.
+ - channels_out (int): Output channel count.
+ - band_configs (dict): A dictionary containing configuration for each frequency band.
+ Keys are 'low', 'mid', 'high' for each band, and values are
+ dictionaries with keys 'SR', 'stride', and 'kernel' for proportion,
+ stride, and kernel size, respectively.
+ """
+
+ def __init__(self, channels_in, channels_out, band_configs):
+ super(SDlayer, self).__init__()
+
+ # Initializing convolutional layers for each band
+ self.convs = nn.ModuleList()
+ self.strides = []
+ self.kernels = []
+ for config in band_configs.values():
+ self.convs.append(
+ nn.Conv2d(channels_in, channels_out, (config['kernel'], 1), (config['stride'], 1), (0, 0)))
+ self.strides.append(config['stride'])
+ self.kernels.append(config['kernel'])
+
+ # Saving rate proportions for determining splits
+ self.SR_low = band_configs['low']['SR']
+ self.SR_mid = band_configs['mid']['SR']
+
+ def forward(self, x):
+ B, C, Fr, T = x.shape
+ # Define splitting points based on sampling rates
+ splits = [
+ (0, math.ceil(Fr * self.SR_low)),
+ (math.ceil(Fr * self.SR_low), math.ceil(Fr * (self.SR_low + self.SR_mid))),
+ (math.ceil(Fr * (self.SR_low + self.SR_mid)), Fr)
+ ]
+
+ # Processing each band with the corresponding convolution
+ outputs = []
+ original_lengths = []
+ for conv, stride, kernel, (start, end) in zip(self.convs, self.strides, self.kernels, splits):
+ extracted = x[:, :, start:end, :]
+ original_lengths.append(end - start)
+ current_length = extracted.shape[2]
+
+ # padding
+ if stride == 1:
+ total_padding = kernel - stride
+ else:
+ total_padding = (stride - current_length % stride) % stride
+ pad_left = total_padding // 2
+ pad_right = total_padding - pad_left
+
+ padded = F.pad(extracted, (0, 0, pad_left, pad_right))
+
+ output = conv(padded)
+ outputs.append(output)
+
+ return outputs, original_lengths
+
+
+class SUlayer(nn.Module):
+ """
+ Implements a Sparse Up-sample Layer in decoder.
+
+ Args:
+ - channels_in: The number of input channels.
+ - channels_out: The number of output channels.
+ - convtr_configs: Dictionary containing the configurations for transposed convolutions.
+ """
+
+ def __init__(self, channels_in, channels_out, band_configs):
+ super(SUlayer, self).__init__()
+
+ # Initializing convolutional layers for each band
+ self.convtrs = nn.ModuleList([
+ nn.ConvTranspose2d(channels_in, channels_out, [config['kernel'], 1], [config['stride'], 1])
+ for _, config in band_configs.items()
+ ])
+
+ def forward(self, x, lengths, origin_lengths):
+ B, C, Fr, T = x.shape
+ # Define splitting points based on input lengths
+ splits = [
+ (0, lengths[0]),
+ (lengths[0], lengths[0] + lengths[1]),
+ (lengths[0] + lengths[1], None)
+ ]
+ # Processing each band with the corresponding convolution
+ outputs = []
+ for idx, (convtr, (start, end)) in enumerate(zip(self.convtrs, splits)):
+ out = convtr(x[:, :, start:end, :])
+ # Calculate the distance to trim the output symmetrically to original length
+ current_Fr_length = out.shape[2]
+ dist = abs(origin_lengths[idx] - current_Fr_length) // 2
+
+ # Trim the output to the original length symmetrically
+ trimmed_out = out[:, :, dist:dist + origin_lengths[idx], :]
+
+ outputs.append(trimmed_out)
+
+ # Concatenate trimmed outputs along the frequency dimension to return the final tensor
+ x = torch.cat(outputs, dim=2)
+
+ return x
+
+
+class SDblock(nn.Module):
+ """
+ Implements a simplified Sparse Down-sample block in encoder.
+
+ Args:
+ - channels_in (int): Number of input channels.
+ - channels_out (int): Number of output channels.
+ - band_config (dict): Configuration for the SDlayer specifying band splits and convolutions.
+ - conv_config (dict): Configuration for convolution modules applied to each band.
+ - depths (list of int): List specifying the convolution depths for low, mid, and high frequency bands.
+ """
+
+ def __init__(self, channels_in, channels_out, band_configs={}, conv_config={}, depths=[3, 2, 1], kernel_size=3):
+ super(SDblock, self).__init__()
+ self.SDlayer = SDlayer(channels_in, channels_out, band_configs)
+
+ # Dynamically create convolution modules for each band based on depths
+ self.conv_modules = nn.ModuleList([
+ ConvolutionModule(channels_out, depth, **conv_config) for depth in depths
+ ])
+ # Set the kernel_size to an odd number.
+ self.globalconv = nn.Conv2d(channels_out, channels_out, kernel_size, 1, (kernel_size - 1) // 2)
+
+ def forward(self, x):
+ bands, original_lengths = self.SDlayer(x)
+ # B, C, f, T = band.shape
+ bands = [
+ F.gelu(
+ conv(band.permute(0, 2, 1, 3).reshape(-1, band.shape[1], band.shape[3]))
+ .view(band.shape[0], band.shape[2], band.shape[1], band.shape[3])
+ .permute(0, 2, 1, 3)
+ )
+ for conv, band in zip(self.conv_modules, bands)
+
+ ]
+ lengths = [band.size(-2) for band in bands]
+ full_band = torch.cat(bands, dim=2)
+ skip = full_band
+
+ output = self.globalconv(full_band)
+
+ return output, skip, lengths, original_lengths
+
+
+class SCNet(nn.Module):
+ """
+ The implementation of SCNet: Sparse Compression Network for Music Source Separation. Paper: https://arxiv.org/abs/2401.13276.pdf
+
+ Args:
+ - sources (List[str]): List of sources to be separated.
+ - audio_channels (int): Number of audio channels.
+ - nfft (int): Number of FFTs to determine the frequency dimension of the input.
+ - hop_size (int): Hop size for the STFT.
+ - win_size (int): Window size for STFT.
+ - normalized (bool): Whether to normalize the STFT.
+ - dims (List[int]): List of channel dimensions for each block.
+ - band_SR (List[float]): The proportion of each frequency band.
+ - band_stride (List[int]): The down-sampling ratio of each frequency band.
+ - band_kernel (List[int]): The kernel sizes for down-sampling convolution in each frequency band
+ - conv_depths (List[int]): List specifying the number of convolution modules in each SD block.
+ - compress (int): Compression factor for convolution module.
+ - conv_kernel (int): Kernel size for convolution layer in convolution module.
+ - num_dplayer (int): Number of dual-path layers.
+ - expand (int): Expansion factor in the dual-path RNN, default is 1.
+
+ """
+
+ def __init__(self,
+ sources=['drums', 'bass', 'other', 'vocals'],
+ audio_channels=2,
+ # Main structure
+ dims=[4, 32, 64, 128], # dims = [4, 64, 128, 256] in SCNet-large
+ # STFT
+ nfft=4096,
+ hop_size=1024,
+ win_size=4096,
+ normalized=True,
+ # SD/SU layer
+ band_SR=[0.175, 0.392, 0.433],
+ band_stride=[1, 4, 16],
+ band_kernel=[3, 4, 16],
+ # Convolution Module
+ conv_depths=[3, 2, 1],
+ compress=4,
+ conv_kernel=3,
+ # Dual-path RNN
+ num_dplayer=6,
+ expand=1,
+ ):
+ super().__init__()
+ self.sources = sources
+ self.audio_channels = audio_channels
+ self.dims = dims
+ band_keys = ['low', 'mid', 'high']
+ self.band_configs = {band_keys[i]: {'SR': band_SR[i], 'stride': band_stride[i], 'kernel': band_kernel[i]} for i
+ in range(len(band_keys))}
+ self.hop_length = hop_size
+ self.conv_config = {
+ 'compress': compress,
+ 'kernel': conv_kernel,
+ }
+
+ self.stft_config = {
+ 'n_fft': nfft,
+ 'hop_length': hop_size,
+ 'win_length': win_size,
+ 'center': True,
+ 'normalized': normalized
+ }
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for index in range(len(dims) - 1):
+ enc = SDblock(
+ channels_in=dims[index],
+ channels_out=dims[index + 1],
+ band_configs=self.band_configs,
+ conv_config=self.conv_config,
+ depths=conv_depths
+ )
+ self.encoder.append(enc)
+
+ dec = nn.Sequential(
+ FusionLayer(channels=dims[index + 1]),
+ SUlayer(
+ channels_in=dims[index + 1],
+ channels_out=dims[index] if index != 0 else dims[index] * len(sources),
+ band_configs=self.band_configs,
+ )
+ )
+ self.decoder.insert(0, dec)
+
+ self.separation_net = SeparationNet(
+ channels=dims[-1],
+ expand=expand,
+ num_layers=num_dplayer,
+ )
+
+ def forward(self, x):
+ # B, C, L = x.shape
+ B = x.shape[0]
+ # In the initial padding, ensure that the number of frames after the STFT (the length of the T dimension) is even,
+ # so that the RFFT operation can be used in the separation network.
+ padding = self.hop_length - x.shape[-1] % self.hop_length
+ if (x.shape[-1] + padding) // self.hop_length % 2 == 0:
+ padding += self.hop_length
+ x = F.pad(x, (0, padding))
+
+ # STFT
+ L = x.shape[-1]
+ x = x.reshape(-1, L)
+ x = torch.stft(x, **self.stft_config, return_complex=True)
+ x = torch.view_as_real(x)
+ x = x.permute(0, 3, 1, 2).reshape(x.shape[0] // self.audio_channels, x.shape[3] * self.audio_channels,
+ x.shape[1], x.shape[2])
+
+ B, C, Fr, T = x.shape
+
+ save_skip = deque()
+ save_lengths = deque()
+ save_original_lengths = deque()
+ # encoder
+ for sd_layer in self.encoder:
+ x, skip, lengths, original_lengths = sd_layer(x)
+ save_skip.append(skip)
+ save_lengths.append(lengths)
+ save_original_lengths.append(original_lengths)
+
+ # separation
+ x = self.separation_net(x)
+
+ # decoder
+ for fusion_layer, su_layer in self.decoder:
+ x = fusion_layer(x, save_skip.pop())
+ x = su_layer(x, save_lengths.pop(), save_original_lengths.pop())
+
+ # output
+ n = self.dims[0]
+ x = x.view(B, n, -1, Fr, T)
+ x = x.reshape(-1, 2, Fr, T).permute(0, 2, 3, 1)
+ x = torch.view_as_complex(x.contiguous())
+ x = torch.istft(x, **self.stft_config)
+ x = x.reshape(B, len(self.sources), self.audio_channels, -1)
+
+ x = x[:, :, :, :-padding]
+
+ return x
diff --git a/models/scnet/separation.py b/models/scnet/separation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d902dac4d947123d3ba1270dd065be0d8b4c5ed9
--- /dev/null
+++ b/models/scnet/separation.py
@@ -0,0 +1,113 @@
+import torch
+import torch.nn as nn
+from torch.nn.modules.rnn import LSTM
+
+
+class FeatureConversion(nn.Module):
+ """
+ Integrates into the adjacent Dual-Path layer.
+
+ Args:
+ channels (int): Number of input channels.
+ inverse (bool): If True, uses ifft; otherwise, uses rfft.
+ """
+
+ def __init__(self, channels, inverse):
+ super().__init__()
+ self.inverse = inverse
+ self.channels = channels
+
+ def forward(self, x):
+ # B, C, F, T = x.shape
+ if self.inverse:
+ x = x.float()
+ x_r = x[:, :self.channels // 2, :, :]
+ x_i = x[:, self.channels // 2:, :, :]
+ x = torch.complex(x_r, x_i)
+ x = torch.fft.irfft(x, dim=3, norm="ortho")
+ else:
+ x = x.float()
+ x = torch.fft.rfft(x, dim=3, norm="ortho")
+ x_real = x.real
+ x_imag = x.imag
+ x = torch.cat([x_real, x_imag], dim=1)
+ return x
+
+
+class DualPathRNN(nn.Module):
+ """
+ Dual-Path RNN in Separation Network.
+
+ Args:
+ d_model (int): The number of expected features in the input (input_size).
+ expand (int): Expansion factor used to calculate the hidden_size of LSTM.
+ bidirectional (bool): If True, becomes a bidirectional LSTM.
+ """
+
+ def __init__(self, d_model, expand, bidirectional=True):
+ super(DualPathRNN, self).__init__()
+
+ self.d_model = d_model
+ self.hidden_size = d_model * expand
+ self.bidirectional = bidirectional
+ # Initialize LSTM layers and normalization layers
+ self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)])
+ self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)])
+ self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)])
+
+ def _init_lstm_layer(self, d_model, hidden_size):
+ return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True)
+
+ def forward(self, x):
+ B, C, F, T = x.shape
+
+ # Process dual-path rnn
+ original_x = x
+ # Frequency-path
+ x = self.norm_layers[0](x)
+ x = x.transpose(1, 3).contiguous().view(B * T, F, C)
+ x, _ = self.lstm_layers[0](x)
+ x = self.linear_layers[0](x)
+ x = x.view(B, T, F, C).transpose(1, 3)
+ x = x + original_x
+
+ original_x = x
+ # Time-path
+ x = self.norm_layers[1](x)
+ x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2)
+ x, _ = self.lstm_layers[1](x)
+ x = self.linear_layers[1](x)
+ x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2)
+ x = x + original_x
+
+ return x
+
+
+class SeparationNet(nn.Module):
+ """
+ Implements a simplified Sparse Down-sample block in an encoder architecture.
+
+ Args:
+ - channels (int): Number input channels.
+ - expand (int): Expansion factor used to calculate the hidden_size of LSTM.
+ - num_layers (int): Number of dual-path layers.
+ """
+
+ def __init__(self, channels, expand=1, num_layers=6):
+ super(SeparationNet, self).__init__()
+
+ self.num_layers = num_layers
+
+ self.dp_modules = nn.ModuleList([
+ DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers)
+ ])
+
+ self.feature_conversion = nn.ModuleList([
+ FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) for i in range(num_layers)
+ ])
+
+ def forward(self, x):
+ for i in range(self.num_layers):
+ x = self.dp_modules[i](x)
+ x = self.feature_conversion[i](x)
+ return x
diff --git a/models/scnet_unofficial/__init__.py b/models/scnet_unofficial/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d034d38a2ace2e81bd28d63dd8f25feb918f33d
--- /dev/null
+++ b/models/scnet_unofficial/__init__.py
@@ -0,0 +1 @@
+from models.scnet_unofficial.scnet import SCNet
\ No newline at end of file
diff --git a/models/scnet_unofficial/modules/__init__.py b/models/scnet_unofficial/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..69617bb15044d9bbfd0211fcdfa0fa605b01c048
--- /dev/null
+++ b/models/scnet_unofficial/modules/__init__.py
@@ -0,0 +1,3 @@
+from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN
+from models.scnet_unofficial.modules.sd_encoder import SDBlock
+from models.scnet_unofficial.modules.su_decoder import SUBlock
diff --git a/models/scnet_unofficial/modules/dualpath_rnn.py b/models/scnet_unofficial/modules/dualpath_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dfcdbcfc102a6fde5a2ff53a2a06f2d6caae196
--- /dev/null
+++ b/models/scnet_unofficial/modules/dualpath_rnn.py
@@ -0,0 +1,228 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as Func
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return Func.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+class MambaModule(nn.Module):
+ def __init__(self, d_model, d_state, d_conv, d_expand):
+ super().__init__()
+ self.norm = RMSNorm(dim=d_model)
+ self.mamba = Mamba(
+ d_model=d_model,
+ d_state=d_state,
+ d_conv=d_conv,
+ d_expand=d_expand
+ )
+
+ def forward(self, x):
+ x = x + self.mamba(self.norm(x))
+ return x
+
+
+class RNNModule(nn.Module):
+ """
+ RNNModule class implements a recurrent neural network module with LSTM cells.
+
+ Args:
+ - input_dim (int): Dimensionality of the input features.
+ - hidden_dim (int): Dimensionality of the hidden state of the LSTM.
+ - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.
+
+ Shapes:
+ - Input: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ - Output: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ """
+
+ def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
+ """
+ Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.
+ """
+ super().__init__()
+ self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
+ self.rnn = nn.LSTM(
+ input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
+ )
+ self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the RNNModule.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, T, D).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, T, D).
+ """
+ x = x.transpose(1, 2)
+ x = self.groupnorm(x)
+ x = x.transpose(1, 2)
+
+ x, (hidden, _) = self.rnn(x)
+ x = self.fc(x)
+ return x
+
+
+class RFFTModule(nn.Module):
+ """
+ RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)
+ or its inverse on input tensors.
+
+ Args:
+ - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.
+
+ Shapes:
+ - Input: (B, F, T, D) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ D is input dimensionality.
+ - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.
+ (B, F, T, D // 2, 2) if performing inverse FFT.
+ """
+
+ def __init__(self, inverse: bool = False):
+ """
+ Initializes RFFTModule with inverse flag.
+ """
+ super().__init__()
+ self.inverse = inverse
+
+ def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
+ """
+ Performs forward or inverse FFT on the input tensor x.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, F, T, D).
+ - time_dim (int): Input size of time dimension.
+
+ Returns:
+ - torch.Tensor: Output tensor after FFT or its inverse operation.
+ """
+ dtype = x.dtype
+ B, F, T, D = x.shape
+
+ # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision
+ x = x.float()
+
+ if not self.inverse:
+ x = torch.fft.rfft(x, dim=2)
+ x = torch.view_as_real(x)
+ x = x.reshape(B, F, T // 2 + 1, D * 2)
+ else:
+ x = x.reshape(B, F, T, D // 2, 2)
+ x = torch.view_as_complex(x)
+ x = torch.fft.irfft(x, n=time_dim, dim=2)
+
+ x = x.to(dtype)
+ return x
+
+ def extra_repr(self) -> str:
+ """
+ Returns extra representation string with module's configuration.
+ """
+ return f"inverse={self.inverse}"
+
+
+class DualPathRNN(nn.Module):
+ """
+ DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.
+
+ Args:
+ - n_layers (int): Number of layers in the network.
+ - input_dim (int): Dimensionality of the input features.
+ - hidden_dim (int): Dimensionality of the hidden state of the RNNModule.
+
+ Shapes:
+ - Input: (B, F, T, D) where
+ B is batch size,
+ F is the number of features (frequency dimension),
+ T is sequence length (time dimension),
+ D is input dimensionality (channel dimension).
+ - Output: (B, F, T, D) where
+ B is batch size,
+ F is the number of features (frequency dimension),
+ T is sequence length (time dimension),
+ D is input dimensionality (channel dimension).
+ """
+
+ def __init__(
+ self,
+ n_layers: int,
+ input_dim: int,
+ hidden_dim: int,
+
+ use_mamba: bool = False,
+ d_state: int = 16,
+ d_conv: int = 4,
+ d_expand: int = 2
+ ):
+ """
+ Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
+ """
+ super().__init__()
+
+ if use_mamba:
+ from mamba_ssm.modules.mamba_simple import Mamba
+ net = MambaModule
+ dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand}
+ ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2}
+ else:
+ net = RNNModule
+ dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim}
+ ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2}
+
+ self.layers = nn.ModuleList()
+ for i in range(1, n_layers + 1):
+ kwargs = dkwargs if i % 2 == 1 else ukwargs
+ layer = nn.ModuleList([
+ net(**kwargs),
+ net(**kwargs),
+ RFFTModule(inverse=(i % 2 == 0)),
+ ])
+ self.layers.append(layer)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the DualPathRNN.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, F, T, D).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, F, T, D).
+ """
+
+ time_dim = x.shape[2]
+
+ for time_layer, freq_layer, rfft_layer in self.layers:
+ B, F, T, D = x.shape
+
+ x = x.reshape((B * F), T, D)
+ x = time_layer(x)
+ x = x.reshape(B, F, T, D)
+ x = x.permute(0, 2, 1, 3)
+
+ x = x.reshape((B * T), F, D)
+ x = freq_layer(x)
+ x = x.reshape(B, T, F, D)
+ x = x.permute(0, 2, 1, 3)
+
+ x = rfft_layer(x, time_dim)
+
+ return x
diff --git a/models/scnet_unofficial/modules/sd_encoder.py b/models/scnet_unofficial/modules/sd_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..742577f480693671437dc50358a1a65d251b6e9b
--- /dev/null
+++ b/models/scnet_unofficial/modules/sd_encoder.py
@@ -0,0 +1,285 @@
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from models.scnet_unofficial.utils import create_intervals
+
+
+class Downsample(nn.Module):
+ """
+ Downsample class implements a module for downsampling input tensors using 2D convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - stride (int): Stride value for the convolution operation.
+
+ Shapes:
+ - Input: (B, C_in, F, T) where
+ B is batch size,
+ C_in is the number of input channels,
+ F is the frequency dimension,
+ T is the time dimension.
+ - Output: (B, C_out, F // stride, T) where
+ B is batch size,
+ C_out is the number of output channels,
+ F // stride is the downsampled frequency dimension.
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ stride: int,
+ ):
+ """
+ Initializes Downsample with input dimension, output dimension, and stride.
+ """
+ super().__init__()
+ self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the Downsample module.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
+
+ Returns:
+ - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T).
+ """
+ return self.conv(x)
+
+
+class ConvolutionModule(nn.Module):
+ """
+ ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer.
+
+ Args:
+ - input_dim (int): Dimensionality of the input features.
+ - hidden_dim (int): Dimensionality of the hidden features.
+ - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
+ - bias (bool, optional): If True, adds a learnable bias to the output. Default is False.
+
+ Shapes:
+ - Input: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ - Output: (B, T, D) where
+ B is batch size,
+ T is sequence length,
+ D is input dimensionality.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dim: int,
+ kernel_sizes: List[int],
+ bias: bool = False,
+ ) -> None:
+ """
+ Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias.
+ """
+ super().__init__()
+ self.sequential = nn.Sequential(
+ nn.GroupNorm(num_groups=1, num_channels=input_dim),
+ nn.Conv1d(
+ input_dim,
+ 2 * hidden_dim,
+ kernel_sizes[0],
+ stride=1,
+ padding=(kernel_sizes[0] - 1) // 2,
+ bias=bias,
+ ),
+ nn.GLU(dim=1),
+ nn.Conv1d(
+ hidden_dim,
+ hidden_dim,
+ kernel_sizes[1],
+ stride=1,
+ padding=(kernel_sizes[1] - 1) // 2,
+ groups=hidden_dim,
+ bias=bias,
+ ),
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
+ nn.SiLU(),
+ nn.Conv1d(
+ hidden_dim,
+ input_dim,
+ kernel_sizes[2],
+ stride=1,
+ padding=(kernel_sizes[2] - 1) // 2,
+ bias=bias,
+ ),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the ConvolutionModule.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, T, D).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, T, D).
+ """
+ x = x.transpose(1, 2)
+ x = x + self.sequential(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class SDLayer(nn.Module):
+ """
+ SDLayer class implements a subband decomposition layer with downsampling and convolutional modules.
+
+ Args:
+ - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition.
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels after downsampling.
+ - downsample_stride (int): Stride value for the downsampling operation.
+ - n_conv_modules (int): Number of convolutional modules.
+ - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
+ - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True.
+
+ Shapes:
+ - Input: (B, Fi, T, Ci) where
+ B is batch size,
+ Fi is the number of input subbands,
+ T is sequence length, and
+ Ci is the number of input channels.
+ - Output: (B, Fi+1, T, Ci+1) where
+ B is batch size,
+ Fi+1 is the number of output subbands,
+ T is sequence length,
+ Ci+1 is the number of output channels.
+ """
+
+ def __init__(
+ self,
+ subband_interval: Tuple[float, float],
+ input_dim: int,
+ output_dim: int,
+ downsample_stride: int,
+ n_conv_modules: int,
+ kernel_sizes: List[int],
+ bias: bool = True,
+ ):
+ """
+ Initializes SDLayer with subband interval, input dimension,
+ output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias.
+ """
+ super().__init__()
+ self.subband_interval = subband_interval
+ self.downsample = Downsample(input_dim, output_dim, downsample_stride)
+ self.activation = nn.GELU()
+ conv_modules = [
+ ConvolutionModule(
+ input_dim=output_dim,
+ hidden_dim=output_dim // 4,
+ kernel_sizes=kernel_sizes,
+ bias=bias,
+ )
+ for _ in range(n_conv_modules)
+ ]
+ self.conv_modules = nn.Sequential(*conv_modules)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SDLayer.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1).
+ """
+ B, F, T, C = x.shape
+ x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)]
+ x = x.permute(0, 3, 1, 2)
+ x = self.downsample(x)
+ x = self.activation(x)
+ x = x.permute(0, 2, 3, 1)
+
+ B, F, T, C = x.shape
+ x = x.reshape((B * F), T, C)
+ x = self.conv_modules(x)
+ x = x.reshape(B, F, T, C)
+
+ return x
+
+
+class SDBlock(nn.Module):
+ """
+ SDBlock class implements a block with subband decomposition layers and global convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
+ - downsample_strides (List[int]): List of stride values for downsampling in each subband layer.
+ - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer.
+ - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None.
+
+ Shapes:
+ - Input: (B, Fi, T, Ci) where
+ B is batch size,
+ Fi is the number of input subbands,
+ T is sequence length,
+ Ci is the number of input channels.
+ - Output: (B, Fi+1, T, Ci+1) where
+ B is batch size,
+ Fi+1 is the number of output subbands,
+ T is sequence length,
+ Ci+1 is the number of output channels.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ bandsplit_ratios: List[float],
+ downsample_strides: List[int],
+ n_conv_modules: List[int],
+ kernel_sizes: List[int] = None,
+ ):
+ """
+ Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes.
+ """
+ super().__init__()
+ if kernel_sizes is None:
+ kernel_sizes = [3, 3, 1]
+ assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1."
+ subband_intervals = create_intervals(bandsplit_ratios)
+ self.sd_layers = nn.ModuleList(
+ SDLayer(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ subband_interval=sbi,
+ downsample_stride=dss,
+ n_conv_modules=ncm,
+ kernel_sizes=kernel_sizes,
+ )
+ for sbi, dss, ncm in zip(
+ subband_intervals, downsample_strides, n_conv_modules
+ )
+ )
+ self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1)
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Performs forward pass through the SDBlock.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
+
+ Returns:
+ - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor.
+ """
+ x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1)
+ x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ return x, x_skip
diff --git a/models/scnet_unofficial/modules/su_decoder.py b/models/scnet_unofficial/modules/su_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..660c1fa6cbfd9b43bed73204a0bb6593524de272
--- /dev/null
+++ b/models/scnet_unofficial/modules/su_decoder.py
@@ -0,0 +1,241 @@
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from models.scnet_unofficial.utils import get_convtranspose_output_padding
+
+
+class FusionLayer(nn.Module):
+ """
+ FusionLayer class implements a module for fusing two input tensors using convolutional operations.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3.
+ - stride (int, optional): Stride value for the convolutional layer. Default is 1.
+ - padding (int, optional): Padding value for the convolutional layer. Default is 1.
+
+ Shapes:
+ - Input: (B, F, T, C) and (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ - Output: (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ """
+
+ def __init__(
+ self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
+ ):
+ """
+ Initializes FusionLayer with input dimension, kernel size, stride, and padding.
+ """
+ super().__init__()
+ self.conv = nn.Conv2d(
+ input_dim * 2,
+ input_dim * 2,
+ kernel_size=(kernel_size, 1),
+ stride=(stride, 1),
+ padding=(padding, 0),
+ )
+ self.activation = nn.GLU()
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the FusionLayer.
+
+ Args:
+ - x1 (torch.Tensor): First input tensor of shape (B, F, T, C).
+ - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, F, T, C).
+ """
+ x = x1 + x2
+ x = x.repeat(1, 1, 1, 2)
+ x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ x = self.activation(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ Upsample class implements a module for upsampling input tensors using transposed 2D convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - stride (int): Stride value for the transposed convolution operation.
+ - output_padding (int): Output padding value for the transposed convolution operation.
+
+ Shapes:
+ - Input: (B, C_in, F, T) where
+ B is batch size,
+ C_in is the number of input channels,
+ F is the frequency dimension,
+ T is the time dimension.
+ - Output: (B, C_out, F * stride + output_padding, T) where
+ B is batch size,
+ C_out is the number of output channels,
+ F * stride + output_padding is the upsampled frequency dimension.
+ """
+
+ def __init__(
+ self, input_dim: int, output_dim: int, stride: int, output_padding: int
+ ):
+ """
+ Initializes Upsample with input dimension, output dimension, stride, and output padding.
+ """
+ super().__init__()
+ self.conv = nn.ConvTranspose2d(
+ input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the Upsample module.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T).
+ """
+ return self.conv(x)
+
+
+class SULayer(nn.Module):
+ """
+ SULayer class implements a subband upsampling layer using transposed convolution.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - upsample_stride (int): Stride value for the upsampling operation.
+ - subband_shape (int): Shape of the subband.
+ - sd_interval (Tuple[int, int]): Start and end indices of the subband interval.
+
+ Shapes:
+ - Input: (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ - Output: (B, F, T, C) where
+ B is batch size,
+ F is the number of features,
+ T is sequence length,
+ C is input dimensionality.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ upsample_stride: int,
+ subband_shape: int,
+ sd_interval: Tuple[int, int],
+ ):
+ """
+ Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval.
+ """
+ super().__init__()
+ sd_shape = sd_interval[1] - sd_interval[0]
+ upsample_output_padding = get_convtranspose_output_padding(
+ input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
+ )
+ self.upsample = Upsample(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ stride=upsample_stride,
+ output_padding=upsample_output_padding,
+ )
+ self.sd_interval = sd_interval
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SULayer.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, F, T, C).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, F, T, C).
+ """
+ x = x[:, self.sd_interval[0] : self.sd_interval[1]]
+ x = x.permute(0, 3, 1, 2)
+ x = self.upsample(x)
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class SUBlock(nn.Module):
+ """
+ SUBlock class implements a block with fusion layer and subband upsampling layers.
+
+ Args:
+ - input_dim (int): Dimensionality of the input channels.
+ - output_dim (int): Dimensionality of the output channels.
+ - upsample_strides (List[int]): List of stride values for the upsampling operations.
+ - subband_shapes (List[int]): List of shapes for the subbands.
+ - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition.
+
+ Shapes:
+ - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where
+ B is batch size,
+ Fi-1 is the number of input subbands,
+ T is sequence length,
+ Ci-1 is the number of input channels.
+ - Output: (B, Fi, T, Ci) where
+ B is batch size,
+ Fi is the number of output subbands,
+ T is sequence length,
+ Ci is the number of output channels.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ upsample_strides: List[int],
+ subband_shapes: List[int],
+ sd_intervals: List[Tuple[int, int]],
+ ):
+ """
+ Initializes SUBlock with input dimension, output dimension,
+ upsample strides, subband shapes, and subband intervals.
+ """
+ super().__init__()
+ self.fusion_layer = FusionLayer(input_dim=input_dim)
+ self.su_layers = nn.ModuleList(
+ SULayer(
+ input_dim=input_dim,
+ output_dim=output_dim,
+ upsample_stride=uss,
+ subband_shape=sbs,
+ sd_interval=sdi,
+ )
+ for i, (uss, sbs, sdi) in enumerate(
+ zip(upsample_strides, subband_shapes, sd_intervals)
+ )
+ )
+
+ def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SUBlock.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1).
+ - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, Fi, T, Ci).
+ """
+ x = self.fusion_layer(x, x_skip)
+ x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
+ return x
diff --git a/models/scnet_unofficial/scnet.py b/models/scnet_unofficial/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d076f85f1d5ce1345dc9a8c56b6a5aef09f2facc
--- /dev/null
+++ b/models/scnet_unofficial/scnet.py
@@ -0,0 +1,249 @@
+'''
+SCNet - great paper, great implementation
+https://arxiv.org/pdf/2401.13276.pdf
+https://github.com/amanteur/SCNet-PyTorch
+'''
+
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchaudio
+
+from models.scnet_unofficial.modules import DualPathRNN, SDBlock, SUBlock
+from models.scnet_unofficial.utils import compute_sd_layer_shapes, compute_gcr
+
+from einops import rearrange, pack, unpack
+from functools import partial
+
+from beartype.typing import Tuple, Optional, List, Callable
+from beartype import beartype
+
+def exists(val):
+ return val is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = dim ** 0.5
+ self.gamma = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
+
+
+class BandSplit(nn.Module):
+ @beartype
+ def __init__(
+ self,
+ dim,
+ dim_inputs: Tuple[int, ...]
+ ):
+ super().__init__()
+ self.dim_inputs = dim_inputs
+ self.to_features = ModuleList([])
+
+ for dim_in in dim_inputs:
+ net = nn.Sequential(
+ RMSNorm(dim_in),
+ nn.Linear(dim_in, dim)
+ )
+
+ self.to_features.append(net)
+
+ def forward(self, x):
+ x = x.split(self.dim_inputs, dim=-1)
+
+ outs = []
+ for split_input, to_feature in zip(x, self.to_features):
+ split_output = to_feature(split_input)
+ outs.append(split_output)
+
+ return torch.stack(outs, dim=-2)
+
+
+class SCNet(nn.Module):
+ """
+ SCNet class implements a source separation network,
+ which explicitly split the spectrogram of the mixture into several subbands
+ and introduce a sparsity-based encoder to model different frequency bands.
+
+ Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION"
+ Authors: Weinan Tong, Jiaxu Zhu et al.
+ Link: https://arxiv.org/abs/2401.13276.pdf
+
+ Args:
+ - n_fft (int): Number of FFTs to determine the frequency dimension of the input.
+ - dims (List[int]): List of channel dimensions for each block.
+ - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
+ - downsample_strides (List[int]): List of stride values for downsampling in each block.
+ - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block.
+ - n_rnn_layers (int): Number of recurrent layers in the dual path RNN.
+ - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN.
+ - n_sources (int, optional): Number of sources to be separated. Default is 4.
+
+ Shapes:
+ - Input: (B, C, T) where
+ B is batch size,
+ C is channel dim (mono / stereo),
+ T is time dim
+ - Output: (B, N, C, T) where
+ B is batch size,
+ N is the number of sources.
+ C is channel dim (mono / stereo),
+ T is sequence length,
+ """
+ @beartype
+ def __init__(
+ self,
+ n_fft: int,
+ dims: List[int],
+ bandsplit_ratios: List[float],
+ downsample_strides: List[int],
+ n_conv_modules: List[int],
+ n_rnn_layers: int,
+ rnn_hidden_dim: int,
+ n_sources: int = 4,
+ hop_length: int = 1024,
+ win_length: int = 4096,
+ stft_window_fn: Optional[Callable] = None,
+ stft_normalized: bool = False,
+ **kwargs
+ ):
+ """
+ Initializes SCNet with input parameters.
+ """
+ super().__init__()
+ self.assert_input_data(
+ bandsplit_ratios,
+ downsample_strides,
+ n_conv_modules,
+ )
+
+ n_blocks = len(dims) - 1
+ n_freq_bins = n_fft // 2 + 1
+ subband_shapes, sd_intervals = compute_sd_layer_shapes(
+ input_shape=n_freq_bins,
+ bandsplit_ratios=bandsplit_ratios,
+ downsample_strides=downsample_strides,
+ n_layers=n_blocks,
+ )
+ self.sd_blocks = nn.ModuleList(
+ SDBlock(
+ input_dim=dims[i],
+ output_dim=dims[i + 1],
+ bandsplit_ratios=bandsplit_ratios,
+ downsample_strides=downsample_strides,
+ n_conv_modules=n_conv_modules,
+ )
+ for i in range(n_blocks)
+ )
+ self.dualpath_blocks = DualPathRNN(
+ n_layers=n_rnn_layers,
+ input_dim=dims[-1],
+ hidden_dim=rnn_hidden_dim,
+ **kwargs
+ )
+ self.su_blocks = nn.ModuleList(
+ SUBlock(
+ input_dim=dims[i + 1],
+ output_dim=dims[i] if i != 0 else dims[i] * n_sources,
+ subband_shapes=subband_shapes[i],
+ sd_intervals=sd_intervals[i],
+ upsample_strides=downsample_strides,
+ )
+ for i in reversed(range(n_blocks))
+ )
+ self.gcr = compute_gcr(subband_shapes)
+
+ self.stft_kwargs = dict(
+ n_fft=n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ normalized=stft_normalized
+ )
+
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length)
+ self.n_sources = n_sources
+ self.hop_length = hop_length
+
+ @staticmethod
+ def assert_input_data(*args):
+ """
+ Asserts that the shapes of input features are equal.
+ """
+ for arg1 in args:
+ for arg2 in args:
+ if len(arg1) != len(arg2):
+ raise ValueError(
+ f"Shapes of input features {arg1} and {arg2} are not equal."
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Performs forward pass through the SCNet.
+
+ Args:
+ - x (torch.Tensor): Input tensor of shape (B, C, T).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (B, N, C, T).
+ """
+
+ device = x.device
+ stft_window = self.stft_window_fn(device=device)
+
+ if x.ndim == 2:
+ x = rearrange(x, 'b t -> b 1 t')
+
+ c = x.shape[1]
+
+ stft_pad = self.hop_length - x.shape[-1] % self.hop_length
+ x = F.pad(x, (0, stft_pad))
+
+ # stft
+ x, ps = pack_one(x, '* t')
+ x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True)
+ x = torch.view_as_real(x)
+ x = unpack_one(x, ps, '* c f t')
+ x = rearrange(x, 'b c f t r -> b f t (c r)')
+
+ # encoder part
+ x_skips = []
+ for sd_block in self.sd_blocks:
+ x, x_skip = sd_block(x)
+ x_skips.append(x_skip)
+
+ # separation part
+ x = self.dualpath_blocks(x)
+
+ # decoder part
+ for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)):
+ x = su_block(x, x_skip)
+
+ # istft
+ x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2)
+ x = x.contiguous()
+
+ x = torch.view_as_complex(x)
+ x = rearrange(x, 'b n c f t -> (b n c) f t')
+ x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False)
+ x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources)
+
+ x = x[..., :-stft_pad]
+
+ return x
diff --git a/models/scnet_unofficial/utils.py b/models/scnet_unofficial/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aae1afcd52e8088926ea984e52c9b62ca68be65c
--- /dev/null
+++ b/models/scnet_unofficial/utils.py
@@ -0,0 +1,135 @@
+'''
+SCNet - great paper, great implementation
+https://arxiv.org/pdf/2401.13276.pdf
+https://github.com/amanteur/SCNet-PyTorch
+'''
+
+from typing import List, Tuple, Union
+
+import torch
+
+
+def create_intervals(
+ splits: List[Union[float, int]]
+) -> List[Union[Tuple[float, float], Tuple[int, int]]]:
+ """
+ Create intervals based on splits provided.
+
+ Args:
+ - splits (List[Union[float, int]]): List of floats or integers representing splits.
+
+ Returns:
+ - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals.
+ """
+ start = 0
+ return [(start, start := start + split) for split in splits]
+
+
+def get_conv_output_shape(
+ input_shape: int,
+ kernel_size: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ stride: int = 1,
+) -> int:
+ """
+ Compute the output shape of a convolutional layer.
+
+ Args:
+ - input_shape (int): Input shape.
+ - kernel_size (int, optional): Kernel size of the convolution. Default is 1.
+ - padding (int, optional): Padding size. Default is 0.
+ - dilation (int, optional): Dilation factor. Default is 1.
+ - stride (int, optional): Stride value. Default is 1.
+
+ Returns:
+ - int: Output shape.
+ """
+ return int(
+ (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
+ )
+
+
+def get_convtranspose_output_padding(
+ input_shape: int,
+ output_shape: int,
+ kernel_size: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ stride: int = 1,
+) -> int:
+ """
+ Compute the output padding for a convolution transpose operation.
+
+ Args:
+ - input_shape (int): Input shape.
+ - output_shape (int): Desired output shape.
+ - kernel_size (int, optional): Kernel size of the convolution. Default is 1.
+ - padding (int, optional): Padding size. Default is 0.
+ - dilation (int, optional): Dilation factor. Default is 1.
+ - stride (int, optional): Stride value. Default is 1.
+
+ Returns:
+ - int: Output padding.
+ """
+ return (
+ output_shape
+ - (input_shape - 1) * stride
+ + 2 * padding
+ - dilation * (kernel_size - 1)
+ - 1
+ )
+
+
+def compute_sd_layer_shapes(
+ input_shape: int,
+ bandsplit_ratios: List[float],
+ downsample_strides: List[int],
+ n_layers: int,
+) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]:
+ """
+ Compute the shapes for the subband layers.
+
+ Args:
+ - input_shape (int): Input shape.
+ - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands.
+ - downsample_strides (List[int]): Strides for downsampling in each layer.
+ - n_layers (int): Number of layers.
+
+ Returns:
+ - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes.
+ """
+ bandsplit_shapes_list = []
+ conv2d_shapes_list = []
+ for _ in range(n_layers):
+ bandsplit_intervals = create_intervals(bandsplit_ratios)
+ bandsplit_shapes = [
+ int(right * input_shape) - int(left * input_shape)
+ for left, right in bandsplit_intervals
+ ]
+ conv2d_shapes = [
+ get_conv_output_shape(bs, stride=ds)
+ for bs, ds in zip(bandsplit_shapes, downsample_strides)
+ ]
+ input_shape = sum(conv2d_shapes)
+ bandsplit_shapes_list.append(bandsplit_shapes)
+ conv2d_shapes_list.append(create_intervals(conv2d_shapes))
+
+ return bandsplit_shapes_list, conv2d_shapes_list
+
+
+def compute_gcr(subband_shapes: List[List[int]]) -> float:
+ """
+ Compute the global compression ratio.
+
+ Args:
+ - subband_shapes (List[List[int]]): List of subband shapes.
+
+ Returns:
+ - float: Global compression ratio.
+ """
+ t = torch.Tensor(subband_shapes)
+ gcr = torch.stack(
+ [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)]
+ ).mean()
+ return float(gcr)
\ No newline at end of file
diff --git a/models/segm_models.py b/models/segm_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31ffaba234ccc6c833669f428c17214938adc83
--- /dev/null
+++ b/models/segm_models.py
@@ -0,0 +1,255 @@
+import torch
+import torch.nn as nn
+import segmentation_models_pytorch as smp
+from utils import prefer_target_instrument
+
+class STFT:
+ def __init__(self, config):
+ self.n_fft = config.n_fft
+ self.hop_length = config.hop_length
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
+ self.dim_f = config.dim_f
+
+ def __call__(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-2]
+ c, t = x.shape[-2:]
+ x = x.reshape([-1, t])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ return_complex=True
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
+ return x[..., :self.dim_f, :]
+
+ def inverse(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-3]
+ c, f, t = x.shape[-3:]
+ n = self.n_fft // 2 + 1
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
+ x = torch.cat([x, f_pad], -2)
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
+ x = x.permute([0, 2, 3, 1])
+ x = x[..., 0] + x[..., 1] * 1.j
+ x = torch.istft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True
+ )
+ x = x.reshape([*batch_dims, 2, -1])
+ return x
+
+
+def get_act(act_type):
+ if act_type == 'gelu':
+ return nn.GELU()
+ elif act_type == 'relu':
+ return nn.ReLU()
+ elif act_type[:3] == 'elu':
+ alpha = float(act_type.replace('elu', ''))
+ return nn.ELU(alpha)
+ else:
+ raise Exception
+
+
+def get_decoder(config, c):
+ decoder = None
+ decoder_options = dict()
+ if config.model.decoder_type == 'unet':
+ try:
+ decoder_options = dict(config.decoder_unet)
+ except:
+ pass
+ decoder = smp.Unet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'fpn':
+ try:
+ decoder_options = dict(config.decoder_fpn)
+ except:
+ pass
+ decoder = smp.FPN(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'unet++':
+ try:
+ decoder_options = dict(config.decoder_unet_plus_plus)
+ except:
+ pass
+ decoder = smp.UnetPlusPlus(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'manet':
+ try:
+ decoder_options = dict(config.decoder_manet)
+ except:
+ pass
+ decoder = smp.MAnet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'linknet':
+ try:
+ decoder_options = dict(config.decoder_linknet)
+ except:
+ pass
+ decoder = smp.Linknet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'pspnet':
+ try:
+ decoder_options = dict(config.decoder_pspnet)
+ except:
+ pass
+ decoder = smp.PSPNet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'pspnet':
+ try:
+ decoder_options = dict(config.decoder_pspnet)
+ except:
+ pass
+ decoder = smp.PSPNet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'pan':
+ try:
+ decoder_options = dict(config.decoder_pan)
+ except:
+ pass
+ decoder = smp.PAN(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'deeplabv3':
+ try:
+ decoder_options = dict(config.decoder_deeplabv3)
+ except:
+ pass
+ decoder = smp.DeepLabV3(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'deeplabv3plus':
+ try:
+ decoder_options = dict(config.decoder_deeplabv3plus)
+ except:
+ pass
+ decoder = smp.DeepLabV3Plus(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ return decoder
+
+
+class Segm_Models_Net(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ act = get_act(act_type=config.model.act)
+
+ self.num_target_instruments = len(prefer_target_instrument(config))
+ self.num_subbands = config.model.num_subbands
+
+ dim_c = self.num_subbands * config.audio.num_channels * 2
+ c = config.model.num_channels
+ f = config.audio.dim_f // self.num_subbands
+
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
+
+ self.unet_model = get_decoder(config, c)
+
+ self.final_conv = nn.Sequential(
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
+ act,
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ )
+
+ self.stft = STFT(config.audio)
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, x):
+
+ x = self.stft(x)
+
+ mix = x = self.cac2cws(x)
+
+ first_conv_out = x = self.first_conv(x)
+
+ x = x.transpose(-1, -2)
+
+ x = self.unet_model(x)
+
+ x = x.transpose(-1, -2)
+
+ x = x * first_conv_out # reduce artifacts
+
+ x = self.final_conv(torch.cat([mix, x], 1))
+
+ x = self.cws2cac(x)
+
+ if self.num_target_instruments > 1:
+ b, c, f, t = x.shape
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
+
+ x = self.stft.inverse(x)
+ return x
diff --git a/models/torchseg_models.py b/models/torchseg_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c895ad6a75c39e3d2c8621cf361f5b0e2b949e9
--- /dev/null
+++ b/models/torchseg_models.py
@@ -0,0 +1,255 @@
+import torch
+import torch.nn as nn
+import torchseg as smp
+from utils import prefer_target_instrument
+
+class STFT:
+ def __init__(self, config):
+ self.n_fft = config.n_fft
+ self.hop_length = config.hop_length
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
+ self.dim_f = config.dim_f
+
+ def __call__(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-2]
+ c, t = x.shape[-2:]
+ x = x.reshape([-1, t])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ return_complex=True
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
+ return x[..., :self.dim_f, :]
+
+ def inverse(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-3]
+ c, f, t = x.shape[-3:]
+ n = self.n_fft // 2 + 1
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
+ x = torch.cat([x, f_pad], -2)
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
+ x = x.permute([0, 2, 3, 1])
+ x = x[..., 0] + x[..., 1] * 1.j
+ x = torch.istft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True
+ )
+ x = x.reshape([*batch_dims, 2, -1])
+ return x
+
+
+def get_act(act_type):
+ if act_type == 'gelu':
+ return nn.GELU()
+ elif act_type == 'relu':
+ return nn.ReLU()
+ elif act_type[:3] == 'elu':
+ alpha = float(act_type.replace('elu', ''))
+ return nn.ELU(alpha)
+ else:
+ raise Exception
+
+
+def get_decoder(config, c):
+ decoder = None
+ decoder_options = dict()
+ if config.model.decoder_type == 'unet':
+ try:
+ decoder_options = dict(config.decoder_unet)
+ except:
+ pass
+ decoder = smp.Unet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'fpn':
+ try:
+ decoder_options = dict(config.decoder_fpn)
+ except:
+ pass
+ decoder = smp.FPN(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'unet++':
+ try:
+ decoder_options = dict(config.decoder_unet_plus_plus)
+ except:
+ pass
+ decoder = smp.UnetPlusPlus(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'manet':
+ try:
+ decoder_options = dict(config.decoder_manet)
+ except:
+ pass
+ decoder = smp.MAnet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'linknet':
+ try:
+ decoder_options = dict(config.decoder_linknet)
+ except:
+ pass
+ decoder = smp.Linknet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'pspnet':
+ try:
+ decoder_options = dict(config.decoder_pspnet)
+ except:
+ pass
+ decoder = smp.PSPNet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'pspnet':
+ try:
+ decoder_options = dict(config.decoder_pspnet)
+ except:
+ pass
+ decoder = smp.PSPNet(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'pan':
+ try:
+ decoder_options = dict(config.decoder_pan)
+ except:
+ pass
+ decoder = smp.PAN(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'deeplabv3':
+ try:
+ decoder_options = dict(config.decoder_deeplabv3)
+ except:
+ pass
+ decoder = smp.DeepLabV3(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ elif config.model.decoder_type == 'deeplabv3plus':
+ try:
+ decoder_options = dict(config.decoder_deeplabv3plus)
+ except:
+ pass
+ decoder = smp.DeepLabV3Plus(
+ encoder_name=config.model.encoder_name,
+ encoder_weights="imagenet",
+ in_channels=c,
+ classes=c,
+ **decoder_options,
+ )
+ return decoder
+
+
+class Torchseg_Net(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ act = get_act(act_type=config.model.act)
+
+ self.num_target_instruments = len(prefer_target_instrument(config))
+ self.num_subbands = config.model.num_subbands
+
+ dim_c = self.num_subbands * config.audio.num_channels * 2
+ c = config.model.num_channels
+ f = config.audio.dim_f // self.num_subbands
+
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
+
+ self.unet_model = get_decoder(config, c)
+
+ self.final_conv = nn.Sequential(
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
+ act,
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ )
+
+ self.stft = STFT(config.audio)
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, x):
+
+ x = self.stft(x)
+
+ mix = x = self.cac2cws(x)
+
+ first_conv_out = x = self.first_conv(x)
+
+ x = x.transpose(-1, -2)
+
+ x = self.unet_model(x)
+
+ x = x.transpose(-1, -2)
+
+ x = x * first_conv_out # reduce artifacts
+
+ x = self.final_conv(torch.cat([mix, x], 1))
+
+ x = self.cws2cac(x)
+
+ if self.num_target_instruments > 1:
+ b, c, f, t = x.shape
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
+
+ x = self.stft.inverse(x)
+ return x
diff --git a/models/ts_bs_mamba2.py b/models/ts_bs_mamba2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a01a2a58c5418bfd846d6edb6900810392091189
--- /dev/null
+++ b/models/ts_bs_mamba2.py
@@ -0,0 +1,319 @@
+# https://github.com/Human9000/nd-Mamba2-torch
+
+from __future__ import print_function
+
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.utils.checkpoint import checkpoint_sequential
+try:
+ from mamba_ssm.modules.mamba2 import Mamba2
+except Exception as e:
+ print('Exception during load Mamba2 modules: {}'.format(str(e)))
+ print('Load local torch implementation!')
+ from .ex_bi_mamba2 import Mamba2
+
+
+class MambaBlock(nn.Module):
+ def __init__(self, in_channels):
+ super(MambaBlock, self).__init__()
+ self.forward_mamba2 = Mamba2(
+ d_model=in_channels,
+ d_state=128,
+ d_conv=4,
+ expand=4,
+ headdim=64,
+ )
+
+ self.backward_mamba2 = Mamba2(
+ d_model=in_channels,
+ d_state=128,
+ d_conv=4,
+ expand=4,
+ headdim=64,
+ )
+ def forward(self, input):
+ forward_f = input
+ forward_f_output = self.forward_mamba2(forward_f)
+ backward_f = torch.flip(input, [1])
+ backward_f_output = self.backward_mamba2(backward_f)
+ backward_f_output2 = torch.flip(backward_f_output, [1])
+ output = torch.cat([forward_f_output + input, backward_f_output2+input], -1)
+ return output
+
+class TAC(nn.Module):
+ """
+ A transform-average-concatenate (TAC) module.
+ """
+ def __init__(self, input_size, hidden_size):
+ super(TAC, self).__init__()
+
+ self.input_size = input_size
+ self.eps = torch.finfo(torch.float32).eps
+
+ self.input_norm = nn.GroupNorm(1, input_size, self.eps)
+ self.TAC_input = nn.Sequential(nn.Linear(input_size, hidden_size),
+ nn.Tanh()
+ )
+ self.TAC_mean = nn.Sequential(nn.Linear(hidden_size, hidden_size),
+ nn.Tanh()
+ )
+ self.TAC_output = nn.Sequential(nn.Linear(hidden_size*2, input_size),
+ nn.Tanh()
+ )
+
+ def forward(self, input):
+ # input shape: batch, group, N, *
+
+ batch_size, G, N = input.shape[:3]
+ output = self.input_norm(input.view(batch_size*G, N, -1)).view(batch_size, G, N, -1)
+ T = output.shape[-1]
+
+ # transform
+ group_input = output # B, G, N, T
+ group_input = group_input.permute(0,3,1,2).contiguous().view(-1, N) # B*T*G, N
+ group_output = self.TAC_input(group_input).view(batch_size, T, G, -1) # B, T, G, H
+
+ # mean pooling
+ group_mean = group_output.mean(2).view(batch_size*T, -1) # B*T, H
+ group_mean = self.TAC_mean(group_mean).unsqueeze(1).expand(batch_size*T, G, group_mean.shape[-1]).contiguous() # B*T, G, H
+
+ # concate
+ group_output = group_output.view(batch_size*T, G, -1) # B*T, G, H
+ group_output = torch.cat([group_output, group_mean], 2) # B*T, G, 2H
+ group_output = self.TAC_output(group_output.view(-1, group_output.shape[-1])) # B*T*G, N
+ group_output = group_output.view(batch_size, T, G, -1).permute(0,2,3,1).contiguous() # B, G, N, T
+ output = input + group_output.view(input.shape)
+
+ return output
+
+class ResMamba(nn.Module):
+ def __init__(self, input_size, hidden_size, dropout=0., bidirectional=True):
+ super(ResMamba, self).__init__()
+
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.eps = torch.finfo(torch.float32).eps
+
+ self.norm = nn.GroupNorm(1, input_size, self.eps)
+ self.dropout = nn.Dropout(p=dropout)
+ self.rnn = MambaBlock(input_size)
+ self.proj = nn.Linear(input_size*2 ,input_size)
+ # linear projection layer
+
+ def forward(self, input):
+ # input shape: batch, dim, seq
+ rnn_output = self.rnn(self.dropout(self.norm(input)).transpose(1, 2).contiguous())
+ rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(input.shape[0],
+ input.shape[2],
+ input.shape[1])
+
+ return input + rnn_output.transpose(1, 2).contiguous()
+
+class BSNet(nn.Module):
+ def __init__(self, in_channel, nband=7):
+ super(BSNet, self).__init__()
+
+ self.nband = nband
+ self.feature_dim = in_channel // nband
+
+ self.band_rnn = ResMamba(self.feature_dim, self.feature_dim*2)
+ self.band_comm = ResMamba(self.feature_dim, self.feature_dim*2)
+ self.channel_comm = TAC(self.feature_dim, self.feature_dim*3)
+
+ def forward(self, input):
+ # input shape: B, nch, nband*N, T
+ B, nch, N, T = input.shape
+
+ band_output = self.band_rnn(input.view(B*nch*self.nband, self.feature_dim, -1)).view(B*nch, self.nband, -1, T)
+
+ # band comm
+ band_output = band_output.permute(0,3,2,1).contiguous().view(B*nch*T, -1, self.nband)
+ output = self.band_comm(band_output).view(B*nch, T, -1, self.nband).permute(0,3,2,1).contiguous()
+
+ # channel comm
+ output = output.view(B, nch, self.nband, -1, T).transpose(1,2).contiguous().view(B*self.nband, nch, -1, T)
+ output = self.channel_comm(output).view(B, self.nband, nch, -1, T).transpose(1,2).contiguous()
+
+ return output.view(B, nch, N, T)
+
+class Separator(nn.Module):
+ def __init__(self, sr=44100, win=2048, stride=512, feature_dim=128, num_repeat_mask=8, num_repeat_map=4, num_output=4):
+ super(Separator, self).__init__()
+
+ self.sr = sr
+ self.win = win
+ self.stride = stride
+ self.group = self.win // 2
+ self.enc_dim = self.win // 2 + 1
+ self.feature_dim = feature_dim
+ self.num_output = num_output
+ self.eps = torch.finfo(torch.float32).eps
+
+ # 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-16k (1k hop), 16k-20k (2k hop), 20k-inf
+ bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim))
+ bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim))
+ bandwidth_250 = int(np.floor(250 / (sr / 2.) * self.enc_dim))
+ bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim))
+ bandwidth_1k = int(np.floor(1000 / (sr / 2.) * self.enc_dim))
+ bandwidth_2k = int(np.floor(2000 / (sr / 2.) * self.enc_dim))
+ self.band_width = [bandwidth_50]*20
+ self.band_width += [bandwidth_100]*10
+ self.band_width += [bandwidth_250]*8
+ self.band_width += [bandwidth_500]*8
+ self.band_width += [bandwidth_1k]*8
+ self.band_width += [bandwidth_2k]*2
+ self.band_width.append(self.enc_dim - np.sum(self.band_width))
+ self.nband = len(self.band_width)
+ print(self.band_width)
+
+ self.BN_mask = nn.ModuleList([])
+ for i in range(self.nband):
+ self.BN_mask.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i]*2, self.eps),
+ nn.Conv1d(self.band_width[i]*2, self.feature_dim, 1)
+ )
+ )
+
+ self.BN_map = nn.ModuleList([])
+ for i in range(self.nband):
+ self.BN_map.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
+ nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1)
+ )
+ )
+
+ self.separator_mask = []
+ for i in range(num_repeat_mask):
+ self.separator_mask.append(BSNet(self.nband*self.feature_dim, self.nband))
+ self.separator_mask = nn.Sequential(*self.separator_mask)
+
+ self.separator_map = []
+ for i in range(num_repeat_map):
+ self.separator_map.append(BSNet(self.nband * self.feature_dim, self.nband))
+ self.separator_map = nn.Sequential(*self.separator_map)
+
+ self.in_conv = nn.Conv1d(self.feature_dim*2, self.feature_dim, 1)
+ self.Tanh = nn.Tanh()
+ self.mask = nn.ModuleList([])
+ self.map = nn.ModuleList([])
+ for i in range(self.nband):
+ self.mask.append(nn.Sequential(nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),
+ nn.Conv1d(self.feature_dim, self.feature_dim*1*self.num_output, 1),
+ nn.Tanh(),
+ nn.Conv1d(self.feature_dim*1*self.num_output, self.feature_dim*1*self.num_output, 1, groups=self.num_output),
+ nn.Tanh(),
+ nn.Conv1d(self.feature_dim*1*self.num_output, self.band_width[i]*4*self.num_output, 1, groups=self.num_output)
+ )
+ )
+ self.map.append(nn.Sequential(nn.GroupNorm(1, self.feature_dim, torch.finfo(torch.float32).eps),
+ nn.Conv1d(self.feature_dim, self.feature_dim*1*self.num_output, 1),
+ nn.Tanh(),
+ nn.Conv1d(self.feature_dim*1*self.num_output, self.feature_dim*1*self.num_output, 1, groups=self.num_output),
+ nn.Tanh(),
+ nn.Conv1d(self.feature_dim*1*self.num_output, self.band_width[i]*4*self.num_output, 1, groups=self.num_output)
+ )
+ )
+
+ def pad_input(self, input, window, stride):
+ """
+ Zero-padding input according to window/stride size.
+ """
+ batch_size, nsample = input.shape
+
+ # pad the signals at the end for matching the window/stride size
+ rest = window - (stride + nsample % window) % window
+ if rest > 0:
+ pad = torch.zeros(batch_size, rest).type(input.type())
+ input = torch.cat([input, pad], 1)
+ pad_aux = torch.zeros(batch_size, stride).type(input.type())
+ input = torch.cat([pad_aux, input, pad_aux], 1)
+
+ return input, rest
+
+ def forward(self, input):
+ # input shape: (B, C, T)
+
+ batch_size, nch, nsample = input.shape
+ input = input.view(batch_size*nch, -1)
+
+ # frequency-domain separation
+ spec = torch.stft(input, n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device).type(input.type()),
+ return_complex=True)
+
+ # concat real and imag, split to subbands
+ spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
+ subband_spec_RI = []
+ subband_spec = []
+ band_idx = 0
+ for i in range(len(self.band_width)):
+ subband_spec_RI.append(spec_RI[:,:,band_idx:band_idx+self.band_width[i]].contiguous())
+ subband_spec.append(spec[:,band_idx:band_idx+self.band_width[i]]) # B*nch, BW, T
+ band_idx += self.band_width[i]
+
+ # normalization and bottleneck
+ subband_feature_mask = []
+ for i in range(len(self.band_width)):
+ subband_feature_mask.append(self.BN_mask[i](subband_spec_RI[i].view(batch_size*nch, self.band_width[i]*2, -1)))
+ subband_feature_mask = torch.stack(subband_feature_mask, 1) # B, nband, N, T
+
+ subband_feature_map = []
+ for i in range(len(self.band_width)):
+ subband_feature_map.append(self.BN_map[i](subband_spec_RI[i].view(batch_size * nch, self.band_width[i] * 2, -1)))
+ subband_feature_map = torch.stack(subband_feature_map, 1) # B, nband, N, T
+ # separator
+ sep_output = checkpoint_sequential(self.separator_mask, 2, subband_feature_mask.view(batch_size, nch, self.nband*self.feature_dim, -1)) # B, nband*N, T
+ sep_output = sep_output.view(batch_size*nch, self.nband, self.feature_dim, -1)
+ combined = torch.cat((subband_feature_map,sep_output), dim=2)
+ combined1 = combined.reshape(batch_size * nch * self.nband,self.feature_dim*2,-1)
+ combined2 = self.Tanh(self.in_conv(combined1))
+ combined3 = combined2.reshape(batch_size * nch, self.nband,self.feature_dim,-1)
+ sep_output2 = checkpoint_sequential(self.separator_map, 2, combined3.view(batch_size, nch, self.nband*self.feature_dim, -1)) # 1B, nband*N, T
+ sep_output2 = sep_output2.view(batch_size * nch, self.nband, self.feature_dim, -1)
+
+ sep_subband_spec = []
+ sep_subband_spec_mask = []
+ for i in range(self.nband):
+ this_output = self.mask[i](sep_output[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1)
+ this_mask = this_output[:,0] * torch.sigmoid(this_output[:,1]) # B*nch, 2, K, BW, T
+ this_mask_real = this_mask[:,0] # B*nch, K, BW, T
+ this_mask_imag = this_mask[:,1] # B*nch, K, BW, T
+ # force mask sum to 1
+ this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T
+ this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T
+ this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output
+ this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output
+ est_spec_real = subband_spec[i].real.unsqueeze(1) * this_mask_real - subband_spec[i].imag.unsqueeze(1) * this_mask_imag # B*nch, K, BW, T
+ est_spec_imag = subband_spec[i].real.unsqueeze(1) * this_mask_imag + subband_spec[i].imag.unsqueeze(1) * this_mask_real # B*nch, K, BW, T
+
+ ##################################
+ this_output2 = self.map[i](sep_output2[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1)
+ this_map = this_output2[:,0] * torch.sigmoid(this_output2[:,1]) # B*nch, 2, K, BW, T
+ this_map_real = this_map[:,0] # B*nch, K, BW, T
+ this_map_imag = this_map[:,1] # B*nch, K, BW, T
+ est_spec_real2 = est_spec_real+this_map_real
+ est_spec_imag2 = est_spec_imag+this_map_imag
+
+ sep_subband_spec.append(torch.complex(est_spec_real2, est_spec_imag2))
+ sep_subband_spec_mask.append(torch.complex(est_spec_real, est_spec_imag))
+
+ sep_subband_spec = torch.cat(sep_subband_spec, 2)
+ est_spec_mask = torch.cat(sep_subband_spec_mask, 2)
+
+ output = torch.istft(sep_subband_spec.view(batch_size*nch*self.num_output, self.enc_dim, -1),
+ n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample)
+ output_mask = torch.istft(est_spec_mask.view(batch_size*nch*self.num_output, self.enc_dim, -1),
+ n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample)
+
+ output = output.view(batch_size, nch, self.num_output, -1).transpose(1,2).contiguous()
+ output_mask = output_mask.view(batch_size, nch, self.num_output, -1).transpose(1,2).contiguous()
+ # return output, output_mask
+ return output
+
+
+if __name__ == '__main__':
+ model = Separator().cuda()
+ arr = np.zeros((1, 2, 3*44100), dtype=np.float32)
+ x = torch.from_numpy(arr).cuda()
+ res = model(x)
diff --git a/models/upernet_swin_transformers.py b/models/upernet_swin_transformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..87d0aac73bec755b9ae8539968e5ca6e37712c24
--- /dev/null
+++ b/models/upernet_swin_transformers.py
@@ -0,0 +1,228 @@
+from functools import partial
+import torch
+import torch.nn as nn
+from transformers import UperNetForSemanticSegmentation
+from utils import prefer_target_instrument
+
+class STFT:
+ def __init__(self, config):
+ self.n_fft = config.n_fft
+ self.hop_length = config.hop_length
+ self.window = torch.hann_window(window_length=self.n_fft, periodic=True)
+ self.dim_f = config.dim_f
+
+ def __call__(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-2]
+ c, t = x.shape[-2:]
+ x = x.reshape([-1, t])
+ x = torch.stft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True,
+ return_complex=True
+ )
+ x = torch.view_as_real(x)
+ x = x.permute([0, 3, 1, 2])
+ x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]])
+ return x[..., :self.dim_f, :]
+
+ def inverse(self, x):
+ window = self.window.to(x.device)
+ batch_dims = x.shape[:-3]
+ c, f, t = x.shape[-3:]
+ n = self.n_fft // 2 + 1
+ f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device)
+ x = torch.cat([x, f_pad], -2)
+ x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t])
+ x = x.permute([0, 2, 3, 1])
+ x = x[..., 0] + x[..., 1] * 1.j
+ x = torch.istft(
+ x,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ window=window,
+ center=True
+ )
+ x = x.reshape([*batch_dims, 2, -1])
+ return x
+
+
+def get_norm(norm_type):
+ def norm(c, norm_type):
+ if norm_type == 'BatchNorm':
+ return nn.BatchNorm2d(c)
+ elif norm_type == 'InstanceNorm':
+ return nn.InstanceNorm2d(c, affine=True)
+ elif 'GroupNorm' in norm_type:
+ g = int(norm_type.replace('GroupNorm', ''))
+ return nn.GroupNorm(num_groups=g, num_channels=c)
+ else:
+ return nn.Identity()
+
+ return partial(norm, norm_type=norm_type)
+
+
+def get_act(act_type):
+ if act_type == 'gelu':
+ return nn.GELU()
+ elif act_type == 'relu':
+ return nn.ReLU()
+ elif act_type[:3] == 'elu':
+ alpha = float(act_type.replace('elu', ''))
+ return nn.ELU(alpha)
+ else:
+ raise Exception
+
+
+class Upscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Downscale(nn.Module):
+ def __init__(self, in_c, out_c, scale, norm, act):
+ super().__init__()
+ self.conv = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False)
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class TFC_TDF(nn.Module):
+ def __init__(self, in_c, c, l, f, bn, norm, act):
+ super().__init__()
+
+ self.blocks = nn.ModuleList()
+ for i in range(l):
+ block = nn.Module()
+
+ block.tfc1 = nn.Sequential(
+ norm(in_c),
+ act,
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
+ )
+ block.tdf = nn.Sequential(
+ norm(c),
+ act,
+ nn.Linear(f, f // bn, bias=False),
+ norm(c),
+ act,
+ nn.Linear(f // bn, f, bias=False),
+ )
+ block.tfc2 = nn.Sequential(
+ norm(c),
+ act,
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
+ )
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
+
+ self.blocks.append(block)
+ in_c = c
+
+ def forward(self, x):
+ for block in self.blocks:
+ s = block.shortcut(x)
+ x = block.tfc1(x)
+ x = x + block.tdf(x)
+ x = block.tfc2(x)
+ x = x + s
+ return x
+
+
+class Swin_UperNet_Model(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ act = get_act(act_type=config.model.act)
+
+ self.num_target_instruments = len(prefer_target_instrument(config))
+ self.num_subbands = config.model.num_subbands
+
+ dim_c = self.num_subbands * config.audio.num_channels * 2
+ c = config.model.num_channels
+ f = config.audio.dim_f // self.num_subbands
+
+ self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False)
+
+ self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large")
+
+ self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(256, c, kernel_size=(1, 1), stride=(1, 1))
+ self.swin_upernet_model.decode_head.classifier = nn.Conv2d(512, c, kernel_size=(1, 1), stride=(1, 1))
+ self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4))
+
+ self.final_conv = nn.Sequential(
+ nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False),
+ act,
+ nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False)
+ )
+
+ self.stft = STFT(config.audio)
+
+ def cac2cws(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c, k, f // k, t)
+ x = x.reshape(b, c * k, f // k, t)
+ return x
+
+ def cws2cac(self, x):
+ k = self.num_subbands
+ b, c, f, t = x.shape
+ x = x.reshape(b, c // k, k, f, t)
+ x = x.reshape(b, c // k, f * k, t)
+ return x
+
+ def forward(self, x):
+
+ x = self.stft(x)
+
+ mix = x = self.cac2cws(x)
+
+ first_conv_out = x = self.first_conv(x)
+
+ x = x.transpose(-1, -2)
+
+ x = self.swin_upernet_model(x).logits
+
+ x = x.transpose(-1, -2)
+
+ x = x * first_conv_out # reduce artifacts
+
+ x = self.final_conv(torch.cat([mix, x], 1))
+
+ x = self.cws2cac(x)
+
+ if self.num_target_instruments > 1:
+ b, c, f, t = x.shape
+ x = x.reshape(b, self.num_target_instruments, -1, f, t)
+
+ x = self.stft.inverse(x)
+ return x
+
+
+if __name__ == "__main__":
+ model = UperNetForSemanticSegmentation.from_pretrained("./results/", ignore_mismatched_sizes=True)
+ print(model)
+ print(model.auxiliary_head.classifier)
+ print(model.decode_head.classifier)
+
+ x = torch.zeros((2, 16, 512, 512), dtype=torch.float32)
+ res = model(x)
+ print(res.logits.shape)
+ model.save_pretrained('./results/')
\ No newline at end of file
diff --git a/tests/admin_test.py b/tests/admin_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..36df717c588cd0d72a18aed2893197c0c181a680
--- /dev/null
+++ b/tests/admin_test.py
@@ -0,0 +1,179 @@
+import shutil
+from test import test_settings
+from scripts.redact_config import redact_config
+from utils import load_config
+from pathlib import Path
+import os
+import numpy as np
+import soundfile as sf
+from typing import List, Dict
+
+MODEL_CONFIGS = {
+ 'config_apollo.yaml': {'model_type': 'apollo'},
+ 'config_dnr_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'},
+ 'config_dnr_bandit_v2_mus64.yaml': {'model_type': 'bandit_v2'},
+ 'config_drumsep.yaml': {'model_type': 'htdemucs'},
+ 'config_htdemucs_6stems.yaml': {'model_type': 'htdemucs'},
+ 'config_musdb18_bs_roformer.yaml': {'model_type': 'bs_roformer'},
+ 'config_musdb18_demucs3_mmi.yaml': {'model_type': 'htdemucs'},
+ 'config_musdb18_htdemucs.yaml': {'model_type': 'htdemucs'},
+ 'config_musdb18_mdx23c.yaml': {'model_type': 'mdx23c'},
+ 'config_musdb18_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'},
+ 'config_musdb18_mel_band_roformer_all_stems.yaml': {'model_type': 'mel_band_roformer'},
+ 'config_musdb18_scnet.yaml': {'model_type': 'scnet'},
+ 'config_musdb18_scnet_large.yaml': {'model_type': 'scnet'},
+ # 'config_musdb18_scnet_large_starrytong.yaml': {'model_type': 'scnet'},
+ 'config_vocals_bandit_bsrnn_multi_mus64.yaml': {'model_type': 'bandit'},
+ 'config_vocals_bs_roformer.yaml': {'model_type': 'bs_roformer'},
+ 'config_vocals_htdemucs.yaml': {'model_type': 'htdemucs'},
+ 'config_vocals_mdx23c.yaml': {'model_type': 'mdx23c'},
+ 'config_vocals_mel_band_roformer.yaml': {'model_type': 'mel_band_roformer'},
+ 'config_vocals_scnet.yaml': {'model_type': 'scnet'},
+ 'config_vocals_scnet_large.yaml': {'model_type': 'scnet'},
+ 'config_vocals_scnet_unofficial.yaml': {'model_type': 'scnet_unofficial'},
+ 'config_vocals_segm_models.yaml': {'model_type': 'segm_models'},
+
+
+ # 'config_vocals_swin_upernet.yaml': {'model_type': 'swin_upernet'},
+ # 'config_musdb18_torchseg.yaml': {'model_type': 'torchseg'},
+ # 'config_musdb18_segm_models.yaml': {'model_type': 'segm_models'},
+ # 'config_musdb18_bs_mamba2.yaml': {'model_type': 'bs_mamba2'},
+ # 'config_vocals_bs_mamba2.yaml': {'model_type': 'bs_mamba2'},
+ # 'config_vocals_torchseg.yaml': {'model_type': 'torchseg'}
+}
+
+
+# Folders for tests
+ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+CONFIGS_DIR = ROOT_DIR / 'configs/'
+TEST_DIR = ROOT_DIR / "tests_cache/"
+TRAIN_DIR = TEST_DIR / "train_tracks/"
+VALID_DIR = TEST_DIR / "valid_tracks/"
+
+
+def create_dummy_tracks(directory: Path, num_tracks: int, instruments: List[str],
+ duration: float = 5.0, sample_rate: int = 44100) -> None:
+ """
+ Generates random audio tracks for stems in two subdirectories within the specified directory.
+
+ Parameters:
+ ----------
+ directory : Path
+ Path to the directory where the tracks will be saved.
+ num_tracks : int
+ Number of tracks to generate in each folder.
+ instruments : List[str]
+ List of instrument names (stems) to create.
+ duration : float, optional
+ Duration of each track in seconds. Default is 5.0.
+ sample_rate : int, optional
+ Sampling rate of the generated audio. Default is 44100 Hz.
+
+ Returns:
+ -------
+ None
+ """
+
+ os.makedirs(directory, exist_ok=True)
+
+ for folder_name in [str(i) for i in range(1, num_tracks+1)]:
+ folder_path = directory / folder_name
+ os.makedirs(folder_path, exist_ok=True)
+ for instrument in instruments:
+ # Generate random noice for each track
+ samples = int(duration * sample_rate)
+ track = np.random.uniform(-1.0, 1.0, (2, samples)).astype(np.float32)
+ file_path = folder_path / f"{instrument}.wav"
+ sf.write(file_path, track.T, sample_rate)
+
+
+def cleanup_test_tracks() -> None:
+ """
+ Removes all cached test tracks.
+
+ This function deletes the entire directory specified by the global `TEST_DIR` variable
+ if it exists.
+
+ Returns:
+ -------
+ None
+ This function does not return a value. It performs cleanup of test data.
+ """
+
+
+def modify_configs() -> Dict[str, Path]:
+ """
+ Updates configuration files in the `configs` directory for use with test data.
+
+ This function processes configuration files defined in the global `MODEL_CONFIGS` dictionary,
+ modifies them to be compatible with test scenarios, and saves the updated configurations
+ in a test-specific directory.
+
+ Returns:
+ -------
+ Dict[str, Path]
+ A dictionary where the keys are the original configuration file names, and the values
+ are the paths to the updated configuration files.
+ """
+ config_dir = CONFIGS_DIR
+ updated_configs = {}
+ for config, args in MODEL_CONFIGS.items():
+ model_type = args['model_type']
+ config_path = config_dir / config
+ updated_config_path = redact_config({
+ 'orig_config': str(config_path),
+ 'model_type': model_type,
+ 'new_config': str(TEST_DIR / 'configs' / config)
+ })
+ updated_configs[config] = updated_config_path
+ return updated_configs
+
+
+def run_tests() -> None:
+ """
+ Executes validation tests for all configurations.
+
+ This function updates configurations, generates random dummy data for testing,
+ and runs a series of tests (training, validation, and inference checks) for each
+ model configuration specified in the global `MODEL_CONFIGS` dictionary.
+
+ Returns:
+ -------
+ None
+ """
+
+ updated_configs = modify_configs()
+
+ # For every config
+ for config, args in MODEL_CONFIGS.items():
+ model_type = args['model_type']
+ cfg = load_config(model_type=model_type, config_path=TEST_DIR / 'configs' / config)
+ # Random tracks
+ create_dummy_tracks(TRAIN_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2)
+ create_dummy_tracks(VALID_DIR, instruments=cfg.training.instruments+['mixture'], num_tracks=2)
+
+ print(f"\nRunning tests for model: {model_type} (config: {config})")
+
+ test_args = {
+ 'check_train': False,
+ 'check_valid': True,
+ 'check_inference': True,
+ 'config_path': updated_configs[config],
+ 'data_path': str(TRAIN_DIR),
+ 'valid_path': str(VALID_DIR),
+ 'results_path': str(TEST_DIR / "results" / model_type),
+ 'store_dir': str(TEST_DIR / "inference_results" / model_type),
+ 'metrics': ['sdr', 'si_sdr', 'l1_freq']
+ }
+
+ test_args.update(args)
+
+ test_settings(test_args, 'admin')
+ print(f"Tests for model {model_type} completed successfully.")
+
+ # Remove test_cache
+ cleanup_test_tracks()
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/tests/test.py b/tests/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d05e128674f0850d82a2583527493f77c9d435d4
--- /dev/null
+++ b/tests/test.py
@@ -0,0 +1,163 @@
+import os
+import sys
+import argparse
+
+# Добавляем корень репозитория в системный путь
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from valid import check_validation
+from inference import proc_folder
+from train import train_model
+from scripts.redact_config import redact_config
+from scripts.valid_to_inference import copying_files
+from scripts.trim import trim_directory
+
+base_args = {
+ 'device_ids': '0',
+ 'model_type': '',
+ 'start_check_point': '',
+ 'config_path': '',
+ 'data_path': '',
+ 'valid_path': '',
+ 'results_path': 'tests/train_results',
+ 'store_dir': 'tests/valid_inference_result',
+ 'input_folder': '',
+ 'metrics': ['neg_log_wmse', 'l1_freq', 'si_sdr', 'sdr', 'aura_stft', 'aura_mrstft', 'bleedless', 'fullness'],
+ 'max_folders': 2
+}
+
+
+def parse_args(dict_args):
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--check_train", action='store_true', help="Check train or not")
+ parser.add_argument("--check_valid", action='store_true', help="Check train or not")
+ parser.add_argument("--check_inference", action='store_true', help="Check train or not")
+ parser.add_argument('--device_ids', type=str, help='Device IDs for training/inference')
+ parser.add_argument('--model_type', type=str, help='Model type')
+ parser.add_argument('--start_check_point', type=str, help='Path to the checkpoint to start from')
+ parser.add_argument('--config_path', type=str, help='Path to the configuration file')
+ parser.add_argument('--data_path', type=str, help='Path to the training data')
+ parser.add_argument('--valid_path', type=str, help='Path to the validation data')
+ parser.add_argument('--results_path', type=str, help='Path to save training results')
+ parser.add_argument('--store_dir', type=str, help='Path to store validation/inference results')
+ parser.add_argument('--input_folder', type=str, help='Path to the input folder for inference')
+ parser.add_argument('--metrics', nargs='+', help='List of metrics to evaluate')
+ parser.add_argument('--max_folders', type=str, help='Maximum number of folders to process')
+ parser.add_argument("--dataset_type", type=int, default=1,
+ help="Dataset type. Must be one of: 1, 2, 3 or 4.")
+ parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers")
+ parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory")
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
+ parser.add_argument("--use_multistft_loss", action='store_true',
+ help="Use MultiSTFT Loss (from auraloss package)")
+ parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss")
+ parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss")
+ parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key')
+ parser.add_argument("--pre_valid", action='store_true', help='Run validation before training')
+ parser.add_argument("--metric_for_scheduler", default="sdr",
+ choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='Metric which will be used for scheduler.')
+ parser.add_argument("--train_lora", action='store_true', help="Train with LoRA")
+ parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
+ parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation")
+ parser.add_argument("--use_tta", action='store_true',
+ help="Flag adds test time augmentation during inference (polarity and channel inverse)."
+ " While this triples the runtime, it reduces noise and slightly improves prediction quality.")
+ parser.add_argument("--extract_instrumental", action='store_true',
+ help="invert vocals to get instrumental if provided")
+ parser.add_argument("--disable_detailed_pbar", action='store_true', help="disable detailed progress bar")
+ parser.add_argument("--force_cpu", action='store_true', help="Force the use of CPU even if CUDA is available")
+ parser.add_argument("--flac_file", action='store_true', help="Output flac file instead of wav")
+ parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24',
+ help="PCM type for FLAC files (PCM_16 or PCM_24)")
+ parser.add_argument("--draw_spectro", type=float, default=0,
+ help="If --store_dir is set then code will generate spectrograms for resulted stems as well."
+ " Value defines for how many seconds os track spectrogram will be generated.")
+
+ if dict_args is not None:
+ args = parser.parse_args([])
+ args_dict = vars(args)
+ args_dict.update(dict_args)
+ args = argparse.Namespace(**args_dict)
+ else:
+ args = parser.parse_args()
+
+ return args
+
+
+def test_settings(dict_args, test_type):
+
+ # Parse from cmd
+ cli_args = parse_args(dict_args)
+
+ # If args from cmd, add or replace in base_args
+ for key, value in vars(cli_args).items():
+ if value is not None:
+ base_args[key] = value
+
+ if test_type == 'user':
+ # Check required arguments
+ missing_args = [arg for arg in ['model_type', 'config_path', 'start_check_point', 'data_path', 'valid_path'] if
+ not base_args[arg]]
+ if missing_args:
+ missing_args_str = ', '.join(f'--{arg}' for arg in missing_args)
+ raise ValueError(
+ f"The following arguments are required but missing: {missing_args_str}."
+ f" Please specify them either via command-line arguments or directly in `base_args`.")
+
+ # Replace config
+ base_args['config_path'] = redact_config({'orig_config': base_args['config_path'],
+ 'model_type': base_args['model_type'],
+ 'new_config': ''})
+
+ # Trim train
+ trim_args_train = {'input_directory': base_args['data_path'],
+ 'max_folders': base_args['max_folders']}
+ base_args['data_path'] = trim_directory(trim_args_train)
+ # Trim valid
+ trim_args_valid = {'input_directory': base_args['valid_path'],
+ 'max_folders': base_args['max_folders']}
+ base_args['valid_path'] = trim_directory(trim_args_valid)
+ # Valid to inference
+ if not base_args['input_folder']:
+ tests_dir = os.path.join(os.path.dirname(base_args['valid_path']), 'for_inference')
+ base_args['input_folder'] = tests_dir
+ val_to_inf_args = {'valid_path': base_args['valid_path'],
+ 'inference_dir': base_args['input_folder'],
+ 'max_mixtures': 1}
+ copying_files(val_to_inf_args)
+
+ if base_args['check_valid']:
+ valid_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point',
+ 'store_dir', 'device_ids', 'num_workers', 'pin_memory', 'extension',
+ 'use_tta', 'metrics', 'lora_checkpoint', 'draw_spectro']}
+ valid_args['valid_path'] = [base_args['valid_path']]
+ print('Start validation.')
+ check_validation(valid_args)
+ print(f'Validation ended. See results in {base_args["store_dir"]}')
+
+ if base_args['check_inference']:
+ inference_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'input_folder',
+ 'store_dir', 'device_ids', 'extract_instrumental',
+ 'disable_detailed_pbar', 'force_cpu', 'flac_file', 'pcm_type',
+ 'use_tta', 'lora_checkpoint', 'draw_spectro']}
+
+ print('Start inference.')
+ proc_folder(inference_args)
+ print(f'Inference ended. See results in {base_args["store_dir"]}')
+
+ if base_args['check_train']:
+ train_args = {key: base_args[key] for key in ['model_type', 'config_path', 'start_check_point', 'results_path',
+ 'data_path', 'dataset_type', 'valid_path', 'num_workers', 'pin_memory',
+ 'seed', 'device_ids', 'use_multistft_loss', 'use_mse_loss',
+ 'use_l1_loss', 'wandb_key', 'pre_valid', 'metrics',
+ 'metric_for_scheduler', 'train_lora', 'lora_checkpoint']}
+
+ print('Start train.')
+ train_model(train_args)
+
+ print('End!')
+
+
+if __name__ == "__main__":
+ test_settings(None, 'user')
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a2ea0d9fcbc9c03ae2823fe1cb3ef07b335fba
--- /dev/null
+++ b/train.py
@@ -0,0 +1,531 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+__version__ = '1.0.4'
+
+import random
+import argparse
+from tqdm.auto import tqdm
+import os
+import torch
+import wandb
+import numpy as np
+import auraloss
+import torch.nn as nn
+from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop
+from torch.utils.data import DataLoader
+from torch.cuda.amp.grad_scaler import GradScaler
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from ml_collections import ConfigDict
+import torch.nn.functional as F
+from typing import List, Tuple, Dict, Union, Callable, Any
+
+from dataset import MSSDataset
+from utils import get_model_from_config
+from valid import valid_multi_gpu, valid
+
+from utils import bind_lora_to_model, load_start_checkpoint
+import loralib as lora
+
+import warnings
+
+warnings.filterwarnings("ignore")
+
+
+def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace:
+ """
+ Parse command-line arguments for configuring the model, dataset, and training parameters.
+
+ Args:
+ dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv.
+
+ Returns:
+ Namespace object containing parsed arguments and their values.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default='mdx23c',
+ help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit")
+ parser.add_argument("--config_path", type=str, help="path to config file")
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to start training")
+ parser.add_argument("--results_path", type=str,
+ help="path to folder where results will be stored (weights, metadata)")
+ parser.add_argument("--data_path", nargs="+", type=str, help="Dataset data paths. You can provide several folders.")
+ parser.add_argument("--dataset_type", type=int, default=1,
+ help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md")
+ parser.add_argument("--valid_path", nargs="+", type=str,
+ help="validation data paths. You can provide several folders.")
+ parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers")
+ parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory")
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
+ parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='list of gpu ids')
+ parser.add_argument("--loss", type=str, nargs='+', choices=['masked_loss', 'mse_loss', 'l1_loss', 'multistft_loss'],
+ default=['masked_loss'], help="List of loss functions to use")
+ parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key')
+ parser.add_argument("--pre_valid", action='store_true', help='Run validation before training')
+ parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"],
+ choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='List of metrics to use.')
+ parser.add_argument("--metric_for_scheduler", default="sdr",
+ choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='Metric which will be used for scheduler.')
+ parser.add_argument("--train_lora", action='store_true', help="Train with LoRA")
+ parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
+
+ if dict_args is not None:
+ args = parser.parse_args([])
+ args_dict = vars(args)
+ args_dict.update(dict_args)
+ args = argparse.Namespace(**args_dict)
+ else:
+ args = parser.parse_args()
+
+ if args.metric_for_scheduler not in args.metrics:
+ args.metrics += [args.metric_for_scheduler]
+
+ return args
+
+
+def manual_seed(seed: int) -> None:
+ """
+ Set the random seed for reproducibility across Python, NumPy, and PyTorch.
+
+ Args:
+ seed: The seed value to set.
+ """
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed) # if multi-GPU
+ torch.backends.cudnn.deterministic = True
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+
+def initialize_environment(seed: int, results_path: str) -> None:
+ """
+ Initialize the environment by setting the random seed, configuring PyTorch settings,
+ and creating the results directory.
+
+ Args:
+ seed: The seed value for reproducibility.
+ results_path: Path to the directory where results will be stored.
+ """
+
+ manual_seed(seed)
+ torch.backends.cudnn.deterministic = False
+ try:
+ torch.multiprocessing.set_start_method('spawn')
+ except Exception as e:
+ pass
+ os.makedirs(results_path, exist_ok=True)
+
+def wandb_init(args: argparse.Namespace, config: Dict, device_ids: List[int], batch_size: int) -> None:
+ """
+ Initialize the Weights & Biases (wandb) logging system.
+
+ Args:
+ args: Parsed command-line arguments containing the wandb key.
+ config: Configuration dictionary for the experiment.
+ device_ids: List of GPU device IDs used for training.
+ batch_size: Batch size for training.
+ """
+
+ if args.wandb_key is None or args.wandb_key.strip() == '':
+ wandb.init(mode='disabled')
+ else:
+ wandb.login(key=args.wandb_key)
+ wandb.init(project='msst', config={'config': config, 'args': args, 'device_ids': device_ids, 'batch_size': batch_size })
+
+
+def prepare_data(config: Dict, args: argparse.Namespace, batch_size: int) -> DataLoader:
+ """
+ Prepare the training dataset and data loader.
+
+ Args:
+ config: Configuration dictionary for the dataset.
+ args: Parsed command-line arguments containing dataset paths and settings.
+ batch_size: Batch size for training.
+
+ Returns:
+ DataLoader object for the training dataset.
+ """
+
+ trainset = MSSDataset(
+ config,
+ args.data_path,
+ batch_size=batch_size,
+ metadata_path=os.path.join(args.results_path, f'metadata_{args.dataset_type}.pkl'),
+ dataset_type=args.dataset_type,
+ )
+
+ train_loader = DataLoader(
+ trainset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_memory
+ )
+ return train_loader
+
+
+def initialize_model_and_device(model: torch.nn.Module, device_ids: List[int]) -> Tuple[Union[torch.device, str], torch.nn.Module]:
+ """
+ Initialize the model and assign it to the appropriate device (GPU or CPU).
+
+ Args:
+ model: The PyTorch model to be initialized.
+ device_ids: List of GPU device IDs to use for parallel processing.
+
+ Returns:
+ A tuple containing the device and the model moved to that device.
+ """
+
+ if torch.cuda.is_available():
+ if len(device_ids) <= 1:
+ device = torch.device(f'cuda:{device_ids[0]}')
+ model = model.to(device)
+ else:
+ device = torch.device(f'cuda:{device_ids[0]}')
+ model = nn.DataParallel(model, device_ids=device_ids).to(device)
+ else:
+ device = 'cpu'
+ model = model.to(device)
+ print("CUDA is not available. Running on CPU.")
+
+ return device, model
+
+
+def get_optimizer(config: ConfigDict, model: torch.nn.Module) -> torch.optim.Optimizer:
+ """
+ Initializes an optimizer based on the configuration.
+
+ Args:
+ config: Configuration object containing training parameters.
+ model: PyTorch model whose parameters will be optimized.
+
+ Returns:
+ A PyTorch optimizer object configured based on the specified settings.
+ """
+
+ optim_params = dict()
+ if 'optimizer' in config:
+ optim_params = dict(config['optimizer'])
+ print(f'Optimizer params from config:\n{optim_params}')
+
+ name_optimizer = getattr(config.training, 'optimizer',
+ 'No optimizer in config')
+
+ if name_optimizer == 'adam':
+ optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'adamw':
+ optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'radam':
+ optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'rmsprop':
+ optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'prodigy':
+ from prodigyopt import Prodigy
+ # you can choose weight decay value based on your problem, 0 by default
+ # We recommend using lr=1.0 (default) for all networks.
+ optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'adamw8bit':
+ import bitsandbytes as bnb
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params)
+ elif name_optimizer == 'sgd':
+ print('Use SGD optimizer')
+ optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params)
+ else:
+ print(f'Unknown optimizer: {name_optimizer}')
+ exit()
+ return optimizer
+
+
+def multistft_loss(y: torch.Tensor, y_: torch.Tensor,
+ loss_multistft: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> torch.Tensor:
+ if len(y_.shape) == 4:
+ y1_ = y_.reshape(y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3])
+ y1 = y.reshape(y.shape[0], y.shape[1] * y.shape[2], y.shape[3])
+ elif len(y_.shape) == 3:
+ y1_, y1 = y_, y
+ else:
+ raise ValueError(f"Invalid shape for predicted array: {y_.shape}. Expected 3 or 4 dimensions.")
+ return loss_multistft(y1_, y1)
+
+
+def masked_loss(y_: torch.Tensor, y: torch.Tensor, q: float, coarse: bool = True) -> torch.Tensor:
+ loss = torch.nn.MSELoss(reduction='none')(y_, y).transpose(0, 1)
+ if coarse:
+ loss = loss.mean(dim=(-1, -2))
+ loss = loss.reshape(loss.shape[0], -1)
+ quantile = torch.quantile(loss.detach(), q, interpolation='linear', dim=1, keepdim=True)
+ mask = loss < quantile
+ return (loss * mask).mean()
+
+
+def choice_loss(args: argparse.Namespace, config: ConfigDict) -> Callable[[Any, Any], int]:
+ """
+ Select and return the appropriate loss function based on the configuration and arguments.
+
+ Args:
+ args: Parsed command-line arguments containing flags for different loss functions.
+ config: Configuration object containing loss settings and parameters.
+
+ Returns:
+ A loss function that can be applied to the predicted and ground truth tensors.
+ """
+
+ print(f'Losses for training: {args.loss}')
+ loss_fns = []
+ if 'masked_loss' in args.loss:
+ loss_fns.append(
+ lambda y_, y: masked_loss(y_, y, q=config['training']['q'], coarse=config['training']['coarse_loss_clip']))
+ if 'mse_loss' in args.loss:
+ loss_fns.append(nn.MSELoss())
+ if 'l1_loss' in args.loss:
+ loss_fns.append(F.l1_loss)
+ if 'multistft_loss' in args.loss:
+ loss_options = dict(config.get('loss_multistft', {}))
+ loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(**loss_options)
+ loss_fns.append(lambda y_, y: multistft_loss(y_, y, loss_multistft) / 1000)
+
+ def multi_loss(y_, y):
+ return sum(loss_fn(y_, y) for loss_fn in loss_fns)
+
+ return multi_loss
+
+
+def normalize_batch(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Normalize a batch of tensors (x and y) by subtracting the mean and dividing by the standard deviation.
+
+ Args:
+ x: Tensor to normalize.
+ y: Tensor to normalize (same as x, typically).
+
+ Returns:
+ A tuple of normalized tensors (x, y).
+ """
+
+ mean = x.mean()
+ std = x.std()
+ if std != 0:
+ x = (x - mean) / std
+ y = (y - mean) / std
+ return x, y
+
+
+def train_one_epoch(model: torch.nn.Module, config: ConfigDict, args: argparse.Namespace, optimizer: torch.optim.Optimizer,
+ device: torch.device, device_ids: List[int], epoch: int, use_amp: bool, scaler: torch.cuda.amp.GradScaler,
+ gradient_accumulation_steps: int, train_loader: torch.utils.data.DataLoader,
+ multi_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> None:
+ """
+ Train the model for one epoch.
+
+ Args:
+ model: The model to train.
+ config: Configuration object containing training parameters.
+ args: Command-line arguments with specific settings (e.g., model type).
+ optimizer: Optimizer used for training.
+ device: Device to run the model on (CPU or GPU).
+ device_ids: List of GPU device IDs if using multiple GPUs.
+ epoch: The current epoch number.
+ use_amp: Whether to use automatic mixed precision (AMP) for training.
+ scaler: Scaler for AMP to manage gradient scaling.
+ gradient_accumulation_steps: Number of gradient accumulation steps before updating the optimizer.
+ train_loader: DataLoader for the training dataset.
+ multi_loss: The loss function to use during training.
+
+ Returns:
+ None
+ """
+
+ model.train().to(device)
+ print(f'Train epoch: {epoch} Learning rate: {optimizer.param_groups[0]["lr"]}')
+ loss_val = 0.
+ total = 0
+
+ normalize = getattr(config.training, 'normalize', False)
+
+ pbar = tqdm(train_loader)
+ for i, (batch, mixes) in enumerate(pbar):
+ x = mixes.to(device) # mixture
+ y = batch.to(device)
+
+ if normalize:
+ x, y = normalize_batch(x, y)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ if args.model_type in ['mel_band_roformer', 'bs_roformer']:
+ # loss is computed in forward pass
+ loss = model(x, y)
+ if isinstance(device_ids, (list, tuple)):
+ # If it's multiple GPUs sum partial loss
+ loss = loss.mean()
+ else:
+ y_ = model(x)
+ loss = multi_loss(y_, y)
+
+ loss /= gradient_accumulation_steps
+ scaler.scale(loss).backward()
+ if config.training.grad_clip:
+ nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)
+
+ if ((i + 1) % gradient_accumulation_steps == 0) or (i == len(train_loader) - 1):
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad(set_to_none=True)
+
+ li = loss.item() * gradient_accumulation_steps
+ loss_val += li
+ total += 1
+ pbar.set_postfix({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1)})
+ wandb.log({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1), 'i': i})
+ loss.detach()
+
+ print(f'Training loss: {loss_val / total}')
+ wandb.log({'train_loss': loss_val / total, 'epoch': epoch, 'learning_rate': optimizer.param_groups[0]['lr']})
+
+
+def save_weights(store_path, model, device_ids, train_lora):
+
+ if train_lora:
+ torch.save(lora.lora_state_dict(model), store_path)
+ else:
+ state_dict = model.state_dict() if len(device_ids) <= 1 else model.module.state_dict()
+ torch.save(
+ state_dict,
+ store_path
+ )
+
+
+def save_last_weights(args: argparse.Namespace, model: torch.nn.Module, device_ids: List[int]) -> None:
+ """
+ Save the model's state_dict to a file for later use.
+
+ Args:
+ args: Command-line arguments containing the results path and model type.
+ model: The model whose weights will be saved.
+ device_ids: List of GPU device IDs if using multiple GPUs.
+
+ Returns:
+ None
+ """
+
+ store_path = f'{args.results_path}/last_{args.model_type}.ckpt'
+ train_lora = args.train_lora
+ save_weights(store_path, model, device_ids, train_lora)
+
+
+def compute_epoch_metrics(model: torch.nn.Module, args: argparse.Namespace, config: ConfigDict,
+ device: torch.device, device_ids: List[int], best_metric: float,
+ epoch: int, scheduler: torch.optim.lr_scheduler._LRScheduler) -> float:
+ """
+ Compute and log the metrics for the current epoch, and save model weights if the metric improves.
+
+ Args:
+ model: The model to evaluate.
+ args: Command-line arguments containing configuration paths and other settings.
+ config: Configuration dictionary containing training settings.
+ device: The device (CPU or GPU) used for evaluation.
+ device_ids: List of GPU device IDs when using multiple GPUs.
+ best_metric: The best metric value seen so far.
+ epoch: The current epoch number.
+ scheduler: The learning rate scheduler to adjust the learning rate.
+
+ Returns:
+ The updated best_metric.
+ """
+
+ if torch.cuda.is_available() and len(device_ids) > 1:
+ metrics_avg, all_metrics = valid_multi_gpu(model, args, config, args.device_ids, verbose=False)
+ else:
+ metrics_avg, all_metrics = valid(model, args, config, device, verbose=False)
+ metric_avg = metrics_avg[args.metric_for_scheduler]
+ if metric_avg > best_metric:
+ store_path = f'{args.results_path}/model_{args.model_type}_ep_{epoch}_{args.metric_for_scheduler}_{metric_avg:.4f}.ckpt'
+ print(f'Store weights: {store_path}')
+ train_lora = args.train_lora
+ save_weights(store_path, model, device_ids, train_lora)
+ best_metric = metric_avg
+ scheduler.step(metric_avg)
+ wandb.log({'metric_main': metric_avg, 'best_metric': best_metric})
+ for metric_name in metrics_avg:
+ wandb.log({f'metric_{metric_name}': metrics_avg[metric_name]})
+
+ return best_metric
+
+
+def train_model(args: argparse.Namespace) -> None:
+ """
+ Trains the model based on the provided arguments, including data preparation, optimizer setup,
+ and loss calculation. The model is trained for multiple epochs with logging via wandb.
+
+ Args:
+ args: Command-line arguments containing configuration paths, hyperparameters, and other settings.
+
+ Returns:
+ None
+ """
+
+ args = parse_args(args)
+
+ initialize_environment(args.seed, args.results_path)
+ model, config = get_model_from_config(args.model_type, args.config_path)
+ use_amp = getattr(config.training, 'use_amp', True)
+ device_ids = args.device_ids
+ batch_size = config.training.batch_size * len(device_ids)
+
+ wandb_init(args, config, device_ids, batch_size)
+
+ train_loader = prepare_data(config, args, batch_size)
+
+ if args.start_check_point:
+ load_start_checkpoint(args, model, type_='train')
+
+ if args.train_lora:
+ model = bind_lora_to_model(config, model)
+ lora.mark_only_lora_as_trainable(model)
+
+ device, model = initialize_model_and_device(model, args.device_ids)
+
+ if args.pre_valid:
+ if torch.cuda.is_available() and len(device_ids) > 1:
+ valid_multi_gpu(model, args, config, args.device_ids, verbose=True)
+ else:
+ valid(model, args, config, device, verbose=True)
+
+ optimizer = get_optimizer(config, model)
+ gradient_accumulation_steps = int(getattr(config.training, 'gradient_accumulation_steps', 1))
+
+ # Reduce LR if no metric improvements for several epochs
+ scheduler = ReduceLROnPlateau(optimizer, 'max', patience=config.training.patience,
+ factor=config.training.reduce_factor)
+
+ multi_loss = choice_loss(args, config)
+ scaler = GradScaler()
+ best_metric = float('-inf')
+
+ print(
+ f"Instruments: {config.training.instruments}\n"
+ f"Metrics for training: {args.metrics}. Metric for scheduler: {args.metric_for_scheduler}\n"
+ f"Patience: {config.training.patience} "
+ f"Reduce factor: {config.training.reduce_factor}\n"
+ f"Batch size: {batch_size} "
+ f"Grad accum steps: {gradient_accumulation_steps} "
+ f"Effective batch size: {batch_size * gradient_accumulation_steps}\n"
+ f"Dataset type: {args.dataset_type}\n"
+ f"Optimizer: {config.training.optimizer}"
+ )
+
+ print(f'Train for: {config.training.num_epochs} epochs')
+
+ for epoch in range(config.training.num_epochs):
+
+ train_one_epoch(model, config, args, optimizer, device, device_ids, epoch,
+ use_amp, scaler, gradient_accumulation_steps, train_loader, multi_loss)
+ save_last_weights(args, model, device_ids)
+ best_metric = compute_epoch_metrics(model, args, config, device, device_ids, best_metric, epoch, scheduler)
+
+
+if __name__ == "__main__":
+ train_model(None)
\ No newline at end of file
diff --git a/train_accelerate.py b/train_accelerate.py
new file mode 100644
index 0000000000000000000000000000000000000000..af676b07c7b789d1574a64f0fa33f8b7d7cdb25f
--- /dev/null
+++ b/train_accelerate.py
@@ -0,0 +1,356 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+__version__ = '1.0.3'
+
+# Read more here:
+# https://huggingface.co/docs/accelerate/index
+
+import argparse
+import soundfile as sf
+import numpy as np
+import time
+import glob
+from tqdm.auto import tqdm
+import os
+import torch
+import wandb
+import auraloss
+import torch.nn as nn
+from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop
+from torch.utils.data import DataLoader
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+import torch.nn.functional as F
+from accelerate import Accelerator
+
+from dataset import MSSDataset
+from utils import get_model_from_config, demix, sdr, prefer_target_instrument
+from train import masked_loss, manual_seed, load_not_compatible_weights
+import warnings
+
+warnings.filterwarnings("ignore")
+
+
+def valid(model, valid_loader, args, config, device, verbose=False):
+ instruments = prefer_target_instrument(config)
+
+ all_sdr = dict()
+ for instr in instruments:
+ all_sdr[instr] = []
+
+ all_mixtures_path = valid_loader
+ if verbose:
+ all_mixtures_path = tqdm(valid_loader)
+
+ pbar_dict = {}
+ for path_list in all_mixtures_path:
+ path = path_list[0]
+ mix, sr = sf.read(path)
+ folder = os.path.dirname(path)
+ res = demix(config, model, mix.T, device, model_type=args.model_type) # mix.T
+ for instr in instruments:
+ if instr != 'other' or config.training.other_fix is False:
+ track, sr1 = sf.read(folder + '/{}.wav'.format(instr))
+ else:
+ # other is actually instrumental
+ track, sr1 = sf.read(folder + '/{}.wav'.format('vocals'))
+ track = mix - track
+ # sf.write("{}.wav".format(instr), res[instr].T, sr, subtype='FLOAT')
+ references = np.expand_dims(track, axis=0)
+ estimates = np.expand_dims(res[instr].T, axis=0)
+ sdr_val = sdr(references, estimates)[0]
+ single_val = torch.from_numpy(np.array([sdr_val])).to(device)
+ all_sdr[instr].append(single_val)
+ pbar_dict['sdr_{}'.format(instr)] = sdr_val
+ if verbose:
+ all_mixtures_path.set_postfix(pbar_dict)
+
+ return all_sdr
+
+
+class MSSValidationDataset(torch.utils.data.Dataset):
+ def __init__(self, args):
+ all_mixtures_path = []
+ for valid_path in args.valid_path:
+ part = sorted(glob.glob(valid_path + '/*/mixture.wav'))
+ if len(part) == 0:
+ print('No validation data found in: {}'.format(valid_path))
+ all_mixtures_path += part
+
+ self.list_of_files = all_mixtures_path
+
+ def __len__(self):
+ return len(self.list_of_files)
+
+ def __getitem__(self, index):
+ return self.list_of_files[index]
+
+
+def train_model(args):
+ accelerator = Accelerator()
+ device = accelerator.device
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit")
+ parser.add_argument("--config_path", type=str, help="path to config file")
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to start training")
+ parser.add_argument("--results_path", type=str, help="path to folder where results will be stored (weights, metadata)")
+ parser.add_argument("--data_path", nargs="+", type=str, help="Dataset data paths. You can provide several folders.")
+ parser.add_argument("--dataset_type", type=int, default=1, help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md")
+ parser.add_argument("--valid_path", nargs="+", type=str, help="validation data paths. You can provide several folders.")
+ parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers")
+ parser.add_argument("--pin_memory", type=bool, default=False, help="dataloader pin_memory")
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
+ parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='list of gpu ids')
+ parser.add_argument("--use_multistft_loss", action='store_true', help="Use MultiSTFT Loss (from auraloss package)")
+ parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss")
+ parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss")
+ parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key')
+ parser.add_argument("--pre_valid", action='store_true', help='Run validation before training')
+ if args is None:
+ args = parser.parse_args()
+ else:
+ args = parser.parse_args(args)
+
+ manual_seed(args.seed + int(time.time()))
+ # torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False # Fix possible slow down with dilation convolutions
+ torch.multiprocessing.set_start_method('spawn')
+
+ model, config = get_model_from_config(args.model_type, args.config_path)
+ accelerator.print("Instruments: {}".format(config.training.instruments))
+
+ os.makedirs(args.results_path, exist_ok=True)
+
+ device_ids = args.device_ids
+ batch_size = config.training.batch_size
+
+ # wandb
+ if accelerator.is_main_process and args.wandb_key is not None and args.wandb_key.strip() != '':
+ wandb.login(key = args.wandb_key)
+ wandb.init(project = 'msst-accelerate', config = { 'config': config, 'args': args, 'device_ids': device_ids, 'batch_size': batch_size })
+ else:
+ wandb.init(mode = 'disabled')
+
+ # Fix for num of steps
+ config.training.num_steps *= accelerator.num_processes
+
+ trainset = MSSDataset(
+ config,
+ args.data_path,
+ batch_size=batch_size,
+ metadata_path=os.path.join(args.results_path, 'metadata_{}.pkl'.format(args.dataset_type)),
+ dataset_type=args.dataset_type,
+ verbose=accelerator.is_main_process,
+ )
+
+ train_loader = DataLoader(
+ trainset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_memory
+ )
+
+ validset = MSSValidationDataset(args)
+ valid_dataset_length = len(validset)
+
+ valid_loader = DataLoader(
+ validset,
+ batch_size=1,
+ shuffle=False,
+ )
+
+ valid_loader = accelerator.prepare(valid_loader)
+
+ if args.start_check_point != '':
+ accelerator.print('Start from checkpoint: {}'.format(args.start_check_point))
+ if 1:
+ load_not_compatible_weights(model, args.start_check_point, verbose=False)
+ else:
+ model.load_state_dict(
+ torch.load(args.start_check_point)
+ )
+
+ optim_params = dict()
+ if 'optimizer' in config:
+ optim_params = dict(config['optimizer'])
+ accelerator.print('Optimizer params from config:\n{}'.format(optim_params))
+
+ if config.training.optimizer == 'adam':
+ optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params)
+ elif config.training.optimizer == 'adamw':
+ optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params)
+ elif config.training.optimizer == 'radam':
+ optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params)
+ elif config.training.optimizer == 'rmsprop':
+ optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params)
+ elif config.training.optimizer == 'prodigy':
+ from prodigyopt import Prodigy
+ # you can choose weight decay value based on your problem, 0 by default
+ # We recommend using lr=1.0 (default) for all networks.
+ optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params)
+ elif config.training.optimizer == 'adamw8bit':
+ import bitsandbytes as bnb
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params)
+ elif config.training.optimizer == 'sgd':
+ accelerator.print('Use SGD optimizer')
+ optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params)
+ else:
+ accelerator.print('Unknown optimizer: {}'.format(config.training.optimizer))
+ exit()
+
+ if accelerator.is_main_process:
+ print('Processes GPU: {}'.format(accelerator.num_processes))
+ print("Patience: {} Reduce factor: {} Batch size: {} Optimizer: {}".format(
+ config.training.patience,
+ config.training.reduce_factor,
+ batch_size,
+ config.training.optimizer,
+ ))
+ # Reduce LR if no SDR improvements for several epochs
+ scheduler = ReduceLROnPlateau(
+ optimizer,
+ 'max',
+ # patience=accelerator.num_processes * config.training.patience, # This is strange place...
+ patience=config.training.patience,
+ factor=config.training.reduce_factor
+ )
+
+ if args.use_multistft_loss:
+ try:
+ loss_options = dict(config.loss_multistft)
+ except:
+ loss_options = dict()
+ accelerator.print('Loss options: {}'.format(loss_options))
+ loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(
+ **loss_options
+ )
+
+ model, optimizer, train_loader, scheduler = accelerator.prepare(model, optimizer, train_loader, scheduler)
+
+ if args.pre_valid:
+ sdr_list = valid(model, valid_loader, args, config, device, verbose=accelerator.is_main_process)
+ sdr_list = accelerator.gather(sdr_list)
+ accelerator.wait_for_everyone()
+
+ # print(sdr_list)
+
+ sdr_avg = 0.0
+ instruments = prefer_target_instrument(config)
+
+ for instr in instruments:
+ # print(sdr_list[instr])
+ sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy()
+ sdr_val = sdr_data.mean()
+ accelerator.print("Valid length: {}".format(valid_dataset_length))
+ accelerator.print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
+ sdr_val = sdr_data[:valid_dataset_length].mean()
+ accelerator.print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
+ sdr_avg += sdr_val
+ sdr_avg /= len(instruments)
+ if len(instruments) > 1:
+ accelerator.print('SDR Avg: {:.4f}'.format(sdr_avg))
+ sdr_list = None
+
+ accelerator.print('Train for: {}'.format(config.training.num_epochs))
+ best_sdr = -100
+ for epoch in range(config.training.num_epochs):
+ model.train().to(device)
+ accelerator.print('Train epoch: {} Learning rate: {}'.format(epoch, optimizer.param_groups[0]['lr']))
+ loss_val = 0.
+ total = 0
+
+ pbar = tqdm(train_loader, disable=not accelerator.is_main_process)
+ for i, (batch, mixes) in enumerate(pbar):
+ y = batch
+ x = mixes
+
+ if args.model_type in ['mel_band_roformer', 'bs_roformer']:
+ # loss is computed in forward pass
+ loss = model(x, y)
+ else:
+ y_ = model(x)
+ if args.use_multistft_loss:
+ y1_ = torch.reshape(y_, (y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3]))
+ y1 = torch.reshape(y, (y.shape[0], y.shape[1] * y.shape[2], y.shape[3]))
+ loss = loss_multistft(y1_, y1)
+ # We can use many losses at the same time
+ if args.use_mse_loss:
+ loss += 1000 * nn.MSELoss()(y1_, y1)
+ if args.use_l1_loss:
+ loss += 1000 * F.l1_loss(y1_, y1)
+ elif args.use_mse_loss:
+ loss = nn.MSELoss()(y_, y)
+ elif args.use_l1_loss:
+ loss = F.l1_loss(y_, y)
+ else:
+ loss = masked_loss(
+ y_,
+ y,
+ q=config.training.q,
+ coarse=config.training.coarse_loss_clip
+ )
+
+ accelerator.backward(loss)
+ if config.training.grad_clip:
+ accelerator.clip_grad_norm_(model.parameters(), config.training.grad_clip)
+
+ optimizer.step()
+ optimizer.zero_grad()
+ li = loss.item()
+ loss_val += li
+ total += 1
+ if accelerator.is_main_process:
+ wandb.log({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1), 'total': total, 'loss_val': loss_val, 'i': i })
+ pbar.set_postfix({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1)})
+
+ if accelerator.is_main_process:
+ print('Training loss: {:.6f}'.format(loss_val / total))
+ wandb.log({'train_loss': loss_val / total, 'epoch': epoch})
+
+ # Save last
+ store_path = args.results_path + '/last_{}.ckpt'.format(args.model_type)
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unwrapped_model = accelerator.unwrap_model(model)
+ accelerator.save(unwrapped_model.state_dict(), store_path)
+
+ sdr_list = valid(model, valid_loader, args, config, device, verbose=accelerator.is_main_process)
+ sdr_list = accelerator.gather(sdr_list)
+ accelerator.wait_for_everyone()
+
+ sdr_avg = 0.0
+ instruments = prefer_target_instrument(config)
+
+ for instr in instruments:
+ if accelerator.is_main_process and 0:
+ print(sdr_list[instr])
+ sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy()
+ # sdr_val = sdr_data.mean()
+ sdr_val = sdr_data[:valid_dataset_length].mean()
+ if accelerator.is_main_process:
+ print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
+ wandb.log({ f'{instr}_sdr': sdr_val })
+ sdr_avg += sdr_val
+ sdr_avg /= len(instruments)
+ if len(instruments) > 1:
+ if accelerator.is_main_process:
+ print('SDR Avg: {:.4f}'.format(sdr_avg))
+ wandb.log({'sdr_avg': sdr_avg, 'best_sdr': best_sdr})
+
+ if accelerator.is_main_process:
+ if sdr_avg > best_sdr:
+ store_path = args.results_path + '/model_{}_ep_{}_sdr_{:.4f}.ckpt'.format(args.model_type, epoch, sdr_avg)
+ print('Store weights: {}'.format(store_path))
+ unwrapped_model = accelerator.unwrap_model(model)
+ accelerator.save(unwrapped_model.state_dict(), store_path)
+ best_sdr = sdr_avg
+
+ scheduler.step(sdr_avg)
+
+ sdr_list = None
+ accelerator.wait_for_everyone()
+
+
+if __name__ == "__main__":
+ train_model(None)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..054b3be317c2901020ad5e5e8a86981a5abb2ba5
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,664 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import yaml
+import os
+import soundfile as sf
+import matplotlib.pyplot as plt
+from ml_collections import ConfigDict
+from omegaconf import OmegaConf
+from tqdm.auto import tqdm
+from typing import Dict, List, Tuple, Any, Union
+import loralib as lora
+
+
+def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
+ """
+ Load the configuration from the specified path based on the model type.
+
+ Parameters:
+ ----------
+ model_type : str
+ The type of model to load (e.g., 'htdemucs', 'mdx23c', etc.).
+ config_path : str
+ The path to the YAML or OmegaConf configuration file.
+
+ Returns:
+ -------
+ config : Any
+ The loaded configuration, which can be in different formats (e.g., OmegaConf or ConfigDict).
+
+ Raises:
+ ------
+ FileNotFoundError:
+ If the configuration file at `config_path` is not found.
+ ValueError:
+ If there is an error loading the configuration file.
+ """
+ try:
+ with open(config_path, 'r') as f:
+ if model_type == 'htdemucs':
+ config = OmegaConf.load(config_path)
+ else:
+ config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
+ return config
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found at {config_path}")
+ except Exception as e:
+ raise ValueError(f"Error loading configuration: {e}")
+
+
+def get_model_from_config(model_type: str, config_path: str) -> Tuple:
+ """
+ Load the model specified by the model type and configuration file.
+
+ Parameters:
+ ----------
+ model_type : str
+ The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.).
+ config_path : str
+ The path to the configuration file (YAML or OmegaConf format).
+
+ Returns:
+ -------
+ model : nn.Module or None
+ The initialized model based on the `model_type`, or None if the model type is not recognized.
+ config : Any
+ The configuration used to initialize the model. This could be in different formats
+ depending on the model type (e.g., OmegaConf, ConfigDict).
+
+ Raises:
+ ------
+ ValueError:
+ If the `model_type` is unknown or an error occurs during model initialization.
+ """
+
+ config = load_config(model_type, config_path)
+
+ if model_type == 'mdx23c':
+ from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
+ model = TFC_TDF_net(config)
+ elif model_type == 'htdemucs':
+ from models.demucs4ht import get_model
+ model = get_model(config)
+ elif model_type == 'segm_models':
+ from models.segm_models import Segm_Models_Net
+ model = Segm_Models_Net(config)
+ elif model_type == 'torchseg':
+ from models.torchseg_models import Torchseg_Net
+ model = Torchseg_Net(config)
+ elif model_type == 'mel_band_roformer':
+ from models.bs_roformer import MelBandRoformer
+ model = MelBandRoformer(**dict(config.model))
+ elif model_type == 'bs_roformer':
+ from models.bs_roformer import BSRoformer
+ model = BSRoformer(**dict(config.model))
+ elif model_type == 'swin_upernet':
+ from models.upernet_swin_transformers import Swin_UperNet_Model
+ model = Swin_UperNet_Model(config)
+ elif model_type == 'bandit':
+ from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
+ model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
+ elif model_type == 'bandit_v2':
+ from models.bandit_v2.bandit import Bandit
+ model = Bandit(**config.kwargs)
+ elif model_type == 'scnet_unofficial':
+ from models.scnet_unofficial import SCNet
+ model = SCNet(**config.model)
+ elif model_type == 'scnet':
+ from models.scnet import SCNet
+ model = SCNet(**config.model)
+ elif model_type == 'apollo':
+ from models.look2hear.models import BaseModel
+ model = BaseModel.apollo(**config.model)
+ elif model_type == 'bs_mamba2':
+ from models.ts_bs_mamba2 import Separator
+ model = Separator(**config.model)
+ elif model_type == 'experimental_mdx23c_stht':
+ from models.mdx23c_tfc_tdf_v3_with_STHT import TFC_TDF_net
+ model = TFC_TDF_net(config)
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+
+ return model, config
+
+
+def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]:
+ """
+ Reads an audio file, ensuring mono audio is converted to two-dimensional format,
+ and transposes the data to have channels as the first dimension.
+ Parameters
+ ----------
+ path : str
+ Path to the audio file.
+ skip_err: bool
+ If true, not raise errors
+ instr:
+ name of instument
+ Returns
+ -------
+ Tuple[np.ndarray, int]
+ A tuple containing:
+ - Transposed audio data as a NumPy array with shape (channels, length).
+ For mono audio, the shape will be (1, length).
+ - Sampling rate (int), e.g., 44100.
+ """
+
+ try:
+ mix, sr = sf.read(path)
+ except Exception as e:
+ if skip_err:
+ print(f"No stem {instr}: skip!")
+ return None, None
+ else:
+ raise RuntimeError(f"Error reading the file at {path}: {e}")
+ else:
+ if len(mix.shape) == 1: # For mono audio
+ mix = np.expand_dims(mix, axis=-1)
+ return mix.T, sr
+
+
+def normalize_audio(audio: np.ndarray) -> tuple[np.ndarray, Dict[str, float]]:
+ """
+ Normalize an audio signal by subtracting the mean and dividing by the standard deviation.
+
+ Parameters:
+ ----------
+ audio : np.ndarray
+ Input audio array with shape (channels, time) or (time,).
+
+ Returns:
+ -------
+ tuple[np.ndarray, dict[str, float]]
+ - Normalized audio array with the same shape as the input.
+ - Dictionary containing the mean and standard deviation of the original audio.
+ """
+
+ mono = audio.mean(0)
+ mean, std = mono.mean(), mono.std()
+ return (audio - mean) / std, {"mean": mean, "std": std}
+
+
+def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray:
+ """
+ Denormalize an audio signal by reversing the normalization process (multiplying by the standard deviation
+ and adding the mean).
+
+ Parameters:
+ ----------
+ audio : np.ndarray
+ Normalized audio array to be denormalized.
+ norm_params : dict[str, float]
+ Dictionary containing the 'mean' and 'std' values used for normalization.
+
+ Returns:
+ -------
+ np.ndarray
+ Denormalized audio array with the same shape as the input.
+ """
+
+ return audio * norm_params["std"] + norm_params["mean"]
+
+
+def apply_tta(
+ config,
+ model: torch.nn.Module,
+ mix: torch.Tensor,
+ waveforms_orig: Dict[str, torch.Tensor],
+ device: torch.device,
+ model_type: str
+) -> Dict[str, torch.Tensor]:
+ """
+ Apply Test-Time Augmentation (TTA) for source separation.
+
+ This function processes the input mixture with test-time augmentations, including
+ channel inversion and polarity inversion, to enhance the separation results. The
+ results from all augmentations are averaged to produce the final output.
+
+ Parameters:
+ ----------
+ config : Any
+ Configuration object containing model and processing parameters.
+ model : torch.nn.Module
+ The trained model used for source separation.
+ mix : torch.Tensor
+ The mixed audio tensor with shape (channels, time).
+ waveforms_orig : Dict[str, torch.Tensor]
+ Dictionary of original separated waveforms (before TTA) for each instrument.
+ device : torch.device
+ Device (CPU or CUDA) on which the model will be executed.
+ model_type : str
+ Type of the model being used (e.g., "demucs", "custom_model").
+
+ Returns:
+ -------
+ Dict[str, torch.Tensor]
+ Updated dictionary of separated waveforms after applying TTA.
+ """
+ # Create augmentations: channel inversion and polarity inversion
+ track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()]
+
+ # Process each augmented mixture
+ for i, augmented_mix in enumerate(track_proc_list):
+ waveforms = demix(config, model, augmented_mix, device, model_type=model_type)
+ for el in waveforms:
+ if i == 0:
+ waveforms_orig[el] += waveforms[el][::-1].copy()
+ else:
+ waveforms_orig[el] -= waveforms[el]
+
+ # Average the results across augmentations
+ for el in waveforms_orig:
+ waveforms_orig[el] /= len(track_proc_list) + 1
+
+ return waveforms_orig
+
+
+def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
+ """
+ Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end.
+
+ This function creates a window of size `window_size` where the first `fade_size` elements
+ linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease
+ from 1 to 0 (fade-out). The middle part of the window is filled with ones.
+
+ Parameters:
+ ----------
+ window_size : int
+ The total size of the window.
+ fade_size : int
+ The size of the fade-in and fade-out regions.
+
+ Returns:
+ -------
+ torch.Tensor
+ A tensor of shape (window_size,) containing the generated windowing array.
+
+ Example:
+ -------
+ If `window_size=10` and `fade_size=3`, the output will be:
+ tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000])
+ """
+
+ fadein = torch.linspace(0, 1, fade_size)
+ fadeout = torch.linspace(1, 0, fade_size)
+
+ window = torch.ones(window_size)
+ window[-fade_size:] = fadeout
+ window[:fade_size] = fadein
+ return window
+
+
+def demix(
+ config: ConfigDict,
+ model: torch.nn.Module,
+ mix: torch.Tensor,
+ device: torch.device,
+ model_type: str,
+ pbar: bool = False
+) -> Tuple[List[Dict[str, np.ndarray]], np.ndarray]:
+ """
+ Unified function for audio source separation with support for multiple processing modes.
+
+ This function separates audio into its constituent sources using either a generic custom logic
+ or a Demucs-specific logic. It supports batch processing and overlapping window-based chunking
+ for efficient and artifact-free separation.
+
+ Parameters:
+ ----------
+ config : ConfigDict
+ Configuration object containing audio and inference settings.
+ model : torch.nn.Module
+ The trained model used for audio source separation.
+ mix : torch.Tensor
+ Input audio tensor with shape (channels, time).
+ device : torch.device
+ The computation device (CPU or CUDA).
+ model_type : str, optional
+ Processing mode:
+ - "demucs" for logic specific to the Demucs model.
+ Default is "generic".
+ pbar : bool, optional
+ If True, displays a progress bar during chunk processing. Default is False.
+
+ Returns:
+ -------
+ Union[Dict[str, np.ndarray], np.ndarray]
+ - A dictionary mapping target instruments to separated audio sources if multiple instruments are present.
+ - A numpy array of the separated source if only one instrument is present.
+ """
+
+ mix = torch.tensor(mix, dtype=torch.float32)
+
+ if model_type == 'htdemucs':
+ mode = 'demucs'
+ else:
+ mode = 'generic'
+ # Define processing parameters based on the mode
+ if mode == 'demucs':
+ chunk_size = config.training.samplerate * config.training.segment
+ num_instruments = len(config.training.instruments)
+ num_overlap = config.inference.num_overlap
+ step = chunk_size // num_overlap
+ else:
+ chunk_size = config.audio.chunk_size
+ num_instruments = len(prefer_target_instrument(config))
+ num_overlap = config.inference.num_overlap
+
+ fade_size = chunk_size // 10
+ step = chunk_size // num_overlap
+ border = chunk_size - step
+ length_init = mix.shape[-1]
+ windowing_array = _getWindowingArray(chunk_size, fade_size)
+ # Add padding for generic mode to handle edge artifacts
+ if length_init > 2 * border and border > 0:
+ mix = nn.functional.pad(mix, (border, border), mode="reflect")
+
+ batch_size = config.inference.batch_size
+
+ use_amp = getattr(config.training, 'use_amp', True)
+
+ with torch.cuda.amp.autocast(enabled=use_amp):
+ with torch.inference_mode():
+ # Initialize result and counter tensors
+ req_shape = (num_instruments,) + mix.shape
+ result = torch.zeros(req_shape, dtype=torch.float32)
+ counter = torch.zeros(req_shape, dtype=torch.float32)
+
+ i = 0
+ batch_data = []
+ batch_locations = []
+ progress_bar = tqdm(
+ total=mix.shape[1], desc="Processing audio chunks", leave=False
+ ) if pbar else None
+
+ while i < mix.shape[1]:
+ # Extract chunk and apply padding if necessary
+ part = mix[:, i:i + chunk_size].to(device)
+ chunk_len = part.shape[-1]
+ if mode == "generic" and chunk_len > chunk_size // 2:
+ pad_mode = "reflect"
+ else:
+ pad_mode = "constant"
+ part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0)
+
+ batch_data.append(part)
+ batch_locations.append((i, chunk_len))
+ i += step
+
+ # Process batch if it's full or the end is reached
+ if len(batch_data) >= batch_size or i >= mix.shape[1]:
+ arr = torch.stack(batch_data, dim=0)
+ x = model(arr)
+
+ if mode == "generic":
+ window = windowing_array.clone() # using clone() fixes the clicks at chunk edges when using batch_size=1
+ if i - step == 0: # First audio chunk, no fadein
+ window[:fade_size] = 1
+ elif i >= mix.shape[1]: # Last audio chunk, no fadeout
+ window[-fade_size:] = 1
+
+ for j, (start, seg_len) in enumerate(batch_locations):
+ if mode == "generic":
+ result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len]
+ counter[..., start:start + seg_len] += window[..., :seg_len]
+ else:
+ result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu()
+ counter[..., start:start + seg_len] += 1.0
+
+ batch_data.clear()
+ batch_locations.clear()
+
+ if progress_bar:
+ progress_bar.update(step)
+
+ if progress_bar:
+ progress_bar.close()
+
+ # Compute final estimated sources
+ estimated_sources = result / counter
+ estimated_sources = estimated_sources.cpu().numpy()
+ np.nan_to_num(estimated_sources, copy=False, nan=0.0)
+
+ # Remove padding for generic mode
+ if mode == "generic":
+ if length_init > 2 * border and border > 0:
+ estimated_sources = estimated_sources[..., border:-border]
+
+ # Return the result as a dictionary or a single array
+ if mode == "demucs":
+ instruments = config.training.instruments
+ else:
+ instruments = prefer_target_instrument(config)
+
+ ret_data = {k: v for k, v in zip(instruments, estimated_sources)}
+
+ if mode == "demucs" and num_instruments <= 1:
+ return estimated_sources
+ else:
+ return ret_data
+
+
+def prefer_target_instrument(config: ConfigDict) -> List[str]:
+ """
+ Return the list of target instruments based on the configuration.
+ If a specific target instrument is specified in the configuration,
+ it returns a list with that instrument. Otherwise, it returns the list of instruments.
+
+ Parameters:
+ ----------
+ config : ConfigDict
+ Configuration object containing the list of instruments or the target instrument.
+
+ Returns:
+ -------
+ List[str]
+ A list of target instruments.
+ """
+ if getattr(config.training, 'target_instrument', None):
+ return [config.training.target_instrument]
+ else:
+ return config.training.instruments
+
+
+def load_not_compatible_weights(model: torch.nn.Module, weights: str, verbose: bool = False) -> None:
+ """
+ Load weights into a model, handling mismatched shapes and dimensions.
+
+ Args:
+ model: PyTorch model into which the weights will be loaded.
+ weights: Path to the weights file.
+ verbose: If True, prints detailed information about matching and mismatched layers.
+ """
+
+ new_model = model.state_dict()
+ old_model = torch.load(weights)
+ if 'state' in old_model:
+ # Fix for htdemucs weights loading
+ old_model = old_model['state']
+ if 'state_dict' in old_model:
+ # Fix for apollo weights loading
+ old_model = old_model['state_dict']
+
+ for el in new_model:
+ if el in old_model:
+ if verbose:
+ print(f'Match found for {el}!')
+ if new_model[el].shape == old_model[el].shape:
+ if verbose:
+ print('Action: Just copy weights!')
+ new_model[el] = old_model[el]
+ else:
+ if len(new_model[el].shape) != len(old_model[el].shape):
+ if verbose:
+ print('Action: Different dimension! Too lazy to write the code... Skip it')
+ else:
+ if verbose:
+ print(f'Shape is different: {tuple(new_model[el].shape)} != {tuple(old_model[el].shape)}')
+ ln = len(new_model[el].shape)
+ max_shape = []
+ slices_old = []
+ slices_new = []
+ for i in range(ln):
+ max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i]))
+ slices_old.append(slice(0, old_model[el].shape[i]))
+ slices_new.append(slice(0, new_model[el].shape[i]))
+ # print(max_shape)
+ # print(slices_old, slices_new)
+ slices_old = tuple(slices_old)
+ slices_new = tuple(slices_new)
+ max_matrix = np.zeros(max_shape, dtype=np.float32)
+ for i in range(ln):
+ max_matrix[slices_old] = old_model[el].cpu().numpy()
+ max_matrix = torch.from_numpy(max_matrix)
+ new_model[el] = max_matrix[slices_new]
+ else:
+ if verbose:
+ print(f'Match not found for {el}!')
+ model.load_state_dict(
+ new_model
+ )
+
+
+def load_lora_weights(model: torch.nn.Module, lora_path: str, device: str = 'cpu') -> None:
+ """
+ Load LoRA weights into a model.
+ This function updates the given model with LoRA-specific weights from the specified checkpoint file.
+ It does not require the checkpoint to match the model's full state dictionary, as only LoRA layers are updated.
+
+ Parameters:
+ ----------
+ model : Module
+ The PyTorch model into which the LoRA weights will be loaded.
+ lora_path : str
+ Path to the LoRA checkpoint file.
+ device : str, optional
+ The device to load the weights onto, by default 'cpu'. Common values are 'cpu' or 'cuda'.
+
+ Returns:
+ -------
+ None
+ The model is updated in place.
+ """
+ lora_state_dict = torch.load(lora_path, map_location=device)
+ model.load_state_dict(lora_state_dict, strict=False)
+
+
+def load_start_checkpoint(args: argparse.Namespace, model: torch.nn.Module, type_='train') -> None:
+ """
+ Load the starting checkpoint for a model.
+
+ Args:
+ args: Parsed command-line arguments containing the checkpoint path.
+ model: PyTorch model to load the checkpoint into.
+ type_: how to load weights - for train we can load not fully compatible weights
+ """
+
+ print(f'Start from checkpoint: {args.start_check_point}')
+ if type_ in ['train']:
+ if 1:
+ load_not_compatible_weights(model, args.start_check_point, verbose=False)
+ else:
+ model.load_state_dict(torch.load(args.start_check_point))
+ else:
+ device='cpu'
+ if args.model_type in ['htdemucs', 'apollo']:
+ state_dict = torch.load(args.start_check_point, map_location=device, weights_only=False)
+ # Fix for htdemucs pretrained models
+ if 'state' in state_dict:
+ state_dict = state_dict['state']
+ # Fix for apollo pretrained models
+ if 'state_dict' in state_dict:
+ state_dict = state_dict['state_dict']
+ else:
+ state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True)
+ model.load_state_dict(state_dict)
+
+ if args.lora_checkpoint:
+ print(f"Loading LoRA weights from: {args.lora_checkpoint}")
+ load_lora_weights(model, args.lora_checkpoint)
+
+
+def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
+ """
+ Replaces specific layers in the model with LoRA-extended versions.
+
+ Parameters:
+ ----------
+ config : Dict[str, Any]
+ Configuration containing parameters for LoRA. It should include a 'lora' key with parameters for `MergedLinear`.
+ model : nn.Module
+ The original model in which the layers will be replaced.
+
+ Returns:
+ -------
+ nn.Module
+ The modified model with the replaced layers.
+ """
+
+ if 'lora' not in config:
+ raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
+
+ replaced_layers = 0 # Counter for replaced layers
+
+ for name, module in model.named_modules():
+ hierarchy = name.split('.')
+ layer_name = hierarchy[-1]
+
+ # Check if this is the target layer to replace (and layer_name == 'to_qkv')
+ if isinstance(module, nn.Linear):
+ try:
+ # Get the parent module
+ parent_module = model
+ for submodule_name in hierarchy[:-1]:
+ parent_module = getattr(parent_module, submodule_name)
+
+ # Replace the module with LoRA-enabled layer
+ setattr(
+ parent_module,
+ layer_name,
+ lora.MergedLinear(
+ in_features=module.in_features,
+ out_features=module.out_features,
+ bias=module.bias is not None,
+ **config['lora']
+ )
+ )
+ replaced_layers += 1 # Increment the counter
+
+ except Exception as e:
+ print(f"Error replacing layer {name}: {e}")
+
+ if replaced_layers == 0:
+ print("Warning: No layers were replaced. Check the model structure and configuration.")
+ else:
+ print(f"Number of layers replaced with LoRA: {replaced_layers}")
+
+ return model
+
+
+def draw_spectrogram(waveform, sample_rate, length, output_file):
+ import librosa.display
+
+ # Cut only required part of spectorgram
+ x = waveform[:int(length * sample_rate), :]
+ X = librosa.stft(x.mean(axis=-1)) # perform short-term fourier transform on mono signal
+ Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) # convert an amplitude spectrogram to dB-scaled spectrogram.
+ fig, ax = plt.subplots()
+ # plt.figure(figsize=(30, 10)) # initialize the fig size
+ img = librosa.display.specshow(
+ Xdb,
+ cmap='plasma',
+ sr=sample_rate,
+ x_axis='time',
+ y_axis='linear',
+ ax=ax
+ )
+ ax.set(title='File: ' + os.path.basename(output_file))
+ fig.colorbar(img, ax=ax, format="%+2.f dB")
+ if output_file is not None:
+ plt.savefig(output_file)
diff --git a/valid.py b/valid.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db3001bd88bb3bd17cf13697fde52e01a4966a3
--- /dev/null
+++ b/valid.py
@@ -0,0 +1,673 @@
+# coding: utf-8
+__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
+
+import argparse
+import time
+import os
+import glob
+import torch
+import librosa
+import numpy as np
+import soundfile as sf
+from tqdm.auto import tqdm
+from ml_collections import ConfigDict
+from typing import Tuple, Dict, List, Union
+from utils import demix, get_model_from_config, prefer_target_instrument, draw_spectrogram
+from utils import normalize_audio, denormalize_audio, apply_tta, read_audio_transposed, load_start_checkpoint
+from metrics import get_metrics
+import warnings
+
+warnings.filterwarnings("ignore")
+
+
+def logging(logs: List[str], text: str, verbose_logging: bool = False) -> None:
+ """
+ Log validation information by printing the text and appending it to a log list.
+
+ Parameters:
+ ----------
+ store_dir : str
+ Directory to store the logs. If empty, logs are not stored.
+ logs : List[str]
+ List where the logs will be appended if the store_dir is specified.
+ text : str
+ The text to be logged, printed, and optionally added to the logs list.
+
+ Returns:
+ -------
+ None
+ This function modifies the logs list in place and prints the text.
+ """
+
+ print(text)
+ if verbose_logging:
+ logs.append(text)
+
+
+def write_results_in_file(store_dir: str, logs: List[str]) -> None:
+ """
+ Write the list of results into a file in the specified directory.
+
+ Parameters:
+ ----------
+ store_dir : str
+ The directory where the results file will be saved.
+ results : List[str]
+ A list of result strings to be written to the file.
+
+ Returns:
+ -------
+ None
+ """
+ with open(f'{store_dir}/results.txt', 'w') as out:
+ for item in logs:
+ out.write(item + "\n")
+
+
+def get_mixture_paths(
+ args,
+ verbose: bool,
+ config: ConfigDict,
+ extension: str
+) -> List[str]:
+ """
+ Retrieve paths to mixture files in the specified validation directories.
+
+ Parameters:
+ ----------
+ valid_path : List[str]
+ A list of directories to search for validation mixtures.
+ verbose : bool
+ If True, prints detailed information about the search process.
+ config : ConfigDict
+ Configuration object containing parameters like `inference.num_overlap` and `inference.batch_size`.
+ extension : str
+ File extension of the mixture files (e.g., 'wav').
+
+ Returns:
+ -------
+ List[str]
+ A list of file paths to the mixture files.
+ """
+ try:
+ valid_path = args.valid_path
+ except Exception as e:
+ print('No valid path in args')
+ raise e
+
+ all_mixtures_path = []
+ for path in valid_path:
+ part = sorted(glob.glob(f"{path}/*/mixture.{extension}"))
+ if len(part) == 0:
+ if verbose:
+ print(f'No validation data found in: {path}')
+ all_mixtures_path += part
+ if verbose:
+ print(f'Total mixtures: {len(all_mixtures_path)}')
+ print(f'Overlap: {config.inference.num_overlap} Batch size: {config.inference.batch_size}')
+
+ return all_mixtures_path
+
+
+def update_metrics_and_pbar(
+ track_metrics: Dict,
+ all_metrics: Dict,
+ instr: str,
+ pbar_dict: Dict,
+ mixture_paths: Union[List[str], tqdm],
+ verbose: bool = False
+) -> None:
+ """
+ Update metrics dictionary and progress bar with new metric values.
+
+ Parameters:
+ ----------
+ track_metrics : Dict
+ Dictionary with metric names as keys and their computed values as values.
+ all_metrics : Dict
+ Dictionary to store all metrics, organized by metric name and instrument.
+ instr : str
+ Name of the instrument for which the metrics are being computed.
+ pbar_dict : Dict
+ Dictionary for progress bar updates.
+ mixture_paths : tqdm, optional
+ Progress bar object, if available. Default is None.
+ verbose : bool, optional
+ If True, prints metric values to the console. Default is False.
+ """
+ for metric_name, metric_value in track_metrics.items():
+ if verbose:
+ print(f"Metric {metric_name:11s} value: {metric_value:.4f}")
+ all_metrics[metric_name][instr].append(metric_value)
+ pbar_dict[f'{metric_name}_{instr}'] = metric_value
+
+ if mixture_paths is not None:
+ try:
+ mixture_paths.set_postfix(pbar_dict)
+ except Exception:
+ pass
+
+
+def process_audio_files(
+ mixture_paths: List[str],
+ model: torch.nn.Module,
+ args,
+ config,
+ device: torch.device,
+ verbose: bool = False,
+ is_tqdm: bool = True
+) -> Dict[str, Dict[str, List[float]]]:
+ """
+ Process a list of audio files, perform source separation, and evaluate metrics.
+
+ Parameters:
+ ----------
+ mixture_paths : List[str]
+ List of file paths to the audio mixtures.
+ model : torch.nn.Module
+ The trained model used for source separation.
+ args : Any
+ Argument object containing user-specified options like metrics, model type, etc.
+ config : Any
+ Configuration object containing model and processing parameters.
+ device : torch.device
+ Device (CPU or CUDA) on which the model will be executed.
+ verbose : bool, optional
+ If True, prints detailed logs for each processed file. Default is False.
+ is_tqdm : bool, optional
+ If True, displays a progress bar for file processing. Default is True.
+
+ Returns:
+ -------
+ Dict[str, Dict[str, List[float]]]
+ A nested dictionary where the outer keys are metric names,
+ the inner keys are instrument names, and the values are lists of metric scores.
+ """
+ instruments = prefer_target_instrument(config)
+
+ use_tta = getattr(args, 'use_tta', False)
+ # dir to save files, if empty no saving
+ store_dir = getattr(args, 'store_dir', '')
+ # codec to save files
+ if 'extension' in config['inference']:
+ extension = config['inference']['extension']
+ else:
+ extension = getattr(args, 'extension', 'wav')
+
+ # Initialize metrics dictionary
+ all_metrics = {
+ metric: {instr: [] for instr in config.training.instruments}
+ for metric in args.metrics
+ }
+
+ if is_tqdm:
+ mixture_paths = tqdm(mixture_paths)
+
+ for path in mixture_paths:
+ start_time = time.time()
+ mix, sr = read_audio_transposed(path)
+ mix_orig = mix.copy()
+ folder = os.path.dirname(path)
+
+ if 'sample_rate' in config.audio:
+ if sr != config.audio['sample_rate']:
+ orig_length = mix.shape[-1]
+ if verbose:
+ print(f'Warning: sample rate is different. In config: {config.audio["sample_rate"]} in file {path}: {sr}')
+ mix = librosa.resample(mix, orig_sr=sr, target_sr=config.audio['sample_rate'], res_type='kaiser_best')
+
+ if verbose:
+ folder_name = os.path.abspath(folder)
+ print(f'Song: {folder_name} Shape: {mix.shape}')
+
+ if 'normalize' in config.inference:
+ if config.inference['normalize'] is True:
+ mix, norm_params = normalize_audio(mix)
+
+ waveforms_orig = demix(config, model, mix.copy(), device, model_type=args.model_type)
+
+ if use_tta:
+ waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
+
+ pbar_dict = {}
+
+ for instr in instruments:
+ if verbose:
+ print(f"Instr: {instr}")
+
+ if instr != 'other' or config.training.other_fix is False:
+ track, sr1 = read_audio_transposed(f"{folder}/{instr}.{extension}", instr, skip_err=True)
+ if track is None:
+ continue
+ else:
+ # if track=vocal+other
+ track, sr1 = read_audio_transposed(f"{folder}/vocals.{extension}")
+ track = mix_orig - track
+
+ estimates = waveforms_orig[instr]
+
+ if 'sample_rate' in config.audio:
+ if sr != config.audio['sample_rate']:
+ estimates = librosa.resample(estimates, orig_sr=config.audio['sample_rate'], target_sr=sr,
+ res_type='kaiser_best')
+ estimates = librosa.util.fix_length(estimates, size=orig_length)
+
+ if 'normalize' in config.inference:
+ if config.inference['normalize'] is True:
+ estimates = denormalize_audio(estimates, norm_params)
+
+ if store_dir:
+ os.makedirs(store_dir, exist_ok=True)
+ out_wav_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.wav"
+ sf.write(out_wav_name, estimates.T, sr, subtype='FLOAT')
+ if args.draw_spectro > 0:
+ out_img_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.jpg"
+ draw_spectrogram(estimates.T, sr, args.draw_spectro, out_img_name)
+ out_img_name_orig = f"{store_dir}/{os.path.basename(folder)}_{instr}_orig.jpg"
+ draw_spectrogram(track.T, sr, args.draw_spectro, out_img_name_orig)
+
+ track_metrics = get_metrics(
+ args.metrics,
+ track,
+ estimates,
+ mix_orig,
+ device=device,
+ )
+
+ update_metrics_and_pbar(
+ track_metrics,
+ all_metrics,
+ instr, pbar_dict,
+ mixture_paths=mixture_paths,
+ verbose=verbose
+ )
+
+ if verbose:
+ print(f"Time for song: {time.time() - start_time:.2f} sec")
+
+ return all_metrics
+
+
+def compute_metric_avg(
+ store_dir: str,
+ args,
+ instruments: List[str],
+ config: ConfigDict,
+ all_metrics: Dict[str, Dict[str, List[float]]],
+ start_time: float
+) -> Dict[str, float]:
+ """
+ Calculate and log the average metrics for each instrument, including per-instrument metrics and overall averages.
+
+ Parameters:
+ ----------
+ store_dir : str
+ Directory to store the logs. If empty, logs are not stored.
+ args : dict
+ Dictionary containing the arguments, used for logging.
+ instruments : List[str]
+ List of instruments to process.
+ config : ConfigDict
+ Configuration dictionary containing the inference settings.
+ all_metrics : Dict[str, Dict[str, List[float]]]
+ A dictionary containing metric values for each instrument.
+ The structure is {metric_name: {instrument_name: [metric_values]}}.
+ start_time : float
+ The starting time for calculating elapsed time.
+
+ Returns:
+ -------
+ Dict[str, float]
+ A dictionary with the average value for each metric across all instruments.
+ """
+
+ logs = []
+ if store_dir:
+ logs.append(str(args))
+ verbose_logging = True
+ else:
+ verbose_logging = False
+
+ logging(logs, text=f"Num overlap: {config.inference.num_overlap}", verbose_logging=verbose_logging)
+
+ metric_avg = {}
+ for instr in instruments:
+ for metric_name in all_metrics:
+ metric_values = np.array(all_metrics[metric_name][instr])
+
+ mean_val = metric_values.mean()
+ std_val = metric_values.std()
+
+ logging(logs, text=f"Instr {instr} {metric_name}: {mean_val:.4f} (Std: {std_val:.4f})", verbose_logging=verbose_logging)
+ if metric_name not in metric_avg:
+ metric_avg[metric_name] = 0.0
+ metric_avg[metric_name] += mean_val
+ for metric_name in all_metrics:
+ metric_avg[metric_name] /= len(instruments)
+
+ if len(instruments) > 1:
+ for metric_name in metric_avg:
+ logging(logs, text=f'Metric avg {metric_name:11s}: {metric_avg[metric_name]:.4f}', verbose_logging=verbose_logging)
+ logging(logs, text=f"Elapsed time: {time.time() - start_time:.2f} sec", verbose_logging=verbose_logging)
+
+ if store_dir:
+ write_results_in_file(store_dir, logs)
+
+ return metric_avg
+
+
+def valid(
+ model: torch.nn.Module,
+ args,
+ config: ConfigDict,
+ device: torch.device,
+ verbose: bool = False
+) -> Tuple[dict, dict]:
+ """
+ Validate a trained model on a set of audio mixtures and compute metrics.
+
+ This function performs validation by separating audio sources from mixtures,
+ computing evaluation metrics, and optionally saving results to a file.
+
+ Parameters:
+ ----------
+ model : torch.nn.Module
+ The trained model for source separation.
+ args : Namespace
+ Command-line arguments or equivalent object containing configurations.
+ config : dict
+ Configuration dictionary with model and processing parameters.
+ device : torch.device
+ The device (CPU or CUDA) to run the model on.
+ verbose : bool, optional
+ If True, enables verbose output during processing. Default is False.
+
+ Returns:
+ -------
+ dict
+ A dictionary of average metrics across all instruments.
+ """
+
+ start_time = time.time()
+ model.eval().to(device)
+
+ # dir to save files, if empty no saving
+ store_dir = getattr(args, 'store_dir', '')
+ # codec to save files
+ if 'extension' in config['inference']:
+ extension = config['inference']['extension']
+ else:
+ extension = getattr(args, 'extension', 'wav')
+
+ all_mixtures_path = get_mixture_paths(args, verbose, config, extension)
+ all_metrics = process_audio_files(all_mixtures_path, model, args, config, device, verbose, not verbose)
+ instruments = prefer_target_instrument(config)
+
+ return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics
+
+
+def validate_in_subprocess(
+ proc_id: int,
+ queue: torch.multiprocessing.Queue,
+ all_mixtures_path: List[str],
+ model: torch.nn.Module,
+ args,
+ config: ConfigDict,
+ device: str,
+ return_dict
+) -> None:
+ """
+ Perform validation on a subprocess with multi-processing support. Each process handles inference on a subset of the mixture files
+ and updates the shared metrics dictionary.
+
+ Parameters:
+ ----------
+ proc_id : int
+ The process ID (used to assign metrics to the correct key in `return_dict`).
+ queue : torch.multiprocessing.Queue
+ Queue to receive paths to the mixture files for processing.
+ all_mixtures_path : List[str]
+ List of paths to the mixture files to be processed.
+ model : torch.nn.Module
+ The model to be used for inference.
+ args : dict
+ Dictionary containing various argument configurations (e.g., metrics to calculate).
+ config : ConfigDict
+ Configuration object containing model settings and training parameters.
+ device : str
+ The device to use for inference (e.g., 'cpu', 'cuda:0').
+ return_dict : torch.multiprocessing.Manager().dict
+ Shared dictionary to store the results from each process.
+
+ Returns:
+ -------
+ None
+ The function modifies the `return_dict` in place, but does not return any value.
+ """
+
+ m1 = model.eval().to(device)
+ if proc_id == 0:
+ progress_bar = tqdm(total=len(all_mixtures_path))
+
+ # Initialize metrics dictionary
+ all_metrics = {
+ metric: {instr: [] for instr in config.training.instruments}
+ for metric in args.metrics
+ }
+
+ while True:
+ current_step, path = queue.get()
+ if path is None: # check for sentinel value
+ break
+ single_metrics = process_audio_files([path], m1, args, config, device, False, False)
+ pbar_dict = {}
+ for instr in config.training.instruments:
+ for metric_name in all_metrics:
+ all_metrics[metric_name][instr] += single_metrics[metric_name][instr]
+ if len(single_metrics[metric_name][instr]) > 0:
+ pbar_dict[f"{metric_name}_{instr}"] = f"{single_metrics[metric_name][instr][0]:.4f}"
+ if proc_id == 0:
+ progress_bar.update(current_step - progress_bar.n)
+ progress_bar.set_postfix(pbar_dict)
+ # print(f"Inference on process {proc_id}", all_sdr)
+ return_dict[proc_id] = all_metrics
+ return
+
+
+def run_parallel_validation(
+ verbose: bool,
+ all_mixtures_path: List[str],
+ config: ConfigDict,
+ model: torch.nn.Module,
+ device_ids: List[int],
+ args,
+ return_dict
+) -> None:
+ """
+ Run parallel validation using multiple processes. Each process handles a subset of the mixture files and computes the metrics.
+ The results are stored in a shared dictionary.
+
+ Parameters:
+ ----------
+ verbose : bool
+ Flag to print detailed information about the validation process.
+ all_mixtures_path : List[str]
+ List of paths to the mixture files to be processed.
+ config : ConfigDict
+ Configuration object containing model settings and validation parameters.
+ model : torch.nn.Module
+ The model to be used for inference.
+ device_ids : List[int]
+ List of device IDs (for multi-GPU setups) to use for validation.
+ args : dict
+ Dictionary containing various argument configurations (e.g., metrics to calculate).
+
+ Returns:
+ -------
+ A shared dictionary containing the validation metrics from all processes.
+ """
+
+ model = model.to('cpu')
+ try:
+ # For multiGPU training extract single model
+ model = model.module
+ except:
+ pass
+
+ queue = torch.multiprocessing.Queue()
+ processes = []
+
+ for i, device in enumerate(device_ids):
+ if torch.cuda.is_available():
+ device = f'cuda:{device}'
+ else:
+ device = 'cpu'
+ p = torch.multiprocessing.Process(
+ target=validate_in_subprocess,
+ args=(i, queue, all_mixtures_path, model, args, config, device, return_dict)
+ )
+ p.start()
+ processes.append(p)
+ for i, path in enumerate(all_mixtures_path):
+ queue.put((i, path))
+ for _ in range(len(device_ids)):
+ queue.put((None, None)) # sentinel value to signal subprocesses to exit
+ for p in processes:
+ p.join() # wait for all subprocesses to finish
+
+ return
+
+
+def valid_multi_gpu(
+ model: torch.nn.Module,
+ args,
+ config: ConfigDict,
+ device_ids: List[int],
+ verbose: bool = False
+) -> Tuple[Dict[str, float], dict]:
+ """
+ Perform validation across multiple GPUs, processing mixtures and computing metrics using parallel processes.
+ The results from each GPU are aggregated and the average metrics are computed.
+
+ Parameters:
+ ----------
+ model : torch.nn.Module
+ The model to be used for inference.
+ args : dict
+ Dictionary containing various argument configurations, such as file saving directory and codec settings.
+ config : ConfigDict
+ Configuration object containing model settings and validation parameters.
+ device_ids : List[int]
+ List of device IDs (for multi-GPU setups) to use for validation.
+ verbose : bool, optional
+ Flag to print detailed information about the validation process. Default is False.
+
+ Returns:
+ -------
+ Dict[str, float]
+ A dictionary containing the average metrics for each metric name.
+ """
+
+ start_time = time.time()
+
+ # dir to save files, if empty no saving
+ store_dir = getattr(args, 'store_dir', '')
+ # codec to save files
+ if 'extension' in config['inference']:
+ extension = config['inference']['extension']
+ else:
+ extension = getattr(args, 'extension', 'wav')
+
+ all_mixtures_path = get_mixture_paths(args, verbose, config, extension)
+
+ return_dict = torch.multiprocessing.Manager().dict()
+
+ run_parallel_validation(verbose, all_mixtures_path, config, model, device_ids, args, return_dict)
+
+ all_metrics = dict()
+ for metric in args.metrics:
+ all_metrics[metric] = dict()
+ for instr in config.training.instruments:
+ all_metrics[metric][instr] = []
+ for i in range(len(device_ids)):
+ all_metrics[metric][instr] += return_dict[i][metric][instr]
+
+ instruments = prefer_target_instrument(config)
+
+ return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics
+
+
+def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace:
+ """
+ Parse command-line arguments for configuring the model, dataset, and training parameters.
+
+ Args:
+ dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv.
+
+ Returns:
+ Namespace object containing parsed arguments and their values.
+ """
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default='mdx23c',
+ help="One of mdx23c, htdemucs, segm_models, mel_band_roformer,"
+ " bs_roformer, swin_upernet, bandit")
+ parser.add_argument("--config_path", type=str, help="Path to config file")
+ parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint"
+ " to valid weights")
+ parser.add_argument("--valid_path", nargs="+", type=str, help="Validate path")
+ parser.add_argument("--store_dir", type=str, default="", help="Path to store results as wav file")
+ parser.add_argument("--draw_spectro", type=float, default=0,
+ help="If --store_dir is set then code will generate spectrograms for resulted stems as well."
+ " Value defines for how many seconds os track spectrogram will be generated.")
+ parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='List of gpu ids')
+ parser.add_argument("--num_workers", type=int, default=0, help="Dataloader num_workers")
+ parser.add_argument("--pin_memory", action='store_true', help="Dataloader pin_memory")
+ parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation")
+ parser.add_argument("--use_tta", action='store_true',
+ help="Flag adds test time augmentation during inference (polarity and channel inverse)."
+ "While this triples the runtime, it reduces noise and slightly improves prediction quality.")
+ parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"],
+ choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
+ 'fullness'], help='List of metrics to use.')
+ parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
+
+ if dict_args is not None:
+ args = parser.parse_args([])
+ args_dict = vars(args)
+ args_dict.update(dict_args)
+ args = argparse.Namespace(**args_dict)
+ else:
+ args = parser.parse_args()
+
+ return args
+
+
+def check_validation(dict_args):
+ args = parse_args(dict_args)
+ torch.backends.cudnn.benchmark = True
+ try:
+ torch.multiprocessing.set_start_method('spawn')
+ except Exception as e:
+ pass
+ model, config = get_model_from_config(args.model_type, args.config_path)
+
+ if args.start_check_point:
+ load_start_checkpoint(args, model, type_='valid')
+
+ print(f"Instruments: {config.training.instruments}")
+
+ device_ids = args.device_ids
+ if torch.cuda.is_available():
+ device = torch.device(f'cuda:{device_ids[0]}')
+ else:
+ device = 'cpu'
+ print('CUDA is not available. Run validation on CPU. It will be very slow...')
+
+ if torch.cuda.is_available() and len(device_ids) > 1:
+ valid_multi_gpu(model, args, config, device_ids, verbose=False)
+ else:
+ valid(model, args, config, device, verbose=True)
+
+
+if __name__ == "__main__":
+ check_validation(None)