Spaces:
Runtime error
Runtime error
| import logging | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from omegaconf import OmegaConf | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.table import Table | |
| logger = logging.getLogger(__name__) | |
| console = Console() | |
| def _make_stft_cfg(hop_length, win_length=None): | |
| if win_length is None: | |
| win_length = 4 * hop_length | |
| n_fft = 2 ** (win_length - 1).bit_length() | |
| return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) | |
| def _build_rich_table(rows, columns, title=None): | |
| table = Table(title=title, header_style=None) | |
| for column in columns: | |
| table.add_column(column.capitalize(), justify="left") | |
| for row in rows: | |
| table.add_row(*map(str, row)) | |
| return Panel(table, expand=False) | |
| def _rich_print_dict(d, title="Config", key="Key", value="Value"): | |
| console.print(_build_rich_table(d.items(), [key, value], title)) | |
| class HParams: | |
| # Dataset | |
| fg_dir: Path = Path("data/fg") | |
| bg_dir: Path = Path("data/bg") | |
| rir_dir: Path = Path("data/rir") | |
| load_fg_only: bool = False | |
| praat_augment_prob: float = 0 | |
| # Audio settings | |
| wav_rate: int = 44_100 | |
| n_fft: int = 2048 | |
| win_size: int = 2048 | |
| hop_size: int = 420 # 9.5ms | |
| num_mels: int = 128 | |
| stft_magnitude_min: float = 1e-4 | |
| preemphasis: float = 0.97 | |
| mix_alpha_range: tuple[float, float] = (0.2, 0.8) | |
| # Training | |
| nj: int = 64 | |
| training_seconds: float = 1.0 | |
| batch_size_per_gpu: int = 16 | |
| min_lr: float = 1e-5 | |
| max_lr: float = 1e-4 | |
| warmup_steps: int = 1000 | |
| max_steps: int = 1_000_000 | |
| gradient_clipping: float = 1.0 | |
| def deepspeed_config(self): | |
| return { | |
| "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, | |
| "optimizer": { | |
| "type": "Adam", | |
| "params": {"lr": float(self.min_lr)}, | |
| }, | |
| "scheduler": { | |
| "type": "WarmupDecayLR", | |
| "params": { | |
| "warmup_min_lr": float(self.min_lr), | |
| "warmup_max_lr": float(self.max_lr), | |
| "warmup_num_steps": self.warmup_steps, | |
| "total_num_steps": self.max_steps, | |
| "warmup_type": "linear", | |
| }, | |
| }, | |
| "gradient_clipping": self.gradient_clipping, | |
| } | |
| def stft_cfgs(self): | |
| assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" | |
| return [_make_stft_cfg(h) for h in (100, 256, 512)] | |
| def from_yaml(cls, path: Path) -> "HParams": | |
| logger.info(f"Reading hparams from {path}") | |
| # First merge to fix types (e.g., str -> Path) | |
| return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) | |
| def save_if_not_exists(self, run_dir: Path): | |
| path = run_dir / "hparams.yaml" | |
| if path.exists(): | |
| logger.info(f"{path} already exists, not saving") | |
| return | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| OmegaConf.save(asdict(self), str(path)) | |
| def load(cls, run_dir, yaml: Path | None = None): | |
| hps = [] | |
| if (run_dir / "hparams.yaml").exists(): | |
| hps.append(cls.from_yaml(run_dir / "hparams.yaml")) | |
| if yaml is not None: | |
| hps.append(cls.from_yaml(yaml)) | |
| if len(hps) == 0: | |
| hps.append(cls()) | |
| for hp in hps[1:]: | |
| if hp != hps[0]: | |
| errors = {} | |
| for k, v in asdict(hp).items(): | |
| if getattr(hps[0], k) != v: | |
| errors[k] = f"{getattr(hps[0], k)} != {v}" | |
| raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") | |
| return hps[0] | |
| def print(self): | |
| _rich_print_dict(asdict(self), title="HParams") | |