OmniAICreator's picture
Update app.py
69ee1b6 verified
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import librosa
import gradio as gr
from snac import SNAC
import re
orpheus_model_id = 'NandemoGHS/Galgame-Orpheus-3B'
tokenizer = AutoTokenizer.from_pretrained(orpheus_model_id)
model = AutoModelForCausalLM.from_pretrained(
orpheus_model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model.eval().cuda()
snac_model_id = 'hubertsiuzdak/snac_24khz'
snac_model = SNAC.from_pretrained(snac_model_id)
snac_model.eval().cuda()
whisper_turbo_pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device='cuda',
)
SOT_ID = 128000 # Start of Text (Not used)
EOT_ID = 128009 # End of Text
SOS_ID = 128257 # Start of Speech
EOS_ID = 128258 # End of Speech
SOH_ID = 128259 # Start of Human
EOH_ID = 128260 # End of Human
SOA_ID = 128261 # Start of AI
EOA_ID = 128262 # End of AI
REPLACE_MAP: dict[str, str] = {
r"\t": "",
r"\[n\]": "",
r" ": "",
r" ": "",
r"[;▼♀♂《》≪≫①②③④⑤⑥]": "",
r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "",
r"[\uff5e\u301C]": "ー",
r"?": "?",
r"!": "!",
r"[●◯〇]": "○",
r"♥": "♡",
}
FULLWIDTH_ALPHA_TO_HALFWIDTH = str.maketrans(
{
chr(full): chr(half)
for full, half in zip(
list(range(0xFF21, 0xFF3B)) + list(range(0xFF41, 0xFF5B)),
list(range(0x41, 0x5B)) + list(range(0x61, 0x7B)),
)
}
)
HALFWIDTH_KATAKANA_TO_FULLWIDTH = str.maketrans(
{
chr(half): chr(full)
for half, full in zip(range(0xFF61, 0xFF9F), range(0x30A1, 0x30FB))
}
)
FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans(
{
chr(full): chr(half)
for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A))
}
)
INVALID_PATTERN = re.compile(
r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
r"\u0041-\u005A\u0061-\u007A"
r"\u0030-\u0039"
r"。、!?…♪♡○]"
)
def normalize(text: str) -> str:
for pattern, replacement in REPLACE_MAP.items():
text = re.sub(pattern, replacement, text)
text = text.translate(FULLWIDTH_ALPHA_TO_HALFWIDTH)
text = text.translate(FULLWIDTH_DIGITS_TO_HALFWIDTH)
text = text.translate(HALFWIDTH_KATAKANA_TO_FULLWIDTH)
text = re.sub(r"…{3,}", "……", text)
def replace_special_chars(match):
seq = match.group(0)
return seq[0] if len(set(seq)) == 1 else seq[0] + seq[-1]
return text
def tokenize_audio(waveform):
waveform = waveform.unsqueeze(0)
with torch.inference_mode():
codes = snac_model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def redistribute_codes(code_list):
new_length = (len(code_list) // 7) * 7
if new_length == 0:
return None
code_list = code_list[:new_length]
layer_1 = []
layer_2 = []
layer_3 = []
for i in range((len(code_list)+1)//7):
layer_1.append(code_list[7*i])
layer_2.append(code_list[7*i+1]-4096)
layer_3.append(code_list[7*i+2]-(2*4096))
layer_3.append(code_list[7*i+3]-(3*4096))
layer_2.append(code_list[7*i+4]-(4*4096))
layer_3.append(code_list[7*i+5]-(5*4096))
layer_3.append(code_list[7*i+6]-(6*4096))
codes = [
torch.tensor(layer_1).unsqueeze(0),
torch.tensor(layer_2).unsqueeze(0),
torch.tensor(layer_3).unsqueeze(0)
]
print(codes)
codes = [c.cuda() for c in codes]
with torch.no_grad():
audio_hat = snac_model.decode(codes)
return audio_hat
@spaces.GPU(duration=60)
def infer(sample_audio_path, target_text, temperature, top_p, repetition_penalty, progress=gr.Progress()):
if not target_text or not target_text.strip():
gr.Warning("Please input text to generate audio.")
return None, None
if len(target_text) > 300:
gr.Warning("Text is too long. Please keep it under 300 characters.")
target_text = target_text[:300]
target_text = normalize(target_text)
with torch.no_grad():
if sample_audio_path:
progress(0, 'Loading and trimming audio...')
audio_array, sample_rate = librosa.load(sample_audio_path, sr=24000)
if len(audio_array) / sample_rate > 15:
gr.Warning("Trimming audio to first 15secs.")
num_samples_to_keep = int(sample_rate * 15)
audio_array = audio_array[:num_samples_to_keep]
prompt_wav = torch.from_numpy(audio_array).unsqueeze(0)
prompt_wav = prompt_wav.to(dtype=torch.float32)
progress(0.2, 'Transcribing reference audio...')
prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
progress(0.4, 'Transcribed! Encoding audio...')
# Encode the prompt wav
voice_tokens = tokenize_audio(prompt_wav)
ref_text_ids = tokenizer(prompt_text, return_tensors="pt").input_ids[0].tolist()
prompt_ids = (
[SOH_ID]
+ ref_text_ids
+ [EOT_ID]
+ [EOH_ID]
+ [SOA_ID]
+ [SOS_ID]
+ voice_tokens
+ [EOS_ID]
+ [EOA_ID]
)
else:
prompt_ids = []
progress(0.6, "Generating audio...")
target_text_ids = tokenizer(target_text, return_tensors="pt").input_ids[0].tolist()
prompt_ids.extend([SOH_ID])
prompt_ids.extend(target_text_ids)
prompt_ids.extend([EOT_ID])
prompt_ids.extend([EOH_ID])
prompt_ids.extend([SOA_ID])
prompt_ids.extend([SOS_ID])
print(prompt_ids)
input_ids = torch.tensor([prompt_ids], dtype=torch.int64).cuda()
# Generate the speech autoregressively
outputs = model.generate(
input_ids,
max_new_tokens=2048,
eos_token_id=EOS_ID,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
generated_ids = outputs[0].tolist()
print(generated_ids)
progress(0.8, "Decoding generated audio...")
try:
last_sos_idx = len(generated_ids) - 1 - generated_ids[::-1].index(SOS_ID)
speech_tokens = generated_ids[last_sos_idx + 1:]
except ValueError:
gr.Error("Audio generation failed: Could not find end of header token.")
return None, None
if EOS_ID in speech_tokens:
speech_tokens = speech_tokens[:speech_tokens.index(EOS_ID)]
if not speech_tokens:
gr.Error("Audio generation failed: No speech tokens were generated.")
return None, None
base_offset = 128266
adjusted_tokens = [token - base_offset for token in speech_tokens if token >= base_offset]
gen_wav_tensor = redistribute_codes(adjusted_tokens)
if gen_wav_tensor is None:
gr.Error("Audio decoding failed.")
return None, None
gen_wav = gen_wav_tensor.cpu().squeeze()
progress(1, 'Synthesized!')
return (24000, gen_wav.numpy())
with gr.Blocks() as app_tts:
gr.Markdown("# Galgame Orpheus 3B")
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
with gr.Row():
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, step=0.05, label="Temperature")
top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="Top-p")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=1.5, value=1.1, step=0.05, label="Repetition Penalty")
generate_btn = gr.Button("Synthesize", variant="primary")
audio_output = gr.Audio(label="Synthesized Audio")
generate_btn.click(
infer,
inputs=[
ref_audio_input,
gen_text_input,
temperature_slider,
top_p_slider,
repetition_penalty_slider,
],
outputs=[audio_output],
)
with gr.Blocks() as app_credits:
gr.Markdown("""
# Credits
* [canopyai](https://github.com/canopyai) for the original [repo](https://github.com/canopyai/Orpheus-TTS)
* [mrfakename](https://huggingface.co/mrfakename) for the [gradio demo code](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
* [SunderAli17](https://huggingface.co/SunderAli17) for the [gradio demo code](https://huggingface.co/spaces/SunderAli17/llasa-3b-tts)
""")
with gr.Blocks() as app:
gr.Markdown(
"""
# Galgame Orpheus 3B
This is a local web UI for Galgame Orpheus 3B TTS model. You can check out the model [here](https://huggingface.co/NandemoGHS/Galgame-Orpheus-3B).
The model is fine-tuned by Japanese audio data.
If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
"""
)
gr.TabbedInterface([app_tts], ["TTS"])
app.launch()