Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from torch import nn | |
from transformers import AutoConfig | |
from flashcosyvoice.config import CosyVoice2LLMConfig | |
from flashcosyvoice.modules.qwen2_components.layers import ( | |
ParallelLMHead, Qwen2DecoderLayer, RMSNorm, VocabParallelEmbedding) | |
class Qwen2Model(nn.Module): | |
def __init__( | |
self, | |
config: CosyVoice2LLMConfig, | |
): | |
super().__init__() | |
self.vocab_size = config.vocab_size | |
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) | |
self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
) -> torch.Tensor: | |
hidden_states = self.embed_tokens(input_ids) | |
residual = None | |
for layer in self.layers: | |
hidden_states, residual = layer( | |
positions, | |
hidden_states, | |
residual, | |
) | |
hidden_states, _ = self.norm(hidden_states, residual) | |
return hidden_states | |
class Qwen2ForCausalLM(nn.Module): | |
packed_modules_mapping = { | |
"q_proj": ("qkv_proj", "q"), | |
"k_proj": ("qkv_proj", "k"), | |
"v_proj": ("qkv_proj", "v"), | |
"gate_proj": ("gate_up_proj", 0), | |
"up_proj": ("gate_up_proj", 1), | |
} | |
def __init__( | |
self, | |
config: CosyVoice2LLMConfig | AutoConfig | |
): | |
super().__init__() | |
self.model = Qwen2Model(config) | |
if hasattr(config, "speech_vocab_size"): | |
self.lm_head = ParallelLMHead(config.speech_vocab_size, config.hidden_size, bias=getattr(config, "lm_head_bias", True)) | |
self.model_type = "speech_llm" | |
else: | |
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False) | |
self.model_type = "text_llm" | |
self.tie_word_embeddings = config.tie_word_embeddings | |
if self.tie_word_embeddings: | |
if self.model_type == "speech_llm": | |
assert config.vocab_size == config.speech_vocab_size, "vocab_size and speech_vocab_size must be the same when tie_word_embeddings is True" | |
self.lm_head.weight.data = self.model.embed_tokens.weight.data | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
) -> torch.Tensor: | |
hidden_states = self.model(input_ids, positions) | |
return hidden_states | |
def compute_logits( | |
self, | |
hidden_states: torch.Tensor, | |
) -> torch.Tensor: | |
logits = self.lm_head(hidden_states) | |
return logits | |