import torch.nn as nn from transformers import Wav2Vec2BertModel class SpoofVerificationModel(nn.Module): def __init__(self, w2v_path='facebook/w2v-bert-2.0', num_types=59): super(SpoofVerificationModel, self).__init__() self.wav2vec2 = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True) self.wav2vec_config = self.wav2vec2.config self.deepfake_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024) self.type_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024) self.deepfake_classifier = nn.Sequential( nn.ReLU(), nn.Linear(1024, 2) ) self.type_classifier = nn.Sequential( nn.ReLU(), nn.Linear(1024, num_types) ) # self.deepfake_classifier = nn.Sequential( # nn.Linear(self.wav2vec2.config.hidden_size, 1024), # nn.ReLU(), # nn.Linear(1024, 2) # ) # self.type_classifier = nn.Sequential( # nn.Linear(self.wav2vec2.config.hidden_size, 1024), # nn.ReLU(), # nn.Linear(1024, num_types) # ) def forward(self, audio_features): audio_features = self.wav2vec2(**audio_features) # [B, T, D] audio_features = audio_features.last_hidden_state # (B, T, D) audio_features = audio_features.mean(dim=1) # (B, D) # deepfake_logits = self.deepfake_classifier(audio_features) # type_logits = self.type_classifier(audio_features) deepfake_emb = self.deepfake_embed(audio_features) type_emb = self.type_embed(audio_features) deepfake_logits = self.deepfake_classifier(deepfake_emb) type_logits = self.type_classifier(type_emb) return { 'deepfake_logits': deepfake_logits, 'type_logits': type_logits, 'embeddings': audio_features, 'deepfake_embed': deepfake_emb, # 新增embedding输出 'type_embed': type_emb # 新增embedding输出 } # return { # 'deepfake_logits': deepfake_logits, # 'type_logits': type_logits, # 'embeddings': audio_features # } def print_parameters_info(self): print(f"wav2vec2 parameters: {sum(p.numel() for p in self.wav2vec2.parameters())/1e6:.2f}M") print(f"deepfake_classifier parameters: {sum(p.numel() for p in self.deepfake_classifier.parameters())/1e6:.2f}M") print(f"type_classifier parameters: {sum(p.numel() for p in self.type_classifier.parameters())/1e6:.2f}M")