Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import copy | |
| from concurrent.futures import ThreadPoolExecutor, Future | |
| from dataclasses import dataclass, fields | |
| from contextlib import ExitStack | |
| import gzip | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import random | |
| import sys | |
| import typing as tp | |
| import torch | |
| import torch.nn.functional as F | |
| from .audio import audio_read, audio_info | |
| from .audio_utils import convert_audio | |
| from .zip import PathInZip | |
| try: | |
| import dora | |
| except ImportError: | |
| dora = None # type: ignore | |
| class BaseInfo: | |
| def _dict2fields(cls, dictionary: dict): | |
| return { | |
| field.name: dictionary[field.name] | |
| for field in fields(cls) if field.name in dictionary | |
| } | |
| def from_dict(cls, dictionary: dict): | |
| _dictionary = cls._dict2fields(dictionary) | |
| return cls(**_dictionary) | |
| def to_dict(self): | |
| return { | |
| field.name: self.__getattribute__(field.name) | |
| for field in fields(self) | |
| } | |
| class AudioMeta(BaseInfo): | |
| path: str | |
| duration: float | |
| sample_rate: int | |
| amplitude: tp.Optional[float] = None | |
| weight: tp.Optional[float] = None | |
| # info_path is used to load additional information about the audio file that is stored in zip files. | |
| info_path: tp.Optional[PathInZip] = None | |
| def from_dict(cls, dictionary: dict): | |
| base = cls._dict2fields(dictionary) | |
| if 'info_path' in base and base['info_path'] is not None: | |
| base['info_path'] = PathInZip(base['info_path']) | |
| return cls(**base) | |
| def to_dict(self): | |
| d = super().to_dict() | |
| if d['info_path'] is not None: | |
| d['info_path'] = str(d['info_path']) | |
| return d | |
| class SegmentInfo(BaseInfo): | |
| meta: AudioMeta | |
| seek_time: float | |
| n_frames: int # actual number of frames without padding | |
| total_frames: int # total number of frames, padding included | |
| sample_rate: int # actual sample rate | |
| DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'] | |
| logger = logging.getLogger(__name__) | |
| def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta: | |
| """AudioMeta from a path to an audio file. | |
| Args: | |
| file_path (str): Resolved path of valid audio file. | |
| minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). | |
| Returns: | |
| AudioMeta: Audio file path and its metadata. | |
| """ | |
| info = audio_info(file_path) | |
| amplitude: tp.Optional[float] = None | |
| if not minimal: | |
| wav, sr = audio_read(file_path) | |
| amplitude = wav.abs().max().item() | |
| return AudioMeta(file_path, info.duration, info.sample_rate, amplitude) | |
| def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta: | |
| """If Dora is available as a dependency, try to resolve potential relative paths | |
| in list of AudioMeta. This method is expected to be used when loading meta from file. | |
| Args: | |
| m (AudioMeta): Audio meta to resolve. | |
| fast (bool): If True, uses a really fast check for determining if a file is already absolute or not. | |
| Only valid on Linux/Mac. | |
| Returns: | |
| AudioMeta: Audio meta with resolved path. | |
| """ | |
| def is_abs(m): | |
| if fast: | |
| return str(m)[0] == '/' | |
| else: | |
| os.path.isabs(str(m)) | |
| if not dora: | |
| return m | |
| if not is_abs(m.path): | |
| m.path = dora.git_save.to_absolute_path(m.path) | |
| if m.info_path is not None and not is_abs(m.info_path.zip_path): | |
| m.info_path.zip_path = dora.git_save.to_absolute_path(m.path) | |
| return m | |
| def find_audio_files(path: tp.Union[Path, str], | |
| exts: tp.List[str] = DEFAULT_EXTS, | |
| resolve: bool = True, | |
| minimal: bool = True, | |
| progress: bool = False, | |
| workers: int = 0) -> tp.List[AudioMeta]: | |
| """Build a list of AudioMeta from a given path, | |
| collecting relevant audio files and fetching meta info. | |
| Args: | |
| path (str or Path): Path to folder containing audio files. | |
| exts (list of str): List of file extensions to consider for audio files. | |
| minimal (bool): Whether to only load the minimal set of metadata (takes longer if not). | |
| progress (bool): Whether to log progress on audio files collection. | |
| workers (int): number of parallel workers, if 0, use only the current thread. | |
| Returns: | |
| List[AudioMeta]: List of audio file path and its metadata. | |
| """ | |
| audio_files = [] | |
| futures: tp.List[Future] = [] | |
| pool: tp.Optional[ThreadPoolExecutor] = None | |
| with ExitStack() as stack: | |
| if workers > 0: | |
| pool = ThreadPoolExecutor(workers) | |
| stack.enter_context(pool) | |
| if progress: | |
| print("Finding audio files...") | |
| for root, folders, files in os.walk(path, followlinks=True): | |
| for file in files: | |
| full_path = Path(root) / file | |
| if full_path.suffix.lower() in exts: | |
| audio_files.append(full_path) | |
| if pool is not None: | |
| futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal)) | |
| if progress: | |
| print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr) | |
| if progress: | |
| print("Getting audio metadata...") | |
| meta: tp.List[AudioMeta] = [] | |
| for idx, file_path in enumerate(audio_files): | |
| try: | |
| if pool is None: | |
| m = _get_audio_meta(str(file_path), minimal) | |
| else: | |
| m = futures[idx].result() | |
| if resolve: | |
| m = _resolve_audio_meta(m) | |
| except Exception as err: | |
| print("Error with", str(file_path), err, file=sys.stderr) | |
| continue | |
| meta.append(m) | |
| if progress: | |
| print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr) | |
| meta.sort() | |
| return meta | |
| def load_audio_meta(path: tp.Union[str, Path], | |
| resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]: | |
| """Load list of AudioMeta from an optionally compressed json file. | |
| Args: | |
| path (str or Path): Path to JSON file. | |
| resolve (bool): Whether to resolve the path from AudioMeta (default=True). | |
| fast (bool): activates some tricks to make things faster. | |
| Returns: | |
| List[AudioMeta]: List of audio file path and its total duration. | |
| """ | |
| open_fn = gzip.open if str(path).lower().endswith('.gz') else open | |
| with open_fn(path, 'rb') as fp: # type: ignore | |
| lines = fp.readlines() | |
| meta = [] | |
| for line in lines: | |
| d = json.loads(line) | |
| m = AudioMeta.from_dict(d) | |
| if resolve: | |
| m = _resolve_audio_meta(m, fast=fast) | |
| meta.append(m) | |
| return meta | |
| def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]): | |
| """Save the audio metadata to the file pointer as json. | |
| Args: | |
| path (str or Path): Path to JSON file. | |
| metadata (list of BaseAudioMeta): List of audio meta to save. | |
| """ | |
| Path(path).parent.mkdir(exist_ok=True, parents=True) | |
| open_fn = gzip.open if str(path).lower().endswith('.gz') else open | |
| with open_fn(path, 'wb') as fp: # type: ignore | |
| for m in meta: | |
| json_str = json.dumps(m.to_dict()) + '\n' | |
| json_bytes = json_str.encode('utf-8') | |
| fp.write(json_bytes) | |
| class AudioDataset: | |
| """Base audio dataset. | |
| The dataset takes a list of AudioMeta and create a dataset composed of segments of audio | |
| and potentially additional information, by creating random segments from the list of audio | |
| files referenced in the metadata and applying minimal data pre-processing such as resampling, | |
| mixing of channels, padding, etc. | |
| If no segment_duration value is provided, the AudioDataset will return the full wav for each | |
| audio file. Otherwise, it will randomly sample audio files and create a segment of the specified | |
| duration, applying padding if required. | |
| By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True | |
| allows to return a tuple containing the torch Tensor and additional metadata on the segment and the | |
| original audio meta. | |
| Args: | |
| meta (tp.List[AudioMeta]): List of audio files metadata. | |
| segment_duration (float): Optional segment duration of audio to load. | |
| If not specified, the dataset will load the full audio segment from the file. | |
| shuffle (bool): Set to `True` to have the data reshuffled at every epoch. | |
| sample_rate (int): Target sample rate of the loaded audio samples. | |
| channels (int): Target number of channels of the loaded audio samples. | |
| sample_on_duration (bool): Set to `True` to sample segments with probability | |
| dependent on audio file duration. This is only used if `segment_duration` is provided. | |
| sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of | |
| `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product | |
| of the file duration and file weight. This is only used if `segment_duration` is provided. | |
| min_segment_ratio (float): Minimum segment ratio to use when the audio file | |
| is shorter than the desired segment. | |
| max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset. | |
| return_info (bool): Whether to return the wav only or return wav along with segment info and metadata. | |
| min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided | |
| audio shorter than this will be filtered out. | |
| max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided | |
| audio longer than this will be filtered out. | |
| """ | |
| def __init__(self, | |
| meta: tp.List[AudioMeta], | |
| segment_duration: tp.Optional[float] = None, | |
| shuffle: bool = True, | |
| num_samples: int = 10_000, | |
| sample_rate: int = 48_000, | |
| channels: int = 2, | |
| pad: bool = True, | |
| sample_on_duration: bool = True, | |
| sample_on_weight: bool = True, | |
| min_segment_ratio: float = 0.5, | |
| max_read_retry: int = 10, | |
| return_info: bool = False, | |
| min_audio_duration: tp.Optional[float] = None, | |
| max_audio_duration: tp.Optional[float] = None | |
| ): | |
| assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.' | |
| assert segment_duration is None or segment_duration > 0 | |
| assert segment_duration is None or min_segment_ratio >= 0 | |
| logging.debug(f'sample_on_duration: {sample_on_duration}') | |
| logging.debug(f'sample_on_weight: {sample_on_weight}') | |
| logging.debug(f'pad: {pad}') | |
| logging.debug(f'min_segment_ratio: {min_segment_ratio}') | |
| self.segment_duration = segment_duration | |
| self.min_segment_ratio = min_segment_ratio | |
| self.max_audio_duration = max_audio_duration | |
| self.min_audio_duration = min_audio_duration | |
| if self.min_audio_duration is not None and self.max_audio_duration is not None: | |
| assert self.min_audio_duration <= self.max_audio_duration | |
| self.meta: tp.List[AudioMeta] = self._filter_duration(meta) | |
| assert len(self.meta) # Fail fast if all data has been filtered. | |
| self.total_duration = sum(d.duration for d in self.meta) | |
| if segment_duration is None: | |
| num_samples = len(self.meta) | |
| self.num_samples = num_samples | |
| self.shuffle = shuffle | |
| self.sample_rate = sample_rate | |
| self.channels = channels | |
| self.pad = pad | |
| self.sample_on_weight = sample_on_weight | |
| self.sample_on_duration = sample_on_duration | |
| self.sampling_probabilities = self._get_sampling_probabilities() | |
| self.max_read_retry = max_read_retry | |
| self.return_info = return_info | |
| def __len__(self): | |
| return self.num_samples | |
| def _get_sampling_probabilities(self, normalized: bool = True): | |
| """Return the sampling probabilities for each file inside `self.meta`. | |
| """ | |
| scores: tp.List[float] = [] | |
| for file_meta in self.meta: | |
| score = 1. | |
| if self.sample_on_weight and file_meta.weight is not None: | |
| score *= file_meta.weight | |
| if self.sample_on_duration: | |
| score *= file_meta.duration | |
| scores.append(score) | |
| probabilities = torch.tensor(scores) | |
| if normalized: | |
| probabilities /= probabilities.sum() | |
| return probabilities | |
| def sample_file(self, rng: torch.Generator) -> AudioMeta: | |
| """Sample a given file from `self.meta`. Can be overriden in subclasses. | |
| This is only called if `segment_duration` is not None. | |
| You must use the provided random number generator `rng` for reproducibility. | |
| """ | |
| if not self.sample_on_weight and not self.sample_on_duration: | |
| file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item()) | |
| else: | |
| file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item()) | |
| return self.meta[file_index] | |
| def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]: | |
| if self.segment_duration is None: | |
| file_meta = self.meta[index] | |
| out, sr = audio_read(file_meta.path) | |
| out = convert_audio(out, sr, self.sample_rate, self.channels) | |
| n_frames = out.shape[-1] | |
| segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames, | |
| sample_rate=self.sample_rate) | |
| else: | |
| rng = torch.Generator() | |
| if self.shuffle: | |
| # We use index, plus extra randomness | |
| rng.manual_seed(index + self.num_samples * random.randint(0, 2**24)) | |
| else: | |
| # We only use index | |
| rng.manual_seed(index) | |
| for retry in range(self.max_read_retry): | |
| file_meta = self.sample_file(rng) | |
| # We add some variance in the file position even if audio file is smaller than segment | |
| # without ending up with empty segments | |
| max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio) | |
| seek_time = torch.rand(1, generator=rng).item() * max_seek | |
| try: | |
| out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False) | |
| out = convert_audio(out, sr, self.sample_rate, self.channels) | |
| n_frames = out.shape[-1] | |
| target_frames = int(self.segment_duration * self.sample_rate) | |
| if self.pad: | |
| out = F.pad(out, (0, target_frames - n_frames)) | |
| segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames, | |
| sample_rate=self.sample_rate) | |
| except Exception as exc: | |
| logger.warning("Error opening file %s: %r", file_meta.path, exc) | |
| if retry == self.max_read_retry - 1: | |
| raise | |
| else: | |
| break | |
| if self.return_info: | |
| # Returns the wav and additional information on the wave segment | |
| return out, segment_info | |
| else: | |
| return out | |
| def collater(self, samples): | |
| """The collater function has to be provided to the dataloader | |
| if AudioDataset has return_info=True in order to properly collate | |
| the samples of a batch. | |
| """ | |
| if self.segment_duration is None and len(samples) > 1: | |
| assert self.pad, "Must allow padding when batching examples of different durations." | |
| # In this case the audio reaching the collater is of variable length as segment_duration=None. | |
| to_pad = self.segment_duration is None and self.pad | |
| if to_pad: | |
| max_len = max([wav.shape[-1] for wav, _ in samples]) | |
| def _pad_wav(wav): | |
| return F.pad(wav, (0, max_len - wav.shape[-1])) | |
| if self.return_info: | |
| if len(samples) > 0: | |
| assert len(samples[0]) == 2 | |
| assert isinstance(samples[0][0], torch.Tensor) | |
| assert isinstance(samples[0][1], SegmentInfo) | |
| wavs = [wav for wav, _ in samples] | |
| segment_infos = [copy.deepcopy(info) for _, info in samples] | |
| if to_pad: | |
| # Each wav could be of a different duration as they are not segmented. | |
| for i in range(len(samples)): | |
| # Determines the total legth of the signal with padding, so we update here as we pad. | |
| segment_infos[i].total_frames = max_len | |
| wavs[i] = _pad_wav(wavs[i]) | |
| wav = torch.stack(wavs) | |
| return wav, segment_infos | |
| else: | |
| assert isinstance(samples[0], torch.Tensor) | |
| if to_pad: | |
| samples = [_pad_wav(s) for s in samples] | |
| return torch.stack(samples) | |
| def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: | |
| """Filters out audio files with short durations. | |
| Removes from meta files that have durations that will not allow to samples examples from them. | |
| """ | |
| orig_len = len(meta) | |
| # Filter data that is too short. | |
| if self.min_audio_duration is not None: | |
| meta = [m for m in meta if m.duration >= self.min_audio_duration] | |
| # Filter data that is too long. | |
| if self.max_audio_duration is not None: | |
| meta = [m for m in meta if m.duration <= self.max_audio_duration] | |
| filtered_len = len(meta) | |
| removed_percentage = 100*(1-float(filtered_len)/orig_len) | |
| msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage | |
| if removed_percentage < 10: | |
| logging.debug(msg) | |
| else: | |
| logging.warning(msg) | |
| return meta | |
| def from_meta(cls, root: tp.Union[str, Path], **kwargs): | |
| """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file. | |
| Args: | |
| root (str or Path): Path to root folder containing audio files. | |
| kwargs: Additional keyword arguments for the AudioDataset. | |
| """ | |
| root = Path(root) | |
| if root.is_dir(): | |
| if (root / 'data.jsonl').exists(): | |
| root = root / 'data.jsonl' | |
| elif (root / 'data.jsonl.gz').exists(): | |
| root = root / 'data.jsonl.gz' | |
| else: | |
| raise ValueError("Don't know where to read metadata from in the dir. " | |
| "Expecting either a data.jsonl or data.jsonl.gz file but none found.") | |
| meta = load_audio_meta(root) | |
| return cls(meta, **kwargs) | |
| def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True, | |
| exts: tp.List[str] = DEFAULT_EXTS, **kwargs): | |
| """Instantiate AudioDataset from a path containing (possibly nested) audio files. | |
| Args: | |
| root (str or Path): Path to root folder containing audio files. | |
| minimal_meta (bool): Whether to only load minimal metadata or not. | |
| exts (list of str): Extensions for audio files. | |
| kwargs: Additional keyword arguments for the AudioDataset. | |
| """ | |
| root = Path(root) | |
| if root.is_file(): | |
| meta = load_audio_meta(root, resolve=True) | |
| else: | |
| meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True) | |
| return cls(meta, **kwargs) | |
| def main(): | |
| logging.basicConfig(stream=sys.stderr, level=logging.INFO) | |
| parser = argparse.ArgumentParser( | |
| prog='audio_dataset', | |
| description='Generate .jsonl files by scanning a folder.') | |
| parser.add_argument('root', help='Root folder with all the audio files') | |
| parser.add_argument('output_meta_file', | |
| help='Output file to store the metadata, ') | |
| parser.add_argument('--complete', | |
| action='store_false', dest='minimal', default=True, | |
| help='Retrieve all metadata, even the one that are expansive ' | |
| 'to compute (e.g. normalization).') | |
| parser.add_argument('--resolve', | |
| action='store_true', default=False, | |
| help='Resolve the paths to be absolute and with no symlinks.') | |
| parser.add_argument('--workers', | |
| default=10, type=int, | |
| help='Number of workers.') | |
| args = parser.parse_args() | |
| meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True, | |
| resolve=args.resolve, minimal=args.minimal, workers=args.workers) | |
| save_audio_meta(args.output_meta_file, meta) | |
| if __name__ == '__main__': | |
| main() | |