File size: 2,842 Bytes
cc8419d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import streamlit as st
import os
from pytube import YouTube
import torch, torchaudio
import yaml  # yaml ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
from src.models import models

# ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋”ฉ
device = "cuda" if torch.cuda.is_available() else "cpu"
with open('ko_model__whisper_specrnet.yaml', 'r') as f:
    model_config = yaml.safe_load(f)
model_paths = model_config["checkpoint"]["path"]
model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"]

model = models.get_model(
    model_name=model_name,
    config=model_parameters,
    device=device,
)
model.load_state_dict(torch.load(model_paths))
model = model.to(device)
model.eval()

# YouTube ๋น„๋””์˜ค ๋‹ค์šด๋กœ๋“œ ๋ฐ ์˜ค๋””์˜ค ์ถ”์ถœ ํ•จ์ˆ˜
def download_youtube_audio(youtube_url, output_path="temp"):
    yt = YouTube(youtube_url)
    audio_stream = yt.streams.get_audio_only()
    output_file = audio_stream.download(output_path=output_path)
    title = audio_stream.default_filename
    return output_file, title

# URL๋กœ๋ถ€ํ„ฐ ์˜ˆ์ธก
def pred_from_url(youtube_url): 
    global model
    audio_path, title = download_youtube_audio(youtube_url)
    waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
    waveform, sample_rate = apply_preprocessing(waveform, sample_rate)

    pred = model(waveform.unsqueeze(0).to(device))
    pred = torch.sigmoid(pred)

    os.remove(audio_path)  # ํŒŒ์ผ ์‚ญ์ œ ์ฝ”๋“œ ์œ„์น˜ ์ˆ˜์ •
    return (f"{title}\n\n{(pred[0][0]*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค.")

# ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ ์˜ˆ์ธก
def pred_from_file(file_path): 
    global model
    waveform, sample_rate = torchaudio.load(file_path, normalize=True)
    waveform, sample_rate = apply_preprocessing(waveform, sample_rate)

    pred = model(waveform.unsqueeze(0).to(device))
    pred = torch.sigmoid(pred)

    return f"{(pred[0][0]*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค."

# Streamlit UI
st.title("DeepFake Detection Demo")
st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub ๊ฐ์„ฑ ๋ฐ ๋ฐœํ™”์Šคํƒ€์ผ ๋™์‹œ ๊ณ ๋ ค ์Œ์„ฑํ•ฉ์„ฑ ๋ฐ์ดํ„ฐ, ์ž์ฒด ์ˆ˜์ง‘ ๋ฐ ์ƒ์„ฑํ•œ KoAAD)")
st.markdown("original code from https://github.com/piotrkawa/deepfake-whisper-features")

tab1, tab2 = st.tabs(["YouTube URL", "ํŒŒ์ผ ์—…๋กœ๋“œ"])

with tab1:
    youtube_url = st.text_input("YouTube URL")
    if st.button("RUN URL"):
        result = pred_from_url(youtube_url)
        st.text_area("๊ฒฐ๊ณผ", value=result, height=150)

with tab2:
    file = st.file_uploader("์˜ค๋””์˜ค ํŒŒ์ผ ์—…๋กœ๋“œ", type=['mp3', 'wav'])
    if file is not None and st.button("RUN ํŒŒ์ผ"):
        # ์ž„์‹œ ํŒŒ์ผ ์ €์žฅ
        with open(file.name, "wb") as f:
            f.write(file.getbuffer())
        result = pred_from_file(file.name)
        st.text_area("๊ฒฐ๊ณผ", value=result, height=150)
        os.remove(file.name)  # ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