Upload 7 files
Browse files- audiocraft/data/audio.py +231 -0
- audiocraft/data/audio_dataset.py +587 -0
- audiocraft/data/audio_utils.py +176 -0
- audiocraft/data/info_audio_dataset.py +110 -0
- audiocraft/data/music_dataset.py +270 -0
- audiocraft/data/sound_dataset.py +330 -0
- audiocraft/data/zip.py +76 -0
    	
        audiocraft/data/audio.py
    ADDED
    
    | @@ -0,0 +1,231 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            Audio IO methods are defined in this module (info, read, write),
         | 
| 9 | 
            +
            We rely on av library for faster read when possible, otherwise on torchaudio.
         | 
| 10 | 
            +
            """
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from dataclasses import dataclass
         | 
| 13 | 
            +
            from pathlib import Path
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
            import typing as tp
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import soundfile
         | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            from torch.nn import functional as F
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import av
         | 
| 23 | 
            +
            import subprocess as sp
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from .audio_utils import f32_pcm, normalize_audio
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            _av_initialized = False
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def _init_av():
         | 
| 32 | 
            +
                global _av_initialized
         | 
| 33 | 
            +
                if _av_initialized:
         | 
| 34 | 
            +
                    return
         | 
| 35 | 
            +
                logger = logging.getLogger('libav.mp3')
         | 
| 36 | 
            +
                logger.setLevel(logging.ERROR)
         | 
| 37 | 
            +
                _av_initialized = True
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            @dataclass(frozen=True)
         | 
| 41 | 
            +
            class AudioFileInfo:
         | 
| 42 | 
            +
                sample_rate: int
         | 
| 43 | 
            +
                duration: float
         | 
| 44 | 
            +
                channels: int
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
         | 
| 48 | 
            +
                _init_av()
         | 
| 49 | 
            +
                with av.open(str(filepath)) as af:
         | 
| 50 | 
            +
                    stream = af.streams.audio[0]
         | 
| 51 | 
            +
                    sample_rate = stream.codec_context.sample_rate
         | 
| 52 | 
            +
                    duration = float(stream.duration * stream.time_base)
         | 
| 53 | 
            +
                    channels = stream.channels
         | 
| 54 | 
            +
                    return AudioFileInfo(sample_rate, duration, channels)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
         | 
| 58 | 
            +
                info = soundfile.info(filepath)
         | 
| 59 | 
            +
                return AudioFileInfo(info.samplerate, info.duration, info.channels)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
         | 
| 63 | 
            +
                # torchaudio no longer returns useful duration informations for some formats like mp3s.
         | 
| 64 | 
            +
                filepath = Path(filepath)
         | 
| 65 | 
            +
                if filepath.suffix in ['.flac', '.ogg']:  # TODO: Validate .ogg can be safely read with av_info
         | 
| 66 | 
            +
                    # ffmpeg has some weird issue with flac.
         | 
| 67 | 
            +
                    return _soundfile_info(filepath)
         | 
| 68 | 
            +
                else:
         | 
| 69 | 
            +
                    return _av_info(filepath)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
         | 
| 73 | 
            +
                """FFMPEG-based audio file reading using PyAV bindings.
         | 
| 74 | 
            +
                Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                Args:
         | 
| 77 | 
            +
                    filepath (str or Path): Path to audio file to read.
         | 
| 78 | 
            +
                    seek_time (float): Time at which to start reading in the file.
         | 
| 79 | 
            +
                    duration (float): Duration to read from the file. If set to -1, the whole file is read.
         | 
| 80 | 
            +
                Returns:
         | 
| 81 | 
            +
                    tuple of torch.Tensor, int: Tuple containing audio data and sample rate
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                _init_av()
         | 
| 84 | 
            +
                with av.open(str(filepath)) as af:
         | 
| 85 | 
            +
                    stream = af.streams.audio[0]
         | 
| 86 | 
            +
                    sr = stream.codec_context.sample_rate
         | 
| 87 | 
            +
                    num_frames = int(sr * duration) if duration >= 0 else -1
         | 
| 88 | 
            +
                    frame_offset = int(sr * seek_time)
         | 
| 89 | 
            +
                    # we need a small negative offset otherwise we get some edge artifact
         | 
| 90 | 
            +
                    # from the mp3 decoder.
         | 
| 91 | 
            +
                    af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
         | 
| 92 | 
            +
                    frames = []
         | 
| 93 | 
            +
                    length = 0
         | 
| 94 | 
            +
                    for frame in af.decode(streams=stream.index):
         | 
| 95 | 
            +
                        current_offset = int(frame.rate * frame.pts * frame.time_base)
         | 
| 96 | 
            +
                        strip = max(0, frame_offset - current_offset)
         | 
| 97 | 
            +
                        buf = torch.from_numpy(frame.to_ndarray())
         | 
| 98 | 
            +
                        if buf.shape[0] != stream.channels:
         | 
| 99 | 
            +
                            buf = buf.view(-1, stream.channels).t()
         | 
| 100 | 
            +
                        buf = buf[:, strip:]
         | 
| 101 | 
            +
                        frames.append(buf)
         | 
| 102 | 
            +
                        length += buf.shape[1]
         | 
| 103 | 
            +
                        if num_frames > 0 and length >= num_frames:
         | 
| 104 | 
            +
                            break
         | 
| 105 | 
            +
                    assert frames
         | 
| 106 | 
            +
                    # If the above assert fails, it is likely because we seeked past the end of file point,
         | 
| 107 | 
            +
                    # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
         | 
| 108 | 
            +
                    # This will need proper debugging, in due time.
         | 
| 109 | 
            +
                    wav = torch.cat(frames, dim=1)
         | 
| 110 | 
            +
                    assert wav.shape[0] == stream.channels
         | 
| 111 | 
            +
                    if num_frames > 0:
         | 
| 112 | 
            +
                        wav = wav[:, :num_frames]
         | 
| 113 | 
            +
                    return f32_pcm(wav), sr
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
         | 
| 117 | 
            +
                           duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
         | 
| 118 | 
            +
                """Read audio by picking the most appropriate backend tool based on the audio format.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                Args:
         | 
| 121 | 
            +
                    filepath (str or Path): Path to audio file to read.
         | 
| 122 | 
            +
                    seek_time (float): Time at which to start reading in the file.
         | 
| 123 | 
            +
                    duration (float): Duration to read from the file. If set to -1, the whole file is read.
         | 
| 124 | 
            +
                    pad (bool): Pad output audio if not reaching expected duration.
         | 
| 125 | 
            +
                Returns:
         | 
| 126 | 
            +
                    tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
                fp = Path(filepath)
         | 
| 129 | 
            +
                if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
         | 
| 130 | 
            +
                    # There is some bug with ffmpeg and reading flac
         | 
| 131 | 
            +
                    info = _soundfile_info(filepath)
         | 
| 132 | 
            +
                    frames = -1 if duration <= 0 else int(duration * info.sample_rate)
         | 
| 133 | 
            +
                    frame_offset = int(seek_time * info.sample_rate)
         | 
| 134 | 
            +
                    wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
         | 
| 135 | 
            +
                    assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
         | 
| 136 | 
            +
                    wav = torch.from_numpy(wav).t().contiguous()
         | 
| 137 | 
            +
                    if len(wav.shape) == 1:
         | 
| 138 | 
            +
                        wav = torch.unsqueeze(wav, 0)
         | 
| 139 | 
            +
                else:
         | 
| 140 | 
            +
                    wav, sr = _av_read(filepath, seek_time, duration)
         | 
| 141 | 
            +
                if pad and duration > 0:
         | 
| 142 | 
            +
                    expected_frames = int(duration * sr)
         | 
| 143 | 
            +
                    wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
         | 
| 144 | 
            +
                return wav, sr
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
         | 
| 148 | 
            +
                # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
         | 
| 149 | 
            +
                assert wav.dim() == 2, wav.shape
         | 
| 150 | 
            +
                command = [
         | 
| 151 | 
            +
                    'ffmpeg',
         | 
| 152 | 
            +
                    '-loglevel', 'error',
         | 
| 153 | 
            +
                    '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
         | 
| 154 | 
            +
                    '-i', '-'] + flags + [str(out_path)]
         | 
| 155 | 
            +
                input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
         | 
| 156 | 
            +
                sp.run(command, input=input_, check=True)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
            def audio_write(stem_name: tp.Union[str, Path],
         | 
| 160 | 
            +
                            wav: torch.Tensor, sample_rate: int,
         | 
| 161 | 
            +
                            format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
         | 
| 162 | 
            +
                            normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 163 | 
            +
                            rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 164 | 
            +
                            loudness_compressor: bool = False,
         | 
| 165 | 
            +
                            log_clipping: bool = True, make_parent_dir: bool = True,
         | 
| 166 | 
            +
                            add_suffix: bool = True) -> Path:
         | 
| 167 | 
            +
                """Convenience function for saving audio to disk. Returns the filename the audio was written to.
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                Args:
         | 
| 170 | 
            +
                    stem_name (str or Path): Filename without extension which will be added automatically.
         | 
| 171 | 
            +
                    wav (torch.Tensor): Audio data to save.
         | 
| 172 | 
            +
                    sample_rate (int): Sample rate of audio data.
         | 
| 173 | 
            +
                    format (str): Either "wav", "mp3", "ogg", or "flac".
         | 
| 174 | 
            +
                    mp3_rate (int): kbps when using mp3s.
         | 
| 175 | 
            +
                    ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
         | 
| 176 | 
            +
                    normalize (bool): if `True` (default), normalizes according to the prescribed
         | 
| 177 | 
            +
                        strategy (see after). If `False`, the strategy is only used in case clipping
         | 
| 178 | 
            +
                        would happen.
         | 
| 179 | 
            +
                    strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
         | 
| 180 | 
            +
                        i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
         | 
| 181 | 
            +
                        with extra headroom to avoid clipping. 'clip' just clips.
         | 
| 182 | 
            +
                    peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
         | 
| 183 | 
            +
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 184 | 
            +
                        than the `peak_clip` one to avoid further clipping.
         | 
| 185 | 
            +
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 186 | 
            +
                    loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
         | 
| 187 | 
            +
                     when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 188 | 
            +
                        occurs despite strategy (only for 'rms').
         | 
| 189 | 
            +
                    make_parent_dir (bool): Make parent directory if it doesn't exist.
         | 
| 190 | 
            +
                Returns:
         | 
| 191 | 
            +
                    Path: Path of the saved audio.
         | 
| 192 | 
            +
                """
         | 
| 193 | 
            +
                assert wav.dtype.is_floating_point, "wav is not floating point"
         | 
| 194 | 
            +
                if wav.dim() == 1:
         | 
| 195 | 
            +
                    wav = wav[None]
         | 
| 196 | 
            +
                elif wav.dim() > 2:
         | 
| 197 | 
            +
                    raise ValueError("Input wav should be at most 2 dimension.")
         | 
| 198 | 
            +
                assert wav.isfinite().all()
         | 
| 199 | 
            +
                wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
         | 
| 200 | 
            +
                                      rms_headroom_db, loudness_headroom_db, loudness_compressor,
         | 
| 201 | 
            +
                                      log_clipping=log_clipping, sample_rate=sample_rate,
         | 
| 202 | 
            +
                                      stem_name=str(stem_name))
         | 
| 203 | 
            +
                if format == 'mp3':
         | 
| 204 | 
            +
                    suffix = '.mp3'
         | 
| 205 | 
            +
                    flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
         | 
| 206 | 
            +
                elif format == 'wav':
         | 
| 207 | 
            +
                    suffix = '.wav'
         | 
| 208 | 
            +
                    flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
         | 
| 209 | 
            +
                elif format == 'ogg':
         | 
| 210 | 
            +
                    suffix = '.ogg'
         | 
| 211 | 
            +
                    flags = ['-f', 'ogg', '-c:a', 'libvorbis']
         | 
| 212 | 
            +
                    if ogg_rate is not None:
         | 
| 213 | 
            +
                        flags += ['-b:a', f'{ogg_rate}k']
         | 
| 214 | 
            +
                elif format == 'flac':
         | 
| 215 | 
            +
                    suffix = '.flac'
         | 
| 216 | 
            +
                    flags = ['-f', 'flac']
         | 
| 217 | 
            +
                else:
         | 
| 218 | 
            +
                    raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
         | 
| 219 | 
            +
                if not add_suffix:
         | 
| 220 | 
            +
                    suffix = ''
         | 
| 221 | 
            +
                path = Path(str(stem_name) + suffix)
         | 
| 222 | 
            +
                if make_parent_dir:
         | 
| 223 | 
            +
                    path.parent.mkdir(exist_ok=True, parents=True)
         | 
| 224 | 
            +
                try:
         | 
| 225 | 
            +
                    _piping_to_ffmpeg(path, wav, sample_rate, flags)
         | 
| 226 | 
            +
                except Exception:
         | 
| 227 | 
            +
                    if path.exists():
         | 
| 228 | 
            +
                        # we do not want to leave half written files around.
         | 
| 229 | 
            +
                        path.unlink()
         | 
| 230 | 
            +
                    raise
         | 
| 231 | 
            +
                return path
         | 
    	
        audiocraft/data/audio_dataset.py
    ADDED
    
    | @@ -0,0 +1,587 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """AudioDataset support. In order to handle a larger number of files
         | 
