streamlit_demo / app.py
ldhldh's picture
Create app.py
cc8419d verified
raw
history blame
2.84 kB
import streamlit as st
import os
from pytube import YouTube
import torch, torchaudio
import yaml # yaml ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค.
from src.models import models
# ๋ชจ๋ธ ์„ค์ • ๋ฐ ๋กœ๋”ฉ
device = "cuda" if torch.cuda.is_available() else "cpu"
with open('ko_model__whisper_specrnet.yaml', 'r') as f:
model_config = yaml.safe_load(f)
model_paths = model_config["checkpoint"]["path"]
model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"]
model = models.get_model(
model_name=model_name,
config=model_parameters,
device=device,
)
model.load_state_dict(torch.load(model_paths))
model = model.to(device)
model.eval()
# YouTube ๋น„๋””์˜ค ๋‹ค์šด๋กœ๋“œ ๋ฐ ์˜ค๋””์˜ค ์ถ”์ถœ ํ•จ์ˆ˜
def download_youtube_audio(youtube_url, output_path="temp"):
yt = YouTube(youtube_url)
audio_stream = yt.streams.get_audio_only()
output_file = audio_stream.download(output_path=output_path)
title = audio_stream.default_filename
return output_file, title
# URL๋กœ๋ถ€ํ„ฐ ์˜ˆ์ธก
def pred_from_url(youtube_url):
global model
audio_path, title = download_youtube_audio(youtube_url)
waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
pred = model(waveform.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
os.remove(audio_path) # ํŒŒ์ผ ์‚ญ์ œ ์ฝ”๋“œ ์œ„์น˜ ์ˆ˜์ •
return (f"{title}\n\n{(pred[0][0]*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค.")
# ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ ์˜ˆ์ธก
def pred_from_file(file_path):
global model
waveform, sample_rate = torchaudio.load(file_path, normalize=True)
waveform, sample_rate = apply_preprocessing(waveform, sample_rate)
pred = model(waveform.unsqueeze(0).to(device))
pred = torch.sigmoid(pred)
return f"{(pred[0][0]*100):.2f}% ํ™•๋ฅ ๋กœ fake์ž…๋‹ˆ๋‹ค."
# Streamlit UI
st.title("DeepFake Detection Demo")
st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub ๊ฐ์„ฑ ๋ฐ ๋ฐœํ™”์Šคํƒ€์ผ ๋™์‹œ ๊ณ ๋ ค ์Œ์„ฑํ•ฉ์„ฑ ๋ฐ์ดํ„ฐ, ์ž์ฒด ์ˆ˜์ง‘ ๋ฐ ์ƒ์„ฑํ•œ KoAAD)")
st.markdown("original code from https://github.com/piotrkawa/deepfake-whisper-features")
tab1, tab2 = st.tabs(["YouTube URL", "ํŒŒ์ผ ์—…๋กœ๋“œ"])
with tab1:
youtube_url = st.text_input("YouTube URL")
if st.button("RUN URL"):
result = pred_from_url(youtube_url)
st.text_area("๊ฒฐ๊ณผ", value=result, height=150)
with tab2:
file = st.file_uploader("์˜ค๋””์˜ค ํŒŒ์ผ ์—…๋กœ๋“œ", type=['mp3', 'wav'])
if file is not None and st.button("RUN ํŒŒ์ผ"):
# ์ž„์‹œ ํŒŒ์ผ ์ €์žฅ
with open(file.name, "wb") as f:
f.write(file.getbuffer())
result = pred_from_file(file.name)
st.text_area("๊ฒฐ๊ณผ", value=result, height=150)
os.remove(file.name) # ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