Qwen3-Coder-30B-A3B-Instruct-ScatterMoE / modeling_qwen3_shared_moe_monkeypatch_liger_cce.py
Doctor-Shotgun's picture
Add training monkeypatches
6742e10 verified
# coding=utf-8
# Copyright 2025 Charles O. Goddard, The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# The following monkeypatches were applied by Doctor Shotgun:
#
# Liger Kernel (https://github.com/linkedin/Liger-Kernel):
# 1. Liger RMSNorm
# 2. Liger RoPE
# 3. Liger SwiGLUMLP
#
# Cut Cross-Entropy (https://github.com/apple/ml-cross-entropy):
# 1. Cut Cross-Entropy
"""PyTorch Qwen3 model with shared expert support."""
from typing import List, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
# CCE Patch #
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
from cut_cross_entropy.transformers.utils import (
PatchOptions,
apply_lce,
)
_PATCH_OPTS = PatchOptions(
impl=LCE_IMPL_DEFAULT,
reduction="mean",
filter_eps="auto",
accum_e_fp32=False,
accum_c_fp32=False,
filter_e_grad=True,
filter_c_grad=True,
train_only=False,
)
# CCE Patch #
# Liger Patch #
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
from liger_kernel.transformers.rope import liger_rotary_pos_emb
import transformers.models.qwen3_moe.modeling_qwen3_moe
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
# Liger Patch #
from transformers.modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
from transformers.activations import ACT2FN
from transformers.utils import logging
from transformers.models.mixtral.modeling_mixtral import (
load_balancing_loss_func,
)
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeMLP,
Qwen3MoeRMSNorm,
Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeModel,
Qwen3MoeForCausalLM,
)
from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig
import scattermoe
logger = logging.get_logger(__name__)
class Qwen3SharedMoeSparseMoeBlock(nn.Module):
def __init__(self, config: Qwen3SharedMoeConfig):
super().__init__()
self.config = config
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
if config.shared_expert_intermediate_size is not None:
self.shared_expert = Qwen3MoeMLP(
config, intermediate_size=config.shared_expert_intermediate_size
)
else:
self.shared_expert = None
self.moe_mlp = scattermoe.mlp.GLUMLP(
input_size=self.config.hidden_size,
hidden_size=self.config.moe_intermediate_size,
num_experts=self.config.num_experts,
top_k=self.config.num_experts_per_tok,
activation=ACT2FN[config.hidden_act],
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# handling of gate/router logits copied from Qwen3MoeSparseMoeBlock
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.config.num_experts_per_tok, dim=-1
)
if self.config.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# modified here to use scattermoe + shared_expert
hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts)
if self.shared_expert is not None:
shared_res = self.shared_expert(hidden_states)
res = hs_0 + shared_res
else:
res = hs_0
res = res.reshape(batch_size, sequence_length, hidden_dim)
return res, router_logits
class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.hidden_size = config.hidden_size
self.self_attn = Qwen3MoeAttention(config, layer_idx)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3SharedMoeSparseMoeBlock(config)
else:
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
self.input_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
class Qwen3SharedMoeModel(Qwen3MoeModel):
config_class = Qwen3SharedMoeConfig
def __init__(self, config: Qwen3SharedMoeConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[
Qwen3SharedMoeDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
config_class = Qwen3SharedMoeConfig
def __init__(self, config):
super().__init__(config)
self.model = Qwen3SharedMoeModel(config)
self.num_experts = config.num_experts
# CCE Patch #
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if hidden_states is None:
raise ValueError("hidden_states is None")
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
# CCE Patch #