| 7 | 
            +
            without having to scan again the folders, we precompute some metadata
         | 
| 8 | 
            +
            (filename, sample rate, duration), and use that to efficiently sample audio segments.
         | 
| 9 | 
            +
            """
         | 
| 10 | 
            +
            import argparse
         | 
| 11 | 
            +
            import copy
         | 
| 12 | 
            +
            from concurrent.futures import ThreadPoolExecutor, Future
         | 
| 13 | 
            +
            from dataclasses import dataclass, fields
         | 
| 14 | 
            +
            from contextlib import ExitStack
         | 
| 15 | 
            +
            from functools import lru_cache
         | 
| 16 | 
            +
            import gzip
         | 
| 17 | 
            +
            import json
         | 
| 18 | 
            +
            import logging
         | 
| 19 | 
            +
            import os
         | 
| 20 | 
            +
            from pathlib import Path
         | 
| 21 | 
            +
            import random
         | 
| 22 | 
            +
            import sys
         | 
| 23 | 
            +
            import typing as tp
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            import torch
         | 
| 26 | 
            +
            import torch.nn.functional as F
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from .audio import audio_read, audio_info
         | 
| 29 | 
            +
            from .audio_utils import convert_audio
         | 
| 30 | 
            +
            from .zip import PathInZip
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            try:
         | 
| 33 | 
            +
                import dora
         | 
| 34 | 
            +
            except ImportError:
         | 
| 35 | 
            +
                dora = None  # type: ignore
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            @dataclass(order=True)
         | 
| 39 | 
            +
            class BaseInfo:
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @classmethod
         | 
| 42 | 
            +
                def _dict2fields(cls, dictionary: dict):
         | 
| 43 | 
            +
                    return {
         | 
| 44 | 
            +
                        field.name: dictionary[field.name]
         | 
| 45 | 
            +
                        for field in fields(cls) if field.name in dictionary
         | 
| 46 | 
            +
                    }
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                @classmethod
         | 
| 49 | 
            +
                def from_dict(cls, dictionary: dict):
         | 
| 50 | 
            +
                    _dictionary = cls._dict2fields(dictionary)
         | 
| 51 | 
            +
                    return cls(**_dictionary)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def to_dict(self):
         | 
| 54 | 
            +
                    return {
         | 
| 55 | 
            +
                        field.name: self.__getattribute__(field.name)
         | 
| 56 | 
            +
                        for field in fields(self)
         | 
| 57 | 
            +
                        }
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            @dataclass(order=True)
         | 
| 61 | 
            +
            class AudioMeta(BaseInfo):
         | 
| 62 | 
            +
                path: str
         | 
| 63 | 
            +
                duration: float
         | 
| 64 | 
            +
                sample_rate: int
         | 
| 65 | 
            +
                amplitude: tp.Optional[float] = None
         | 
| 66 | 
            +
                weight: tp.Optional[float] = None
         | 
| 67 | 
            +
                # info_path is used to load additional information about the audio file that is stored in zip files.
         | 
| 68 | 
            +
                info_path: tp.Optional[PathInZip] = None
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                @classmethod
         | 
| 71 | 
            +
                def from_dict(cls, dictionary: dict):
         | 
| 72 | 
            +
                    base = cls._dict2fields(dictionary)
         | 
| 73 | 
            +
                    if 'info_path' in base and base['info_path'] is not None:
         | 
| 74 | 
            +
                        base['info_path'] = PathInZip(base['info_path'])
         | 
| 75 | 
            +
                    return cls(**base)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def to_dict(self):
         | 
| 78 | 
            +
                    d = super().to_dict()
         | 
| 79 | 
            +
                    if d['info_path'] is not None:
         | 
| 80 | 
            +
                        d['info_path'] = str(d['info_path'])
         | 
| 81 | 
            +
                    return d
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            @dataclass(order=True)
         | 
| 85 | 
            +
            class SegmentInfo(BaseInfo):
         | 
| 86 | 
            +
                meta: AudioMeta
         | 
| 87 | 
            +
                seek_time: float
         | 
| 88 | 
            +
                # The following values are given once the audio is processed, e.g.
         | 
| 89 | 
            +
                # at the target sample rate and target number of channels.
         | 
| 90 | 
            +
                n_frames: int      # actual number of frames without padding
         | 
| 91 | 
            +
                total_frames: int  # total number of frames, padding included
         | 
| 92 | 
            +
                sample_rate: int   # actual sample rate
         | 
| 93 | 
            +
                channels: int      # number of audio channels.
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
         | 
| 102 | 
            +
                """AudioMeta from a path to an audio file.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Args:
         | 
| 105 | 
            +
                    file_path (str): Resolved path of valid audio file.
         | 
| 106 | 
            +
                    minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
         | 
| 107 | 
            +
                Returns:
         | 
| 108 | 
            +
                    AudioMeta: Audio file path and its metadata.
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                info = audio_info(file_path)
         | 
| 111 | 
            +
                amplitude: tp.Optional[float] = None
         | 
| 112 | 
            +
                if not minimal:
         | 
| 113 | 
            +
                    wav, sr = audio_read(file_path)
         | 
| 114 | 
            +
                    amplitude = wav.abs().max().item()
         | 
| 115 | 
            +
                return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
             | 
| 118 | 
            +
            def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
         | 
| 119 | 
            +
                """If Dora is available as a dependency, try to resolve potential relative paths
         | 
| 120 | 
            +
                in list of AudioMeta. This method is expected to be used when loading meta from file.
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                Args:
         | 
| 123 | 
            +
                    m (AudioMeta): Audio meta to resolve.
         | 
| 124 | 
            +
                    fast (bool): If True, uses a really fast check for determining if a file
         | 
| 125 | 
            +
                        is already absolute or not. Only valid on Linux/Mac.
         | 
| 126 | 
            +
                Returns:
         | 
| 127 | 
            +
                    AudioMeta: Audio meta with resolved path.
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
                def is_abs(m):
         | 
| 130 | 
            +
                    if fast:
         | 
| 131 | 
            +
                        return str(m)[0] == '/'
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        os.path.isabs(str(m))
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                if not dora:
         | 
| 136 | 
            +
                    return m
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                if not is_abs(m.path):
         | 
| 139 | 
            +
                    m.path = dora.git_save.to_absolute_path(m.path)
         | 
| 140 | 
            +
                if m.info_path is not None and not is_abs(m.info_path.zip_path):
         | 
| 141 | 
            +
                    m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
         | 
| 142 | 
            +
                return m
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def find_audio_files(path: tp.Union[Path, str],
         | 
| 146 | 
            +
                                 exts: tp.List[str] = DEFAULT_EXTS,
         | 
| 147 | 
            +
                                 resolve: bool = True,
         | 
| 148 | 
            +
                                 minimal: bool = True,
         | 
| 149 | 
            +
                                 progress: bool = False,
         | 
| 150 | 
            +
                                 workers: int = 0) -> tp.List[AudioMeta]:
         | 
| 151 | 
            +
                """Build a list of AudioMeta from a given path,
         | 
| 152 | 
            +
                collecting relevant audio files and fetching meta info.
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                Args:
         | 
| 155 | 
            +
                    path (str or Path): Path to folder containing audio files.
         | 
| 156 | 
            +
                    exts (list of str): List of file extensions to consider for audio files.
         | 
| 157 | 
            +
                    minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
         | 
| 158 | 
            +
                    progress (bool): Whether to log progress on audio files collection.
         | 
| 159 | 
            +
                    workers (int): number of parallel workers, if 0, use only the current thread.
         | 
| 160 | 
            +
                Returns:
         | 
| 161 | 
            +
                    list of AudioMeta: List of audio file path and its metadata.
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                audio_files = []
         | 
| 164 | 
            +
                futures: tp.List[Future] = []
         | 
| 165 | 
            +
                pool: tp.Optional[ThreadPoolExecutor] = None
         | 
| 166 | 
            +
                with ExitStack() as stack:
         | 
| 167 | 
            +
                    if workers > 0:
         | 
| 168 | 
            +
                        pool = ThreadPoolExecutor(workers)
         | 
| 169 | 
            +
                        stack.enter_context(pool)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    if progress:
         | 
| 172 | 
            +
                        print("Finding audio files...")
         | 
| 173 | 
            +
                    for root, folders, files in os.walk(path, followlinks=True):
         | 
| 174 | 
            +
                        for file in files:
         | 
| 175 | 
            +
                            full_path = Path(root) / file
         | 
| 176 | 
            +
                            if full_path.suffix.lower() in exts:
         | 
| 177 | 
            +
                                audio_files.append(full_path)
         | 
| 178 | 
            +
                                if pool is not None:
         | 
| 179 | 
            +
                                    futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
         | 
| 180 | 
            +
                                if progress:
         | 
| 181 | 
            +
                                    print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    if progress:
         | 
| 184 | 
            +
                        print("Getting audio metadata...")
         | 
| 185 | 
            +
                    meta: tp.List[AudioMeta] = []
         | 
| 186 | 
            +
                    for idx, file_path in enumerate(audio_files):
         | 
| 187 | 
            +
                        try:
         | 
| 188 | 
            +
                            if pool is None:
         | 
| 189 | 
            +
                                m = _get_audio_meta(str(file_path), minimal)
         | 
| 190 | 
            +
                            else:
         | 
| 191 | 
            +
                                m = futures[idx].result()
         | 
| 192 | 
            +
                            if resolve:
         | 
| 193 | 
            +
                                m = _resolve_audio_meta(m)
         | 
| 194 | 
            +
                        except Exception as err:
         | 
| 195 | 
            +
                            print("Error with", str(file_path), err, file=sys.stderr)
         | 
| 196 | 
            +
                            continue
         | 
| 197 | 
            +
                        meta.append(m)
         | 
| 198 | 
            +
                        if progress:
         | 
| 199 | 
            +
                            print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
         | 
| 200 | 
            +
                meta.sort()
         | 
| 201 | 
            +
                return meta
         | 
| 202 | 
            +
             | 
| 203 | 
            +
             | 
