|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import time |
|
|
|
import torch |
|
from torch.utils.data import DataLoader, DistributedSampler |
|
import soundfile as sf |
|
import numpy as np |
|
|
|
from dist_utils import is_main_process, get_world_size, get_rank |
|
|
|
|
|
def now(): |
|
from datetime import datetime |
|
|
|
return datetime.now().strftime("%Y%m%d%H%M") |
|
|
|
|
|
def setup_logger(): |
|
logging.basicConfig( |
|
level=logging.INFO if is_main_process() else logging.WARN, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
handlers=[logging.StreamHandler()], |
|
) |
|
|
|
|
|
def get_dataloader(dataset, config, is_train=True, use_distributed=True): |
|
if use_distributed: |
|
sampler = DistributedSampler( |
|
dataset, |
|
shuffle=is_train, |
|
num_replicas=get_world_size(), |
|
rank=get_rank() |
|
) |
|
else: |
|
sampler = None |
|
|
|
loader = DataLoader( |
|
dataset, |
|
batch_size=config.batch_size_train if is_train else config.batch_size_eval, |
|
num_workers=config.num_workers, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=sampler is None and is_train, |
|
collate_fn=dataset.collater, |
|
drop_last=is_train, |
|
) |
|
|
|
if is_train: |
|
loader = IterLoader(loader, use_distributed=use_distributed) |
|
|
|
return loader |
|
|
|
|
|
def apply_to_sample(f, sample): |
|
if len(sample) == 0: |
|
return {} |
|
|
|
def _apply(x): |
|
if torch.is_tensor(x): |
|
return f(x) |
|
elif isinstance(x, dict): |
|
return {key: _apply(value) for key, value in x.items()} |
|
elif isinstance(x, list): |
|
return [_apply(x) for x in x] |
|
else: |
|
return x |
|
|
|
return _apply(sample) |
|
|
|
|
|
def move_to_cuda(sample): |
|
def _move_to_cuda(tensor): |
|
return tensor.cuda() |
|
|
|
return apply_to_sample(_move_to_cuda, sample) |
|
|
|
|
|
def prepare_sample(samples, cuda_enabled=True): |
|
if cuda_enabled: |
|
samples = move_to_cuda(samples) |
|
|
|
|
|
|
|
return samples |
|
|
|
|
|
class IterLoader: |
|
""" |
|
A wrapper to convert DataLoader as an infinite iterator. |
|
|
|
Modified from: |
|
https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py |
|
""" |
|
|
|
def __init__(self, dataloader: DataLoader, use_distributed: bool = False): |
|
self._dataloader = dataloader |
|
self.iter_loader = iter(self._dataloader) |
|
self._use_distributed = use_distributed |
|
self._epoch = 0 |
|
|
|
@property |
|
def epoch(self) -> int: |
|
return self._epoch |
|
|
|
def __next__(self): |
|
try: |
|
data = next(self.iter_loader) |
|
except StopIteration: |
|
self._epoch += 1 |
|
if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: |
|
self._dataloader.sampler.set_epoch(self._epoch) |
|
time.sleep(2) |
|
self.iter_loader = iter(self._dataloader) |
|
data = next(self.iter_loader) |
|
|
|
return data |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __len__(self): |
|
return len(self._dataloader) |
|
|
|
|
|
def prepare_one_sample(wav_path, wav_processor, cuda_enabled=True): |
|
audio, sr = sf.read(wav_path) |
|
if len(audio.shape) == 2: |
|
audio = audio[:, 0] |
|
if len(audio) < sr: |
|
sil = np.zeros(sr - len(audio), dtype=float) |
|
audio = np.concatenate((audio, sil), axis=0) |
|
audio = audio[: sr * 30] |
|
|
|
spectrogram = wav_processor(audio, sampling_rate=sr, return_tensors="pt")["input_features"] |
|
|
|
samples = { |
|
"spectrogram": spectrogram, |
|
"raw_wav": torch.from_numpy(audio).unsqueeze(0), |
|
"padding_mask": torch.zeros(len(audio), dtype=torch.bool).unsqueeze(0), |
|
} |
|
if cuda_enabled: |
|
samples = move_to_cuda(samples) |
|
|
|
return samples |