import argparse import re import time import pandas import numpy as np from tqdm import tqdm import random import os import gradio as gr import json from utils import combine_audio, save_audio, batch_split, normalize_zh from tts_model import load_chat_tts_model, clear_cuda_cache, deterministic, generate_audio_for_seed import spaces parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX") parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.") parser.add_argument("--local_path", type=str, help="Path to local model if source is 'local'.") parser.add_argument("--share", default=False, action="store_true", help="Share the server publicly.") args = parser.parse_args() # 存放音频种子文件的目录 SAVED_DIR = "saved_seeds" # mkdir if not os.path.exists(SAVED_DIR): os.makedirs(SAVED_DIR) # 文件路径 SAVED_SEEDS_FILE = os.path.join(SAVED_DIR, "saved_seeds.json") # 选中的种子index SELECTED_SEED_INDEX = -1 # 初始化JSON文件 if not os.path.exists(SAVED_SEEDS_FILE): with open(SAVED_SEEDS_FILE, "w") as f: f.write("[]") chat = load_chat_tts_model(source=args.source, local_path=args.local_path) # chat = None # chat = load_chat_tts_model(source="local", local_path="models") # 抽卡的最大数量 max_audio_components = 10 # print("loading ChatTTS model...") # chat = ChatTTS.Chat() # chat.load_models(source="local", local_path="models") # torch.cuda.empty_cache() # 加载 def load_seeds(): with open(SAVED_SEEDS_FILE, "r") as f: global saved_seeds saved_seeds = json.load(f) return saved_seeds def display_seeds(): seeds = load_seeds() # 转换为 List[List] 的形式 return [[i, s['seed'], s['name']] for i, s in enumerate(seeds)] saved_seeds = load_seeds() num_seeds_default = 2 def save_seeds(): global saved_seeds with open(SAVED_SEEDS_FILE, "w") as f: json.dump(saved_seeds, f) saved_seeds = load_seeds() # 添加 seed def add_seed(seed, name, save=True): for s in saved_seeds: if s['seed'] == seed: return False saved_seeds.append({ 'seed': seed, 'name': name }) if save: save_seeds() # 修改 seed def modify_seed(seed, name, save=True): for s in saved_seeds: if s['seed'] == seed: s['name'] = name if save: save_seeds() return True return False def delete_seed(seed, save=True): for s in saved_seeds: if s['seed'] == seed: saved_seeds.remove(s) if save: save_seeds() return True return False @spaces.GPU def generate_seeds(num_seeds, texts, tq): """ 生成随机音频种子并保存 :param num_seeds: :param texts: :param tq: :return: """ seeds = [] sample_rate = 24000 # 按行分割文本 并正则化数字和标点字符 texts = [normalize_zh(_) for _ in texts.split('\n') if _.strip()] print(texts) if not tq: tq = tqdm for _ in tq(range(num_seeds), desc=f"随机音色生成中..."): seed = np.random.randint(0, 9999) filename = generate_audio_for_seed(chat, seed, texts, 1, 5, "[oral_2][laugh_0][break_4]", 0.3, 0.7, 20) seeds.append((filename, seed)) clear_cuda_cache() return seeds # 保存选定的音频种子 def do_save_seed(seed): seed = seed.replace('保存种子 ', '').strip() if not seed: return add_seed(int(seed), seed) gr.Info(f"Seed {seed} has been saved.") def do_save_seeds(seeds): assert isinstance(seeds, pandas.DataFrame) seeds = seeds.drop(columns=['Index']) # 将 DataFrame 转换为字典列表格式,并将键转换为小写 result = [{k.lower(): v for k, v in row.items()} for row in seeds.to_dict(orient='records')] print(result) if result: global saved_seeds saved_seeds = result save_seeds() gr.Info(f"Seeds have been saved.") return result def do_delete_seed(val): # 从 val 匹配 [(\d+)] 获取index index = re.search(r'\[(\d+)\]', val) global saved_seeds if index: index = int(index.group(1)) seed = saved_seeds[index]['seed'] delete_seed(seed) gr.Info(f"Seed {seed} has been deleted.") return display_seeds() def seed_change_btn(): global SELECTED_SEED_INDEX if SELECTED_SEED_INDEX == -1: return '删除' return f'删除 idx=[{SELECTED_SEED_INDEX[0]}]' def audio_interface(num_seeds, texts, progress=gr.Progress()): """ 生成音频 :param num_seeds: :param texts: :param progress: :return: """ seeds = generate_seeds(num_seeds, texts, progress.tqdm) wavs = [_[0] for _ in seeds] seeds = [f"保存种子 {_[1]}" for _ in seeds] # 不足的部分 all_wavs = wavs + [None] * (max_audio_components - len(wavs)) all_seeds = seeds + [''] * (max_audio_components - len(seeds)) return [item for pair in zip(all_wavs, all_seeds) for item in pair] def audio_interface_empty(num_seeds, texts, progress=gr.Progress(track_tqdm=True)): return [None, ""] * max_audio_components def update_audio_components(slider_value): # 根据滑块的值更新 Audio 组件的可见性 k = int(slider_value) audios = [gr.Audio(visible=True)] * k + [gr.Audio(visible=False)] * (max_audio_components - k) tbs = [gr.Textbox(visible=True)] * k + [gr.Textbox(visible=False)] * (max_audio_components - k) print(f'k={k}, audios={len(audios)}') return [item for pair in zip(audios, tbs) for item in pair] def seed_change(evt: gr.SelectData): # print(f"You selected {evt.value} at {evt.index} from {evt.target}") global SELECTED_SEED_INDEX SELECTED_SEED_INDEX = evt.index return evt.index @spaces.GPU def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P, top_K, progress=gr.Progress()): from tts_model import generate_audio_for_seed from utils import split_text if seed in [0, -1, None]: seed = random.randint(1, 9999) content = '' if os.path.isfile(text_file): content = "" elif isinstance(text_file, str): content = text_file texts = split_text(content, min_length=min_length) print(texts) if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7: raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range") refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" try: output_files = generate_audio_for_seed(chat, seed, texts, batch_size, speed, refine_text_prompt, temperature, top_P, top_K, progress.tqdm) return output_files except Exception as e: return str(e) def generate_seed(): new_seed = random.randint(1, 9999) return { "__type__": "update", "value": new_seed } def update_label(text): word_count = len(text) return gr.update(label=f"朗读文本(字数: {word_count})") with gr.Blocks() as demo: with gr.Tab("音色抽卡"): with gr.Row(): with gr.Column(scale=1): texts = [ "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。", "我是一个充满活力的人,喜欢运动,喜欢旅行,喜欢尝试新鲜事物。我喜欢挑战自己,不断突破自己的极限,让自己变得更加强大。", "罗森宣布将于7月24日退市,在华门店超6000家!", ] # gr.Markdown("### 随机音色抽卡") gr.Markdown(""" 在相同的 seed 和 温度等参数下,音色具有一定的一致性。点击下面的“随机音色生成”按钮将生成多个 seed。找到满意的音色后,点击音频下方“保存”按钮。 **注意:不同机器使用相同种子生成的音频音色可能不同,同一机器使用相同种子多次生成的音频音色也可能变化。** """) input_text = gr.Textbox(label="测试文本", info="**每行文本**都会生成一段音频,最终输出的音频是将这些音频段合成后的结果。建议使用**多行文本**进行测试,以确保音色稳定性。", lines=4, placeholder="请输入文本...", value='\n'.join(texts)) num_seeds = gr.Slider(minimum=1, maximum=max_audio_components, step=1, label="seed生成数量", value=num_seeds_default) generate_button = gr.Button("随机音色抽卡🎲", variant="primary") # 保存的种子 gr.Markdown("### 种子管理界面") seed_list = gr.DataFrame( label="种子列表", headers=["Index", "Seed", "Name"], datatype=["number", "number", "str"], interactive=True, col_count=(3, "fixed"), value=display_seeds() ) with gr.Row(): refresh_button = gr.Button("刷新") save_button = gr.Button("保存") del_button = gr.Button("删除") # 绑定按钮和函数 refresh_button.click(display_seeds, outputs=seed_list) seed_list.select(seed_change).success(seed_change_btn, outputs=[del_button]) save_button.click(do_save_seeds, inputs=[seed_list], outputs=None) del_button.click(do_delete_seed, inputs=del_button, outputs=seed_list) with gr.Column(scale=1): audio_components = [] for i in range(max_audio_components): visible = i < num_seeds_default a = gr.Audio(f"Audio {i}", visible=visible) t = gr.Button(f"Seed", visible=visible) t.click(do_save_seed, inputs=[t], outputs=None).success(display_seeds, outputs=seed_list) audio_components.append(a) audio_components.append(t) num_seeds.change(update_audio_components, inputs=num_seeds, outputs=audio_components) # output = gr.Column() # audio = gr.Audio(label="Output Audio") generate_button.click( audio_interface_empty, inputs=[num_seeds, input_text], outputs=audio_components ).success(audio_interface, inputs=[num_seeds, input_text], outputs=audio_components) with gr.Tab("长音频生成"): with gr.Row(): with gr.Column(): gr.Markdown("### 文本") # gr.Markdown("请上传要转换的文本文件(.txt 格式)。") # text_file_input = gr.File(label="文本文件", file_types=[".txt"]) default_text = "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。" text_file_input = gr.Textbox(label=f"朗读文本(字数: {len(default_text)})", lines=4, placeholder="Please Input Text...", value=default_text) # 当文本框内容发生变化时调用 update_label 函数 text_file_input.change(update_label, inputs=text_file_input, outputs=text_file_input) with gr.Column(): gr.Markdown("### 配置参数") gr.Markdown("根据需要配置以下参数来生成音频。") with gr.Row(): num_seeds_input = gr.Number(label="生成音频的数量", value=1, precision=0, visible=False) seed_input = gr.Number(label="指定种子(留空则随机)", value=None, precision=0) generate_audio_seed = gr.Button("\U0001F3B2") with gr.Row(): speed_input = gr.Slider(label="语速", minimum=1, maximum=10, value=5, step=1) oral_input = gr.Slider(label="口语化", minimum=0, maximum=9, value=2, step=1) laugh_input = gr.Slider(label="笑声", minimum=0, maximum=2, value=0, step=1) bk_input = gr.Slider(label="停顿", minimum=0, maximum=7, value=4, step=1) # gr.Markdown("### 文本参数") with gr.Row(): min_length_input = gr.Number(label="文本分段长度", info="大于这个数值进行分段", value=120, precision=0) batch_size_input = gr.Number(label="批大小", info="同时处理的批次 越高越快 太高爆显存", value=5, precision=0) with gr.Accordion("其他参数", open=False): with gr.Row(): # 温度 top_P top_K temperature_input = gr.Slider(label="温度", minimum=0.01, maximum=1.0, step=0.01, value=0.3) top_P_input = gr.Slider(label="top_P", minimum=0.1, maximum=0.9, step=0.05, value=0.7) top_K_input = gr.Slider(label="top_K", minimum=1, maximum=20, step=1, value=20) # reset 按钮 reset_button = gr.Button("重置") with gr.Row(): generate_button = gr.Button("生成音频", variant="primary") with gr.Row(): output_audio = gr.Audio(label="生成的音频文件") generate_audio_seed.click(generate_seed, inputs=[], outputs=seed_input) # 重置按钮 重置温度等参数 reset_button.click( lambda: [0.3, 0.7, 20], inputs=None, outputs=[temperature_input, top_P_input, top_K_input] ) generate_button.click( fn=generate_tts_audio, inputs=[ text_file_input, num_seeds_input, seed_input, speed_input, oral_input, laugh_input, bk_input, min_length_input, batch_size_input, temperature_input, top_P_input, top_K_input, ], outputs=[output_audio] ) with gr.Tab("角色扮演"): def txt_2_script(text): lines = text.split("\n") data = [] for line in lines: if not line.strip(): continue parts = line.split("::") if len(parts) != 2: continue data.append({ "character": parts[0], "txt": parts[1] }) return data def script_2_txt(data): assert isinstance(data, list) result = [] for item in data: txt = item['txt'].replace('\n', ' ') result.append(f"{item['character']}::{txt}") return "\n".join(result) def get_characters(lines): assert isinstance(lines, list) characters = list([_["character"] for _ in lines]) unique_characters = list(dict.fromkeys(characters)) print([[character, 0] for character in unique_characters]) return [[character, 0] for character in unique_characters] def get_txt_characters(text): return get_characters(txt_2_script(text)) def llm_change(model): llm_setting = { "gpt-3.5-turbo-0125": ["https://api.openai.com/v1"], "gpt-4o": ["https://api.openai.com/v1"], "deepseek-chat": ["https://api.deepseek.com"], "yi-large": ["https://api.lingyiwanwu.com/v1"] } if model in llm_setting: return llm_setting[model][0] else: gr.Error("Model not found.") return None def ai_script_generate(model, api_base, api_key, text, progress=gr.Progress(track_tqdm=True)): from llm_utils import llm_operation from config import LLM_PROMPT scripts = llm_operation(api_base, api_key, model, LLM_PROMPT, text, required_keys=["txt", "character"]) return script_2_txt(scripts) @spaces.GPU def generate_script_audio(text, models_seeds, progress=gr.Progress()): scripts = txt_2_script(text) # 将文本转换为剧本 characters = get_characters(scripts) # 从剧本中提取角色 # import pandas as pd from collections import defaultdict import itertools from tts_model import generate_audio_for_seed from utils import combine_audio, save_audio, normalize_zh from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P assert isinstance(models_seeds, pd.DataFrame) # 批次处理函数 def batch(iterable, batch_size): it = iter(iterable) while True: batch = list(itertools.islice(it, batch_size)) if not batch: break yield batch models_seeds = models_seeds.to_dict(orient='records') # 检查每个角色是否都有对应的种子 for character, _ in characters: if not any(seed['Character'] == character for seed in models_seeds): gr.Info(f"角色 {character} 没有种子,请先设置种子。") return None # 将角色和对应的种子存为字典 character_seeds = {character: [seed['Seed'] for seed in models_seeds if seed['Character'] == character][0] for character, _ in characters} # todo 可以自定义 最好是按角色 refine_text_prompt = "[oral_2][laugh_0][break_4]" all_wavs = [] # 按角色分组,加速推理 grouped_lines = defaultdict(list) for line in scripts: grouped_lines[line["character"]].append(line) batch_results = {character: [] for character in grouped_lines} batch_size = 5 # 设置批次大小 # 按角色处理 for character, lines in progress.tqdm(grouped_lines.items(), desc="生成剧本音频"): seed = character_seeds.get(character, 0) # 按批次处理 for batch_lines in batch(lines, batch_size): texts = [normalize_zh(line["txt"]) for line in batch_lines] print(f"seed={seed} t={texts} c={character}") wavs = generate_audio_for_seed(chat, int(seed), texts, DEFAULT_BATCH_SIZE, DEFAULT_SPEED, refine_text_prompt, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_TOP_K, skip_save=True) # 批量处理文本 batch_results[character].extend(wavs) # 转换回原排序 for line in scripts: character = line["character"] all_wavs.append(batch_results[character].pop(0)) # 合成所有音频 audio = combine_audio(all_wavs) fname = f"script_{int(time.time())}.wav" save_audio(fname, audio) return fname script_example = { "lines": [{ "txt": "在一个风和日丽的下午,小红帽准备去森林里看望她的奶奶。", "character": "旁白" }, { "txt": "小红帽说", "character": "旁白" }, { "txt": "我要给奶奶带点好吃的。", "character": "年轻女性" }, { "txt": "在森林里,小红帽遇到了狡猾的大灰狼。", "character": "旁白" }, { "txt": "大灰狼说", "character": "旁白" }, { "txt": "小红帽,你的篮子里装的是什么?", "character": "中年男性" }, { "txt": "小红帽回答", "character": "旁白" }, { "txt": "这是给奶奶的蛋糕和果酱。", "character": "年轻女性" }, { "txt": "大灰狼心生一计,决定先到奶奶家等待小红帽。", "character": "旁白" }, { "txt": "当小红帽到达奶奶家时,她发现大灰狼伪装成了奶奶。", "character": "旁白" }, { "txt": "小红帽疑惑地问", "character": "旁白" }, { "txt": "奶奶,你的耳朵怎么这么尖?", "character": "年轻女性" }, { "txt": "大灰狼慌张地回答", "character": "旁白" }, { "txt": "哦,这是为了更好地听你说话。", "character": "中年男性" }, { "txt": "小红帽越发觉得不对劲,最终发现了大灰狼的诡计。", "character": "旁白" }, { "txt": "她大声呼救,森林里的猎人听到后赶来救了她和奶奶。", "character": "旁白" }, { "txt": "从此,小红帽再也没有单独进入森林,而是和家人一起去看望奶奶。", "character": "旁白" }] } ai_text_default = "武侠小说《花木兰大战周树人》 要符合人物背景" with gr.Row(equal_height=True): with gr.Column(scale=2): gr.Markdown("### AI脚本") gr.Markdown(""" 为确保生成效果稳定,仅支持与 GPT-4 相当的模型,推荐使用 4o yi-large deepseek。 如果没有反应,请检查日志中的错误信息。如果提示格式错误,请重试几次。国内模型可能会受到风控影响,建议更换文本内容后再试。 申请渠道(免费额度): - [https://platform.deepseek.com/](https://platform.deepseek.com/) - [https://platform.lingyiwanwu.com/](https://platform.lingyiwanwu.com/) """) # 申请渠道 with gr.Row(equal_height=True): # 选择模型 只有 gpt4o deepseek-chat yi-large 三个选项 model_select = gr.Radio(label="选择模型", choices=["gpt-4o", "deepseek-chat", "yi-large"], value="gpt-4o", interactive=True, ) with gr.Row(equal_height=True): openai_api_base_input = gr.Textbox(label="OpenAI API Base URL", placeholder="请输入API Base URL", value=r"https://api.openai.com/v1") openai_api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="请输入API Key", value="sk-xxxxxxx") # AI提示词 ai_text_input = gr.Textbox(label="剧情简介或者一段故事", placeholder="请输入文本...", lines=2, value=ai_text_default) # 生成脚本的按钮 ai_script_generate_button = gr.Button("AI脚本生成") with gr.Column(scale=3): gr.Markdown("### 脚本") gr.Markdown( "脚本可以手工编写也可以从右侧的AI脚本生成按钮生成。脚本格式 **角色::文本** 一行为一句” 注意是::") script_text = "\n".join( [f"{_.get('character', '')}::{_.get('txt', '')}" for _ in script_example['lines']]) script_text_input = gr.Textbox(label="脚本格式 “角色::文本 一行为一句” 注意是::", placeholder="请输入文本...", lines=12, value=script_text) script_translate_button = gr.Button("步骤①:提取角色") with gr.Column(scale=1): gr.Markdown("### 角色种子") # DataFrame 来存放转换后的脚本 # 默认数据 default_data = [ ["旁白", 2222], ["年轻女性", 2], ["中年男性", 2424] ] script_data = gr.DataFrame( value=default_data, label="角色对应的音色种子,从抽卡那获取", headers=["Character", "Seed"], datatype=["str", "number"], interactive=True, col_count=(2, "fixed"), ) # 生视频按钮 script_generate_audio = gr.Button("步骤②:生成音频") # 输出的脚本音频 script_audio = gr.Audio(label="AI生成的音频", interactive=False) # 脚本相关事件 # 脚本转换 script_translate_button.click( get_txt_characters, inputs=[script_text_input], outputs=script_data ) # 处理模型切换 model_select.change( llm_change, inputs=[model_select], outputs=[openai_api_base_input] ) # AI脚本生成 ai_script_generate_button.click( ai_script_generate, inputs=[model_select, openai_api_base_input, openai_api_key_input, ai_text_input], outputs=[script_text_input] ) # 音频生成 script_generate_audio.click( generate_script_audio, inputs=[script_text_input, script_data], outputs=[script_audio] ) demo.launch(share=args.share)