Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
from transformers import Wav2Vec2BertModel | |
class SpoofVerificationModel(nn.Module): | |
def __init__(self, w2v_path='facebook/w2v-bert-2.0', num_types=59): | |
super(SpoofVerificationModel, self).__init__() | |
self.wav2vec2 = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True) | |
self.wav2vec_config = self.wav2vec2.config | |
self.deepfake_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024) | |
self.type_embed = nn.Linear(self.wav2vec2.config.hidden_size, 1024) | |
self.deepfake_classifier = nn.Sequential( | |
nn.ReLU(), | |
nn.Linear(1024, 2) | |
) | |
self.type_classifier = nn.Sequential( | |
nn.ReLU(), | |
nn.Linear(1024, num_types) | |
) | |
# self.deepfake_classifier = nn.Sequential( | |
# nn.Linear(self.wav2vec2.config.hidden_size, 1024), | |
# nn.ReLU(), | |
# nn.Linear(1024, 2) | |
# ) | |
# self.type_classifier = nn.Sequential( | |
# nn.Linear(self.wav2vec2.config.hidden_size, 1024), | |
# nn.ReLU(), | |
# nn.Linear(1024, num_types) | |
# ) | |
def forward(self, audio_features): | |
audio_features = self.wav2vec2(**audio_features) # [B, T, D] | |
audio_features = audio_features.last_hidden_state # (B, T, D) | |
audio_features = audio_features.mean(dim=1) # (B, D) | |
# deepfake_logits = self.deepfake_classifier(audio_features) | |
# type_logits = self.type_classifier(audio_features) | |
deepfake_emb = self.deepfake_embed(audio_features) | |
type_emb = self.type_embed(audio_features) | |
deepfake_logits = self.deepfake_classifier(deepfake_emb) | |
type_logits = self.type_classifier(type_emb) | |
return { | |
'deepfake_logits': deepfake_logits, | |
'type_logits': type_logits, | |
'embeddings': audio_features, | |
'deepfake_embed': deepfake_emb, # 新增embedding输出 | |
'type_embed': type_emb # 新增embedding输出 | |
} | |
# return { | |
# 'deepfake_logits': deepfake_logits, | |
# 'type_logits': type_logits, | |
# 'embeddings': audio_features | |
# } | |
def print_parameters_info(self): | |
print(f"wav2vec2 parameters: {sum(p.numel() for p in self.wav2vec2.parameters())/1e6:.2f}M") | |
print(f"deepfake_classifier parameters: {sum(p.numel() for p in self.deepfake_classifier.parameters())/1e6:.2f}M") | |
print(f"type_classifier parameters: {sum(p.numel() for p in self.type_classifier.parameters())/1e6:.2f}M") | |