Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,363 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            gpt_path = os.environ.get(
         | 
| 4 | 
            +
                "gpt_path", "models/Carol/Carol-e15.ckpt"
         | 
| 5 | 
            +
            )
         | 
| 6 | 
            +
            sovits_path = os.environ.get("sovits_path", "models/Carol/Carol_e40_s2160.pth")
         | 
| 7 | 
            +
            cnhubert_base_path = os.environ.get(
         | 
| 8 | 
            +
                "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
            bert_path = os.environ.get(
         | 
| 11 | 
            +
                "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
            infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
         | 
| 14 | 
            +
            infer_ttswebui = int(infer_ttswebui)
         | 
| 15 | 
            +
            if "_CUDA_VISIBLE_DEVICES" in os.environ:
         | 
| 16 | 
            +
                os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
         | 
| 17 | 
            +
            is_half = eval(os.environ.get("is_half", "True"))
         | 
| 18 | 
            +
            import gradio as gr
         | 
| 19 | 
            +
            from transformers import AutoModelForMaskedLM, AutoTokenizer
         | 
| 20 | 
            +
            import numpy as np
         | 
| 21 | 
            +
            import librosa,torch
         | 
| 22 | 
            +
            from feature_extractor import cnhubert
         | 
| 23 | 
            +
            cnhubert.cnhubert_base_path=cnhubert_base_path
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from module.models import SynthesizerTrn
         | 
| 26 | 
            +
            from AR.models.t2s_lightning_module import Text2SemanticLightningModule
         | 
| 27 | 
            +
            from text import cleaned_text_to_sequence
         | 
| 28 | 
            +
            from text.cleaner import clean_text
         | 
| 29 | 
            +
            from time import time as ttime
         | 
| 30 | 
            +
            from module.mel_processing import spectrogram_torch
         | 
| 31 | 
            +
            from my_utils import load_audio
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            device = "cuda"
         | 
| 34 | 
            +
            tokenizer = AutoTokenizer.from_pretrained(bert_path)
         | 
| 35 | 
            +
            bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
         | 
| 36 | 
            +
            if is_half == True:
         | 
| 37 | 
            +
                bert_model = bert_model.half().to(device)
         | 
| 38 | 
            +
            else:
         | 
| 39 | 
            +
                bert_model = bert_model.to(device)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            # bert_model=bert_model.to(device)
         | 
| 43 | 
            +
            def get_bert_feature(text, word2ph):
         | 
| 44 | 
            +
                with torch.no_grad():
         | 
| 45 | 
            +
                    inputs = tokenizer(text, return_tensors="pt")
         | 
| 46 | 
            +
                    for i in inputs:
         | 
| 47 | 
            +
                        inputs[i] = inputs[i].to(device)  #####输入是long不用管精度问题,精度随bert_model
         | 
| 48 | 
            +
                    res = bert_model(**inputs, output_hidden_states=True)
         | 
| 49 | 
            +
                    res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
         | 
| 50 | 
            +
                assert len(word2ph) == len(text)
         | 
| 51 | 
            +
                phone_level_feature = []
         | 
| 52 | 
            +
                for i in range(len(word2ph)):
         | 
| 53 | 
            +
                    repeat_feature = res[i].repeat(word2ph[i], 1)
         | 
| 54 | 
            +
                    phone_level_feature.append(repeat_feature)
         | 
| 55 | 
            +
                phone_level_feature = torch.cat(phone_level_feature, dim=0)
         | 
| 56 | 
            +
                # if(is_half==True):phone_level_feature=phone_level_feature.half()
         | 
| 57 | 
            +
                return phone_level_feature.T
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            n_semantic = 1024
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            dict_s2=torch.load(sovits_path,map_location="cpu")
         | 
| 63 | 
            +
            hps=dict_s2["config"]
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            class DictToAttrRecursive(dict):
         | 
| 66 | 
            +
                def __init__(self, input_dict):
         | 
| 67 | 
            +
                    super().__init__(input_dict)
         | 
| 68 | 
            +
                    for key, value in input_dict.items():
         | 
| 69 | 
            +
                        if isinstance(value, dict):
         | 
| 70 | 
            +
                            value = DictToAttrRecursive(value)
         | 
| 71 | 
            +
                        self[key] = value
         | 
| 72 | 
            +
                        setattr(self, key, value)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def __getattr__(self, item):
         | 
| 75 | 
            +
                    try:
         | 
| 76 | 
            +
                        return self[item]
         | 
| 77 | 
            +
                    except KeyError:
         | 
| 78 | 
            +
                        raise AttributeError(f"Attribute {item} not found")
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def __setattr__(self, key, value):
         | 
| 81 | 
            +
                    if isinstance(value, dict):
         | 
| 82 | 
            +
                        value = DictToAttrRecursive(value)
         | 
| 83 | 
            +
                    super(DictToAttrRecursive, self).__setitem__(key, value)
         | 
| 84 | 
            +
                    super().__setattr__(key, value)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def __delattr__(self, item):
         | 
| 87 | 
            +
                    try:
         | 
| 88 | 
            +
                        del self[item]
         | 
| 89 | 
            +
                    except KeyError:
         | 
| 90 | 
            +
                        raise AttributeError(f"Attribute {item} not found")
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            hps = DictToAttrRecursive(hps)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            hps.model.semantic_frame_rate = "25hz"
         | 
| 96 | 
            +
            dict_s1 = torch.load(gpt_path, map_location="cpu")
         | 
| 97 | 
            +
            config = dict_s1["config"]
         | 
| 98 | 
            +
            ssl_model = cnhubert.get_model()
         | 
| 99 | 
            +
            if is_half == True:
         | 
| 100 | 
            +
                ssl_model = ssl_model.half().to(device)
         | 
| 101 | 
            +
            else:
         | 
| 102 | 
            +
                ssl_model = ssl_model.to(device)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            vq_model = SynthesizerTrn(
         | 
| 105 | 
            +
                hps.data.filter_length // 2 + 1,
         | 
| 106 | 
            +
                hps.train.segment_size // hps.data.hop_length,
         | 
| 107 | 
            +
                n_speakers=hps.data.n_speakers,
         | 
| 108 | 
            +
                **hps.model
         | 
| 109 | 
            +
            )
         | 
| 110 | 
            +
            if is_half == True:
         | 
| 111 | 
            +
                vq_model = vq_model.half().to(device)
         | 
| 112 | 
            +
            else:
         | 
| 113 | 
            +
                vq_model = vq_model.to(device)
         | 
| 114 | 
            +
            vq_model.eval()
         | 
| 115 | 
            +
            print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
         | 
| 116 | 
            +
            hz = 50
         | 
| 117 | 
            +
            max_sec = config["data"]["max_sec"]
         | 
| 118 | 
            +
            # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
         | 
| 119 | 
            +
            t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
         | 
| 120 | 
            +
            t2s_model.load_state_dict(dict_s1["weight"])
         | 
| 121 | 
            +
            if is_half == True:
         | 
| 122 | 
            +
                t2s_model = t2s_model.half()
         | 
| 123 | 
            +
            t2s_model = t2s_model.to(device)
         | 
| 124 | 
            +
            t2s_model.eval()
         | 
| 125 | 
            +
            total = sum([param.nelement() for param in t2s_model.parameters()])
         | 
| 126 | 
            +
            print("Number of parameter: %.2fM" % (total / 1e6))
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            def get_spepc(hps, filename):
         | 
| 130 | 
            +
                audio = load_audio(filename, int(hps.data.sampling_rate))
         | 
| 131 | 
            +
                audio = torch.FloatTensor(audio)
         | 
| 132 | 
            +
                audio_norm = audio
         | 
| 133 | 
            +
                audio_norm = audio_norm.unsqueeze(0)
         | 
| 134 | 
            +
                spec = spectrogram_torch(
         | 
| 135 | 
            +
                    audio_norm,
         | 
| 136 | 
            +
                    hps.data.filter_length,
         | 
| 137 | 
            +
                    hps.data.sampling_rate,
         | 
| 138 | 
            +
                    hps.data.hop_length,
         | 
| 139 | 
            +
                    hps.data.win_length,
         | 
| 140 | 
            +
                    center=False,
         | 
| 141 | 
            +
                )
         | 
| 142 | 
            +
                return spec
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
         | 
| 149 | 
            +
                t0 = ttime()
         | 
| 150 | 
            +
                prompt_text = prompt_text.strip("\n")
         | 
| 151 | 
            +
                prompt_language, text = prompt_language, text.strip("\n")
         | 
| 152 | 
            +
                with torch.no_grad():
         | 
| 153 | 
            +
                    wav16k, sr = librosa.load(ref_wav_path, sr=16000)  # 派蒙
         | 
| 154 | 
            +
                    wav16k = torch.from_numpy(wav16k)
         | 
| 155 | 
            +
                    if is_half == True:
         | 
| 156 | 
            +
                        wav16k = wav16k.half().to(device)
         | 
| 157 | 
            +
                    else:
         | 
| 158 | 
            +
                        wav16k = wav16k.to(device)
         | 
| 159 | 
            +
                    ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
         | 
| 160 | 
            +
                        "last_hidden_state"
         | 
| 161 | 
            +
                    ].transpose(
         | 
| 162 | 
            +
                        1, 2
         | 
| 163 | 
            +
                    )  # .float()
         | 
| 164 | 
            +
                    codes = vq_model.extract_latent(ssl_content)
         | 
| 165 | 
            +
                    prompt_semantic = codes[0, 0]
         | 
| 166 | 
            +
                t1 = ttime()
         | 
| 167 | 
            +
                prompt_language = dict_language[prompt_language]
         | 
| 168 | 
            +
                text_language = dict_language[text_language]
         | 
| 169 | 
            +
                phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
         | 
| 170 | 
            +
                phones1 = cleaned_text_to_sequence(phones1)
         | 
| 171 | 
            +
                texts = text.split("\n")
         | 
| 172 | 
            +
                audio_opt = []
         | 
| 173 | 
            +
                zero_wav = np.zeros(
         | 
| 174 | 
            +
                    int(hps.data.sampling_rate * 0.3),
         | 
| 175 | 
            +
                    dtype=np.float16 if is_half == True else np.float32,
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
                for text in texts:
         | 
| 178 | 
            +
                    # 解决输入目标文本的空行导致报错的问题
         | 
| 179 | 
            +
                    if (len(text.strip()) == 0):
         | 
| 180 | 
            +
                        continue
         | 
| 181 | 
            +
                    phones2, word2ph2, norm_text2 = clean_text(text, text_language)
         | 
| 182 | 
            +
                    phones2 = cleaned_text_to_sequence(phones2)
         | 
| 183 | 
            +
                    if prompt_language == "zh":
         | 
| 184 | 
            +
                        bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
         | 
| 185 | 
            +
                    else:
         | 
| 186 | 
            +
                        bert1 = torch.zeros(
         | 
| 187 | 
            +
                            (1024, len(phones1)),
         | 
| 188 | 
            +
                            dtype=torch.float16 if is_half == True else torch.float32,
         | 
| 189 | 
            +
                        ).to(device)
         | 
| 190 | 
            +
                    if text_language == "zh":
         | 
| 191 | 
            +
                        bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
         | 
| 192 | 
            +
                    else:
         | 
| 193 | 
            +
                        bert2 = torch.zeros((1024, len(phones2))).to(bert1)
         | 
| 194 | 
            +
                    bert = torch.cat([bert1, bert2], 1)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
         | 
| 197 | 
            +
                    bert = bert.to(device).unsqueeze(0)
         | 
| 198 | 
            +
                    all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
         | 
| 199 | 
            +
                    prompt = prompt_semantic.unsqueeze(0).to(device)
         | 
| 200 | 
            +
                    t2 = ttime()
         | 
| 201 | 
            +
                    with torch.no_grad():
         | 
| 202 | 
            +
                        # pred_semantic = t2s_model.model.infer(
         | 
| 203 | 
            +
                        pred_semantic, idx = t2s_model.model.infer_panel(
         | 
| 204 | 
            +
                            all_phoneme_ids,
         | 
| 205 | 
            +
                            all_phoneme_len,
         | 
| 206 | 
            +
                            prompt,
         | 
| 207 | 
            +
                            bert,
         | 
| 208 | 
            +
                            # prompt_phone_len=ph_offset,
         | 
| 209 | 
            +
                            top_k=config["inference"]["top_k"],
         | 
| 210 | 
            +
                            early_stop_num=hz * max_sec,
         | 
| 211 | 
            +
                        )
         | 
| 212 | 
            +
                    t3 = ttime()
         | 
| 213 | 
            +
                    # print(pred_semantic.shape,idx)
         | 
| 214 | 
            +
                    pred_semantic = pred_semantic[:, -idx:].unsqueeze(
         | 
| 215 | 
            +
                        0
         | 
| 216 | 
            +
                    )  # .unsqueeze(0)#mq要多unsqueeze一次
         | 
| 217 | 
            +
                    refer = get_spepc(hps, ref_wav_path)  # .to(device)
         | 
| 218 | 
            +
                    if is_half == True:
         | 
| 219 | 
            +
                        refer = refer.half().to(device)
         | 
| 220 | 
            +
                    else:
         | 
| 221 | 
            +
                        refer = refer.to(device)
         | 
| 222 | 
            +
                    # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
         | 
| 223 | 
            +
                    audio = (
         | 
| 224 | 
            +
                        vq_model.decode(
         | 
| 225 | 
            +
                            pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
         | 
| 226 | 
            +
                        )
         | 
| 227 | 
            +
                        .detach()
         | 
| 228 | 
            +
                        .cpu()
         | 
| 229 | 
            +
                        .numpy()[0, 0]
         | 
| 230 | 
            +
                    )  ###试试重建不带上prompt部分
         | 
| 231 | 
            +
                    audio_opt.append(audio)
         | 
| 232 | 
            +
                    audio_opt.append(zero_wav)
         | 
| 233 | 
            +
                    t4 = ttime()
         | 
| 234 | 
            +
                print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
         | 
| 235 | 
            +
                yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
         | 
| 236 | 
            +
                    np.int16
         | 
| 237 | 
            +
                )
         | 
| 238 | 
            +
             | 
| 239 | 
            +
             | 
| 240 | 
            +
            splits = {
         | 
| 241 | 
            +
                ",",
         | 
| 242 | 
            +
                "。",
         | 
| 243 | 
            +
                "?",
         | 
| 244 | 
            +
                "!",
         | 
| 245 | 
            +
                ",",
         | 
| 246 | 
            +
                ".",
         | 
| 247 | 
            +
                "?",
         | 
| 248 | 
            +
                "!",
         | 
| 249 | 
            +
                "~",
         | 
| 250 | 
            +
                ":",
         | 
| 251 | 
            +
                ":",
         | 
| 252 | 
            +
                "—",
         | 
| 253 | 
            +
                "…",
         | 
| 254 | 
            +
            }  # 不考虑省略号
         | 
| 255 | 
            +
             | 
| 256 | 
            +
             | 
| 257 | 
            +
            def split(todo_text):
         | 
| 258 | 
            +
                todo_text = todo_text.replace("……", "。").replace("——", ",")
         | 
| 259 | 
            +
                if todo_text[-1] not in splits:
         | 
| 260 | 
            +
                    todo_text += "。"
         | 
| 261 | 
            +
                i_split_head = i_split_tail = 0
         | 
| 262 | 
            +
                len_text = len(todo_text)
         | 
| 263 | 
            +
                todo_texts = []
         | 
| 264 | 
            +
                while 1:
         | 
| 265 | 
            +
                    if i_split_head >= len_text:
         | 
| 266 | 
            +
                        break  # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
         | 
| 267 | 
            +
                    if todo_text[i_split_head] in splits:
         | 
| 268 | 
            +
                        i_split_head += 1
         | 
| 269 | 
            +
                        todo_texts.append(todo_text[i_split_tail:i_split_head])
         | 
| 270 | 
            +
                        i_split_tail = i_split_head
         | 
| 271 | 
            +
                    else:
         | 
| 272 | 
            +
                        i_split_head += 1
         | 
| 273 | 
            +
                return todo_texts
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            def cut1(inp):
         | 
| 277 | 
            +
                inp = inp.strip("\n")
         | 
| 278 | 
            +
                inps = split(inp)
         | 
| 279 | 
            +
                split_idx = list(range(0, len(inps), 5))
         | 
| 280 | 
            +
                split_idx[-1] = None
         | 
| 281 | 
            +
                if len(split_idx) > 1:
         | 
| 282 | 
            +
                    opts = []
         | 
| 283 | 
            +
                    for idx in range(len(split_idx) - 1):
         | 
| 284 | 
            +
                        opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
         | 
| 285 | 
            +
                else:
         | 
| 286 | 
            +
                    opts = [inp]
         | 
| 287 | 
            +
                return "\n".join(opts)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
             | 
| 290 | 
            +
            def cut2(inp):
         | 
| 291 | 
            +
                inp = inp.strip("\n")
         | 
| 292 | 
            +
                inps = split(inp)
         | 
| 293 | 
            +
                if len(inps) < 2:
         | 
| 294 | 
            +
                    return [inp]
         | 
| 295 | 
            +
                opts = []
         | 
| 296 | 
            +
                summ = 0
         | 
| 297 | 
            +
                tmp_str = ""
         | 
| 298 | 
            +
                for i in range(len(inps)):
         | 
| 299 | 
            +
                    summ += len(inps[i])
         | 
| 300 | 
            +
                    tmp_str += inps[i]
         | 
| 301 | 
            +
                    if summ > 50:
         | 
| 302 | 
            +
                        summ = 0
         | 
| 303 | 
            +
                        opts.append(tmp_str)
         | 
| 304 | 
            +
                        tmp_str = ""
         | 
| 305 | 
            +
                if tmp_str != "":
         | 
| 306 | 
            +
                    opts.append(tmp_str)
         | 
| 307 | 
            +
                if len(opts[-1]) < 50:  ##如果最后一个太短了,和前一个合一起
         | 
| 308 | 
            +
                    opts[-2] = opts[-2] + opts[-1]
         | 
| 309 | 
            +
                    opts = opts[:-1]
         | 
| 310 | 
            +
                return "\n".join(opts)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
             | 
| 313 | 
            +
            def cut3(inp):
         | 
| 314 | 
            +
                inp = inp.strip("\n")
         | 
| 315 | 
            +
                return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            with gr.Blocks(title="GPT-SoVITS WebUI") as app:
         | 
| 319 | 
            +
                gr.Markdown(
         | 
| 320 | 
            +
                    value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
         | 
| 321 | 
            +
                )
         | 
| 322 | 
            +
                # with gr.Tabs():
         | 
| 323 | 
            +
                #     with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
         | 
| 324 | 
            +
                with gr.Group():
         | 
| 325 | 
            +
                    gr.Markdown(value="*请上传并填写参考信息")
         | 
| 326 | 
            +
                    with gr.Row():
         | 
| 327 | 
            +
                        inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
         | 
| 328 | 
            +
                        prompt_text = gr.Textbox(label="参考音频的文本", value="")
         | 
| 329 | 
            +
                        prompt_language = gr.Dropdown(
         | 
| 330 | 
            +
                            label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
         | 
| 331 | 
            +
                        )
         | 
| 332 | 
            +
                    gr.Markdown(value="*请填写需要合成的目标文本")
         | 
| 333 | 
            +
                    with gr.Row():
         | 
| 334 | 
            +
                        text = gr.Textbox(label="需要合成的文本", value="")
         | 
| 335 | 
            +
                        text_language = gr.Dropdown(
         | 
| 336 | 
            +
                            label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
         | 
| 337 | 
            +
                        )
         | 
| 338 | 
            +
                        inference_button = gr.Button("合成语音", variant="primary")
         | 
| 339 | 
            +
                        output = gr.Audio(label="输出的语音")
         | 
| 340 | 
            +
                    inference_button.click(
         | 
| 341 | 
            +
                        get_tts_wav,
         | 
| 342 | 
            +
                        [inp_ref, prompt_text, prompt_language, text, text_language],
         | 
| 343 | 
            +
                        [output],
         | 
| 344 | 
            +
                    )
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
         | 
| 347 | 
            +
                    with gr.Row():
         | 
| 348 | 
            +
                        text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
         | 
| 349 | 
            +
                        button1 = gr.Button("凑五句一切", variant="primary")
         | 
| 350 | 
            +
                        button2 = gr.Button("凑50字一切", variant="primary")
         | 
| 351 | 
            +
                        button3 = gr.Button("按中文句号。切", variant="primary")
         | 
| 352 | 
            +
                        text_opt = gr.Textbox(label="切分后文本", value="")
         | 
| 353 | 
            +
                        button1.click(cut1, [text_inp], [text_opt])
         | 
| 354 | 
            +
                        button2.click(cut2, [text_inp], [text_opt])
         | 
| 355 | 
            +
                        button3.click(cut3, [text_inp], [text_opt])
         | 
| 356 | 
            +
                    gr.Markdown(value="后续将支持混合语种编码文本输入。")
         | 
| 357 | 
            +
             | 
| 358 | 
            +
            app.queue(concurrency_count=511, max_size=1022).launch(
         | 
| 359 | 
            +
                server_name="0.0.0.0",
         | 
| 360 | 
            +
                inbrowser=True,
         | 
| 361 | 
            +
                server_port=infer_ttswebui,
         | 
| 362 | 
            +
                quiet=True,
         | 
| 363 | 
            +
            )
         | 