| 204 | 
            +
            def load_audio_meta(path: tp.Union[str, Path],
         | 
| 205 | 
            +
                                resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
         | 
| 206 | 
            +
                """Load list of AudioMeta from an optionally compressed json file.
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                Args:
         | 
| 209 | 
            +
                    path (str or Path): Path to JSON file.
         | 
| 210 | 
            +
                    resolve (bool): Whether to resolve the path from AudioMeta (default=True).
         | 
| 211 | 
            +
                    fast (bool): activates some tricks to make things faster.
         | 
| 212 | 
            +
                Returns:
         | 
| 213 | 
            +
                    list of AudioMeta: List of audio file path and its total duration.
         | 
| 214 | 
            +
                """
         | 
| 215 | 
            +
                open_fn = gzip.open if str(path).lower().endswith('.gz') else open
         | 
| 216 | 
            +
                with open_fn(path, 'rb') as fp:  # type: ignore
         | 
| 217 | 
            +
                    lines = fp.readlines()
         | 
| 218 | 
            +
                meta = []
         | 
| 219 | 
            +
                for line in lines:
         | 
| 220 | 
            +
                    d = json.loads(line)
         | 
| 221 | 
            +
                    m = AudioMeta.from_dict(d)
         | 
| 222 | 
            +
                    if resolve:
         | 
| 223 | 
            +
                        m = _resolve_audio_meta(m, fast=fast)
         | 
| 224 | 
            +
                    meta.append(m)
         | 
| 225 | 
            +
                return meta
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
         | 
| 229 | 
            +
                """Save the audio metadata to the file pointer as json.
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                Args:
         | 
| 232 | 
            +
                    path (str or Path): Path to JSON file.
         | 
| 233 | 
            +
                    metadata (list of BaseAudioMeta): List of audio meta to save.
         | 
| 234 | 
            +
                """
         | 
| 235 | 
            +
                Path(path).parent.mkdir(exist_ok=True, parents=True)
         | 
| 236 | 
            +
                open_fn = gzip.open if str(path).lower().endswith('.gz') else open
         | 
| 237 | 
            +
                with open_fn(path, 'wb') as fp:  # type: ignore
         | 
| 238 | 
            +
                    for m in meta:
         | 
| 239 | 
            +
                        json_str = json.dumps(m.to_dict()) + '\n'
         | 
| 240 | 
            +
                        json_bytes = json_str.encode('utf-8')
         | 
| 241 | 
            +
                        fp.write(json_bytes)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            class AudioDataset:
         | 
| 245 | 
            +
                """Base audio dataset.
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
         | 
| 248 | 
            +
                and potentially additional information, by creating random segments from the list of audio
         | 
| 249 | 
            +
                files referenced in the metadata and applying minimal data pre-processing such as resampling,
         | 
| 250 | 
            +
                mixing of channels, padding, etc.
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                If no segment_duration value is provided, the AudioDataset will return the full wav for each
         | 
| 253 | 
            +
                audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
         | 
| 254 | 
            +
                duration, applying padding if required.
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
         | 
| 257 | 
            +
                allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
         | 
| 258 | 
            +
                original audio meta.
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                Note that you can call `start_epoch(epoch)` in order to get
         | 
| 261 | 
            +
                a deterministic "randomization" for `shuffle=True`.
         | 
| 262 | 
            +
                For a given epoch and dataset index, this will always return the same extract.
         | 
| 263 | 
            +
                You can get back some diversity by setting the `shuffle_seed` param.
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                Args:
         | 
| 266 | 
            +
                    meta (list of AudioMeta): List of audio files metadata.
         | 
| 267 | 
            +
                    segment_duration (float, optional): Optional segment duration of audio to load.
         | 
| 268 | 
            +
                        If not specified, the dataset will load the full audio segment from the file.
         | 
| 269 | 
            +
                    shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
         | 
| 270 | 
            +
                    sample_rate (int): Target sample rate of the loaded audio samples.
         | 
| 271 | 
            +
                    channels (int): Target number of channels of the loaded audio samples.
         | 
| 272 | 
            +
                    sample_on_duration (bool): Set to `True` to sample segments with probability
         | 
| 273 | 
            +
                        dependent on audio file duration. This is only used if `segment_duration` is provided.
         | 
| 274 | 
            +
                    sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
         | 
| 275 | 
            +
                        `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
         | 
| 276 | 
            +
                        of the file duration and file weight. This is only used if `segment_duration` is provided.
         | 
| 277 | 
            +
                    min_segment_ratio (float): Minimum segment ratio to use when the audio file
         | 
| 278 | 
            +
                        is shorter than the desired segment.
         | 
| 279 | 
            +
                    max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
         | 
| 280 | 
            +
                    return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
         | 
| 281 | 
            +
                    min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
         | 
| 282 | 
            +
                        audio shorter than this will be filtered out.
         | 
| 283 | 
            +
                    max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
         | 
| 284 | 
            +
                        audio longer than this will be filtered out.
         | 
| 285 | 
            +
                    shuffle_seed (int): can be used to further randomize
         | 
| 286 | 
            +
                    load_wav (bool): if False, skip loading the wav but returns a tensor of 0
         | 
| 287 | 
            +
                        with the expected segment_duration (which must be provided if load_wav is False).
         | 
| 288 | 
            +
                    permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
         | 
| 289 | 
            +
                        are False. Will ensure a permutation on files when going through the dataset.
         | 
| 290 | 
            +
                        In that case the epoch number must be provided in order for the model
         | 
| 291 | 
            +
                        to continue the permutation across epochs. In that case, it is assumed
         | 
| 292 | 
            +
                        that `num_samples = total_batch_size * num_updates_per_epoch`, with
         | 
| 293 | 
            +
                        `total_batch_size` the overall batch size accounting for all gpus.
         | 
| 294 | 
            +
                """
         | 
| 295 | 
            +
                def __init__(self,
         | 
| 296 | 
            +
                             meta: tp.List[AudioMeta],
         | 
| 297 | 
            +
                             segment_duration: tp.Optional[float] = None,
         | 
| 298 | 
            +
                             shuffle: bool = True,
         | 
| 299 | 
            +
                             num_samples: int = 10_000,
         | 
| 300 | 
            +
                             sample_rate: int = 48_000,
         | 
| 301 | 
            +
                             channels: int = 2,
         | 
| 302 | 
            +
                             pad: bool = True,
         | 
| 303 | 
            +
                             sample_on_duration: bool = True,
         | 
| 304 | 
            +
                             sample_on_weight: bool = True,
         | 
| 305 | 
            +
                             min_segment_ratio: float = 0.5,
         | 
| 306 | 
            +
                             max_read_retry: int = 10,
         | 
| 307 | 
            +
                             return_info: bool = False,
         | 
| 308 | 
            +
                             min_audio_duration: tp.Optional[float] = None,
         | 
| 309 | 
            +
                             max_audio_duration: tp.Optional[float] = None,
         | 
| 310 | 
            +
                             shuffle_seed: int = 0,
         | 
| 311 | 
            +
                             load_wav: bool = True,
         | 
| 312 | 
            +
                             permutation_on_files: bool = False,
         | 
| 313 | 
            +
                             ):
         | 
| 314 | 
            +
                    assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
         | 
| 315 | 
            +
                    assert segment_duration is None or segment_duration > 0
         | 
| 316 | 
            +
                    assert segment_duration is None or min_segment_ratio >= 0
         | 
| 317 | 
            +
                    self.segment_duration = segment_duration
         | 
| 318 | 
            +
                    self.min_segment_ratio = min_segment_ratio
         | 
| 319 | 
            +
                    self.max_audio_duration = max_audio_duration
         | 
| 320 | 
            +
                    self.min_audio_duration = min_audio_duration
         | 
| 321 | 
            +
                    if self.min_audio_duration is not None and self.max_audio_duration is not None:
         | 
| 322 | 
            +
                        assert self.min_audio_duration <= self.max_audio_duration
         | 
| 323 | 
            +
                    self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
         | 
| 324 | 
            +
                    assert len(self.meta)  # Fail fast if all data has been filtered.
         | 
| 325 | 
            +
                    self.total_duration = sum(d.duration for d in self.meta)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    if segment_duration is None:
         | 
| 328 | 
            +
                        num_samples = len(self.meta)
         | 
| 329 | 
            +
                    self.num_samples = num_samples
         | 
| 330 | 
            +
                    self.shuffle = shuffle
         | 
| 331 | 
            +
                    self.sample_rate = sample_rate
         | 
| 332 | 
            +
                    self.channels = channels
         | 
| 333 | 
            +
                    self.pad = pad
         | 
| 334 | 
            +
                    self.sample_on_weight = sample_on_weight
         | 
| 335 | 
            +
                    self.sample_on_duration = sample_on_duration
         | 
| 336 | 
            +
                    self.sampling_probabilities = self._get_sampling_probabilities()
         | 
| 337 | 
            +
                    self.max_read_retry = max_read_retry
         | 
| 338 | 
            +
                    self.return_info = return_info
         | 
| 339 | 
            +
                    self.shuffle_seed = shuffle_seed
         | 
| 340 | 
            +
                    self.current_epoch: tp.Optional[int] = None
         | 
| 341 | 
            +
                    self.load_wav = load_wav
         | 
| 342 | 
            +
                    if not load_wav:
         | 
| 343 | 
            +
                        assert segment_duration is not None
         | 
| 344 | 
            +
                    self.permutation_on_files = permutation_on_files
         | 
| 345 | 
            +
                    if permutation_on_files:
         | 
| 346 | 
            +
                        assert not self.sample_on_duration
         | 
| 347 | 
            +
                        assert not self.sample_on_weight
         | 
| 348 | 
            +
                        assert self.shuffle
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def start_epoch(self, epoch: int):
         | 
| 351 | 
            +
                    self.current_epoch = epoch
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                def __len__(self):
         | 
| 354 | 
            +
                    return self.num_samples
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                def _get_sampling_probabilities(self, normalized: bool = True):
         | 
| 357 | 
            +
                    """Return the sampling probabilities for each file inside `self.meta`."""
         | 
| 358 | 
            +
                    scores: tp.List[float] = []
         | 
| 359 | 
            +
                    for file_meta in self.meta:
         | 
| 360 | 
            +
                        score = 1.
         | 
| 361 | 
            +
                        if self.sample_on_weight and file_meta.weight is not None:
         | 
| 362 | 
            +
                            score *= file_meta.weight
         | 
| 363 | 
            +
                        if self.sample_on_duration:
         | 
| 364 | 
            +
                            score *= file_meta.duration
         | 
| 365 | 
            +
                        scores.append(score)
         | 
| 366 | 
            +
                    probabilities = torch.tensor(scores)
         | 
| 367 | 
            +
                    if normalized:
         | 
| 368 | 
            +
                        probabilities /= probabilities.sum()
         | 
| 369 | 
            +
                    return probabilities
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                @staticmethod
         | 
| 372 | 
            +
                @lru_cache(16)
         | 
| 373 | 
            +
                def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
         | 
| 374 | 
            +
                    # Used to keep the most recent files permutation in memory implicitely.
         | 
| 375 | 
            +
                    # will work unless someone is using a lot of Datasets in parallel.
         | 
| 376 | 
            +
                    rng = torch.Generator()
         | 
| 377 | 
            +
                    rng.manual_seed(base_seed + permutation_index)
         | 
| 378 | 
            +
                    return torch.randperm(num_files, generator=rng)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
         | 
| 381 | 
            +
                    """Sample a given file from `self.meta`. Can be overridden in subclasses.
         | 
| 382 | 
            +
                    This is only called if `segment_duration` is not None.
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    You must use the provided random number generator `rng` for reproducibility.
         | 
| 385 | 
            +
                    You can further make use of the index accessed.
         | 
| 386 | 
            +
                    """
         | 
| 387 | 
            +
                    if self.permutation_on_files:
         | 
| 388 | 
            +
                        assert self.current_epoch is not None
         | 
| 389 | 
            +
                        total_index = self.current_epoch * len(self) + index
         | 
| 390 | 
            +
                        permutation_index = total_index // len(self.meta)
         | 
| 391 | 
            +
                        relative_index = total_index % len(self.meta)
         | 
| 392 | 
            +
                        permutation = AudioDataset._get_file_permutation(
         | 
| 393 | 
            +
                            len(self.meta), permutation_index, self.shuffle_seed)
         | 
| 394 | 
            +
                        file_index = permutation[relative_index]
         | 
| 395 | 
            +
                        return self.meta[file_index]
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    if not self.sample_on_weight and not self.sample_on_duration:
         | 
| 398 | 
            +
                        file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
         | 
| 399 | 
            +
                    else:
         | 
| 400 | 
            +
                        file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    return self.meta[file_index]
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
         | 
| 405 | 
            +
                    # Override this method in subclass if needed.
         | 
| 406 | 
            +
                    if self.load_wav:
         | 
| 407 | 
            +
                        return audio_read(path, seek_time, duration, pad=False)
         | 
| 408 | 
            +
                    else:
         | 
| 409 | 
            +
                        assert self.segment_duration is not None
         | 
| 410 | 
            +
                        n_frames = int(self.sample_rate * self.segment_duration)
         | 
| 411 | 
            +
                        return torch.zeros(self.channels, n_frames), self.sample_rate
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
         | 
| 414 | 
            +
                    if self.segment_duration is None:
         | 
| 415 | 
            +
                        file_meta = self.meta[index]
         | 
| 416 | 
            +
                        out, sr = audio_read(file_meta.path)
         | 
| 417 | 
            +
                        out = convert_audio(out, sr, self.sample_rate, self.channels)
         | 
| 418 | 
            +
                        n_frames = out.shape[-1]
         | 
| 419 | 
            +
                        segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
         | 
| 420 | 
            +
                                                   sample_rate=self.sample_rate, channels=out.shape[0])
         | 
