Spaces:
Running
on
L4
Running
on
L4
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from lightning import LightningDataModule | |
| from torch.utils.data import DataLoader, Dataset | |
| from fish_speech.utils import RankedLogger | |
| logger = RankedLogger(__name__, rank_zero_only=False) | |
| class VQGANDataset(Dataset): | |
| def __init__( | |
| self, | |
| filelist: str, | |
| sample_rate: int = 32000, | |
| hop_length: int = 640, | |
| slice_frames: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| filelist = Path(filelist) | |
| root = filelist.parent | |
| self.files = [ | |
| root / line.strip() | |
| for line in filelist.read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| self.sample_rate = sample_rate | |
| self.hop_length = hop_length | |
| self.slice_frames = slice_frames | |
| def __len__(self): | |
| return len(self.files) | |
| def get_item(self, idx): | |
| file = self.files[idx] | |
| audio, _ = librosa.load(file, sr=self.sample_rate, mono=True) | |
| # Slice audio and features | |
| if ( | |
| self.slice_frames is not None | |
| and audio.shape[0] > self.slice_frames * self.hop_length | |
| ): | |
| start = np.random.randint( | |
| 0, audio.shape[0] - self.slice_frames * self.hop_length | |
| ) | |
| audio = audio[start : start + self.slice_frames * self.hop_length] | |
| if len(audio) == 0: | |
| return None | |
| max_value = np.abs(audio).max() | |
| if max_value > 1.0: | |
| audio = audio / max_value | |
| return { | |
| "audio": torch.from_numpy(audio), | |
| } | |
| def __getitem__(self, idx): | |
| try: | |
| return self.get_item(idx) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| logger.error(f"Error loading {self.files[idx]}: {e}") | |
| return None | |
| class VQGANCollator: | |
| def __call__(self, batch): | |
| batch = [x for x in batch if x is not None] | |
| audio_lengths = torch.tensor([len(x["audio"]) for x in batch]) | |
| audio_maxlen = audio_lengths.max() | |
| # Rounds up to nearest multiple of 2 (audio_lengths) | |
| audios = [] | |
| for x in batch: | |
| audios.append( | |
| torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"]))) | |
| ) | |
| return { | |
| "audios": torch.stack(audios), | |
| "audio_lengths": audio_lengths, | |
| } | |
| class VQGANDataModule(LightningDataModule): | |
| def __init__( | |
| self, | |
| train_dataset: VQGANDataset, | |
| val_dataset: VQGANDataset, | |
| batch_size: int = 32, | |
| num_workers: int = 4, | |
| val_batch_size: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.train_dataset = train_dataset | |
| self.val_dataset = val_dataset | |
| self.batch_size = batch_size | |
| self.val_batch_size = val_batch_size or batch_size | |
| self.num_workers = num_workers | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| collate_fn=VQGANCollator(), | |
| num_workers=self.num_workers, | |
| shuffle=True, | |
| persistent_workers=True, | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.val_batch_size, | |
| collate_fn=VQGANCollator(), | |
| num_workers=self.num_workers, | |
| persistent_workers=True, | |
| ) | |
| if __name__ == "__main__": | |
| dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt") | |
| dataloader = DataLoader( | |
| dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator() | |
| ) | |
| for batch in dataloader: | |
| print(batch["audios"].shape) | |
| print(batch["features"].shape) | |
| print(batch["audio_lengths"]) | |
| print(batch["feature_lengths"]) | |
| break | |