import torch import torch.nn as nn from mamba_ssm.models.mixer_seq_simple import create_block from mamba_ssm.ops.triton.layer_norm import layer_norm_fn from mamba_ssm.utils.generation import InferenceParams from zonos.config import BackboneConfig class ZonosBackbone(nn.Module): def __init__(self, config: BackboneConfig): super().__init__() self.config = config self.layers = nn.ModuleList( [ create_block( d_model=config.d_model, d_intermediate=config.d_intermediate if (i not in config.attn_layer_idx) else config.attn_mlp_d_intermediate, ssm_cfg=config.ssm_cfg, layer_idx=i, attn_layer_idx=config.attn_layer_idx, attn_cfg=config.attn_cfg, norm_epsilon=config.norm_epsilon, residual_in_fp32=config.residual_in_fp32, fused_add_norm=True, rms_norm=config.rms_norm, ) for i in range(config.n_layer) ] ) self.norm_f = nn.LayerNorm(config.d_model, eps=config.norm_epsilon) def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams | None = None): residual = None for layer in self.layers: hidden_states, residual = layer(hidden_states, residual, inference_params) return layer_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, residual, eps=self.norm_f.eps, residual_in_fp32=self.config.residual_in_fp32, is_rms_norm=self.config.rms_norm, )