import streamlit as st import os from pytube import YouTube import torch, torchaudio import yaml # yaml 라이브러리를 불러옵니다. from src.models import models # 모델 설정 및 로딩 device = "cuda" if torch.cuda.is_available() else "cpu" with open('ko_model__whisper_specrnet.yaml', 'r') as f: model_config = yaml.safe_load(f) model_paths = model_config["checkpoint"]["path"] model_name, model_parameters = model_config["model"]["name"], model_config["model"]["parameters"] model = models.get_model( model_name=model_name, config=model_parameters, device=device, ) model.load_state_dict(torch.load(model_paths)) model = model.to(device) model.eval() # YouTube 비디오 다운로드 및 오디오 추출 함수 def download_youtube_audio(youtube_url, output_path="temp"): yt = YouTube(youtube_url) audio_stream = yt.streams.get_audio_only() output_file = audio_stream.download(output_path=output_path) title = audio_stream.default_filename return output_file, title # URL로부터 예측 def pred_from_url(youtube_url): global model audio_path, title = download_youtube_audio(youtube_url) waveform, sample_rate = torchaudio.load(audio_path, normalize=True) waveform, sample_rate = apply_preprocessing(waveform, sample_rate) pred = model(waveform.unsqueeze(0).to(device)) pred = torch.sigmoid(pred) os.remove(audio_path) # 파일 삭제 코드 위치 수정 return (f"{title}\n\n{(pred[0][0]*100):.2f}% 확률로 fake입니다.") # 파일로부터 예측 def pred_from_file(file_path): global model waveform, sample_rate = torchaudio.load(file_path, normalize=True) waveform, sample_rate = apply_preprocessing(waveform, sample_rate) pred = model(waveform.unsqueeze(0).to(device)) pred = torch.sigmoid(pred) return f"{(pred[0][0]*100):.2f}% 확률로 fake입니다." # Streamlit UI st.title("DeepFake Detection Demo") st.markdown("whisper-specrnet (using MLAAD, MAILABS, aihub 감성 및 발화스타일 동시 고려 음성합성 데이터, 자체 수집 및 생성한 KoAAD)") st.markdown("original code from https://github.com/piotrkawa/deepfake-whisper-features") tab1, tab2 = st.tabs(["YouTube URL", "파일 업로드"]) with tab1: youtube_url = st.text_input("YouTube URL") if st.button("RUN URL"): result = pred_from_url(youtube_url) st.text_area("결과", value=result, height=150) with tab2: file = st.file_uploader("오디오 파일 업로드", type=['mp3', 'wav']) if file is not None and st.button("RUN 파일"): # 임시 파일 저장 with open(file.name, "wb") as f: f.write(file.getbuffer()) result = pred_from_file(file.name) st.text_area("결과", value=result, height=150) os.remove(file.name) # 임시 파일 삭제