| 421 | 
            +
                    else:
         | 
| 422 | 
            +
                        rng = torch.Generator()
         | 
| 423 | 
            +
                        if self.shuffle:
         | 
| 424 | 
            +
                            # We use index, plus extra randomness, either totally random if we don't know the epoch.
         | 
| 425 | 
            +
                            # otherwise we make use of the epoch number and optional shuffle_seed.
         | 
| 426 | 
            +
                            if self.current_epoch is None:
         | 
| 427 | 
            +
                                rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
         | 
| 428 | 
            +
                            else:
         | 
| 429 | 
            +
                                rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
         | 
| 430 | 
            +
                        else:
         | 
| 431 | 
            +
                            # We only use index
         | 
| 432 | 
            +
                            rng.manual_seed(index)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                        for retry in range(self.max_read_retry):
         | 
| 435 | 
            +
                            file_meta = self.sample_file(index, rng)
         | 
| 436 | 
            +
                            # We add some variance in the file position even if audio file is smaller than segment
         | 
| 437 | 
            +
                            # without ending up with empty segments
         | 
| 438 | 
            +
                            max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
         | 
| 439 | 
            +
                            seek_time = torch.rand(1, generator=rng).item() * max_seek
         | 
| 440 | 
            +
                            try:
         | 
| 441 | 
            +
                                out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
         | 
| 442 | 
            +
                                out = convert_audio(out, sr, self.sample_rate, self.channels)
         | 
| 443 | 
            +
                                n_frames = out.shape[-1]
         | 
| 444 | 
            +
                                target_frames = int(self.segment_duration * self.sample_rate)
         | 
| 445 | 
            +
                                if self.pad:
         | 
| 446 | 
            +
                                    out = F.pad(out, (0, target_frames - n_frames))
         | 
| 447 | 
            +
                                segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
         | 
| 448 | 
            +
                                                           sample_rate=self.sample_rate, channels=out.shape[0])
         | 
| 449 | 
            +
                            except Exception as exc:
         | 
| 450 | 
            +
                                logger.warning("Error opening file %s: %r", file_meta.path, exc)
         | 
| 451 | 
            +
                                if retry == self.max_read_retry - 1:
         | 
| 452 | 
            +
                                    raise
         | 
| 453 | 
            +
                            else:
         | 
| 454 | 
            +
                                break
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    if self.return_info:
         | 
| 457 | 
            +
                        # Returns the wav and additional information on the wave segment
         | 
| 458 | 
            +
                        return out, segment_info
         | 
| 459 | 
            +
                    else:
         | 
| 460 | 
            +
                        return out
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                def collater(self, samples):
         | 
| 463 | 
            +
                    """The collater function has to be provided to the dataloader
         | 
| 464 | 
            +
                    if AudioDataset has return_info=True in order to properly collate
         | 
| 465 | 
            +
                    the samples of a batch.
         | 
| 466 | 
            +
                    """
         | 
| 467 | 
            +
                    if self.segment_duration is None and len(samples) > 1:
         | 
| 468 | 
            +
                        assert self.pad, "Must allow padding when batching examples of different durations."
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    # In this case the audio reaching the collater is of variable length as segment_duration=None.
         | 
| 471 | 
            +
                    to_pad = self.segment_duration is None and self.pad
         | 
| 472 | 
            +
                    if to_pad:
         | 
| 473 | 
            +
                        max_len = max([wav.shape[-1] for wav, _ in samples])
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                        def _pad_wav(wav):
         | 
| 476 | 
            +
                            return F.pad(wav, (0, max_len - wav.shape[-1]))
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                    if self.return_info:
         | 
| 479 | 
            +
                        if len(samples) > 0:
         | 
| 480 | 
            +
                            assert len(samples[0]) == 2
         | 
| 481 | 
            +
                            assert isinstance(samples[0][0], torch.Tensor)
         | 
| 482 | 
            +
                            assert isinstance(samples[0][1], SegmentInfo)
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                        wavs = [wav for wav, _ in samples]
         | 
| 485 | 
            +
                        segment_infos = [copy.deepcopy(info) for _, info in samples]
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                        if to_pad:
         | 
| 488 | 
            +
                            # Each wav could be of a different duration as they are not segmented.
         | 
| 489 | 
            +
                            for i in range(len(samples)):
         | 
| 490 | 
            +
                                # Determines the total length of the signal with padding, so we update here as we pad.
         | 
| 491 | 
            +
                                segment_infos[i].total_frames = max_len
         | 
| 492 | 
            +
                                wavs[i] = _pad_wav(wavs[i])
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                        wav = torch.stack(wavs)
         | 
| 495 | 
            +
                        return wav, segment_infos
         | 
| 496 | 
            +
                    else:
         | 
| 497 | 
            +
                        assert isinstance(samples[0], torch.Tensor)
         | 
| 498 | 
            +
                        if to_pad:
         | 
| 499 | 
            +
                            samples = [_pad_wav(s) for s in samples]
         | 
| 500 | 
            +
                        return torch.stack(samples)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
         | 
| 503 | 
            +
                    """Filters out audio files with audio durations that will not allow to sample examples from them."""
         | 
| 504 | 
            +
                    orig_len = len(meta)
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    # Filter data that is too short.
         | 
| 507 | 
            +
                    if self.min_audio_duration is not None:
         | 
| 508 | 
            +
                        meta = [m for m in meta if m.duration >= self.min_audio_duration]
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    # Filter data that is too long.
         | 
| 511 | 
            +
                    if self.max_audio_duration is not None:
         | 
| 512 | 
            +
                        meta = [m for m in meta if m.duration <= self.max_audio_duration]
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    filtered_len = len(meta)
         | 
| 515 | 
            +
                    removed_percentage = 100*(1-float(filtered_len)/orig_len)
         | 
| 516 | 
            +
                    msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
         | 
| 517 | 
            +
                    if removed_percentage < 10:
         | 
| 518 | 
            +
                        logging.debug(msg)
         | 
| 519 | 
            +
                    else:
         | 
| 520 | 
            +
                        logging.warning(msg)
         | 
| 521 | 
            +
                    return meta
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                @classmethod
         | 
| 524 | 
            +
                def from_meta(cls, root: tp.Union[str, Path], **kwargs):
         | 
| 525 | 
            +
                    """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    Args:
         | 
| 528 | 
            +
                        root (str or Path): Path to root folder containing audio files.
         | 
| 529 | 
            +
                        kwargs: Additional keyword arguments for the AudioDataset.
         | 
| 530 | 
            +
                    """
         | 
| 531 | 
            +
                    root = Path(root)
         | 
| 532 | 
            +
                    if root.is_dir():
         | 
| 533 | 
            +
                        if (root / 'data.jsonl').exists():
         | 
| 534 | 
            +
                            root = root / 'data.jsonl'
         | 
| 535 | 
            +
                        elif (root / 'data.jsonl.gz').exists():
         | 
| 536 | 
            +
                            root = root / 'data.jsonl.gz'
         | 
| 537 | 
            +
                        else:
         | 
| 538 | 
            +
                            raise ValueError("Don't know where to read metadata from in the dir. "
         | 
| 539 | 
            +
                                             "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
         | 
| 540 | 
            +
                    meta = load_audio_meta(root)
         | 
| 541 | 
            +
                    return cls(meta, **kwargs)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                @classmethod
         | 
| 544 | 
            +
                def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
         | 
| 545 | 
            +
                              exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
         | 
| 546 | 
            +
                    """Instantiate AudioDataset from a path containing (possibly nested) audio files.
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    Args:
         | 
| 549 | 
            +
                        root (str or Path): Path to root folder containing audio files.
         | 
| 550 | 
            +
                        minimal_meta (bool): Whether to only load minimal metadata or not.
         | 
| 551 | 
            +
                        exts (list of str): Extensions for audio files.
         | 
| 552 | 
            +
                        kwargs: Additional keyword arguments for the AudioDataset.
         | 
| 553 | 
            +
                    """
         | 
| 554 | 
            +
                    root = Path(root)
         | 
| 555 | 
            +
                    if root.is_file():
         | 
| 556 | 
            +
                        meta = load_audio_meta(root, resolve=True)
         | 
| 557 | 
            +
                    else:
         | 
| 558 | 
            +
                        meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
         | 
| 559 | 
            +
                    return cls(meta, **kwargs)
         | 
| 560 | 
            +
             | 
| 561 | 
            +
             | 
| 562 | 
            +
            def main():
         | 
| 563 | 
            +
                logging.basicConfig(stream=sys.stderr, level=logging.INFO)
         | 
| 564 | 
            +
                parser = argparse.ArgumentParser(
         | 
| 565 | 
            +
                    prog='audio_dataset',
         | 
| 566 | 
            +
                    description='Generate .jsonl files by scanning a folder.')
         | 
| 567 | 
            +
                parser.add_argument('root', help='Root folder with all the audio files')
         | 
| 568 | 
            +
                parser.add_argument('output_meta_file',
         | 
| 569 | 
            +
                                    help='Output file to store the metadata, ')
         | 
| 570 | 
            +
                parser.add_argument('--complete',
         | 
| 571 | 
            +
                                    action='store_false', dest='minimal', default=True,
         | 
| 572 | 
            +
                                    help='Retrieve all metadata, even the one that are expansive '
         | 
| 573 | 
            +
                                         'to compute (e.g. normalization).')
         | 
| 574 | 
            +
                parser.add_argument('--resolve',
         | 
| 575 | 
            +
                                    action='store_true', default=False,
         | 
| 576 | 
            +
                                    help='Resolve the paths to be absolute and with no symlinks.')
         | 
| 577 | 
            +
                parser.add_argument('--workers',
         | 
| 578 | 
            +
                                    default=10, type=int,
         | 
| 579 | 
            +
                                    help='Number of workers.')
         | 
| 580 | 
            +
                args = parser.parse_args()
         | 
| 581 | 
            +
                meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
         | 
| 582 | 
            +
                                        resolve=args.resolve, minimal=args.minimal, workers=args.workers)
         | 
| 583 | 
            +
                save_audio_meta(args.output_meta_file, meta)
         | 
| 584 | 
            +
             | 
| 585 | 
            +
             | 
| 586 | 
            +
            if __name__ == '__main__':
         | 
| 587 | 
            +
                main()
         | 
    	
        audiocraft/data/audio_utils.py
    ADDED
    
    | @@ -0,0 +1,176 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Various utilities for audio convertion (pcm format, sample rate and channels),
         | 
| 7 | 
            +
            and volume normalization."""
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            import typing as tp
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import julius
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torchaudio
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
         | 
| 17 | 
            +
                """Convert audio to the given number of channels.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                Args:
         | 
| 20 | 
            +
                    wav (torch.Tensor): Audio wave of shape [B, C, T].
         | 
| 21 | 
            +
                    channels (int): Expected number of channels as output.
         | 
| 22 | 
            +
                Returns:
         | 
