sbv2-gradio-app / app.py
buchi-stdesign's picture
Update app.py
fa548a8 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
# モデルとトークナイザーをロード
tokenizer = AutoTokenizer.from_pretrained("buchi-stdesign/style-bert-vits2-demo")
model = GPT2LMHeadModel.from_pretrained("buchi-stdesign/style-bert-vits2-demo")
# 音声生成関数
def text_to_speech(text, speaker, emotion):
if not text:
return "Error: 入力テキストが空です"
# テキストをトークナイズして、パディングとトランケートを行い attention_mask を生成
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask'] # Attention maskを取得
# モデルの generate メソッドに attention_mask を渡す
try:
generated_output = model.generate(
input_ids=input_ids, # トークナイズされたテキスト
attention_mask=attention_mask, # 明示的に attention mask を渡す
max_length=50, # 最大長
num_return_sequences=1, # 生成するシーケンス数
do_sample=True, # サンプリングを使用
top_k=50, # トークンの上位50個からサンプリング
top_p=0.95 # トークンの累積確率0.95までのトークンを使用
)
except Exception as e:
return f"Error in model generation: {str(e)}"
# 生成されたトークンをデコード
try:
if len(generated_output) > 0:
generated_text = tokenizer.decode(generated_output[0], skip_special_tokens=True)
else:
return "Error: モデルからの出力がありません"
except Exception as e:
return f"Error in decoding: {str(e)}"
return f"Generated text: {generated_text}"
# Gradioインターフェース
demo = gr.Interface(fn=text_to_speech,
inputs=["text", gr.Dropdown(["Anneli", "Amitaro"]), gr.Dropdown(["Neutral", "Happy", "Sad", "Angry"])],
outputs="text")
# Gradioアプリを起動
demo.launch() # share=True を削除