# Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import time import torch from torch.utils.data import DataLoader, DistributedSampler import soundfile as sf import numpy as np from dist_utils import is_main_process, get_world_size, get_rank def now(): from datetime import datetime return datetime.now().strftime("%Y%m%d%H%M") def setup_logger(): logging.basicConfig( level=logging.INFO if is_main_process() else logging.WARN, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler()], ) def get_dataloader(dataset, config, is_train=True, use_distributed=True): if use_distributed: sampler = DistributedSampler( dataset, shuffle=is_train, num_replicas=get_world_size(), rank=get_rank() ) else: sampler = None loader = DataLoader( dataset, batch_size=config.batch_size_train if is_train else config.batch_size_eval, num_workers=config.num_workers, pin_memory=True, sampler=sampler, shuffle=sampler is None and is_train, collate_fn=dataset.collater, drop_last=is_train, ) if is_train: loader = IterLoader(loader, use_distributed=use_distributed) return loader def apply_to_sample(f, sample): if len(sample) == 0: return {} def _apply(x): if torch.is_tensor(x): return f(x) elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] else: return x return _apply(sample) def move_to_cuda(sample): def _move_to_cuda(tensor): return tensor.cuda() return apply_to_sample(_move_to_cuda, sample) def prepare_sample(samples, cuda_enabled=True): if cuda_enabled: samples = move_to_cuda(samples) # TODO fp16 support return samples class IterLoader: """ A wrapper to convert DataLoader as an infinite iterator. Modified from: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py """ def __init__(self, dataloader: DataLoader, use_distributed: bool = False): self._dataloader = dataloader self.iter_loader = iter(self._dataloader) self._use_distributed = use_distributed self._epoch = 0 @property def epoch(self) -> int: return self._epoch def __next__(self): try: data = next(self.iter_loader) except StopIteration: self._epoch += 1 if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: self._dataloader.sampler.set_epoch(self._epoch) time.sleep(2) # Prevent possible deadlock during epoch transition self.iter_loader = iter(self._dataloader) data = next(self.iter_loader) return data def __iter__(self): return self def __len__(self): return len(self._dataloader) def prepare_one_sample(wav_path, wav_processor, cuda_enabled=True): audio, sr = sf.read(wav_path) if len(audio.shape) == 2: # stereo to mono audio = audio[:, 0] if len(audio) < sr: # pad audio to at least 1s sil = np.zeros(sr - len(audio), dtype=float) audio = np.concatenate((audio, sil), axis=0) audio = audio[: sr * 30] # truncate audio to at most 30s spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"] samples = { "spectrogram": spectrogram, "raw_wav": torch.from_numpy(audio).unsqueeze(0), "padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0), } if cuda_enabled: samples = move_to_cuda(samples) return samples