| 23 | 
            +
                    torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                *shape, src_channels, length = wav.shape
         | 
| 26 | 
            +
                if src_channels == channels:
         | 
| 27 | 
            +
                    pass
         | 
| 28 | 
            +
                elif channels == 1:
         | 
| 29 | 
            +
                    # Case 1:
         | 
| 30 | 
            +
                    # The caller asked 1-channel audio, and the stream has multiple
         | 
| 31 | 
            +
                    # channels, downmix all channels.
         | 
| 32 | 
            +
                    wav = wav.mean(dim=-2, keepdim=True)
         | 
| 33 | 
            +
                elif src_channels == 1:
         | 
| 34 | 
            +
                    # Case 2:
         | 
| 35 | 
            +
                    # The caller asked for multiple channels, but the input file has
         | 
| 36 | 
            +
                    # a single channel, replicate the audio over all channels.
         | 
| 37 | 
            +
                    wav = wav.expand(*shape, channels, length)
         | 
| 38 | 
            +
                elif src_channels >= channels:
         | 
| 39 | 
            +
                    # Case 3:
         | 
| 40 | 
            +
                    # The caller asked for multiple channels, and the input file has
         | 
| 41 | 
            +
                    # more channels than requested. In that case return the first channels.
         | 
| 42 | 
            +
                    wav = wav[..., :channels, :]
         | 
| 43 | 
            +
                else:
         | 
| 44 | 
            +
                    # Case 4: What is a reasonable choice here?
         | 
| 45 | 
            +
                    raise ValueError('The audio file has less channels than requested but is not mono.')
         | 
| 46 | 
            +
                return wav
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def convert_audio(wav: torch.Tensor, from_rate: float,
         | 
| 50 | 
            +
                              to_rate: float, to_channels: int) -> torch.Tensor:
         | 
| 51 | 
            +
                """Convert audio to new sample rate and number of audio channels."""
         | 
| 52 | 
            +
                wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
         | 
| 53 | 
            +
                wav = convert_audio_channels(wav, to_channels)
         | 
| 54 | 
            +
                return wav
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
         | 
| 58 | 
            +
                                   loudness_compressor: bool = False, energy_floor: float = 2e-3):
         | 
| 59 | 
            +
                """Normalize an input signal to a user loudness in dB LKFS.
         | 
| 60 | 
            +
                Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                Args:
         | 
| 63 | 
            +
                    wav (torch.Tensor): Input multichannel audio data.
         | 
| 64 | 
            +
                    sample_rate (int): Sample rate.
         | 
| 65 | 
            +
                    loudness_headroom_db (float): Target loudness of the output in dB LUFS.
         | 
| 66 | 
            +
                    loudness_compressor (bool): Uses tanh for soft clipping.
         | 
| 67 | 
            +
                    energy_floor (float): anything below that RMS level will not be rescaled.
         | 
| 68 | 
            +
                Returns:
         | 
| 69 | 
            +
                    torch.Tensor: Loudness normalized output data.
         | 
| 70 | 
            +
                """
         | 
| 71 | 
            +
                energy = wav.pow(2).mean().sqrt().item()
         | 
| 72 | 
            +
                if energy < energy_floor:
         | 
| 73 | 
            +
                    return wav
         | 
| 74 | 
            +
                transform = torchaudio.transforms.Loudness(sample_rate)
         | 
| 75 | 
            +
                input_loudness_db = transform(wav).item()
         | 
| 76 | 
            +
                # calculate the gain needed to scale to the desired loudness level
         | 
| 77 | 
            +
                delta_loudness = -loudness_headroom_db - input_loudness_db
         | 
| 78 | 
            +
                gain = 10.0 ** (delta_loudness / 20.0)
         | 
| 79 | 
            +
                output = gain * wav
         | 
| 80 | 
            +
                if loudness_compressor:
         | 
| 81 | 
            +
                    output = torch.tanh(output)
         | 
| 82 | 
            +
                assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
         | 
| 83 | 
            +
                return output
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
         | 
| 87 | 
            +
                """Utility function to clip the audio with logging if specified."""
         | 
| 88 | 
            +
                max_scale = wav.abs().max()
         | 
| 89 | 
            +
                if log_clipping and max_scale > 1:
         | 
| 90 | 
            +
                    clamp_prob = (wav.abs() > 1).float().mean().item()
         | 
| 91 | 
            +
                    print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
         | 
| 92 | 
            +
                          clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
         | 
| 93 | 
            +
                wav.clamp_(-1, 1)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def normalize_audio(wav: torch.Tensor, normalize: bool = True,
         | 
| 97 | 
            +
                                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
         | 
| 98 | 
            +
                                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
         | 
| 99 | 
            +
                                loudness_compressor: bool = False, log_clipping: bool = False,
         | 
| 100 | 
            +
                                sample_rate: tp.Optional[int] = None,
         | 
| 101 | 
            +
                                stem_name: tp.Optional[str] = None) -> torch.Tensor:
         | 
| 102 | 
            +
                """Normalize the audio according to the prescribed strategy (see after).
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Args:
         | 
| 105 | 
            +
                    wav (torch.Tensor): Audio data.
         | 
| 106 | 
            +
                    normalize (bool): if `True` (default), normalizes according to the prescribed
         | 
| 107 | 
            +
                        strategy (see after). If `False`, the strategy is only used in case clipping
         | 
| 108 | 
            +
                        would happen.
         | 
| 109 | 
            +
                    strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
         | 
| 110 | 
            +
                        i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
         | 
| 111 | 
            +
                        with extra headroom to avoid clipping. 'clip' just clips.
         | 
| 112 | 
            +
                    peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
         | 
| 113 | 
            +
                    rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
         | 
| 114 | 
            +
                        than the `peak_clip` one to avoid further clipping.
         | 
| 115 | 
            +
                    loudness_headroom_db (float): Target loudness for loudness normalization.
         | 
| 116 | 
            +
                    loudness_compressor (bool): If True, uses tanh based soft clipping.
         | 
| 117 | 
            +
                    log_clipping (bool): If True, basic logging on stderr when clipping still
         | 
| 118 | 
            +
                        occurs despite strategy (only for 'rms').
         | 
| 119 | 
            +
                    sample_rate (int): Sample rate for the audio data (required for loudness).
         | 
| 120 | 
            +
                    stem_name (str, optional): Stem name for clipping logging.
         | 
| 121 | 
            +
                Returns:
         | 
| 122 | 
            +
                    torch.Tensor: Normalized audio.
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                scale_peak = 10 ** (-peak_clip_headroom_db / 20)
         | 
| 125 | 
            +
                scale_rms = 10 ** (-rms_headroom_db / 20)
         | 
| 126 | 
            +
                if strategy == 'peak':
         | 
| 127 | 
            +
                    rescaling = (scale_peak / wav.abs().max())
         | 
| 128 | 
            +
                    if normalize or rescaling < 1:
         | 
| 129 | 
            +
                        wav = wav * rescaling
         | 
| 130 | 
            +
                elif strategy == 'clip':
         | 
| 131 | 
            +
                    wav = wav.clamp(-scale_peak, scale_peak)
         | 
| 132 | 
            +
                elif strategy == 'rms':
         | 
| 133 | 
            +
                    mono = wav.mean(dim=0)
         | 
| 134 | 
            +
                    rescaling = scale_rms / mono.pow(2).mean().sqrt()
         | 
| 135 | 
            +
                    if normalize or rescaling < 1:
         | 
| 136 | 
            +
                        wav = wav * rescaling
         | 
| 137 | 
            +
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 138 | 
            +
                elif strategy == 'loudness':
         | 
| 139 | 
            +
                    assert sample_rate is not None, "Loudness normalization requires sample rate."
         | 
| 140 | 
            +
                    wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
         | 
| 141 | 
            +
                    _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                    assert wav.abs().max() < 1
         | 
| 144 | 
            +
                    assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
         | 
| 145 | 
            +
                return wav
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
         | 
| 149 | 
            +
                """Convert audio to float 32 bits PCM format.
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                if wav.dtype.is_floating_point:
         | 
| 152 | 
            +
                    return wav
         | 
| 153 | 
            +
                elif wav.dtype == torch.int16:
         | 
| 154 | 
            +
                    return wav.float() / 2**15
         | 
| 155 | 
            +
                elif wav.dtype == torch.int32:
         | 
| 156 | 
            +
                    return wav.float() / 2**31
         | 
| 157 | 
            +
                raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
         | 
| 161 | 
            +
                """Convert audio to int 16 bits PCM format.
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                ..Warning:: There exist many formula for doing this conversion. None are perfect
         | 
| 164 | 
            +
                due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
         | 
| 165 | 
            +
                or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
         | 
| 166 | 
            +
                it is possible that `i16_pcm(f32_pcm)) != Identity`.
         | 
| 167 | 
            +
                """
         | 
| 168 | 
            +
                if wav.dtype.is_floating_point:
         | 
| 169 | 
            +
                    assert wav.abs().max() <= 1
         | 
| 170 | 
            +
                    candidate = (wav * 2 ** 15).round()
         | 
| 171 | 
            +
                    if candidate.max() >= 2 ** 15:  # clipping would occur
         | 
| 172 | 
            +
                        candidate = (wav * (2 ** 15 - 1)).round()
         | 
| 173 | 
            +
                    return candidate.short()
         | 
| 174 | 
            +
                else:
         | 
| 175 | 
            +
                    assert wav.dtype == torch.int16
         | 
| 176 | 
            +
                    return wav
         | 
    	
        audiocraft/data/info_audio_dataset.py
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Base classes for the datasets that also provide non-audio metadata,
         | 
| 7 | 
            +
            e.g. description, text transcription etc.
         | 
| 8 | 
            +
            """
         | 
| 9 | 
            +
            from dataclasses import dataclass
         | 
| 10 | 
            +
            import logging
         | 
| 11 | 
            +
            import math
         | 
| 12 | 
            +
            import re
         | 
| 13 | 
            +
            import typing as tp
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from .audio_dataset import AudioDataset, AudioMeta
         | 
| 18 | 
            +
            from ..environment import AudioCraftEnvironment
         | 
| 19 | 
            +
            from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
         | 
| 26 | 
            +
                """Monkey-patch meta to match cluster specificities."""
         | 
| 27 | 
            +
                meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
         | 
| 28 | 
            +
                if meta.info_path is not None:
         | 
| 29 | 
            +
                    meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
         | 
| 30 | 
            +
                return meta
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
         | 
| 34 | 
            +
                """Monkey-patch all meta to match cluster specificities."""
         | 
| 35 | 
            +
                return [_clusterify_meta(m) for m in meta]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            @dataclass
         | 
| 39 | 
            +
            class AudioInfo(SegmentWithAttributes):
         | 
| 40 | 
            +
                """Dummy SegmentInfo with empty attributes.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                The InfoAudioDataset is expected to return metadata that inherits
         | 
| 43 | 
            +
                from SegmentWithAttributes class and can return conditioning attributes.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                This basically guarantees all datasets will be compatible with current
         | 
| 46 | 
            +
                solver that contain conditioners requiring this.
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                audio_tokens: tp.Optional[torch.Tensor] = None  # populated when using cached batch for training a LM.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def to_condition_attributes(self) -> ConditioningAttributes:
         | 
| 51 | 
            +
                    return ConditioningAttributes()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class InfoAudioDataset(AudioDataset):
         | 
