Spaces:
Running
Running
| import argparse | |
| import sys | |
| from pathlib import Path | |
| from typing import List | |
| import os | |
| from dora.log import fatal | |
| import torch as th | |
| from demucs.apply import apply_model, BagOfModels | |
| from demucs.audio import save_audio | |
| from demucs.pretrained import get_model_from_args, ModelLoadingError | |
| from demucs.separate import load_track | |
| import streamlit as st | |
| def separator( | |
| tracks: List[Path], | |
| out: Path, | |
| model: str, | |
| shifts: int, | |
| overlap: float, | |
| stem: str, | |
| int24: bool, | |
| float32: bool, | |
| clip_mode: str, | |
| mp3: bool, | |
| mp3_bitrate: int, | |
| verbose: bool, | |
| *args, | |
| **kwargs, | |
| ): | |
| """Separate the sources for the given tracks | |
| Args: | |
| tracks (Path): Path to tracks | |
| out (Path): Folder where to put extracted tracks. A subfolder with the model name will be | |
| created. | |
| model (str): Model name | |
| shifts (int): Number of random shifts for equivariant stabilization. | |
| Increase separation time but improves quality for Demucs. | |
| 10 was used in the original paper. | |
| overlap (float): Overlap | |
| stem (str): Only separate audio into {STEM} and no_{STEM}. | |
| int24 (bool): Save wav output as 24 bits wav. | |
| float32 (bool): Save wav output as float32 (2x bigger). | |
| clip_mode (str): Strategy for avoiding clipping: rescaling entire signal if necessary | |
| (rescale) or hard clipping (clamp). | |
| mp3 (bool): Convert the output wavs to mp3. | |
| mp3_bitrate (int): Bitrate of converted mp3. | |
| verbose (bool): Verbose | |
| """ | |
| if os.environ.get("LIMIT_CPU", False): | |
| th.set_num_threads(1) | |
| jobs = 1 | |
| else: | |
| # Number of jobs. This can increase memory usage but will be much faster when | |
| # multiple cores are available. | |
| jobs = os.cpu_count() | |
| if th.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| args = argparse.Namespace() | |
| args.tracks = tracks | |
| args.out = out | |
| args.model = model | |
| args.device = device | |
| args.shifts = shifts | |
| args.overlap = overlap | |
| args.stem = stem | |
| args.int24 = int24 | |
| args.float32 = float32 | |
| args.clip_mode = clip_mode | |
| args.mp3 = mp3 | |
| args.mp3_bitrate = mp3_bitrate | |
| args.jobs = jobs | |
| args.verbose = verbose | |
| args.filename = "{track}/{stem}.{ext}" | |
| args.split = True | |
| args.segment = None | |
| args.name = model | |
| args.repo = None | |
| try: | |
| model = get_model_from_args(args) | |
| except ModelLoadingError as error: | |
| fatal(error.args[0]) | |
| if args.segment is not None and args.segment < 8: | |
| fatal("Segment must greater than 8. ") | |
| if ".." in args.filename.replace("\\", "/").split("/"): | |
| fatal('".." must not appear in filename. ') | |
| if isinstance(model, BagOfModels): | |
| print( | |
| f"Selected model is a bag of {len(model.models)} models. " | |
| "You will see that many progress bars per track." | |
| ) | |
| if args.segment is not None: | |
| for sub in model.models: | |
| sub.segment = args.segment | |
| else: | |
| if args.segment is not None: | |
| model.segment = args.segment | |
| model.cpu() | |
| model.eval() | |
| if args.stem is not None and args.stem not in model.sources: | |
| fatal( | |
| 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( | |
| stem=args.stem, sources=", ".join(model.sources) | |
| ) | |
| ) | |
| out = args.out / args.name | |
| out.mkdir(parents=True, exist_ok=True) | |
| print(f"Separated tracks will be stored in {out.resolve()}") | |
| for track in args.tracks: | |
| if not track.exists(): | |
| print( | |
| f"File {track} does not exist. If the path contains spaces, " | |
| 'please try again after surrounding the entire path with quotes "".', | |
| file=sys.stderr, | |
| ) | |
| continue | |
| print(f"Separating track {track}") | |
| wav = load_track(track, model.audio_channels, model.samplerate) | |
| ref = wav.mean(0) | |
| wav = (wav - ref.mean()) / ref.std() | |
| sources = apply_model( | |
| model, | |
| wav[None], | |
| device=args.device, | |
| shifts=args.shifts, | |
| split=args.split, | |
| overlap=args.overlap, | |
| progress=True, | |
| num_workers=args.jobs, | |
| )[0] | |
| sources = sources * ref.std() + ref.mean() | |
| if args.mp3: | |
| ext = "mp3" | |
| else: | |
| ext = "wav" | |
| kwargs = { | |
| "samplerate": model.samplerate, | |
| "bitrate": args.mp3_bitrate, | |
| "clip": args.clip_mode, | |
| "as_float": args.float32, | |
| "bits_per_sample": 24 if args.int24 else 16, | |
| } | |
| if args.stem is None: | |
| for source, name in zip(sources, model.sources): | |
| stem = out / args.filename.format( | |
| track=track.name.rsplit(".", 1)[0], | |
| trackext=track.name.rsplit(".", 1)[-1], | |
| stem=name, | |
| ext=ext, | |
| ) | |
| stem.parent.mkdir(parents=True, exist_ok=True) | |
| save_audio(source, str(stem), **kwargs) | |
| else: | |
| sources = list(sources) | |
| stem = out / args.filename.format( | |
| track=track.name.rsplit(".", 1)[0], | |
| trackext=track.name.rsplit(".", 1)[-1], | |
| stem=args.stem, | |
| ext=ext, | |
| ) | |
| stem.parent.mkdir(parents=True, exist_ok=True) | |
| save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs) | |
| # Warning : after poping the stem, selected stem is no longer in the list 'sources' | |
| other_stem = th.zeros_like(sources[0]) | |
| for i in sources: | |
| other_stem += i | |
| stem = out / args.filename.format( | |
| track=track.name.rsplit(".", 1)[0], | |
| trackext=track.name.rsplit(".", 1)[-1], | |
| stem="no_" + args.stem, | |
| ext=ext, | |
| ) | |
| stem.parent.mkdir(parents=True, exist_ok=True) | |
| save_audio(other_stem, str(stem), **kwargs) | |