ireneminhee's picture
Update app.py
da304d0 verified
import gradio as gr
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from pydub import AudioSegment
import re
# ์—…๋กœ๋“œํ•œ ๋ชจ๋ธ ๋กœ๋“œ
repo_name = "ireneminhee/speech-to-depression"
model = WhisperForConditionalGeneration.from_pretrained(repo_name)
processor = WhisperProcessor.from_pretrained(repo_name)
# ์Œ์„ฑ์„ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
def transcribe(audio):
inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
generated_ids = model.generate(inputs.input_features)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription
# ์šฐ์šธ์ฆ ์˜ˆ์ธก ๋ชจ๋ธ ๋กœ๋“œ
def load_model_from_safetensors(model_name, safetensors_path):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=model_name)
state_dict = torch.load(safetensors_path) # safetensors๋ฅผ ๋ชจ๋ธ๋กœ ๋กœ๋“œ
model.load_state_dict(state_dict)
model.eval()
return model, tokenizer
# ์˜ˆ์ธก ํ•จ์ˆ˜
def predict_depression(sentences, model, tokenizer):
results = []
for sentence in sentences:
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1).item()
results.append((sentence, prediction))
return results
# ์ „์ฒด ํ”„๋กœ์„ธ์Šค๋ฅผ ์‹คํ–‰ํ•˜๋Š” ํ•จ์ˆ˜
def process_audio_and_predict(audio):
# 1. Whisper ๋ชจ๋ธ๋กœ ์Œ์„ฑ์„ ํ…์ŠคํŠธ๋กœ ๋ณ€ํ™˜
text = transcribe_audio(audio)
# 2. ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๊ธฐ
#sentences = split_sentences_using_gpt(text)
# 3. ๋ชจ๋ธ ๋กœ๋“œ (๋ชจ๋ธ ๊ฒฝ๋กœ์— ๋งž๊ฒŒ ์ˆ˜์ •)
# ๋ชจ๋ธ๊ณผ tokenizer ๊ฒฝ๋กœ (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์„ค์ •)
safetensors_path = "./model/model.safetensors" # SafeTensors ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
model_name = "klue/bert-base" # ๋ชจ๋ธ ์ด๋ฆ„ ๋˜๋Š” Hugging Face ๊ฒฝ๋กœ
model, tokenizer = load_model_from_safetensors(model_name, safetensors_path)
# 4. ๋ฌธ์žฅ๋ณ„๋กœ ์šฐ์šธ ์ฆ์ƒ ์˜ˆ์ธก
results = predict_depression(text, model, tokenizer)
# 5. ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜
df_result = pd.DataFrame(results, columns=["Sentence", "Depression_Prediction"])
average_probability = df_result["Depression_Prediction"].mean()
return f"Average Depression Probability: {average_probability:.2f}"
# Gradio ์ธํ„ฐํŽ˜์ด์Šค๋กœ ์—ฐ๊ฒฐํ•  ํ•จ์ˆ˜
def gradio_process_audio(audio_data):
# ์‚ฌ์šฉ์ž๊ฐ€ ๋งˆ์ดํฌ๋กœ ์ž…๋ ฅํ•œ ์Œ์„ฑ์„ ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ
temp_audio_path = "temp_audio.wav"
with open(temp_audio_path, "wb") as f:
f.write(audio_data)
# ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ์˜ˆ์ธก
average_probability, df_result = process_audio_and_detect_depression(temp_audio_path, safetensors_path, model_name)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
return f"Average Depression Probability: {average_probability:.2f}", df_result
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
interface = gr.Interface(
fn=gradio_process_audio, # Gradio์—์„œ ํ˜ธ์ถœํ•  ํ•จ์ˆ˜
inputs=gr.Audio(type="numpy"), # ์‚ฌ์šฉ์ž ์Œ์„ฑ ์ž…๋ ฅ (๋งˆ์ดํฌ)
outputs=[
gr.Textbox(label="Depression Probability"), # ํ‰๊ท  ํ™•๋ฅ 
gr.Dataframe(label="Sentence-wise Analysis") # ์ƒ์„ธ ๋ถ„์„ ๊ฒฐ๊ณผ
],
title="Depression Detection from Audio",
description="Record your voice, and the model will analyze the text for depression likelihood."
)
# Gradio ์‹คํ–‰
interface.launch(share=True)