| 55 | 
            +
                """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                def __init__(self, meta: tp.List[AudioMeta], **kwargs):
         | 
| 60 | 
            +
                    super().__init__(clusterify_all_meta(meta), **kwargs)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
         | 
| 63 | 
            +
                    if not self.return_info:
         | 
| 64 | 
            +
                        wav = super().__getitem__(index)
         | 
| 65 | 
            +
                        assert isinstance(wav, torch.Tensor)
         | 
| 66 | 
            +
                        return wav
         | 
| 67 | 
            +
                    wav, meta = super().__getitem__(index)
         | 
| 68 | 
            +
                    return wav, AudioInfo(**meta.to_dict())
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
         | 
| 72 | 
            +
                """Preprocess a single keyword or possible a list of keywords."""
         | 
| 73 | 
            +
                if isinstance(value, list):
         | 
| 74 | 
            +
                    return get_keyword_list(value)
         | 
| 75 | 
            +
                else:
         | 
| 76 | 
            +
                    return get_keyword(value)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
         | 
| 80 | 
            +
                """Preprocess a single keyword."""
         | 
| 81 | 
            +
                if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
         | 
| 82 | 
            +
                    return None
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    return value.strip()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
         | 
| 88 | 
            +
                """Preprocess a single keyword."""
         | 
| 89 | 
            +
                if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
         | 
| 90 | 
            +
                    return None
         | 
| 91 | 
            +
                else:
         | 
| 92 | 
            +
                    return value.strip().lower()
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
         | 
| 96 | 
            +
                """Preprocess a list of keywords."""
         | 
| 97 | 
            +
                if isinstance(values, str):
         | 
| 98 | 
            +
                    values = [v.strip() for v in re.split(r'[,\s]', values)]
         | 
| 99 | 
            +
                elif isinstance(values, float) and math.isnan(values):
         | 
| 100 | 
            +
                    values = []
         | 
| 101 | 
            +
                if not isinstance(values, list):
         | 
| 102 | 
            +
                    logger.debug(f"Unexpected keyword list {values}")
         | 
| 103 | 
            +
                    values = [str(values)]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                kws = [get_keyword(v) for v in values]
         | 
| 106 | 
            +
                kw_list = [k for k in kws if k is not None]
         | 
| 107 | 
            +
                if len(kw_list) == 0:
         | 
| 108 | 
            +
                    return None
         | 
| 109 | 
            +
                else:
         | 
| 110 | 
            +
                    return kw_list
         | 
    	
        audiocraft/data/music_dataset.py
    ADDED
    
    | @@ -0,0 +1,270 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Dataset of music tracks with rich metadata.
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
            from dataclasses import dataclass, field, fields, replace
         | 
| 9 | 
            +
            import gzip
         | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            import logging
         | 
| 12 | 
            +
            from pathlib import Path
         | 
| 13 | 
            +
            import random
         | 
| 14 | 
            +
            import typing as tp
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .info_audio_dataset import (
         | 
| 19 | 
            +
                InfoAudioDataset,
         | 
| 20 | 
            +
                AudioInfo,
         | 
| 21 | 
            +
                get_keyword_list,
         | 
| 22 | 
            +
                get_keyword,
         | 
| 23 | 
            +
                get_string
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from ..modules.conditioners import (
         | 
| 26 | 
            +
                ConditioningAttributes,
         | 
| 27 | 
            +
                JointEmbedCondition,
         | 
| 28 | 
            +
                WavCondition,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from ..utils.utils import warn_once
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            @dataclass
         | 
| 37 | 
            +
            class MusicInfo(AudioInfo):
         | 
| 38 | 
            +
                """Segment info augmented with music metadata.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                # music-specific metadata
         | 
| 41 | 
            +
                title: tp.Optional[str] = None
         | 
| 42 | 
            +
                artist: tp.Optional[str] = None  # anonymized artist id, used to ensure no overlap between splits
         | 
| 43 | 
            +
                key: tp.Optional[str] = None
         | 
| 44 | 
            +
                bpm: tp.Optional[float] = None
         | 
| 45 | 
            +
                genre: tp.Optional[str] = None
         | 
| 46 | 
            +
                moods: tp.Optional[list] = None
         | 
| 47 | 
            +
                keywords: tp.Optional[list] = None
         | 
| 48 | 
            +
                description: tp.Optional[str] = None
         | 
| 49 | 
            +
                name: tp.Optional[str] = None
         | 
| 50 | 
            +
                instrument: tp.Optional[str] = None
         | 
| 51 | 
            +
                # original wav accompanying the metadata
         | 
| 52 | 
            +
                self_wav: tp.Optional[WavCondition] = None
         | 
| 53 | 
            +
                # dict mapping attributes names to tuple of wav, text and metadata
         | 
| 54 | 
            +
                joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @property
         | 
| 57 | 
            +
                def has_music_meta(self) -> bool:
         | 
| 58 | 
            +
                    return self.name is not None
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def to_condition_attributes(self) -> ConditioningAttributes:
         | 
| 61 | 
            +
                    out = ConditioningAttributes()
         | 
| 62 | 
            +
                    for _field in fields(self):
         | 
| 63 | 
            +
                        key, value = _field.name, getattr(self, _field.name)
         | 
| 64 | 
            +
                        if key == 'self_wav':
         | 
| 65 | 
            +
                            out.wav[key] = value
         | 
| 66 | 
            +
                        elif key == 'joint_embed':
         | 
| 67 | 
            +
                            for embed_attribute, embed_cond in value.items():
         | 
| 68 | 
            +
                                out.joint_embed[embed_attribute] = embed_cond
         | 
| 69 | 
            +
                        else:
         | 
| 70 | 
            +
                            if isinstance(value, list):
         | 
| 71 | 
            +
                                value = ' '.join(value)
         | 
| 72 | 
            +
                            out.text[key] = value
         | 
| 73 | 
            +
                    return out
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                @staticmethod
         | 
| 76 | 
            +
                def attribute_getter(attribute):
         | 
| 77 | 
            +
                    if attribute == 'bpm':
         | 
| 78 | 
            +
                        preprocess_func = get_bpm
         | 
| 79 | 
            +
                    elif attribute == 'key':
         | 
| 80 | 
            +
                        preprocess_func = get_musical_key
         | 
| 81 | 
            +
                    elif attribute in ['moods', 'keywords']:
         | 
| 82 | 
            +
                        preprocess_func = get_keyword_list
         | 
| 83 | 
            +
                    elif attribute in ['genre', 'name', 'instrument']:
         | 
| 84 | 
            +
                        preprocess_func = get_keyword
         | 
| 85 | 
            +
                    elif attribute in ['title', 'artist', 'description']:
         | 
| 86 | 
            +
                        preprocess_func = get_string
         | 
| 87 | 
            +
                    else:
         | 
| 88 | 
            +
                        preprocess_func = None
         | 
| 89 | 
            +
                    return preprocess_func
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                @classmethod
         | 
| 92 | 
            +
                def from_dict(cls, dictionary: dict, fields_required: bool = False):
         | 
| 93 | 
            +
                    _dictionary: tp.Dict[str, tp.Any] = {}
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # allow a subset of attributes to not be loaded from the dictionary
         | 
| 96 | 
            +
                    # these attributes may be populated later
         | 
| 97 | 
            +
                    post_init_attributes = ['self_wav', 'joint_embed']
         | 
| 98 | 
            +
                    optional_fields = ['keywords']
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    for _field in fields(cls):
         | 
| 101 | 
            +
                        if _field.name in post_init_attributes:
         | 
| 102 | 
            +
                            continue
         | 
| 103 | 
            +
                        elif _field.name not in dictionary:
         | 
| 104 | 
            +
                            if fields_required and _field.name not in optional_fields:
         | 
| 105 | 
            +
                                raise KeyError(f"Unexpected missing key: {_field.name}")
         | 
| 106 | 
            +
                        else:
         | 
| 107 | 
            +
                            preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
         | 
| 108 | 
            +
                            value = dictionary[_field.name]
         | 
| 109 | 
            +
                            if preprocess_func:
         | 
| 110 | 
            +
                                value = preprocess_func(value)
         | 
| 111 | 
            +
                            _dictionary[_field.name] = value
         | 
| 112 | 
            +
                    return cls(**_dictionary)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
         | 
| 116 | 
            +
                                               drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
         | 
| 117 | 
            +
                """Augment MusicInfo description with additional metadata fields and potential dropout.
         | 
| 118 | 
            +
                Additional textual attributes are added given probability 'merge_text_conditions_p' and
         | 
| 119 | 
            +
                the original textual description is dropped from the augmented description given probability drop_desc_p.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Args:
         | 
| 122 | 
            +
                    music_info (MusicInfo): The music metadata to augment.
         | 
| 123 | 
            +
                    merge_text_p (float): Probability of merging additional metadata to the description.
         | 
| 124 | 
            +
                        If provided value is 0, then no merging is performed.
         | 
| 125 | 
            +
                    drop_desc_p (float): Probability of dropping the original description on text merge.
         | 
| 126 | 
            +
                        if provided value is 0, then no drop out is performed.
         | 
| 127 | 
            +
                    drop_other_p (float): Probability of dropping the other fields used for text augmentation.
         | 
| 128 | 
            +
                Returns:
         | 
| 129 | 
            +
                    MusicInfo: The MusicInfo with augmented textual description.
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
         | 
| 132 | 
            +
                    valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
         | 
| 133 | 
            +
                    valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
         | 
| 134 | 
            +
                    keep_field = random.uniform(0, 1) < drop_other_p
         | 
| 135 | 
            +
                    return valid_field_name and valid_field_value and keep_field
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def process_value(v: tp.Any) -> str:
         | 
| 138 | 
            +
                    if isinstance(v, (int, float, str)):
         | 
| 139 | 
            +
                        return str(v)
         | 
| 140 | 
            +
                    if isinstance(v, list):
         | 
| 141 | 
            +
                        return ", ".join(v)
         | 
| 142 | 
            +
                    else:
         | 
| 143 | 
            +
                        raise ValueError(f"Unknown type for text value! ({type(v), v})")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                description = music_info.description
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                metadata_text = ""
         | 
| 148 | 
            +
                if random.uniform(0, 1) < merge_text_p:
         | 
| 149 | 
            +
                    meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
         | 
| 150 | 
            +
                                  for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
         | 
| 151 | 
            +
                    random.shuffle(meta_pairs)
         | 
| 152 | 
            +
                    metadata_text = ". ".join(meta_pairs)
         | 
| 153 | 
            +
                    description = description if not random.uniform(0, 1) < drop_desc_p else None
         | 
| 154 | 
            +
                    logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                if description is None:
         | 
| 157 | 
            +
                    description = metadata_text if len(metadata_text) > 1 else None
         | 
| 158 | 
            +
                else:
         | 
| 159 | 
            +
                    description = ". ".join([description.rstrip('.'), metadata_text])
         | 
| 160 | 
            +
                description = description.strip() if description else None
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                music_info = replace(music_info)
         | 
| 163 | 
            +
                music_info.description = description
         | 
| 164 | 
            +
                return music_info
         | 
| 165 | 
            +
             | 
| 166 | 
            +
             | 
| 167 | 
            +
            class Paraphraser:
         | 
| 168 | 
            +
                def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
         | 
| 169 | 
            +
                    self.paraphrase_p = paraphrase_p
         | 
| 170 | 
            +
                    open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
         | 
| 171 | 
            +
                    with open_fn(paraphrase_source, 'rb') as f:  # type: ignore
         | 
| 172 | 
            +
                        self.paraphrase_source = json.loads(f.read())
         | 
| 173 | 
            +
                    logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def sample_paraphrase(self, audio_path: str, description: str):
         | 
| 176 | 
            +
                    if random.random() >= self.paraphrase_p:
         | 
| 177 | 
            +
                        return description
         | 
| 178 | 
            +
                    info_path = Path(audio_path).with_suffix('.json')
         | 
| 179 | 
            +
                    if info_path not in self.paraphrase_source:
         | 
| 180 | 
            +
                        warn_once(logger, f"{info_path} not in paraphrase source!")
         | 
| 181 | 
            +
                        return description
         | 
| 182 | 
            +
                    new_desc = random.choice(self.paraphrase_source[info_path])
         | 
| 183 | 
            +
                    logger.debug(f"{description} -> {new_desc}")
         | 
| 184 | 
            +
                    return new_desc
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            class MusicDataset(InfoAudioDataset):
         | 
| 188 | 
            +
                """Music dataset is an AudioDataset with music-related metadata.
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                Args:
         | 
| 191 | 
            +
                    info_fields_required (bool): Whether to enforce having required fields.
         | 
| 192 | 
            +
                    merge_text_p (float): Probability of merging additional metadata to the description.
         | 
