Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import os | |
| from pytube import YouTube | |
| import torch, torchaudio | |
| import yaml # yaml ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ถ๋ฌ์ต๋๋ค. | |
| from src.models import models | |
| # ๋ชจ๋ธ ์ค์ ๋ฐ ๋ก๋ฉ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| with open('ko_model__whisper_specrnet.yaml', 'r') as f: | |
| model_config = yaml.safe_load(f) | |
| model_paths = model_config["checkpoint"]["path"] | |
| model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"] | |
| model = models.get_model( | |
| model_name=model_name, | |
| config=model_parameters, | |
| device=device, | |
| ) | |
| model.load_state_dict(torch.load(model_paths)) | |
| model = model.to(device) | |
| model.eval() | |
| # YouTube ๋น๋์ค ๋ค์ด๋ก๋ ๋ฐ ์ค๋์ค ์ถ์ถ ํจ์ | |
| def download_youtube_audio(youtube_url, output_path="temp"): | |
| yt = YouTube(youtube_url) | |
| audio_stream = yt.streams.get_audio_only() | |
| output_file = audio_stream.download(output_path=output_path) | |
| title = audio_stream.default_filename | |
| return output_file, title | |
| # URL๋ก๋ถํฐ ์์ธก | |
| def pred_from_url(youtube_url): | |
| global model | |
| audio_path, title = download_youtube_audio(youtube_url) | |
| waveform, sample_rate = torchaudio.load(audio_path, normalize=True) | |
| waveform, sample_rate = apply_preprocessing(waveform, sample_rate) | |
| pred = model(waveform.unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| os.remove(audio_path) # ํ์ผ ์ญ์ ์ฝ๋ ์์น ์์ | |
| return (f"{title}\n\n{(pred[0][0]*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค.") | |
| # ํ์ผ๋ก๋ถํฐ ์์ธก | |
| def pred_from_file(file_path): | |
| global model | |
| waveform, sample_rate = torchaudio.load(file_path, normalize=True) | |
| waveform, sample_rate = apply_preprocessing(waveform, sample_rate) | |
| pred = model(waveform.unsqueeze(0).to(device)) | |
| pred = torch.sigmoid(pred) | |
| return f"{(pred[0][0]*100):.2f}% ํ๋ฅ ๋ก fake์ ๋๋ค." | |
| # Streamlit UI | |
| st.title("DeepFake Detection Demo") | |
| st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub ๊ฐ์ฑ ๋ฐ ๋ฐํ์คํ์ผ ๋์ ๊ณ ๋ ค ์์ฑํฉ์ฑ ๋ฐ์ดํฐ, ์์ฒด ์์ง ๋ฐ ์์ฑํ KoAAD)") | |
| st.markdown("original code from https://github.com/piotrkawa/deepfake-whisper-features") | |
| tab1, tab2 = st.tabs(["YouTube URL", "ํ์ผ ์ ๋ก๋"]) | |
| with tab1: | |
| youtube_url = st.text_input("YouTube URL") | |
| if st.button("RUN URL"): | |
| result = pred_from_url(youtube_url) | |
| st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
| with tab2: | |
| file = st.file_uploader("์ค๋์ค ํ์ผ ์ ๋ก๋", type=['mp3', 'wav']) | |
| if file is not None and st.button("RUN ํ์ผ"): | |
| # ์์ ํ์ผ ์ ์ฅ | |
| with open(file.name, "wb") as f: | |
| f.write(file.getbuffer()) | |
| result = pred_from_file(file.name) | |
| st.text_area("๊ฒฐ๊ณผ", value=result, height=150) | |
| os.remove(file.name) # ์์ ํ์ผ ์ญ์ | |