Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass, field | |
from typing import Literal | |
class BackboneConfig: | |
d_model: int = 1024 | |
d_intermediate: int = 0 | |
attn_mlp_d_intermediate: int = 0 | |
n_layer: int = 16 | |
ssm_cfg: dict = field(default_factory=dict) | |
attn_layer_idx: list = field(default_factory=list) | |
attn_cfg: dict = field(default_factory=dict) | |
rms_norm: bool = False | |
residual_in_fp32: bool = False | |
norm_epsilon: float = 1e-5 | |
class PrefixConditionerConfig: | |
conditioners: list[dict] | |
projection: Literal["none", "linear", "mlp"] | |
class ZonosConfig: | |
backbone: BackboneConfig | |
prefix_conditioner: PrefixConditionerConfig | |
eos_token_id: int = 1024 | |
masked_token_id: int = 1025 | |
def from_dict(cls, d: dict) -> "ZonosConfig": | |
d = d.copy() | |
backbone_config = BackboneConfig(**d.pop("backbone")) | |
prefix_conditioner_config = PrefixConditionerConfig(**d.pop("prefix_conditioner")) | |
config = cls(backbone_config, prefix_conditioner_config, **d) | |
return config | |