| 193 | 
            +
                    drop_desc_p (float): Probability of dropping the original description on text merge.
         | 
| 194 | 
            +
                    drop_other_p (float): Probability of dropping the other fields used for text augmentation.
         | 
| 195 | 
            +
                    joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
         | 
| 196 | 
            +
                    paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
         | 
| 197 | 
            +
                        paraphrases for the description. The json should be a dict with keys are the
         | 
| 198 | 
            +
                        original info path (e.g. track_path.json) and each value is a list of possible
         | 
| 199 | 
            +
                        paraphrased.
         | 
| 200 | 
            +
                    paraphrase_p (float): probability of taking a paraphrase.
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
         | 
| 203 | 
            +
                """
         | 
| 204 | 
            +
                def __init__(self, *args, info_fields_required: bool = True,
         | 
| 205 | 
            +
                             merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
         | 
| 206 | 
            +
                             joint_embed_attributes: tp.List[str] = [],
         | 
| 207 | 
            +
                             paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
         | 
| 208 | 
            +
                             **kwargs):
         | 
| 209 | 
            +
                    kwargs['return_info'] = True  # We require the info for each song of the dataset.
         | 
| 210 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 211 | 
            +
                    self.info_fields_required = info_fields_required
         | 
| 212 | 
            +
                    self.merge_text_p = merge_text_p
         | 
| 213 | 
            +
                    self.drop_desc_p = drop_desc_p
         | 
| 214 | 
            +
                    self.drop_other_p = drop_other_p
         | 
| 215 | 
            +
                    self.joint_embed_attributes = joint_embed_attributes
         | 
| 216 | 
            +
                    self.paraphraser = None
         | 
| 217 | 
            +
                    if paraphrase_source is not None:
         | 
| 218 | 
            +
                        self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def __getitem__(self, index):
         | 
| 221 | 
            +
                    wav, info = super().__getitem__(index)
         | 
| 222 | 
            +
                    info_data = info.to_dict()
         | 
| 223 | 
            +
                    music_info_path = Path(info.meta.path).with_suffix('.json')
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    if Path(music_info_path).exists():
         | 
| 226 | 
            +
                        with open(music_info_path, 'r') as json_file:
         | 
| 227 | 
            +
                            music_data = json.load(json_file)
         | 
| 228 | 
            +
                            music_data.update(info_data)
         | 
| 229 | 
            +
                            music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
         | 
| 230 | 
            +
                        if self.paraphraser is not None:
         | 
| 231 | 
            +
                            music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
         | 
| 232 | 
            +
                        if self.merge_text_p:
         | 
| 233 | 
            +
                            music_info = augment_music_info_description(
         | 
| 234 | 
            +
                                music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
         | 
| 235 | 
            +
                    else:
         | 
| 236 | 
            +
                        music_info = MusicInfo.from_dict(info_data, fields_required=False)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    music_info.self_wav = WavCondition(
         | 
| 239 | 
            +
                        wav=wav[None], length=torch.tensor([info.n_frames]),
         | 
| 240 | 
            +
                        sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    for att in self.joint_embed_attributes:
         | 
| 243 | 
            +
                        att_value = getattr(music_info, att)
         | 
| 244 | 
            +
                        joint_embed_cond = JointEmbedCondition(
         | 
| 245 | 
            +
                            wav[None], [att_value], torch.tensor([info.n_frames]),
         | 
| 246 | 
            +
                            sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
         | 
| 247 | 
            +
                        music_info.joint_embed[att] = joint_embed_cond
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    return wav, music_info
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
         | 
| 253 | 
            +
                """Preprocess key keywords, discarding them if there are multiple key defined."""
         | 
| 254 | 
            +
                if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
         | 
| 255 | 
            +
                    return None
         | 
| 256 | 
            +
                elif ',' in value:
         | 
| 257 | 
            +
                    # For now, we discard when multiple keys are defined separated with comas
         | 
| 258 | 
            +
                    return None
         | 
| 259 | 
            +
                else:
         | 
| 260 | 
            +
                    return value.strip().lower()
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
         | 
| 264 | 
            +
                """Preprocess to a float."""
         | 
| 265 | 
            +
                if value is None:
         | 
| 266 | 
            +
                    return None
         | 
| 267 | 
            +
                try:
         | 
| 268 | 
            +
                    return float(value)
         | 
| 269 | 
            +
                except ValueError:
         | 
| 270 | 
            +
                    return None
         | 
    	
        audiocraft/data/sound_dataset.py
    ADDED
    
    | @@ -0,0 +1,330 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Dataset of audio with a simple description.
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from dataclasses import dataclass, fields, replace
         | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            from pathlib import Path
         | 
| 12 | 
            +
            import random
         | 
| 13 | 
            +
            import typing as tp
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            import torch
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from .info_audio_dataset import (
         | 
| 19 | 
            +
                InfoAudioDataset,
         | 
| 20 | 
            +
                get_keyword_or_keyword_list
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
            from ..modules.conditioners import (
         | 
| 23 | 
            +
                ConditioningAttributes,
         | 
| 24 | 
            +
                SegmentWithAttributes,
         | 
| 25 | 
            +
                WavCondition,
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            EPS = torch.finfo(torch.float32).eps
         | 
| 30 | 
            +
            TARGET_LEVEL_LOWER = -35
         | 
| 31 | 
            +
            TARGET_LEVEL_UPPER = -15
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            @dataclass
         | 
| 35 | 
            +
            class SoundInfo(SegmentWithAttributes):
         | 
| 36 | 
            +
                """Segment info augmented with Sound metadata.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                description: tp.Optional[str] = None
         | 
| 39 | 
            +
                self_wav: tp.Optional[torch.Tensor] = None
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @property
         | 
| 42 | 
            +
                def has_sound_meta(self) -> bool:
         | 
| 43 | 
            +
                    return self.description is not None
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def to_condition_attributes(self) -> ConditioningAttributes:
         | 
| 46 | 
            +
                    out = ConditioningAttributes()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    for _field in fields(self):
         | 
| 49 | 
            +
                        key, value = _field.name, getattr(self, _field.name)
         | 
| 50 | 
            +
                        if key == 'self_wav':
         | 
| 51 | 
            +
                            out.wav[key] = value
         | 
| 52 | 
            +
                        else:
         | 
| 53 | 
            +
                            out.text[key] = value
         | 
| 54 | 
            +
                    return out
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @staticmethod
         | 
| 57 | 
            +
                def attribute_getter(attribute):
         | 
| 58 | 
            +
                    if attribute == 'description':
         | 
| 59 | 
            +
                        preprocess_func = get_keyword_or_keyword_list
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        preprocess_func = None
         | 
| 62 | 
            +
                    return preprocess_func
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @classmethod
         | 
| 65 | 
            +
                def from_dict(cls, dictionary: dict, fields_required: bool = False):
         | 
| 66 | 
            +
                    _dictionary: tp.Dict[str, tp.Any] = {}
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # allow a subset of attributes to not be loaded from the dictionary
         | 
| 69 | 
            +
                    # these attributes may be populated later
         | 
| 70 | 
            +
                    post_init_attributes = ['self_wav']
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    for _field in fields(cls):
         | 
| 73 | 
            +
                        if _field.name in post_init_attributes:
         | 
| 74 | 
            +
                            continue
         | 
| 75 | 
            +
                        elif _field.name not in dictionary:
         | 
| 76 | 
            +
                            if fields_required:
         | 
| 77 | 
            +
                                raise KeyError(f"Unexpected missing key: {_field.name}")
         | 
| 78 | 
            +
                        else:
         | 
| 79 | 
            +
                            preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
         | 
| 80 | 
            +
                            value = dictionary[_field.name]
         | 
| 81 | 
            +
                            if preprocess_func:
         | 
| 82 | 
            +
                                value = preprocess_func(value)
         | 
| 83 | 
            +
                            _dictionary[_field.name] = value
         | 
| 84 | 
            +
                    return cls(**_dictionary)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class SoundDataset(InfoAudioDataset):
         | 
| 88 | 
            +
                """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                Args:
         | 
| 91 | 
            +
                    info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
         | 
| 92 | 
            +
                    external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
         | 
| 93 | 
            +
                        The metadata files contained in this folder are expected to match the stem of the audio file with
         | 
| 94 | 
            +
                        a json extension.
         | 
| 95 | 
            +
                    aug_p (float): Probability of performing audio mixing augmentation on the batch.
         | 
| 96 | 
            +
                    mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
         | 
| 97 | 
            +
                    mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
         | 
| 98 | 
            +
                    mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
         | 
| 99 | 
            +
                    mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
         | 
| 100 | 
            +
                    kwargs: Additional arguments for AudioDataset.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
                def __init__(
         | 
| 105 | 
            +
                    self,
         | 
| 106 | 
            +
                    *args,
         | 
| 107 | 
            +
                    info_fields_required: bool = True,
         | 
| 108 | 
            +
                    external_metadata_source: tp.Optional[str] = None,
         | 
| 109 | 
            +
                    aug_p: float = 0.,
         | 
| 110 | 
            +
                    mix_p: float = 0.,
         | 
| 111 | 
            +
                    mix_snr_low: int = -5,
         | 
| 112 | 
            +
                    mix_snr_high: int = 5,
         | 
| 113 | 
            +
                    mix_min_overlap: float = 0.5,
         | 
| 114 | 
            +
                    **kwargs
         | 
| 115 | 
            +
                ):
         | 
| 116 | 
            +
                    kwargs['return_info'] = True  # We require the info for each song of the dataset.
         | 
| 117 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 118 | 
            +
                    self.info_fields_required = info_fields_required
         | 
| 119 | 
            +
                    self.external_metadata_source = external_metadata_source
         | 
| 120 | 
            +
                    self.aug_p = aug_p
         | 
| 121 | 
            +
                    self.mix_p = mix_p
         | 
| 122 | 
            +
                    if self.aug_p > 0:
         | 
| 123 | 
            +
                        assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
         | 
| 124 | 
            +
                        assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
         | 
| 125 | 
            +
                    self.mix_snr_low = mix_snr_low
         | 
| 126 | 
            +
                    self.mix_snr_high = mix_snr_high
         | 
| 127 | 
            +
                    self.mix_min_overlap = mix_min_overlap
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
         | 
| 130 | 
            +
                    """Get path of JSON with metadata (description, etc.).
         | 
| 131 | 
            +
                    If there exists a JSON with the same name as 'path.name', then it will be used.
         | 
| 132 | 
            +
                    Else, such JSON will be searched for in an external json source folder if it exists.
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    info_path = Path(path).with_suffix('.json')
         | 
| 135 | 
            +
                    if Path(info_path).exists():
         | 
| 136 | 
            +
                        return info_path
         | 
| 137 | 
            +
                    elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
         | 
| 138 | 
            +
                        return Path(self.external_metadata_source) / info_path.name
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        raise Exception(f"Unable to find a metadata JSON for path: {path}")
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def __getitem__(self, index):
         | 
| 143 | 
            +
                    wav, info = super().__getitem__(index)
         | 
| 144 | 
            +
                    info_data = info.to_dict()
         | 
| 145 | 
            +
                    info_path = self._get_info_path(info.meta.path)
         | 
| 146 | 
            +
                    if Path(info_path).exists():
         | 
| 147 | 
            +
                        with open(info_path, 'r') as json_file:
         | 
| 148 | 
            +
                            sound_data = json.load(json_file)
         | 
| 149 | 
            +
                            sound_data.update(info_data)
         | 
| 150 | 
            +
                            sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
         | 
| 151 | 
            +
                            # if there are multiple descriptions, sample one randomly
         | 
| 152 | 
            +
                            if isinstance(sound_info.description, list):
         | 
