Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,254 Bytes
5db4ee9 0a63b23 719b808 b82b421 88a8fb2 719b808 b82b421 ca6274c e7b9619 0a63b23 4a81ee5 719b808 b82b421 c24b00b 719b808 4a81ee5 719b808 4a81ee5 719b808 4a81ee5 4151dd8 4a81ee5 0aed679 4151dd8 0aed679 b82b421 4a81ee5 0aed679 4a81ee5 b82b421 0aed679 4a81ee5 0aed679 4a81ee5 0aed679 4a81ee5 719b808 4a81ee5 0aed679 88a8fb2 b82b421 0aed679 34146f0 4151dd8 0aed679 af54972 b82b421 34146f0 b82b421 0aed679 88a8fb2 34146f0 b82b421 4151dd8 afbf6ef 87c8d2e 4a81ee5 ca6274c afbf6ef ca6274c 88a8fb2 7c8310f ca6274c 88a8fb2 18c21b7 88a8fb2 efcd998 88a8fb2 efcd998 e9c3646 ca6274c 88a8fb2 ca6274c 88a8fb2 719b808 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import spaces
import gradio as gr
import os
import torch
from model import SpoofVerificationModel # 自定义模型模块
import dataset # 自定义数据集模块
from huggingface_hub import hf_hub_download
from transformers import AutoFeatureExtractor
@spaces.GPU
def dummy(): # just a dummy
pass
# 修改 load_model 函数
def load_model():
checkpoint_path = hf_hub_download(
repo_id="amphion/deepfake_detection",
filename="checkpoints_w2v-bert_SpoofVerification_MultiDataset/model_checkpoint_4_new.pth",
repo_type="model"
)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
return checkpoint_path
checkpoint_path = load_model()
# 将 detect 函数移到 GPU 装饰器下
@spaces.GPU
def detect_on_gpu(audio_path):
"""在 GPU 上进行音频伪造检测"""
print("\n=== 开始音频检测 ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据集处理移到GPU函数内部
audio_dataset = dataset.DemoDataset(audio_path)
print("正在初始化模型...")
model = SpoofVerificationModel().to(device)
print(f"正在加载模型权重: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model_state_dict = checkpoint['model_state_dict']
threshold = 0.5
print(f"检测阈值设置为: {threshold}")
# 处理模型状态字典的 key
if hasattr(model, 'module') and not any(key.startswith('module.') for key in model_state_dict.keys()):
print("添加 'module.' 前缀到状态字典的 key")
model_state_dict = {'module.' + key: value for key, value in model_state_dict.items()}
elif not hasattr(model, 'module') and any(key.startswith('module.') for key in model_state_dict.keys()):
print("移除状态字典 key 中的 'module.' 前缀")
model_state_dict = {key.replace('module.', ''): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
model.eval()
print("模型加载完成,进入评估模式")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
print("\n开始处理音频数据...")
with torch.no_grad():
for batch_idx, batch in enumerate(audio_dataset):
print(f"\n处理批次 {batch_idx + 1}")
if len(batch['waveforms'].shape) == 1:
batch['waveforms'] = batch['waveforms'].unsqueeze(0)
print('shape:', batch['waveforms'].shape)
waveforms = batch['waveforms'].numpy() # [B, T]
features = feature_extractor(waveforms, sampling_rate=16000, return_attention_mask=True, padding_value=0, return_tensors="pt").to(device)
outputs = model(features)
deepfake_logits = outputs['deepfake_logits']
deepfake_scores = deepfake_logits.float().softmax(dim=-1)[:, 1].contiguous()
is_fake = deepfake_scores[0].item() > threshold
result = {"is_fake": is_fake, "confidence": deepfake_scores[0] if is_fake else 1-deepfake_scores[0]}
break
print("\n=== 检测完成 ===")
return result
def audio_deepfake_detection(audio_path):
# 移除了数据集处理步骤
# 直接传递音频路径到GPU函数
result = detect_on_gpu(audio_path)
is_fake = "是/Yes" if result["is_fake"] else "否/No"
confidence = f"{100*result['confidence']:.2f}%"
return {
"是否为AI生成/Is AI Generated": is_fake,
"检测可信度/Confidence": confidence
}
# Gradio 界面
def gradio_ui():
# def detection_wrapper(demonstration_audio1, label1, demonstration_audio2, label2, demonstration_audio3, label3, query_audio):
# demonstrations = [
# (demonstration_audio1, label1),
# (demonstration_audio2, label2),
# (demonstration_audio3, label3),
# ]
# return audio_deepfake_detection(demonstrations,query_audio)
# interface = gr.Interface(
# fn=detection_wrapper,
# inputs=[
# gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 1"),
# gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 1"),
# gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 2"),
# gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 2"),
# gr.Audio(sources=["upload"], type="filepath", label="Demonstration Audio 3"),
# gr.Dropdown(choices=["bonafide", "spoof"], value="bonafide", label="Label 3"),
# gr.Audio(sources=["upload"], type="filepath", label="Query Audio (Audio for Detection)")
# ],
# outputs=gr.JSON(label="Detection Results"),
# title="Audio Deepfake Detection System",
# description="Upload demonstration audios and a query audio to detect whether the query is AI-generated.",
# )
# return interface
def detection_wrapper(query_audio):
return audio_deepfake_detection(query_audio)
interface = gr.Interface(
fn=detection_wrapper,
inputs=[
gr.Audio(sources=["upload"], type="filepath", label="测试音频 / Test Audio")
],
outputs=gr.JSON(label="检测结果 / Detection Results"),
title="音频伪造检测系统 / Audio Deepfake Detection System",
description="上传一个测试音频以检测该音频是否为AI生成。/ Upload a test audio to detect whether the audio is AI-generated.",
article=(
"由香港中文大学(深圳)武执政教授团队开发。"
"Developed by a team led by Prof Zhizheng Wu from the Chinese University of Hong Kong, Shenzhen."
"\n\n"
"本系统用于检测音频是否为AI生成,适用于研究和教育目的。"
"This system is designed to detect whether an audio is AI-generated, "
"and is intended for research and educational purposes."
)
)
return interface
if __name__ == "__main__":
demo = gradio_ui()
demo.launch()
|