Zonos / zonos /config.py
Steveeeeeeen's picture
Steveeeeeeen HF staff
ZeroGPU (#2)
0af138e verified
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class BackboneConfig:
d_model: int = 1024
d_intermediate: int = 0
attn_mlp_d_intermediate: int = 0
n_layer: int = 16
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = False
residual_in_fp32: bool = False
norm_epsilon: float = 1e-5
@dataclass
class PrefixConditionerConfig:
conditioners: list[dict]
projection: Literal["none", "linear", "mlp"]
@dataclass
class ZonosConfig:
backbone: BackboneConfig
prefix_conditioner: PrefixConditionerConfig
eos_token_id: int = 1024
masked_token_id: int = 1025
@classmethod
def from_dict(cls, d: dict) -> "ZonosConfig":
d = d.copy()
backbone_config = BackboneConfig(**d.pop("backbone"))
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner"))
config = cls(backbone_config, prefix_conditioner_config, **d)
return config