| 153 | 
            +
                                sound_info.description = random.choice(sound_info.description)
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        sound_info = SoundInfo.from_dict(info_data, fields_required=False)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    sound_info.self_wav = WavCondition(
         | 
| 158 | 
            +
                        wav=wav[None], length=torch.tensor([info.n_frames]),
         | 
| 159 | 
            +
                        sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    return wav, sound_info
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def collater(self, samples):
         | 
| 164 | 
            +
                    # when training, audio mixing is performed in the collate function
         | 
| 165 | 
            +
                    wav, sound_info = super().collater(samples)  # SoundDataset always returns infos
         | 
| 166 | 
            +
                    if self.aug_p > 0:
         | 
| 167 | 
            +
                        wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
         | 
| 168 | 
            +
                                                      snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
         | 
| 169 | 
            +
                                                      min_overlap=self.mix_min_overlap)
         | 
| 170 | 
            +
                    return wav, sound_info
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            def rms_f(x: torch.Tensor) -> torch.Tensor:
         | 
| 174 | 
            +
                return (x ** 2).mean(1).pow(0.5)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
         | 
| 178 | 
            +
                """Normalize the signal to the target level."""
         | 
| 179 | 
            +
                rms = rms_f(audio)
         | 
| 180 | 
            +
                scalar = 10 ** (target_level / 20) / (rms + EPS)
         | 
| 181 | 
            +
                audio = audio * scalar.unsqueeze(1)
         | 
| 182 | 
            +
                return audio
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
         | 
| 186 | 
            +
                return (abs(audio) > clipping_threshold).any(1)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
         | 
| 190 | 
            +
                start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
         | 
| 191 | 
            +
                remainder = src.shape[1] - start
         | 
| 192 | 
            +
                if dst.shape[1] > remainder:
         | 
| 193 | 
            +
                    src[:, start:] = src[:, start:] + dst[:, :remainder]
         | 
| 194 | 
            +
                else:
         | 
| 195 | 
            +
                    src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
         | 
| 196 | 
            +
                return src
         | 
| 197 | 
            +
             | 
| 198 | 
            +
             | 
| 199 | 
            +
            def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
         | 
| 200 | 
            +
                          target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
         | 
| 201 | 
            +
                """Function to mix clean speech and noise at various SNR levels.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                Args:
         | 
| 204 | 
            +
                    clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
         | 
| 205 | 
            +
                    noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
         | 
| 206 | 
            +
                    snr (int): SNR level when mixing.
         | 
| 207 | 
            +
                    min_overlap (float): Minimum overlap between the two mixed sources.
         | 
| 208 | 
            +
                    target_level (int): Gain level in dB.
         | 
| 209 | 
            +
                    clipping_threshold (float): Threshold for clipping the audio.
         | 
| 210 | 
            +
                Returns:
         | 
| 211 | 
            +
                    torch.Tensor: The mixed audio, of shape [B, T].
         | 
| 212 | 
            +
                """
         | 
| 213 | 
            +
                if clean.shape[1] > noise.shape[1]:
         | 
| 214 | 
            +
                    noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
         | 
| 215 | 
            +
                else:
         | 
| 216 | 
            +
                    noise = noise[:, :clean.shape[1]]
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                # normalizing to -25 dB FS
         | 
| 219 | 
            +
                clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
         | 
| 220 | 
            +
                clean = normalize(clean, target_level)
         | 
| 221 | 
            +
                rmsclean = rms_f(clean)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
         | 
| 224 | 
            +
                noise = normalize(noise, target_level)
         | 
| 225 | 
            +
                rmsnoise = rms_f(noise)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                # set the noise level for a given SNR
         | 
| 228 | 
            +
                noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
         | 
| 229 | 
            +
                noisenewlevel = noise * noisescalar
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                # mix noise and clean speech
         | 
| 232 | 
            +
                noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
         | 
| 235 | 
            +
                # there is a chance of clipping that might happen with very less probability, which is not a major issue.
         | 
| 236 | 
            +
                noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
         | 
| 237 | 
            +
                rmsnoisy = rms_f(noisyspeech)
         | 
| 238 | 
            +
                scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
         | 
| 239 | 
            +
                noisyspeech = noisyspeech * scalarnoisy
         | 
| 240 | 
            +
                clean = clean * scalarnoisy
         | 
| 241 | 
            +
                noisenewlevel = noisenewlevel * scalarnoisy
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
         | 
| 244 | 
            +
                clipped = is_clipped(noisyspeech)
         | 
| 245 | 
            +
                if clipped.any():
         | 
| 246 | 
            +
                    noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
         | 
| 247 | 
            +
                    noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                return noisyspeech
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
         | 
| 253 | 
            +
                if snr_low == snr_high:
         | 
| 254 | 
            +
                    snr = snr_low
         | 
| 255 | 
            +
                else:
         | 
| 256 | 
            +
                    snr = np.random.randint(snr_low, snr_high)
         | 
| 257 | 
            +
                mix = snr_mixer(src, dst, snr, min_overlap)
         | 
| 258 | 
            +
                return mix
         | 
| 259 | 
            +
             | 
| 260 | 
            +
             | 
| 261 | 
            +
            def mix_text(src_text: str, dst_text: str):
         | 
| 262 | 
            +
                """Mix text from different sources by concatenating them."""
         | 
| 263 | 
            +
                if src_text == dst_text:
         | 
| 264 | 
            +
                    return src_text
         | 
| 265 | 
            +
                return src_text + " " + dst_text
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
         | 
| 269 | 
            +
                            snr_low: int, snr_high: int, min_overlap: float):
         | 
| 270 | 
            +
                """Mix samples within a batch, summing the waveforms and concatenating the text infos.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                Args:
         | 
| 273 | 
            +
                    wavs (torch.Tensor): Audio tensors of shape [B, C, T].
         | 
| 274 | 
            +
                    infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
         | 
| 275 | 
            +
                    aug_p (float): Augmentation probability.
         | 
| 276 | 
            +
                    mix_p (float): Proportion of items in the batch to mix (and merge) together.
         | 
| 277 | 
            +
                    snr_low (int): Lowerbound for sampling SNR.
         | 
| 278 | 
            +
                    snr_high (int): Upperbound for sampling SNR.
         | 
| 279 | 
            +
                    min_overlap (float): Minimum overlap between mixed samples.
         | 
| 280 | 
            +
                Returns:
         | 
| 281 | 
            +
                    tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
         | 
| 282 | 
            +
                        and mixed SoundInfo for the given batch.
         | 
| 283 | 
            +
                """
         | 
| 284 | 
            +
                # no mixing to perform within the batch
         | 
| 285 | 
            +
                if mix_p == 0:
         | 
| 286 | 
            +
                    return wavs, infos
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                if random.uniform(0, 1) < aug_p:
         | 
| 289 | 
            +
                    # perform all augmentations on waveforms as [B, T]
         | 
| 290 | 
            +
                    # randomly picking pairs of audio to mix
         | 
| 291 | 
            +
                    assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
         | 
| 292 | 
            +
                    wavs = wavs.mean(dim=1, keepdim=False)
         | 
| 293 | 
            +
                    B, T = wavs.shape
         | 
| 294 | 
            +
                    k = int(mix_p * B)
         | 
| 295 | 
            +
                    mixed_sources_idx = torch.randperm(B)[:k]
         | 
| 296 | 
            +
                    mixed_targets_idx = torch.randperm(B)[:k]
         | 
| 297 | 
            +
                    aug_wavs = snr_mix(
         | 
| 298 | 
            +
                        wavs[mixed_sources_idx],
         | 
| 299 | 
            +
                        wavs[mixed_targets_idx],
         | 
| 300 | 
            +
                        snr_low,
         | 
| 301 | 
            +
                        snr_high,
         | 
| 302 | 
            +
                        min_overlap,
         | 
| 303 | 
            +
                    )
         | 
| 304 | 
            +
                    # mixing textual descriptions in metadata
         | 
| 305 | 
            +
                    descriptions = [info.description for info in infos]
         | 
| 306 | 
            +
                    aug_infos = []
         | 
| 307 | 
            +
                    for i, j in zip(mixed_sources_idx, mixed_targets_idx):
         | 
| 308 | 
            +
                        text = mix_text(descriptions[i], descriptions[j])
         | 
| 309 | 
            +
                        m = replace(infos[i])
         | 
| 310 | 
            +
                        m.description = text
         | 
| 311 | 
            +
                        aug_infos.append(m)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # back to [B, C, T]
         | 
| 314 | 
            +
                    aug_wavs = aug_wavs.unsqueeze(1)
         | 
| 315 | 
            +
                    assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
         | 
| 316 | 
            +
                    assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
         | 
| 317 | 
            +
                    assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    return aug_wavs, aug_infos  # [B, C, T]
         | 
| 320 | 
            +
                else:
         | 
| 321 | 
            +
                    # randomly pick samples in the batch to match
         | 
| 322 | 
            +
                    # the batch size when performing audio mixing
         | 
| 323 | 
            +
                    B, C, T = wavs.shape
         | 
| 324 | 
            +
                    k = int(mix_p * B)
         | 
| 325 | 
            +
                    wav_idx = torch.randperm(B)[:k]
         | 
| 326 | 
            +
                    wavs = wavs[wav_idx]
         | 
| 327 | 
            +
                    infos = [infos[i] for i in wav_idx]
         | 
| 328 | 
            +
                    assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    return wavs, infos  # [B, C, T]
         | 
    	
        audiocraft/data/zip.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         | 
| 2 | 
            +
            # All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
            """Utility for reading some info from inside a zip file.
         | 
| 7 | 
            +
            """
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import typing
         | 
| 10 | 
            +
            import zipfile
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from dataclasses import dataclass
         | 
| 13 | 
            +
            from functools import lru_cache
         | 
| 14 | 
            +
            from typing_extensions import Literal
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            DEFAULT_SIZE = 32
         | 
| 18 | 
            +
            MODE = Literal['r', 'w', 'x', 'a']
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            @dataclass(order=True)
         | 
| 22 | 
            +
            class PathInZip:
         | 
| 23 | 
            +
                """Hold a path of file within a zip file.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                Args:
         | 
| 26 | 
            +
                    path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
         | 
| 27 | 
            +
                        Let's assume there is a zip file /some/location/foo.zip
         | 
| 28 | 
            +
                        and inside of it is a json file located at /data/file1.json,
         | 
| 29 | 
            +
                        Then we expect path = "/some/location/foo.zip:/data/file1.json".
         | 
| 30 | 
            +
                """
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                INFO_PATH_SEP = ':'
         | 
| 33 | 
            +
                zip_path: str
         | 
| 34 | 
            +
                file_path: str
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __init__(self, path: str) -> None:
         | 
| 37 | 
            +
                    split_path = path.split(self.INFO_PATH_SEP)
         | 
| 38 | 
            +
                    assert len(split_path) == 2
         | 
| 39 | 
            +
                    self.zip_path, self.file_path = split_path
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                @classmethod
         | 
| 42 | 
            +
                def from_paths(cls, zip_path: str, file_path: str):
         | 
| 43 | 
            +
                    return cls(zip_path + cls.INFO_PATH_SEP + file_path)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def __str__(self) -> str:
         | 
| 46 | 
            +
                    return self.zip_path + self.INFO_PATH_SEP + self.file_path
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def _open_zip(path: str, mode: MODE = 'r'):
         | 
| 50 | 
            +
                return zipfile.ZipFile(path, mode)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def set_zip_cache_size(max_size: int):
         | 
| 57 | 
            +
                """Sets the maximal LRU caching for zip file opening.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Args:
         | 
| 60 | 
            +
                    max_size (int): the maximal LRU cache.
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                global _cached_open_zip
         | 
| 63 | 
            +
                _cached_open_zip = lru_cache(max_size)(_open_zip)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
         | 
| 67 | 
            +
                """Opens a file stored inside a zip and returns a file-like object.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Args:
         | 
| 70 | 
            +
                    path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
         | 
| 71 | 
            +
                    mode (str): The mode in which to open the file with.
         | 
| 72 | 
            +
                Returns:
         | 
| 73 | 
            +
                    A file-like object for PathInZip.
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                zf = _cached_open_zip(path_in_zip.zip_path)
         | 
| 76 | 
            +
                return zf.open(path_in_zip.file_path)
         |