Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		zhzluke96
		
	commited on
		
		
					Commit 
							
							·
						
						d5d0921
	
1
								Parent(s):
							
							2be0618
								
update
Browse files- data/speakers/Bob_ft10.pt +3 -0
 - modules/ChatTTS/ChatTTS/core.py +1 -1
 - modules/SynthesizeSegments.py +40 -7
 - modules/api/app_config.py +2 -2
 - modules/api/impl/google_api.py +66 -107
 - modules/api/impl/handler/AudioHandler.py +37 -0
 - modules/api/impl/handler/SSMLHandler.py +94 -0
 - modules/api/impl/handler/TTSHandler.py +97 -0
 - modules/api/impl/model/audio_model.py +14 -0
 - modules/api/impl/model/chattts_model.py +19 -0
 - modules/api/impl/model/enhancer_model.py +11 -0
 - modules/api/impl/openai_api.py +57 -56
 - modules/api/impl/refiner_api.py +1 -0
 - modules/api/impl/ssml_api.py +30 -25
 - modules/api/impl/tts_api.py +58 -31
 - modules/api/impl/xtts_v2_api.py +52 -6
 - modules/api/utils.py +2 -11
 - modules/devices/devices.py +7 -1
 - modules/finetune/train_speaker.py +18 -11
 - modules/prompts/news_oral_prompt.txt +14 -0
 - modules/prompts/podcast_prompt.txt +1 -0
 - modules/ssml_parser/SSMLParser.py +1 -4
 - modules/webui/speaker/speaker_editor.py +1 -1
 - modules/webui/speaker/speaker_merger.py +2 -6
 
    	
        data/speakers/Bob_ft10.pt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:91015b82a99c40034048090228b6d647ab99fd7b86e8babd6a7c3a9236e8d800
         
     | 
| 3 | 
         
            +
            size 4508
         
     | 
    	
        modules/ChatTTS/ChatTTS/core.py
    CHANGED
    
    | 
         @@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code 
     | 
|
| 17 | 
         | 
| 18 | 
         
             
            from huggingface_hub import snapshot_download
         
     | 
| 19 | 
         | 
| 20 | 
         
            -
            logging.basicConfig(level=logging. 
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
             
            class Chat:
         
     | 
| 
         | 
|
| 17 | 
         | 
| 18 | 
         
             
            from huggingface_hub import snapshot_download
         
     | 
| 19 | 
         | 
| 20 | 
         
            +
            logging.basicConfig(level=logging.INFO)
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
             
            class Chat:
         
     | 
    	
        modules/SynthesizeSegments.py
    CHANGED
    
    | 
         @@ -1,4 +1,5 @@ 
     | 
|
| 1 | 
         
             
            import copy
         
     | 
| 
         | 
|
| 2 | 
         
             
            from box import Box
         
     | 
| 3 | 
         
             
            from pydub import AudioSegment
         
     | 
| 4 | 
         
             
            from typing import List, Union
         
     | 
| 
         @@ -160,7 +161,21 @@ class SynthesizeSegments: 
     | 
|
| 160 | 
         
             
                    for i in range(0, len(bucket), self.batch_size):
         
     | 
| 161 | 
         
             
                        batch = bucket[i : i + self.batch_size]
         
     | 
| 162 | 
         
             
                        param_arr = [self.segment_to_generate_params(segment) for segment in batch]
         
     | 
| 163 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 164 | 
         | 
| 165 | 
         
             
                        params = param_arr[0]
         
     | 
| 166 | 
         
             
                        audio_datas = generate_audio.generate_audio_batch(
         
     | 
| 
         @@ -182,6 +197,7 @@ class SynthesizeSegments: 
     | 
|
| 182 | 
         | 
| 183 | 
         
             
                            audio_segment = audio_data_to_segment(audio_data, sr)
         
     | 
| 184 | 
         
             
                            audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
         
     | 
| 
         | 
|
| 185 | 
         
             
                            original_index = src_segments.index(segment)
         
     | 
| 186 | 
         
             
                            audio_segments[original_index] = audio_segment
         
     | 
| 187 | 
         | 
| 
         @@ -226,13 +242,30 @@ class SynthesizeSegments: 
     | 
|
| 226 | 
         | 
| 227 | 
         
             
                        sentences = spliter.parse(text)
         
     | 
| 228 | 
         
             
                        for sentence in sentences:
         
     | 
| 229 | 
         
            -
                             
     | 
| 230 | 
         
            -
                                 
     | 
| 231 | 
         
            -
             
     | 
| 232 | 
         
            -
             
     | 
| 233 | 
         
            -
                                    params=copy.copy(segment.params),
         
     | 
| 234 | 
         
            -
                                )
         
     | 
| 235 | 
         
             
                            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 236 | 
         | 
| 237 | 
         
             
                    return ret_segments
         
     | 
| 238 | 
         | 
| 
         | 
|
| 1 | 
         
             
            import copy
         
     | 
| 2 | 
         
            +
            import re
         
     | 
| 3 | 
         
             
            from box import Box
         
     | 
| 4 | 
         
             
            from pydub import AudioSegment
         
     | 
| 5 | 
         
             
            from typing import List, Union
         
     | 
| 
         | 
|
| 161 | 
         
             
                    for i in range(0, len(bucket), self.batch_size):
         
     | 
| 162 | 
         
             
                        batch = bucket[i : i + self.batch_size]
         
     | 
| 163 | 
         
             
                        param_arr = [self.segment_to_generate_params(segment) for segment in batch]
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                        def append_eos(text: str):
         
     | 
| 166 | 
         
            +
                            text = text.strip()
         
     | 
| 167 | 
         
            +
                            eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"]
         
     | 
| 168 | 
         
            +
                            has_eos = False
         
     | 
| 169 | 
         
            +
                            for eos in eos_arr:
         
     | 
| 170 | 
         
            +
                                if eos in text:
         
     | 
| 171 | 
         
            +
                                    has_eos = True
         
     | 
| 172 | 
         
            +
                                    break
         
     | 
| 173 | 
         
            +
                            if not has_eos:
         
     | 
| 174 | 
         
            +
                                text += self.eos
         
     | 
| 175 | 
         
            +
                            return text
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                        # 这里会添加 end_of_text 到 text 之后
         
     | 
| 178 | 
         
            +
                        texts = [append_eos(params.text) for params in param_arr]
         
     | 
| 179 | 
         | 
| 180 | 
         
             
                        params = param_arr[0]
         
     | 
| 181 | 
         
             
                        audio_datas = generate_audio.generate_audio_batch(
         
     | 
| 
         | 
|
| 197 | 
         | 
| 198 | 
         
             
                            audio_segment = audio_data_to_segment(audio_data, sr)
         
     | 
| 199 | 
         
             
                            audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
         
     | 
| 200 | 
         
            +
                            # compare by Box object
         
     | 
| 201 | 
         
             
                            original_index = src_segments.index(segment)
         
     | 
| 202 | 
         
             
                            audio_segments[original_index] = audio_segment
         
     | 
| 203 | 
         | 
| 
         | 
|
| 242 | 
         | 
| 243 | 
         
             
                        sentences = spliter.parse(text)
         
     | 
| 244 | 
         
             
                        for sentence in sentences:
         
     | 
| 245 | 
         
            +
                            seg = SSMLSegment(
         
     | 
| 246 | 
         
            +
                                text=sentence,
         
     | 
| 247 | 
         
            +
                                attrs=segment.attrs.copy(),
         
     | 
| 248 | 
         
            +
                                params=copy.copy(segment.params),
         
     | 
| 
         | 
|
| 
         | 
|
| 249 | 
         
             
                            )
         
     | 
| 250 | 
         
            +
                            ret_segments.append(seg)
         
     | 
| 251 | 
         
            +
                            setattr(seg, "_idx", len(ret_segments) - 1)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    def is_none_speak_segment(segment: SSMLSegment):
         
     | 
| 254 | 
         
            +
                        text = segment.text.strip()
         
     | 
| 255 | 
         
            +
                        regexp = r"\[[^\]]+?\]"
         
     | 
| 256 | 
         
            +
                        text = re.sub(regexp, "", text)
         
     | 
| 257 | 
         
            +
                        text = text.strip()
         
     | 
| 258 | 
         
            +
                        if not text:
         
     | 
| 259 | 
         
            +
                            return True
         
     | 
| 260 | 
         
            +
                        return False
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    # 将 none_speak 合并到前一个 speak segment
         
     | 
| 263 | 
         
            +
                    for i in range(1, len(ret_segments)):
         
     | 
| 264 | 
         
            +
                        if is_none_speak_segment(ret_segments[i]):
         
     | 
| 265 | 
         
            +
                            ret_segments[i - 1].text += ret_segments[i].text
         
     | 
| 266 | 
         
            +
                            ret_segments[i].text = ""
         
     | 
| 267 | 
         
            +
                    # 移除空的 segment
         
     | 
| 268 | 
         
            +
                    ret_segments = [seg for seg in ret_segments if seg.text.strip()]
         
     | 
| 269 | 
         | 
| 270 | 
         
             
                    return ret_segments
         
     | 
| 271 | 
         | 
    	
        modules/api/app_config.py
    CHANGED
    
    | 
         @@ -1,6 +1,6 @@ 
     | 
|
| 1 | 
         
             
            app_description = """
         
     | 
| 2 | 
         
            -
            ChatTTS-Forge  
     | 
| 3 | 
         
            -
            ChatTTS-Forge is a  
     | 
| 4 | 
         | 
| 5 | 
         
             
            项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
         
     | 
| 6 | 
         | 
| 
         | 
|
| 1 | 
         
             
            app_description = """
         
     | 
| 2 | 
         
            +
            🍦 ChatTTS-Forge 是一个围绕 TTS 生成模型 ChatTTS 开发的项目,实现了 API Server 和 基于 Gradio 的 WebUI。<br/>
         
     | 
| 3 | 
         
            +
            🍦 ChatTTS-Forge is a project developed around the TTS generation model ChatTTS, implementing an API Server and a Gradio-based WebUI.
         
     | 
| 4 | 
         | 
| 5 | 
         
             
            项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
         
     | 
| 6 | 
         | 
    	
        modules/api/impl/google_api.py
    CHANGED
    
    | 
         @@ -1,38 +1,25 @@ 
     | 
|
| 1 | 
         
            -
            import  
     | 
| 2 | 
         
            -
            from typing import Literal
         
     | 
| 3 | 
         
             
            from fastapi import HTTPException
         
     | 
| 4 | 
         | 
| 5 | 
         
            -
            import io
         
     | 
| 6 | 
         
            -
            import soundfile as sf
         
     | 
| 7 | 
         
             
            from pydantic import BaseModel
         
     | 
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
            -
            from modules.Enhancer.ResembleEnhance import (
         
     | 
| 11 | 
         
            -
                apply_audio_enhance,
         
     | 
| 12 | 
         
            -
                apply_audio_enhance_full,
         
     | 
| 13 | 
         
            -
            )
         
     | 
| 14 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 15 | 
         
            -
            from modules. 
     | 
| 16 | 
         
            -
            from modules. 
     | 
| 17 | 
         
            -
            from modules. 
     | 
| 18 | 
         
            -
            from modules. 
     | 
| 
         | 
|
| 19 | 
         | 
| 20 | 
         
            -
            from modules import  
     | 
| 21 | 
         
            -
            from modules.speaker import speaker_mgr
         
     | 
| 22 | 
         | 
| 23 | 
         | 
| 24 | 
         
            -
            from modules.ssml_parser.SSMLParser import create_ssml_parser
         
     | 
| 25 | 
         
            -
            from modules.SynthesizeSegments import (
         
     | 
| 26 | 
         
            -
                SynthesizeSegments,
         
     | 
| 27 | 
         
            -
                combine_audio_segments,
         
     | 
| 28 | 
         
            -
            )
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
             
            from modules.api import utils as api_utils
         
     | 
| 31 | 
         | 
| 32 | 
         | 
| 33 | 
         
             
            class SynthesisInput(BaseModel):
         
     | 
| 34 | 
         
            -
                text: str =  
     | 
| 35 | 
         
            -
                ssml: str =  
     | 
| 36 | 
         | 
| 37 | 
         | 
| 38 | 
         
             
            class VoiceSelectionParams(BaseModel):
         
     | 
| 
         @@ -50,24 +37,15 @@ class VoiceSelectionParams(BaseModel): 
     | 
|
| 50 | 
         | 
| 51 | 
         | 
| 52 | 
         
             
            class AudioConfig(BaseModel):
         
     | 
| 53 | 
         
            -
                audioEncoding:  
     | 
| 54 | 
         
             
                speakingRate: float = 1
         
     | 
| 55 | 
         
             
                pitch: float = 0
         
     | 
| 56 | 
         
             
                volumeGainDb: float = 0
         
     | 
| 57 | 
         
             
                sampleRateHertz: int = 24000
         
     | 
| 58 | 
         
            -
                batchSize: int =  
     | 
| 59 | 
         
             
                spliterThreshold: int = 100
         
     | 
| 60 | 
         | 
| 61 | 
         | 
| 62 | 
         
            -
            class EnhancerConfig(BaseModel):
         
     | 
| 63 | 
         
            -
                enabled: bool = False
         
     | 
| 64 | 
         
            -
                model: str = "resemble-enhance"
         
     | 
| 65 | 
         
            -
                nfe: int = 32
         
     | 
| 66 | 
         
            -
                solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
         
     | 
| 67 | 
         
            -
                lambd: float = 0.5
         
     | 
| 68 | 
         
            -
                tau: float = 0.5
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
             
            class GoogleTextSynthesizeRequest(BaseModel):
         
     | 
| 72 | 
         
             
                input: SynthesisInput
         
     | 
| 73 | 
         
             
                voice: VoiceSelectionParams
         
     | 
| 
         @@ -92,7 +70,11 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): 
     | 
|
| 92 | 
         
             
                voice_name = voice.name
         
     | 
| 93 | 
         
             
                infer_seed = voice.seed or 42
         
     | 
| 94 | 
         
             
                eos = voice.eos or "[uv_break]"
         
     | 
| 95 | 
         
            -
                audio_format = audioConfig.audioEncoding 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 96 | 
         
             
                speaking_rate = audioConfig.speakingRate or 1
         
     | 
| 97 | 
         
             
                pitch = audioConfig.pitch or 0
         
     | 
| 98 | 
         
             
                volume_gain_db = audioConfig.volumeGainDb or 0
         
     | 
| 
         @@ -101,6 +83,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): 
     | 
|
| 101 | 
         | 
| 102 | 
         
             
                spliter_threshold = audioConfig.spliterThreshold or 100
         
     | 
| 103 | 
         | 
| 
         | 
|
| 104 | 
         
             
                sample_rate = audioConfig.sampleRateHertz or 24000
         
     | 
| 105 | 
         | 
| 106 | 
         
             
                params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
         
     | 
| 
         @@ -111,92 +94,68 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest): 
     | 
|
| 111 | 
         
             
                        status_code=422, detail="The specified voice name is not supported."
         
     | 
| 112 | 
         
             
                    )
         
     | 
| 113 | 
         | 
| 114 | 
         
            -
                if  
     | 
| 115 | 
         
             
                    raise HTTPException(
         
     | 
| 116 | 
         
            -
                        status_code=422, detail=" 
     | 
| 117 | 
         
             
                    )
         
     | 
| 118 | 
         | 
| 119 | 
         
            -
                 
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
                     
     | 
| 122 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 123 | 
         
             
                try:
         
     | 
| 124 | 
         
             
                    if input.text:
         
     | 
| 125 | 
         
            -
                         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
                         
     | 
| 128 | 
         
            -
                             
     | 
| 129 | 
         
            -
                             
     | 
| 130 | 
         
            -
             
     | 
| 131 | 
         
            -
             
     | 
| 132 | 
         
            -
             
     | 
| 133 | 
         
            -
                             
     | 
| 134 | 
         
            -
                            top_P=voice.topP if voice.topP else params.get("top_p", 0.7),
         
     | 
| 135 | 
         
            -
                            top_K=voice.topK if voice.topK else params.get("top_k", 20),
         
     | 
| 136 | 
         
            -
                            spk=params.get("spk", -1),
         
     | 
| 137 | 
         
            -
                            infer_seed=infer_seed,
         
     | 
| 138 | 
         
            -
                            prompt1=params.get("prompt1", ""),
         
     | 
| 139 | 
         
            -
                            prompt2=params.get("prompt2", ""),
         
     | 
| 140 | 
         
            -
                            prefix=params.get("prefix", ""),
         
     | 
| 141 | 
         
            -
                            batch_size=batch_size,
         
     | 
| 142 | 
         
            -
                            spliter_threshold=spliter_threshold,
         
     | 
| 143 | 
         
            -
                            end_of_sentence=eos,
         
     | 
| 144 | 
         
             
                        )
         
     | 
| 145 | 
         | 
| 146 | 
         
            -
             
     | 
| 147 | 
         
            -
                         
     | 
| 148 | 
         
            -
                        segments = parser.parse(input.ssml)
         
     | 
| 149 | 
         
            -
                        for seg in segments:
         
     | 
| 150 | 
         
            -
                            seg["text"] = text_normalize(seg["text"], is_end=True)
         
     | 
| 151 | 
         
            -
             
     | 
| 152 | 
         
            -
                        if len(segments) == 0:
         
     | 
| 153 | 
         
            -
                            raise HTTPException(
         
     | 
| 154 | 
         
            -
                                status_code=422, detail="The SSML text is empty or parsing failed."
         
     | 
| 155 | 
         
            -
                            )
         
     | 
| 156 | 
         
            -
             
     | 
| 157 | 
         
            -
                        synthesize = SynthesizeSegments(
         
     | 
| 158 | 
         
            -
                            batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
         
     | 
| 159 | 
         
            -
                        )
         
     | 
| 160 | 
         
            -
                        audio_segments = synthesize.synthesize_segments(segments)
         
     | 
| 161 | 
         
            -
                        combined_audio = combine_audio_segments(audio_segments)
         
     | 
| 162 | 
         | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
             
     | 
| 165 | 
         
            -
                        raise HTTPException(
         
     | 
| 166 | 
         
            -
                            status_code=422, detail="Either text or SSML input must be provided."
         
     | 
| 167 | 
         
            -
                        )
         
     | 
| 168 | 
         | 
| 169 | 
         
            -
             
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
                             
     | 
| 172 | 
         
            -
                             
     | 
| 173 | 
         
            -
                             
     | 
| 174 | 
         
            -
                            solver=enhancerConfig.solver,
         
     | 
| 175 | 
         
            -
                            lambd=enhancerConfig.lambd,
         
     | 
| 176 | 
         
            -
                            tau=enhancerConfig.tau,
         
     | 
| 177 | 
         
             
                        )
         
     | 
| 178 | 
         | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
                        audio_data,
         
     | 
| 181 | 
         
            -
                        rate=speaking_rate,
         
     | 
| 182 | 
         
            -
                        pitch=pitch,
         
     | 
| 183 | 
         
            -
                        volume=volume_gain_db,
         
     | 
| 184 | 
         
            -
                        sr=sample_rate,
         
     | 
| 185 | 
         
            -
                    )
         
     | 
| 186 | 
         
            -
             
     | 
| 187 | 
         
            -
                    buffer = io.BytesIO()
         
     | 
| 188 | 
         
            -
                    sf.write(buffer, audio_data, sample_rate, format="wav")
         
     | 
| 189 | 
         
            -
                    buffer.seek(0)
         
     | 
| 190 | 
         | 
| 191 | 
         
            -
             
     | 
| 192 | 
         
            -
                        buffer = api_utils.wav_to_mp3(buffer)
         
     | 
| 193 | 
         | 
| 194 | 
         
            -
                     
     | 
| 195 | 
         
            -
             
     | 
| 196 | 
         
            -
             
     | 
| 197 | 
         
            -
             
     | 
| 198 | 
         
            -
                        "audioContent": f"data:audio/{audio_format.lower()};base64,{base64_string}"
         
     | 
| 199 | 
         
            -
                    }
         
     | 
| 200 | 
         | 
| 201 | 
         
             
                except Exception as e:
         
     | 
| 202 | 
         
             
                    import logging
         
     | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Union
         
     | 
| 
         | 
|
| 2 | 
         
             
            from fastapi import HTTPException
         
     | 
| 3 | 
         | 
| 
         | 
|
| 
         | 
|
| 4 | 
         
             
            from pydantic import BaseModel
         
     | 
| 5 | 
         | 
| 6 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 8 | 
         
            +
            from modules.api.impl.handler.SSMLHandler import SSMLHandler
         
     | 
| 9 | 
         
            +
            from modules.api.impl.handler.TTSHandler import TTSHandler
         
     | 
| 10 | 
         
            +
            from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
         
     | 
| 11 | 
         
            +
            from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
         
     | 
| 12 | 
         
            +
            from modules.api.impl.model.enhancer_model import EnhancerConfig
         
     | 
| 13 | 
         | 
| 14 | 
         
            +
            from modules.speaker import Speaker, speaker_mgr
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 17 | 
         
             
            from modules.api import utils as api_utils
         
     | 
| 18 | 
         | 
| 19 | 
         | 
| 20 | 
         
             
            class SynthesisInput(BaseModel):
         
     | 
| 21 | 
         
            +
                text: Union[str, None] = None
         
     | 
| 22 | 
         
            +
                ssml: Union[str, None] = None
         
     | 
| 23 | 
         | 
| 24 | 
         | 
| 25 | 
         
             
            class VoiceSelectionParams(BaseModel):
         
     | 
| 
         | 
|
| 37 | 
         | 
| 38 | 
         | 
| 39 | 
         
             
            class AudioConfig(BaseModel):
         
     | 
| 40 | 
         
            +
                audioEncoding: AudioFormat = AudioFormat.mp3
         
     | 
| 41 | 
         
             
                speakingRate: float = 1
         
     | 
| 42 | 
         
             
                pitch: float = 0
         
     | 
| 43 | 
         
             
                volumeGainDb: float = 0
         
     | 
| 44 | 
         
             
                sampleRateHertz: int = 24000
         
     | 
| 45 | 
         
            +
                batchSize: int = 4
         
     | 
| 46 | 
         
             
                spliterThreshold: int = 100
         
     | 
| 47 | 
         | 
| 48 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 49 | 
         
             
            class GoogleTextSynthesizeRequest(BaseModel):
         
     | 
| 50 | 
         
             
                input: SynthesisInput
         
     | 
| 51 | 
         
             
                voice: VoiceSelectionParams
         
     | 
| 
         | 
|
| 70 | 
         
             
                voice_name = voice.name
         
     | 
| 71 | 
         
             
                infer_seed = voice.seed or 42
         
     | 
| 72 | 
         
             
                eos = voice.eos or "[uv_break]"
         
     | 
| 73 | 
         
            +
                audio_format = audioConfig.audioEncoding
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                if not isinstance(audio_format, AudioFormat) and isinstance(audio_format, str):
         
     | 
| 76 | 
         
            +
                    audio_format = AudioFormat(audio_format)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
             
                speaking_rate = audioConfig.speakingRate or 1
         
     | 
| 79 | 
         
             
                pitch = audioConfig.pitch or 0
         
     | 
| 80 | 
         
             
                volume_gain_db = audioConfig.volumeGainDb or 0
         
     | 
| 
         | 
|
| 83 | 
         | 
| 84 | 
         
             
                spliter_threshold = audioConfig.spliterThreshold or 100
         
     | 
| 85 | 
         | 
| 86 | 
         
            +
                # TODO
         
     | 
| 87 | 
         
             
                sample_rate = audioConfig.sampleRateHertz or 24000
         
     | 
| 88 | 
         | 
| 89 | 
         
             
                params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
         
     | 
| 
         | 
|
| 94 | 
         
             
                        status_code=422, detail="The specified voice name is not supported."
         
     | 
| 95 | 
         
             
                    )
         
     | 
| 96 | 
         | 
| 97 | 
         
            +
                if not isinstance(params.get("spk"), Speaker):
         
     | 
| 98 | 
         
             
                    raise HTTPException(
         
     | 
| 99 | 
         
            +
                        status_code=422, detail="The specified voice name is not supported."
         
     | 
| 100 | 
         
             
                    )
         
     | 
| 101 | 
         | 
| 102 | 
         
            +
                speaker = params.get("spk")
         
     | 
| 103 | 
         
            +
                tts_config = ChatTTSConfig(
         
     | 
| 104 | 
         
            +
                    style=params.get("style", ""),
         
     | 
| 105 | 
         
            +
                    temperature=voice.temperature,
         
     | 
| 106 | 
         
            +
                    top_k=voice.topK,
         
     | 
| 107 | 
         
            +
                    top_p=voice.topP,
         
     | 
| 108 | 
         
            +
                )
         
     | 
| 109 | 
         
            +
                infer_config = InferConfig(
         
     | 
| 110 | 
         
            +
                    batch_size=batch_size,
         
     | 
| 111 | 
         
            +
                    spliter_threshold=spliter_threshold,
         
     | 
| 112 | 
         
            +
                    eos=eos,
         
     | 
| 113 | 
         
            +
                    seed=infer_seed,
         
     | 
| 114 | 
         
            +
                )
         
     | 
| 115 | 
         
            +
                adjust_config = AdjustConfig(
         
     | 
| 116 | 
         
            +
                    speaking_rate=speaking_rate,
         
     | 
| 117 | 
         
            +
                    pitch=pitch,
         
     | 
| 118 | 
         
            +
                    volume_gain_db=volume_gain_db,
         
     | 
| 119 | 
         
            +
                )
         
     | 
| 120 | 
         
            +
                enhancer_config = enhancerConfig
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                mime_type = f"audio/{audio_format.value}"
         
     | 
| 123 | 
         
            +
                if audio_format == AudioFormat.mp3:
         
     | 
| 124 | 
         
            +
                    mime_type = "audio/mpeg"
         
     | 
| 125 | 
         
             
                try:
         
     | 
| 126 | 
         
             
                    if input.text:
         
     | 
| 127 | 
         
            +
                        text_content = input.text
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                        handler = TTSHandler(
         
     | 
| 130 | 
         
            +
                            text_content=text_content,
         
     | 
| 131 | 
         
            +
                            spk=speaker,
         
     | 
| 132 | 
         
            +
                            tts_config=tts_config,
         
     | 
| 133 | 
         
            +
                            infer_config=infer_config,
         
     | 
| 134 | 
         
            +
                            adjust_config=adjust_config,
         
     | 
| 135 | 
         
            +
                            enhancer_config=enhancer_config,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 136 | 
         
             
                        )
         
     | 
| 137 | 
         | 
| 138 | 
         
            +
                        base64_string = handler.enqueue_to_base64(format=audio_format)
         
     | 
| 139 | 
         
            +
                        return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 140 | 
         | 
| 141 | 
         
            +
                    elif input.ssml:
         
     | 
| 142 | 
         
            +
                        ssml_content = input.ssml
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 143 | 
         | 
| 144 | 
         
            +
                        handler = SSMLHandler(
         
     | 
| 145 | 
         
            +
                            ssml_content=ssml_content,
         
     | 
| 146 | 
         
            +
                            infer_config=infer_config,
         
     | 
| 147 | 
         
            +
                            adjust_config=adjust_config,
         
     | 
| 148 | 
         
            +
                            enhancer_config=enhancer_config,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 149 | 
         
             
                        )
         
     | 
| 150 | 
         | 
| 151 | 
         
            +
                        base64_string = handler.enqueue_to_base64(format=audio_format)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 152 | 
         | 
| 153 | 
         
            +
                        return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
         
     | 
| 
         | 
|
| 154 | 
         | 
| 155 | 
         
            +
                    else:
         
     | 
| 156 | 
         
            +
                        raise HTTPException(
         
     | 
| 157 | 
         
            +
                            status_code=422, detail="Invalid input text or ssml specified."
         
     | 
| 158 | 
         
            +
                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 159 | 
         | 
| 160 | 
         
             
                except Exception as e:
         
     | 
| 161 | 
         
             
                    import logging
         
     | 
    	
        modules/api/impl/handler/AudioHandler.py
    ADDED
    
    | 
         @@ -0,0 +1,37 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import base64
         
     | 
| 2 | 
         
            +
            import io
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import soundfile as sf
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from modules.api.impl.model.audio_model import AudioFormat
         
     | 
| 7 | 
         
            +
            from modules.api import utils as api_utils
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class AudioHandler:
         
     | 
| 11 | 
         
            +
                def enqueue(self) -> tuple[np.ndarray, int]:
         
     | 
| 12 | 
         
            +
                    raise NotImplementedError
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
         
     | 
| 15 | 
         
            +
                    audio_data, sample_rate = self.enqueue()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    buffer = io.BytesIO()
         
     | 
| 18 | 
         
            +
                    sf.write(buffer, audio_data, sample_rate, format="wav")
         
     | 
| 19 | 
         
            +
                    buffer.seek(0)
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    if format == AudioFormat.mp3:
         
     | 
| 22 | 
         
            +
                        buffer = api_utils.wav_to_mp3(buffer)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    return buffer
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def enqueue_to_bytes(self, format: AudioFormat) -> bytes:
         
     | 
| 27 | 
         
            +
                    buffer = self.enqueue_to_buffer(format=format)
         
     | 
| 28 | 
         
            +
                    binary = buffer.read()
         
     | 
| 29 | 
         
            +
                    return binary
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def enqueue_to_base64(self, format: AudioFormat) -> str:
         
     | 
| 32 | 
         
            +
                    binary = self.enqueue_to_bytes(format=format)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    base64_encoded = base64.b64encode(binary)
         
     | 
| 35 | 
         
            +
                    base64_string = base64_encoded.decode("utf-8")
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    return base64_string
         
     | 
    	
        modules/api/impl/handler/SSMLHandler.py
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from fastapi import HTTPException
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
         
     | 
| 5 | 
         
            +
            from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
         
     | 
| 6 | 
         
            +
            from modules.api.impl.handler.AudioHandler import AudioHandler
         
     | 
| 7 | 
         
            +
            from modules.api.impl.model.audio_model import AdjustConfig
         
     | 
| 8 | 
         
            +
            from modules.api.impl.model.chattts_model import InferConfig
         
     | 
| 9 | 
         
            +
            from modules.api.impl.model.enhancer_model import EnhancerConfig
         
     | 
| 10 | 
         
            +
            from modules.normalization import text_normalize
         
     | 
| 11 | 
         
            +
            from modules.ssml_parser.SSMLParser import create_ssml_parser
         
     | 
| 12 | 
         
            +
            from modules.utils import audio
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class SSMLHandler(AudioHandler):
         
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    ssml_content: str,
         
     | 
| 19 | 
         
            +
                    infer_config: InferConfig,
         
     | 
| 20 | 
         
            +
                    adjust_config: AdjustConfig,
         
     | 
| 21 | 
         
            +
                    enhancer_config: EnhancerConfig,
         
     | 
| 22 | 
         
            +
                ) -> None:
         
     | 
| 23 | 
         
            +
                    assert isinstance(ssml_content, str), "ssml_content must be a string."
         
     | 
| 24 | 
         
            +
                    assert isinstance(
         
     | 
| 25 | 
         
            +
                        infer_config, InferConfig
         
     | 
| 26 | 
         
            +
                    ), "infer_config must be an InferConfig object."
         
     | 
| 27 | 
         
            +
                    assert isinstance(
         
     | 
| 28 | 
         
            +
                        adjust_config, AdjustConfig
         
     | 
| 29 | 
         
            +
                    ), "adjest_config should be AdjustConfig"
         
     | 
| 30 | 
         
            +
                    assert isinstance(
         
     | 
| 31 | 
         
            +
                        enhancer_config, EnhancerConfig
         
     | 
| 32 | 
         
            +
                    ), "enhancer_config must be an EnhancerConfig object."
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.ssml_content = ssml_content
         
     | 
| 35 | 
         
            +
                    self.infer_config = infer_config
         
     | 
| 36 | 
         
            +
                    self.adjest_config = adjust_config
         
     | 
| 37 | 
         
            +
                    self.enhancer_config = enhancer_config
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.validate()
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def validate(self):
         
     | 
| 42 | 
         
            +
                    # TODO params checker
         
     | 
| 43 | 
         
            +
                    pass
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                def enqueue(self) -> tuple[np.ndarray, int]:
         
     | 
| 46 | 
         
            +
                    ssml_content = self.ssml_content
         
     | 
| 47 | 
         
            +
                    infer_config = self.infer_config
         
     | 
| 48 | 
         
            +
                    adjust_config = self.adjest_config
         
     | 
| 49 | 
         
            +
                    enhancer_config = self.enhancer_config
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    parser = create_ssml_parser()
         
     | 
| 52 | 
         
            +
                    segments = parser.parse(ssml_content)
         
     | 
| 53 | 
         
            +
                    for seg in segments:
         
     | 
| 54 | 
         
            +
                        seg["text"] = text_normalize(seg["text"], is_end=True)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    if len(segments) == 0:
         
     | 
| 57 | 
         
            +
                        raise HTTPException(
         
     | 
| 58 | 
         
            +
                            status_code=422, detail="The SSML text is empty or parsing failed."
         
     | 
| 59 | 
         
            +
                        )
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    synthesize = SynthesizeSegments(
         
     | 
| 62 | 
         
            +
                        batch_size=infer_config.batch_size,
         
     | 
| 63 | 
         
            +
                        eos=infer_config.eos,
         
     | 
| 64 | 
         
            +
                        spliter_thr=infer_config.spliter_threshold,
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
                    audio_segments = synthesize.synthesize_segments(segments)
         
     | 
| 67 | 
         
            +
                    combined_audio = combine_audio_segments(audio_segments)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    sample_rate, audio_data = audio.pydub_to_np(combined_audio)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    if enhancer_config.enabled:
         
     | 
| 72 | 
         
            +
                        nfe = enhancer_config.nfe
         
     | 
| 73 | 
         
            +
                        solver = enhancer_config.solver
         
     | 
| 74 | 
         
            +
                        lambd = enhancer_config.lambd
         
     | 
| 75 | 
         
            +
                        tau = enhancer_config.tau
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                        audio_data, sample_rate = apply_audio_enhance_full(
         
     | 
| 78 | 
         
            +
                            audio_data=audio_data,
         
     | 
| 79 | 
         
            +
                            sr=sample_rate,
         
     | 
| 80 | 
         
            +
                            nfe=nfe,
         
     | 
| 81 | 
         
            +
                            solver=solver,
         
     | 
| 82 | 
         
            +
                            lambd=lambd,
         
     | 
| 83 | 
         
            +
                            tau=tau,
         
     | 
| 84 | 
         
            +
                        )
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    audio_data = audio.apply_prosody_to_audio_data(
         
     | 
| 87 | 
         
            +
                        audio_data=audio_data,
         
     | 
| 88 | 
         
            +
                        rate=adjust_config.speed_rate,
         
     | 
| 89 | 
         
            +
                        pitch=adjust_config.pitch,
         
     | 
| 90 | 
         
            +
                        volume=adjust_config.volume_gain_db,
         
     | 
| 91 | 
         
            +
                        sr=sample_rate,
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    return audio_data, sample_rate
         
     | 
    	
        modules/api/impl/handler/TTSHandler.py
    ADDED
    
    | 
         @@ -0,0 +1,97 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
         
     | 
| 3 | 
         
            +
            from modules.api.impl.handler.AudioHandler import AudioHandler
         
     | 
| 4 | 
         
            +
            from modules.api.impl.model.audio_model import AdjustConfig
         
     | 
| 5 | 
         
            +
            from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
         
     | 
| 6 | 
         
            +
            from modules.api.impl.model.enhancer_model import EnhancerConfig
         
     | 
| 7 | 
         
            +
            from modules.normalization import text_normalize
         
     | 
| 8 | 
         
            +
            from modules.speaker import Speaker
         
     | 
| 9 | 
         
            +
            from modules.synthesize_audio import synthesize_audio
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from modules.utils.audio import apply_prosody_to_audio_data
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class TTSHandler(AudioHandler):
         
     | 
| 15 | 
         
            +
                def __init__(
         
     | 
| 16 | 
         
            +
                    self,
         
     | 
| 17 | 
         
            +
                    text_content: str,
         
     | 
| 18 | 
         
            +
                    spk: Speaker,
         
     | 
| 19 | 
         
            +
                    tts_config: ChatTTSConfig,
         
     | 
| 20 | 
         
            +
                    infer_config: InferConfig,
         
     | 
| 21 | 
         
            +
                    adjust_config: AdjustConfig,
         
     | 
| 22 | 
         
            +
                    enhancer_config: EnhancerConfig,
         
     | 
| 23 | 
         
            +
                ):
         
     | 
| 24 | 
         
            +
                    assert isinstance(text_content, str), "text_content should be str"
         
     | 
| 25 | 
         
            +
                    assert isinstance(spk, Speaker), "spk should be Speaker"
         
     | 
| 26 | 
         
            +
                    assert isinstance(
         
     | 
| 27 | 
         
            +
                        tts_config, ChatTTSConfig
         
     | 
| 28 | 
         
            +
                    ), "tts_config should be ChatTTSConfig"
         
     | 
| 29 | 
         
            +
                    assert isinstance(
         
     | 
| 30 | 
         
            +
                        infer_config, InferConfig
         
     | 
| 31 | 
         
            +
                    ), "infer_config should be InferConfig"
         
     | 
| 32 | 
         
            +
                    assert isinstance(
         
     | 
| 33 | 
         
            +
                        adjust_config, AdjustConfig
         
     | 
| 34 | 
         
            +
                    ), "adjest_config should be AdjustConfig"
         
     | 
| 35 | 
         
            +
                    assert isinstance(
         
     | 
| 36 | 
         
            +
                        enhancer_config, EnhancerConfig
         
     | 
| 37 | 
         
            +
                    ), "enhancer_config should be EnhancerConfig"
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.text_content = text_content
         
     | 
| 40 | 
         
            +
                    self.spk = spk
         
     | 
| 41 | 
         
            +
                    self.tts_config = tts_config
         
     | 
| 42 | 
         
            +
                    self.infer_config = infer_config
         
     | 
| 43 | 
         
            +
                    self.adjest_config = adjust_config
         
     | 
| 44 | 
         
            +
                    self.enhancer_config = enhancer_config
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    self.validate()
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def validate(self):
         
     | 
| 49 | 
         
            +
                    # TODO params checker
         
     | 
| 50 | 
         
            +
                    pass
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def enqueue(self) -> tuple[np.ndarray, int]:
         
     | 
| 53 | 
         
            +
                    text = text_normalize(self.text_content)
         
     | 
| 54 | 
         
            +
                    tts_config = self.tts_config
         
     | 
| 55 | 
         
            +
                    infer_config = self.infer_config
         
     | 
| 56 | 
         
            +
                    adjust_config = self.adjest_config
         
     | 
| 57 | 
         
            +
                    enhancer_config = self.enhancer_config
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    sample_rate, audio_data = synthesize_audio(
         
     | 
| 60 | 
         
            +
                        text,
         
     | 
| 61 | 
         
            +
                        spk=self.spk,
         
     | 
| 62 | 
         
            +
                        temperature=tts_config.temperature,
         
     | 
| 63 | 
         
            +
                        top_P=tts_config.top_p,
         
     | 
| 64 | 
         
            +
                        top_K=tts_config.top_k,
         
     | 
| 65 | 
         
            +
                        prompt1=tts_config.prompt1,
         
     | 
| 66 | 
         
            +
                        prompt2=tts_config.prompt2,
         
     | 
| 67 | 
         
            +
                        prefix=tts_config.prefix,
         
     | 
| 68 | 
         
            +
                        infer_seed=infer_config.seed,
         
     | 
| 69 | 
         
            +
                        batch_size=infer_config.batch_size,
         
     | 
| 70 | 
         
            +
                        spliter_threshold=infer_config.spliter_threshold,
         
     | 
| 71 | 
         
            +
                        end_of_sentence=infer_config.eos,
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    if enhancer_config.enabled:
         
     | 
| 75 | 
         
            +
                        nfe = enhancer_config.nfe
         
     | 
| 76 | 
         
            +
                        solver = enhancer_config.solver
         
     | 
| 77 | 
         
            +
                        lambd = enhancer_config.lambd
         
     | 
| 78 | 
         
            +
                        tau = enhancer_config.tau
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                        audio_data, sample_rate = apply_audio_enhance_full(
         
     | 
| 81 | 
         
            +
                            audio_data=audio_data,
         
     | 
| 82 | 
         
            +
                            sr=sample_rate,
         
     | 
| 83 | 
         
            +
                            nfe=nfe,
         
     | 
| 84 | 
         
            +
                            solver=solver,
         
     | 
| 85 | 
         
            +
                            lambd=lambd,
         
     | 
| 86 | 
         
            +
                            tau=tau,
         
     | 
| 87 | 
         
            +
                        )
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    audio_data = apply_prosody_to_audio_data(
         
     | 
| 90 | 
         
            +
                        audio_data=audio_data,
         
     | 
| 91 | 
         
            +
                        rate=adjust_config.speed_rate,
         
     | 
| 92 | 
         
            +
                        pitch=adjust_config.pitch,
         
     | 
| 93 | 
         
            +
                        volume=adjust_config.volume_gain_db,
         
     | 
| 94 | 
         
            +
                        sr=sample_rate,
         
     | 
| 95 | 
         
            +
                    )
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    return audio_data, sample_rate
         
     | 
    	
        modules/api/impl/model/audio_model.py
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from enum import Enum
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class AudioFormat(str, Enum):
         
     | 
| 7 | 
         
            +
                mp3 = "mp3"
         
     | 
| 8 | 
         
            +
                wav = "wav"
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class AdjustConfig(BaseModel):
         
     | 
| 12 | 
         
            +
                pitch: float = 0
         
     | 
| 13 | 
         
            +
                speed_rate: float = 1
         
     | 
| 14 | 
         
            +
                volume_gain_db: float = 0
         
     | 
    	
        modules/api/impl/model/chattts_model.py
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            class ChatTTSConfig(BaseModel):
         
     | 
| 5 | 
         
            +
                style: str = ""
         
     | 
| 6 | 
         
            +
                temperature: float = 0.3
         
     | 
| 7 | 
         
            +
                top_p: float = 0.7
         
     | 
| 8 | 
         
            +
                top_k: int = 20
         
     | 
| 9 | 
         
            +
                prompt1: str = ""
         
     | 
| 10 | 
         
            +
                prompt2: str = ""
         
     | 
| 11 | 
         
            +
                prefix: str = ""
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class InferConfig(BaseModel):
         
     | 
| 15 | 
         
            +
                batch_size: int = 4
         
     | 
| 16 | 
         
            +
                spliter_threshold: int = 100
         
     | 
| 17 | 
         
            +
                # end_of_sentence
         
     | 
| 18 | 
         
            +
                eos: str = "[uv_break]"
         
     | 
| 19 | 
         
            +
                seed: int = 42
         
     | 
    	
        modules/api/impl/model/enhancer_model.py
    ADDED
    
    | 
         @@ -0,0 +1,11 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Literal
         
     | 
| 2 | 
         
            +
            from pydantic import BaseModel
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class EnhancerConfig(BaseModel):
         
     | 
| 6 | 
         
            +
                enabled: bool = False
         
     | 
| 7 | 
         
            +
                model: str = "resemble-enhance"
         
     | 
| 8 | 
         
            +
                nfe: int = 32
         
     | 
| 9 | 
         
            +
                solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
         
     | 
| 10 | 
         
            +
                lambd: float = 0.5
         
     | 
| 11 | 
         
            +
                tau: float = 0.5
         
     | 
    	
        modules/api/impl/openai_api.py
    CHANGED
    
    | 
         @@ -1,42 +1,38 @@ 
     | 
|
| 1 | 
         
             
            from fastapi import File, Form, HTTPException, Body, UploadFile
         
     | 
| 2 | 
         
            -
            from fastapi.responses import StreamingResponse
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            import io
         
     | 
| 5 | 
         
             
            from numpy import clip
         
     | 
| 6 | 
         
            -
            import soundfile as sf
         
     | 
| 7 | 
         
             
            from pydantic import BaseModel, Field
         
     | 
| 8 | 
         
            -
            from fastapi.responses import  
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         | 
| 11 | 
         
            -
            from modules.synthesize_audio import synthesize_audio
         
     | 
| 12 | 
         
            -
            from modules.normalization import text_normalize
         
     | 
| 13 | 
         | 
| 14 | 
         
            -
            from modules import  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
            -
            from typing import List,  
     | 
| 18 | 
         
            -
            import pyrubberband as pyrb
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            from modules.api import utils as api_utils
         
     | 
| 21 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 22 | 
         | 
| 23 | 
         
            -
            from modules.speaker import speaker_mgr
         
     | 
| 24 | 
         
             
            from modules.data import styles_mgr
         
     | 
| 25 | 
         | 
| 26 | 
         
            -
            import numpy as np
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         | 
| 29 | 
         
             
            class AudioSpeechRequest(BaseModel):
         
     | 
| 30 | 
         
             
                input: str  # 需要合成的文本
         
     | 
| 31 | 
         
             
                model: str = "chattts-4w"
         
     | 
| 32 | 
         
             
                voice: str = "female2"
         
     | 
| 33 | 
         
            -
                response_format:  
     | 
| 34 | 
         
             
                speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
         
     | 
| 35 | 
         
             
                seed: int = 42
         
     | 
| 
         | 
|
| 36 | 
         
             
                temperature: float = 0.3
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 37 | 
         
             
                style: str = ""
         
     | 
| 38 | 
         
            -
                # 是否开启batch合成,小于等于1表示不适用batch
         
     | 
| 39 | 
         
            -
                # 开启batch合成会自动分割句子
         
     | 
| 40 | 
         
             
                batch_size: int = Field(1, ge=1, le=20, description="Batch size")
         
     | 
| 41 | 
         
             
                spliter_threshold: float = Field(
         
     | 
| 42 | 
         
             
                    100, ge=10, le=1024, description="Threshold for sentence spliter"
         
     | 
| 
         @@ -44,6 +40,9 @@ class AudioSpeechRequest(BaseModel): 
     | 
|
| 44 | 
         
             
                # end of sentence
         
     | 
| 45 | 
         
             
                eos: str = "[uv_break]"
         
     | 
| 46 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 47 | 
         | 
| 48 | 
         
             
            async def openai_speech_api(
         
     | 
| 49 | 
         
             
                request: AudioSpeechRequest = Body(
         
     | 
| 
         @@ -55,7 +54,14 @@ async def openai_speech_api( 
     | 
|
| 55 | 
         
             
                voice = request.voice
         
     | 
| 56 | 
         
             
                style = request.style
         
     | 
| 57 | 
         
             
                eos = request.eos
         
     | 
| 
         | 
|
| 
         | 
|
| 58 | 
         
             
                response_format = request.response_format
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 59 | 
         
             
                batch_size = request.batch_size
         
     | 
| 60 | 
         
             
                spliter_threshold = request.spliter_threshold
         
     | 
| 61 | 
         
             
                speed = request.speed
         
     | 
| 
         @@ -71,49 +77,45 @@ async def openai_speech_api( 
     | 
|
| 71 | 
         
             
                except:
         
     | 
| 72 | 
         
             
                    raise HTTPException(status_code=400, detail="Invalid style.")
         
     | 
| 73 | 
         | 
| 74 | 
         
            -
                 
     | 
| 75 | 
         
            -
                    # Normalize the text
         
     | 
| 76 | 
         
            -
                    text = text_normalize(input_text, is_end=True)
         
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
                    # Calculate speaker and style based on input voice
         
     | 
| 79 | 
         
            -
                    params = api_utils.calc_spk_style(spk=voice, style=style)
         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
                    spk = params.get("spk", -1)
         
     | 
| 82 | 
         
            -
                    seed = params.get("seed", request.seed or 42)
         
     | 
| 83 | 
         
            -
                    temperature = params.get("temperature", request.temperature or 0.3)
         
     | 
| 84 | 
         
            -
                    prompt1 = params.get("prompt1", "")
         
     | 
| 85 | 
         
            -
                    prompt2 = params.get("prompt2", "")
         
     | 
| 86 | 
         
            -
                    prefix = params.get("prefix", "")
         
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                    # Generate audio
         
     | 
| 89 | 
         
            -
                    sample_rate, audio_data = synthesize_audio(
         
     | 
| 90 | 
         
            -
                        text,
         
     | 
| 91 | 
         
            -
                        temperature=temperature,
         
     | 
| 92 | 
         
            -
                        top_P=0.7,
         
     | 
| 93 | 
         
            -
                        top_K=20,
         
     | 
| 94 | 
         
            -
                        spk=spk,
         
     | 
| 95 | 
         
            -
                        infer_seed=seed,
         
     | 
| 96 | 
         
            -
                        batch_size=batch_size,
         
     | 
| 97 | 
         
            -
                        spliter_threshold=spliter_threshold,
         
     | 
| 98 | 
         
            -
                        prompt1=prompt1,
         
     | 
| 99 | 
         
            -
                        prompt2=prompt2,
         
     | 
| 100 | 
         
            -
                        prefix=prefix,
         
     | 
| 101 | 
         
            -
                        end_of_sentence=eos,
         
     | 
| 102 | 
         
            -
                    )
         
     | 
| 103 | 
         | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
             
     | 
| 
         | 
|
| 106 | 
         | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
                     
     | 
| 109 | 
         
            -
                     
     | 
| 110 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 111 | 
         | 
| 112 | 
         
            -
                     
     | 
| 113 | 
         
            -
                        # Convert wav to mp3
         
     | 
| 114 | 
         
            -
                        buffer = api_utils.wav_to_mp3(buffer)
         
     | 
| 115 | 
         | 
| 116 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 117 | 
         | 
| 118 | 
         
             
                except Exception as e:
         
     | 
| 119 | 
         
             
                    import logging
         
     | 
| 
         @@ -150,7 +152,6 @@ class TranscriptionsVerboseResponse(BaseModel): 
     | 
|
| 150 | 
         
             
            def setup(app: APIManager):
         
     | 
| 151 | 
         
             
                app.post(
         
     | 
| 152 | 
         
             
                    "/v1/audio/speech",
         
     | 
| 153 | 
         
            -
                    response_class=FileResponse,
         
     | 
| 154 | 
         
             
                    description="""
         
     | 
| 155 | 
         
             
            openai api document: 
         
     | 
| 156 | 
         
             
            [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
         
     | 
| 
         | 
|
| 1 | 
         
             
            from fastapi import File, Form, HTTPException, Body, UploadFile
         
     | 
| 
         | 
|
| 2 | 
         | 
| 
         | 
|
| 3 | 
         
             
            from numpy import clip
         
     | 
| 
         | 
|
| 4 | 
         
             
            from pydantic import BaseModel, Field
         
     | 
| 5 | 
         
            +
            from fastapi.responses import StreamingResponse
         
     | 
| 
         | 
|
| 6 | 
         | 
| 
         | 
|
| 
         | 
|
| 7 | 
         | 
| 8 | 
         
            +
            from modules.api.impl.handler.TTSHandler import TTSHandler
         
     | 
| 9 | 
         
            +
            from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
         
     | 
| 10 | 
         
            +
            from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
         
     | 
| 11 | 
         
            +
            from modules.api.impl.model.enhancer_model import EnhancerConfig
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
            +
            from typing import List, Optional
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         
             
            from modules.api import utils as api_utils
         
     | 
| 17 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 18 | 
         | 
| 19 | 
         
            +
            from modules.speaker import Speaker, speaker_mgr
         
     | 
| 20 | 
         
             
            from modules.data import styles_mgr
         
     | 
| 21 | 
         | 
| 
         | 
|
| 
         | 
|
| 22 | 
         | 
| 23 | 
         
             
            class AudioSpeechRequest(BaseModel):
         
     | 
| 24 | 
         
             
                input: str  # 需要合成的文本
         
     | 
| 25 | 
         
             
                model: str = "chattts-4w"
         
     | 
| 26 | 
         
             
                voice: str = "female2"
         
     | 
| 27 | 
         
            +
                response_format: AudioFormat = "mp3"
         
     | 
| 28 | 
         
             
                speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
         
     | 
| 29 | 
         
             
                seed: int = 42
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
             
                temperature: float = 0.3
         
     | 
| 32 | 
         
            +
                top_k: int = 20
         
     | 
| 33 | 
         
            +
                top_p: float = 0.7
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
             
                style: str = ""
         
     | 
| 
         | 
|
| 
         | 
|
| 36 | 
         
             
                batch_size: int = Field(1, ge=1, le=20, description="Batch size")
         
     | 
| 37 | 
         
             
                spliter_threshold: float = Field(
         
     | 
| 38 | 
         
             
                    100, ge=10, le=1024, description="Threshold for sentence spliter"
         
     | 
| 
         | 
|
| 40 | 
         
             
                # end of sentence
         
     | 
| 41 | 
         
             
                eos: str = "[uv_break]"
         
     | 
| 42 | 
         | 
| 43 | 
         
            +
                enhance: bool = False
         
     | 
| 44 | 
         
            +
                denoise: bool = False
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         | 
| 47 | 
         
             
            async def openai_speech_api(
         
     | 
| 48 | 
         
             
                request: AudioSpeechRequest = Body(
         
     | 
| 
         | 
|
| 54 | 
         
             
                voice = request.voice
         
     | 
| 55 | 
         
             
                style = request.style
         
     | 
| 56 | 
         
             
                eos = request.eos
         
     | 
| 57 | 
         
            +
                seed = request.seed
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
             
                response_format = request.response_format
         
     | 
| 60 | 
         
            +
                if not isinstance(response_format, AudioFormat) and isinstance(
         
     | 
| 61 | 
         
            +
                    response_format, str
         
     | 
| 62 | 
         
            +
                ):
         
     | 
| 63 | 
         
            +
                    response_format = AudioFormat(response_format)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
             
                batch_size = request.batch_size
         
     | 
| 66 | 
         
             
                spliter_threshold = request.spliter_threshold
         
     | 
| 67 | 
         
             
                speed = request.speed
         
     | 
| 
         | 
|
| 77 | 
         
             
                except:
         
     | 
| 78 | 
         
             
                    raise HTTPException(status_code=400, detail="Invalid style.")
         
     | 
| 79 | 
         | 
| 80 | 
         
            +
                ctx_params = api_utils.calc_spk_style(spk=voice, style=style)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 81 | 
         | 
| 82 | 
         
            +
                speaker = ctx_params.get("spk")
         
     | 
| 83 | 
         
            +
                if not isinstance(speaker, Speaker):
         
     | 
| 84 | 
         
            +
                    raise HTTPException(status_code=400, detail="Invalid voice.")
         
     | 
| 85 | 
         | 
| 86 | 
         
            +
                tts_config = ChatTTSConfig(
         
     | 
| 87 | 
         
            +
                    style=style,
         
     | 
| 88 | 
         
            +
                    temperature=request.temperature,
         
     | 
| 89 | 
         
            +
                    top_k=request.top_k,
         
     | 
| 90 | 
         
            +
                    top_p=request.top_p,
         
     | 
| 91 | 
         
            +
                )
         
     | 
| 92 | 
         
            +
                infer_config = InferConfig(
         
     | 
| 93 | 
         
            +
                    batch_size=batch_size,
         
     | 
| 94 | 
         
            +
                    spliter_threshold=spliter_threshold,
         
     | 
| 95 | 
         
            +
                    eos=eos,
         
     | 
| 96 | 
         
            +
                    seed=seed,
         
     | 
| 97 | 
         
            +
                )
         
     | 
| 98 | 
         
            +
                adjust_config = AdjustConfig(speaking_rate=speed)
         
     | 
| 99 | 
         
            +
                enhancer_config = EnhancerConfig(
         
     | 
| 100 | 
         
            +
                    enabled=request.enhance or request.denoise or False,
         
     | 
| 101 | 
         
            +
                    lambd=0.9 if request.denoise else 0.1,
         
     | 
| 102 | 
         
            +
                )
         
     | 
| 103 | 
         
            +
                try:
         
     | 
| 104 | 
         
            +
                    handler = TTSHandler(
         
     | 
| 105 | 
         
            +
                        text_content=input_text,
         
     | 
| 106 | 
         
            +
                        spk=speaker,
         
     | 
| 107 | 
         
            +
                        tts_config=tts_config,
         
     | 
| 108 | 
         
            +
                        infer_config=infer_config,
         
     | 
| 109 | 
         
            +
                        adjust_config=adjust_config,
         
     | 
| 110 | 
         
            +
                        enhancer_config=enhancer_config,
         
     | 
| 111 | 
         
            +
                    )
         
     | 
| 112 | 
         | 
| 113 | 
         
            +
                    buffer = handler.enqueue_to_buffer(response_format)
         
     | 
| 
         | 
|
| 
         | 
|
| 114 | 
         | 
| 115 | 
         
            +
                    mime_type = f"audio/{response_format.value}"
         
     | 
| 116 | 
         
            +
                    if response_format == AudioFormat.mp3:
         
     | 
| 117 | 
         
            +
                        mime_type = "audio/mpeg"
         
     | 
| 118 | 
         
            +
                    return StreamingResponse(buffer, media_type=mime_type)
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                except Exception as e:
         
     | 
| 121 | 
         
             
                    import logging
         
     | 
| 
         | 
|
| 152 | 
         
             
            def setup(app: APIManager):
         
     | 
| 153 | 
         
             
                app.post(
         
     | 
| 154 | 
         
             
                    "/v1/audio/speech",
         
     | 
| 
         | 
|
| 155 | 
         
             
                    description="""
         
     | 
| 156 | 
         
             
            openai api document: 
         
     | 
| 157 | 
         
             
            [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
         
     | 
    	
        modules/api/impl/refiner_api.py
    CHANGED
    
    | 
         @@ -31,6 +31,7 @@ async def refiner_prompt_post(request: RefineTextRequest): 
     | 
|
| 31 | 
         
             
                    text = request.text
         
     | 
| 32 | 
         
             
                    if request.normalize:
         
     | 
| 33 | 
         
             
                        text = text_normalize(request.text)
         
     | 
| 
         | 
|
| 34 | 
         
             
                    refined_text = refiner.refine_text(
         
     | 
| 35 | 
         
             
                        text=text,
         
     | 
| 36 | 
         
             
                        prompt=request.prompt,
         
     | 
| 
         | 
|
| 31 | 
         
             
                    text = request.text
         
     | 
| 32 | 
         
             
                    if request.normalize:
         
     | 
| 33 | 
         
             
                        text = text_normalize(request.text)
         
     | 
| 34 | 
         
            +
                    # TODO 其实这里可以做 spliter 和 batch 处理
         
     | 
| 35 | 
         
             
                    refined_text = refiner.refine_text(
         
     | 
| 36 | 
         
             
                        text=text,
         
     | 
| 37 | 
         
             
                        prompt=request.prompt,
         
     | 
    	
        modules/api/impl/ssml_api.py
    CHANGED
    
    | 
         @@ -1,27 +1,22 @@ 
     | 
|
| 1 | 
         
             
            from fastapi import HTTPException, Body
         
     | 
| 2 | 
         
             
            from fastapi.responses import StreamingResponse
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            import io
         
     | 
| 5 | 
         
             
            from pydantic import BaseModel
         
     | 
| 6 | 
         
             
            from fastapi.responses import FileResponse
         
     | 
| 7 | 
         | 
| 8 | 
         | 
| 9 | 
         
            -
            from modules. 
     | 
| 10 | 
         
            -
            from modules. 
     | 
| 11 | 
         
            -
            from modules. 
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
                combine_audio_segments,
         
     | 
| 14 | 
         
            -
            )
         
     | 
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
            -
            from modules.api import utils as api_utils
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 20 | 
         | 
| 21 | 
         | 
| 22 | 
         
             
            class SSMLRequest(BaseModel):
         
     | 
| 23 | 
         
             
                ssml: str
         
     | 
| 24 | 
         
            -
                format:  
     | 
| 25 | 
         | 
| 26 | 
         
             
                # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
         
     | 
| 27 | 
         
             
                batch_size: int = 4
         
     | 
| 
         @@ -31,6 +26,9 @@ class SSMLRequest(BaseModel): 
     | 
|
| 31 | 
         | 
| 32 | 
         
             
                spliter_thr: int = 100
         
     | 
| 33 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 34 | 
         | 
| 35 | 
         
             
            async def synthesize_ssml_api(
         
     | 
| 36 | 
         
             
                request: SSMLRequest = Body(
         
     | 
| 
         @@ -43,6 +41,8 @@ async def synthesize_ssml_api( 
     | 
|
| 43 | 
         
             
                    batch_size = request.batch_size
         
     | 
| 44 | 
         
             
                    eos = request.eos
         
     | 
| 45 | 
         
             
                    spliter_thr = request.spliter_thr
         
     | 
| 
         | 
|
| 
         | 
|
| 46 | 
         | 
| 47 | 
         
             
                    if batch_size < 1:
         
     | 
| 48 | 
         
             
                        raise HTTPException(
         
     | 
| 
         @@ -62,22 +62,27 @@ async def synthesize_ssml_api( 
     | 
|
| 62 | 
         
             
                            status_code=400, detail="Format must be 'mp3' or 'wav'."
         
     | 
| 63 | 
         
             
                        )
         
     | 
| 64 | 
         | 
| 65 | 
         
            -
                     
     | 
| 66 | 
         
            -
             
     | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
            -
                         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
                    synthesize = SynthesizeSegments(
         
     | 
| 71 | 
         
            -
                        batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
         
     | 
| 72 | 
         
             
                    )
         
     | 
| 73 | 
         
            -
                     
     | 
| 74 | 
         
            -
                     
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                     
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
                         
     | 
| 80 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 81 | 
         | 
| 82 | 
         
             
                except Exception as e:
         
     | 
| 83 | 
         
             
                    import logging
         
     | 
| 
         | 
|
| 1 | 
         
             
            from fastapi import HTTPException, Body
         
     | 
| 2 | 
         
             
            from fastapi.responses import StreamingResponse
         
     | 
| 3 | 
         | 
| 
         | 
|
| 4 | 
         
             
            from pydantic import BaseModel
         
     | 
| 5 | 
         
             
            from fastapi.responses import FileResponse
         
     | 
| 6 | 
         | 
| 7 | 
         | 
| 8 | 
         
            +
            from modules.api.impl.handler.SSMLHandler import SSMLHandler
         
     | 
| 9 | 
         
            +
            from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
         
     | 
| 10 | 
         
            +
            from modules.api.impl.model.chattts_model import InferConfig
         
     | 
| 11 | 
         
            +
            from modules.api.impl.model.enhancer_model import EnhancerConfig
         
     | 
| 
         | 
|
| 
         | 
|
| 12 | 
         | 
| 13 | 
         | 
| 
         | 
|
| 
         | 
|
| 14 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
             
            class SSMLRequest(BaseModel):
         
     | 
| 18 | 
         
             
                ssml: str
         
     | 
| 19 | 
         
            +
                format: AudioFormat = "mp3"
         
     | 
| 20 | 
         | 
| 21 | 
         
             
                # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
         
     | 
| 22 | 
         
             
                batch_size: int = 4
         
     | 
| 
         | 
|
| 26 | 
         | 
| 27 | 
         
             
                spliter_thr: int = 100
         
     | 
| 28 | 
         | 
| 29 | 
         
            +
                enhancer: EnhancerConfig = EnhancerConfig()
         
     | 
| 30 | 
         
            +
                adjuster: AdjustConfig = AdjustConfig()
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         | 
| 33 | 
         
             
            async def synthesize_ssml_api(
         
     | 
| 34 | 
         
             
                request: SSMLRequest = Body(
         
     | 
| 
         | 
|
| 41 | 
         
             
                    batch_size = request.batch_size
         
     | 
| 42 | 
         
             
                    eos = request.eos
         
     | 
| 43 | 
         
             
                    spliter_thr = request.spliter_thr
         
     | 
| 44 | 
         
            +
                    enhancer = request.enhancer
         
     | 
| 45 | 
         
            +
                    adjuster = request.adjuster
         
     | 
| 46 | 
         | 
| 47 | 
         
             
                    if batch_size < 1:
         
     | 
| 48 | 
         
             
                        raise HTTPException(
         
     | 
| 
         | 
|
| 62 | 
         
             
                            status_code=400, detail="Format must be 'mp3' or 'wav'."
         
     | 
| 63 | 
         
             
                        )
         
     | 
| 64 | 
         | 
| 65 | 
         
            +
                    infer_config = InferConfig(
         
     | 
| 66 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 67 | 
         
            +
                        spliter_threshold=spliter_thr,
         
     | 
| 68 | 
         
            +
                        eos=eos,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 69 | 
         
             
                    )
         
     | 
| 70 | 
         
            +
                    adjust_config = adjuster
         
     | 
| 71 | 
         
            +
                    enhancer_config = enhancer
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    handler = SSMLHandler(
         
     | 
| 74 | 
         
            +
                        ssml_content=ssml,
         
     | 
| 75 | 
         
            +
                        infer_config=infer_config,
         
     | 
| 76 | 
         
            +
                        adjust_config=adjust_config,
         
     | 
| 77 | 
         
            +
                        enhancer_config=enhancer_config,
         
     | 
| 78 | 
         
            +
                    )
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    buffer = handler.enqueue_to_buffer(format=request.format)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    mime_type = f"audio/{format}"
         
     | 
| 83 | 
         
            +
                    if format == AudioFormat.mp3:
         
     | 
| 84 | 
         
            +
                        mime_type = "audio/mpeg"
         
     | 
| 85 | 
         
            +
                    return StreamingResponse(buffer, media_type=mime_type)
         
     | 
| 86 | 
         | 
| 87 | 
         
             
                except Exception as e:
         
     | 
| 88 | 
         
             
                    import logging
         
     | 
    	
        modules/api/impl/tts_api.py
    CHANGED
    
    | 
         @@ -1,17 +1,18 @@ 
     | 
|
| 1 | 
         
             
            from fastapi import Depends, HTTPException, Query
         
     | 
| 2 | 
         
             
            from fastapi.responses import StreamingResponse
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            import io
         
     | 
| 5 | 
         
             
            from pydantic import BaseModel
         
     | 
| 6 | 
         
            -
            import soundfile as sf
         
     | 
| 7 | 
         
             
            from fastapi.responses import FileResponse
         
     | 
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
            -
            from modules. 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         | 
| 12 | 
         
             
            from modules.api import utils as api_utils
         
     | 
| 13 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 14 | 
         
            -
            from modules. 
     | 
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
             
            class TTSParams(BaseModel):
         
     | 
| 
         @@ -23,10 +24,10 @@ class TTSParams(BaseModel): 
     | 
|
| 23 | 
         
             
                temperature: float = Query(
         
     | 
| 24 | 
         
             
                    0.3, description="Temperature for sampling (may be overridden by style or spk)"
         
     | 
| 25 | 
         
             
                )
         
     | 
| 26 | 
         
            -
                 
     | 
| 27 | 
         
             
                    0.5, description="Top P for sampling (may be overridden by style or spk)"
         
     | 
| 28 | 
         
             
                )
         
     | 
| 29 | 
         
            -
                 
     | 
| 30 | 
         
             
                    20, description="Top K for sampling (may be overridden by style or spk)"
         
     | 
| 31 | 
         
             
                )
         
     | 
| 32 | 
         
             
                seed: int = Query(
         
     | 
| 
         @@ -38,7 +39,14 @@ class TTSParams(BaseModel): 
     | 
|
| 38 | 
         
             
                prefix: str = Query("", description="Text prefix for inference")
         
     | 
| 39 | 
         
             
                bs: str = Query("8", description="Batch size for inference")
         
     | 
| 40 | 
         
             
                thr: str = Query("100", description="Threshold for sentence spliter")
         
     | 
| 41 | 
         
            -
                eos: str = Query("", description="End of sentence str")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 42 | 
         | 
| 43 | 
         | 
| 44 | 
         
             
            async def synthesize_tts(params: TTSParams = Depends()):
         
     | 
| 
         @@ -55,18 +63,18 @@ async def synthesize_tts(params: TTSParams = Depends()): 
     | 
|
| 55 | 
         
             
                            status_code=422, detail="Temperature must be between 0 and 1"
         
     | 
| 56 | 
         
             
                        )
         
     | 
| 57 | 
         | 
| 58 | 
         
            -
                    # Validate  
     | 
| 59 | 
         
            -
                    if not (0 <= params. 
     | 
| 60 | 
         
            -
                        raise HTTPException(status_code=422, detail=" 
     | 
| 61 | 
         | 
| 62 | 
         
            -
                    # Validate  
     | 
| 63 | 
         
            -
                    if params. 
     | 
| 64 | 
         
             
                        raise HTTPException(
         
     | 
| 65 | 
         
            -
                            status_code=422, detail=" 
     | 
| 66 | 
         
             
                        )
         
     | 
| 67 | 
         
            -
                    if params. 
     | 
| 68 | 
         
             
                        raise HTTPException(
         
     | 
| 69 | 
         
            -
                            status_code=422, detail=" 
     | 
| 70 | 
         
             
                        )
         
     | 
| 71 | 
         | 
| 72 | 
         
             
                    # Validate format
         
     | 
| 
         @@ -76,11 +84,13 @@ async def synthesize_tts(params: TTSParams = Depends()): 
     | 
|
| 76 | 
         
             
                            detail="Invalid format. Supported formats are mp3 and wav",
         
     | 
| 77 | 
         
             
                        )
         
     | 
| 78 | 
         | 
| 79 | 
         
            -
                    text = text_normalize(params.text, is_end=False)
         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
             
                    calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
         
     | 
| 82 | 
         | 
| 83 | 
         
             
                    spk = calc_params.get("spk", params.spk)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 84 | 
         
             
                    seed = params.seed or calc_params.get("seed", params.seed)
         
     | 
| 85 | 
         
             
                    temperature = params.temperature or calc_params.get(
         
     | 
| 86 | 
         
             
                        "temperature", params.temperature
         
     | 
| 
         @@ -93,29 +103,46 @@ async def synthesize_tts(params: TTSParams = Depends()): 
     | 
|
| 93 | 
         
             
                    batch_size = int(params.bs)
         
     | 
| 94 | 
         
             
                    threshold = int(params.thr)
         
     | 
| 95 | 
         | 
| 96 | 
         
            -
                     
     | 
| 97 | 
         
            -
                         
     | 
| 98 | 
         
             
                        temperature=temperature,
         
     | 
| 99 | 
         
            -
                         
     | 
| 100 | 
         
            -
                         
     | 
| 101 | 
         
            -
                         
     | 
| 102 | 
         
            -
                        infer_seed=seed,
         
     | 
| 103 | 
         
             
                        prompt1=prompt1,
         
     | 
| 104 | 
         
             
                        prompt2=prompt2,
         
     | 
| 105 | 
         
            -
             
     | 
| 
         | 
|
| 106 | 
         
             
                        batch_size=batch_size,
         
     | 
| 107 | 
         
             
                        spliter_threshold=threshold,
         
     | 
| 108 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 109 | 
         
             
                    )
         
     | 
| 110 | 
         | 
| 111 | 
         
            -
                     
     | 
| 112 | 
         
            -
             
     | 
| 113 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 114 | 
         | 
| 115 | 
         
            -
                     
     | 
| 116 | 
         
            -
                        buffer = api_utils.wav_to_mp3(buffer)
         
     | 
| 117 | 
         | 
| 118 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 119 | 
         | 
| 120 | 
         
             
                except Exception as e:
         
     | 
| 121 | 
         
             
                    import logging
         
     | 
| 
         | 
|
| 1 | 
         
             
            from fastapi import Depends, HTTPException, Query
         
     | 
| 2 | 
         
             
            from fastapi.responses import StreamingResponse
         
     | 
| 3 | 
         | 
| 
         | 
|
| 4 | 
         
             
            from pydantic import BaseModel
         
     | 
| 
         | 
|
| 5 | 
         
             
            from fastapi.responses import FileResponse
         
     | 
| 6 | 
         | 
| 7 | 
         | 
| 8 | 
         
            +
            from modules.api.impl.handler.TTSHandler import TTSHandler
         
     | 
| 9 | 
         
            +
            from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
         
     | 
| 10 | 
         
            +
            from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
         
     | 
| 11 | 
         
            +
            from modules.api.impl.model.enhancer_model import EnhancerConfig
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            from modules.api import utils as api_utils
         
     | 
| 14 | 
         
             
            from modules.api.Api import APIManager
         
     | 
| 15 | 
         
            +
            from modules.speaker import Speaker
         
     | 
| 16 | 
         | 
| 17 | 
         | 
| 18 | 
         
             
            class TTSParams(BaseModel):
         
     | 
| 
         | 
|
| 24 | 
         
             
                temperature: float = Query(
         
     | 
| 25 | 
         
             
                    0.3, description="Temperature for sampling (may be overridden by style or spk)"
         
     | 
| 26 | 
         
             
                )
         
     | 
| 27 | 
         
            +
                top_p: float = Query(
         
     | 
| 28 | 
         
             
                    0.5, description="Top P for sampling (may be overridden by style or spk)"
         
     | 
| 29 | 
         
             
                )
         
     | 
| 30 | 
         
            +
                top_k: int = Query(
         
     | 
| 31 | 
         
             
                    20, description="Top K for sampling (may be overridden by style or spk)"
         
     | 
| 32 | 
         
             
                )
         
     | 
| 33 | 
         
             
                seed: int = Query(
         
     | 
| 
         | 
|
| 39 | 
         
             
                prefix: str = Query("", description="Text prefix for inference")
         
     | 
| 40 | 
         
             
                bs: str = Query("8", description="Batch size for inference")
         
     | 
| 41 | 
         
             
                thr: str = Query("100", description="Threshold for sentence spliter")
         
     | 
| 42 | 
         
            +
                eos: str = Query("[uv_break]", description="End of sentence str")
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                enhance: bool = Query(False, description="Enable enhancer")
         
     | 
| 45 | 
         
            +
                denoise: bool = Query(False, description="Enable denoiser")
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                speed: float = Query(1.0, description="Speed of the audio")
         
     | 
| 48 | 
         
            +
                pitch: float = Query(0, description="Pitch of the audio")
         
     | 
| 49 | 
         
            +
                volume_gain: float = Query(0, description="Volume gain of the audio")
         
     | 
| 50 | 
         | 
| 51 | 
         | 
| 52 | 
         
             
            async def synthesize_tts(params: TTSParams = Depends()):
         
     | 
| 
         | 
|
| 63 | 
         
             
                            status_code=422, detail="Temperature must be between 0 and 1"
         
     | 
| 64 | 
         
             
                        )
         
     | 
| 65 | 
         | 
| 66 | 
         
            +
                    # Validate top_p
         
     | 
| 67 | 
         
            +
                    if not (0 <= params.top_p <= 1):
         
     | 
| 68 | 
         
            +
                        raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")
         
     | 
| 69 | 
         | 
| 70 | 
         
            +
                    # Validate top_k
         
     | 
| 71 | 
         
            +
                    if params.top_k <= 0:
         
     | 
| 72 | 
         
             
                        raise HTTPException(
         
     | 
| 73 | 
         
            +
                            status_code=422, detail="top_k must be a positive integer"
         
     | 
| 74 | 
         
             
                        )
         
     | 
| 75 | 
         
            +
                    if params.top_k > 100:
         
     | 
| 76 | 
         
             
                        raise HTTPException(
         
     | 
| 77 | 
         
            +
                            status_code=422, detail="top_k must be less than or equal to 100"
         
     | 
| 78 | 
         
             
                        )
         
     | 
| 79 | 
         | 
| 80 | 
         
             
                    # Validate format
         
     | 
| 
         | 
|
| 84 | 
         
             
                            detail="Invalid format. Supported formats are mp3 and wav",
         
     | 
| 85 | 
         
             
                        )
         
     | 
| 86 | 
         | 
| 
         | 
|
| 
         | 
|
| 87 | 
         
             
                    calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
         
     | 
| 88 | 
         | 
| 89 | 
         
             
                    spk = calc_params.get("spk", params.spk)
         
     | 
| 90 | 
         
            +
                    if not isinstance(spk, Speaker):
         
     | 
| 91 | 
         
            +
                        raise HTTPException(status_code=422, detail="Invalid speaker")
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    style = calc_params.get("style", params.style)
         
     | 
| 94 | 
         
             
                    seed = params.seed or calc_params.get("seed", params.seed)
         
     | 
| 95 | 
         
             
                    temperature = params.temperature or calc_params.get(
         
     | 
| 96 | 
         
             
                        "temperature", params.temperature
         
     | 
| 
         | 
|
| 103 | 
         
             
                    batch_size = int(params.bs)
         
     | 
| 104 | 
         
             
                    threshold = int(params.thr)
         
     | 
| 105 | 
         | 
| 106 | 
         
            +
                    tts_config = ChatTTSConfig(
         
     | 
| 107 | 
         
            +
                        style=style,
         
     | 
| 108 | 
         
             
                        temperature=temperature,
         
     | 
| 109 | 
         
            +
                        top_k=params.top_k,
         
     | 
| 110 | 
         
            +
                        top_p=params.top_p,
         
     | 
| 111 | 
         
            +
                        prefix=prefix,
         
     | 
| 
         | 
|
| 112 | 
         
             
                        prompt1=prompt1,
         
     | 
| 113 | 
         
             
                        prompt2=prompt2,
         
     | 
| 114 | 
         
            +
                    )
         
     | 
| 115 | 
         
            +
                    infer_config = InferConfig(
         
     | 
| 116 | 
         
             
                        batch_size=batch_size,
         
     | 
| 117 | 
         
             
                        spliter_threshold=threshold,
         
     | 
| 118 | 
         
            +
                        eos=eos,
         
     | 
| 119 | 
         
            +
                        seed=seed,
         
     | 
| 120 | 
         
            +
                    )
         
     | 
| 121 | 
         
            +
                    adjust_config = AdjustConfig(
         
     | 
| 122 | 
         
            +
                        pitch=params.pitch,
         
     | 
| 123 | 
         
            +
                        speed_rate=params.speed,
         
     | 
| 124 | 
         
            +
                        volume_gain_db=params.volume_gain,
         
     | 
| 125 | 
         
            +
                    )
         
     | 
| 126 | 
         
            +
                    enhancer_config = EnhancerConfig(
         
     | 
| 127 | 
         
            +
                        enabled=params.enhance or params.denoise or False,
         
     | 
| 128 | 
         
            +
                        lambd=0.9 if params.denoise else 0.1,
         
     | 
| 129 | 
         
             
                    )
         
     | 
| 130 | 
         | 
| 131 | 
         
            +
                    handler = TTSHandler(
         
     | 
| 132 | 
         
            +
                        text_content=params.text,
         
     | 
| 133 | 
         
            +
                        spk=spk,
         
     | 
| 134 | 
         
            +
                        tts_config=tts_config,
         
     | 
| 135 | 
         
            +
                        infer_config=infer_config,
         
     | 
| 136 | 
         
            +
                        adjust_config=adjust_config,
         
     | 
| 137 | 
         
            +
                        enhancer_config=enhancer_config,
         
     | 
| 138 | 
         
            +
                    )
         
     | 
| 139 | 
         | 
| 140 | 
         
            +
                    buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
         
     | 
| 
         | 
|
| 141 | 
         | 
| 142 | 
         
            +
                    media_type = f"audio/{params.format}"
         
     | 
| 143 | 
         
            +
                    if params.format == "mp3":
         
     | 
| 144 | 
         
            +
                        media_type = "audio/mpeg"
         
     | 
| 145 | 
         
            +
                    return StreamingResponse(buffer, media_type=media_type)
         
     | 
| 146 | 
         | 
| 147 | 
         
             
                except Exception as e:
         
     | 
| 148 | 
         
             
                    import logging
         
     | 
    	
        modules/api/impl/xtts_v2_api.py
    CHANGED
    
    | 
         @@ -30,8 +30,19 @@ class XTTS_V2_Settings: 
     | 
|
| 30 | 
         
             
                    self.top_k = 20
         
     | 
| 31 | 
         
             
                    self.enable_text_splitting = True
         
     | 
| 32 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 33 | 
         | 
| 34 | 
         
             
            class TTSSettingsRequest(BaseModel):
         
     | 
| 
         | 
|
| 35 | 
         
             
                stream_chunk_size: int
         
     | 
| 36 | 
         
             
                temperature: float
         
     | 
| 37 | 
         
             
                speed: float
         
     | 
| 
         @@ -41,6 +52,15 @@ class TTSSettingsRequest(BaseModel): 
     | 
|
| 41 | 
         
             
                top_k: int
         
     | 
| 42 | 
         
             
                enable_text_splitting: bool
         
     | 
| 43 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 44 | 
         | 
| 45 | 
         
             
            class SynthesisRequest(BaseModel):
         
     | 
| 46 | 
         
             
                text: str
         
     | 
| 
         @@ -79,17 +99,22 @@ def setup(app: APIManager): 
     | 
|
| 79 | 
         | 
| 80 | 
         
             
                    text = text_normalize(text, is_end=True)
         
     | 
| 81 | 
         
             
                    sample_rate, audio_data = synthesize_audio(
         
     | 
| 82 | 
         
            -
                         
     | 
| 83 | 
         
            -
                        temperature=XTTSV2.temperature,
         
     | 
| 84 | 
         
             
                        # length_penalty=XTTSV2.length_penalty,
         
     | 
| 85 | 
         
             
                        # repetition_penalty=XTTSV2.repetition_penalty,
         
     | 
| 
         | 
|
| 
         | 
|
| 86 | 
         
             
                        top_P=XTTSV2.top_p,
         
     | 
| 87 | 
         
             
                        top_K=XTTSV2.top_k,
         
     | 
| 88 | 
         
             
                        spk=spk,
         
     | 
| 89 | 
         
            -
                        spliter_threshold=XTTSV2. 
     | 
| 90 | 
         
            -
                         
     | 
| 91 | 
         
            -
                         
     | 
| 92 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 93 | 
         
             
                    )
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                    if XTTSV2.speed:
         
     | 
| 
         @@ -145,6 +170,8 @@ def setup(app: APIManager): 
     | 
|
| 145 | 
         
             
                            )
         
     | 
| 146 | 
         | 
| 147 | 
         
             
                        XTTSV2.stream_chunk_size = request.stream_chunk_size
         
     | 
| 
         | 
|
| 
         | 
|
| 148 | 
         
             
                        XTTSV2.temperature = request.temperature
         
     | 
| 149 | 
         
             
                        XTTSV2.speed = request.speed
         
     | 
| 150 | 
         
             
                        XTTSV2.length_penalty = request.length_penalty
         
     | 
| 
         @@ -152,6 +179,25 @@ def setup(app: APIManager): 
     | 
|
| 152 | 
         
             
                        XTTSV2.top_p = request.top_p
         
     | 
| 153 | 
         
             
                        XTTSV2.top_k = request.top_k
         
     | 
| 154 | 
         
             
                        XTTSV2.enable_text_splitting = request.enable_text_splitting
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 155 | 
         
             
                        return {"message": "Settings successfully applied"}
         
     | 
| 156 | 
         
             
                    except Exception as e:
         
     | 
| 157 | 
         
             
                        if isinstance(e, HTTPException):
         
     | 
| 
         | 
|
| 30 | 
         
             
                    self.top_k = 20
         
     | 
| 31 | 
         
             
                    self.enable_text_splitting = True
         
     | 
| 32 | 
         | 
| 33 | 
         
            +
                    # 下面是额外配置 xtts_v2 中不包含的,但是本系统需要的
         
     | 
| 34 | 
         
            +
                    self.batch_size = 4
         
     | 
| 35 | 
         
            +
                    self.eos = "[uv_break]"
         
     | 
| 36 | 
         
            +
                    self.infer_seed = 42
         
     | 
| 37 | 
         
            +
                    self.use_decoder = True
         
     | 
| 38 | 
         
            +
                    self.prompt1 = ""
         
     | 
| 39 | 
         
            +
                    self.prompt2 = ""
         
     | 
| 40 | 
         
            +
                    self.prefix = ""
         
     | 
| 41 | 
         
            +
                    self.spliter_threshold = 100
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         | 
| 44 | 
         
             
            class TTSSettingsRequest(BaseModel):
         
     | 
| 45 | 
         
            +
                # 这个 stream_chunk 现在当作 spliter_threshold 用
         
     | 
| 46 | 
         
             
                stream_chunk_size: int
         
     | 
| 47 | 
         
             
                temperature: float
         
     | 
| 48 | 
         
             
                speed: float
         
     | 
| 
         | 
|
| 52 | 
         
             
                top_k: int
         
     | 
| 53 | 
         
             
                enable_text_splitting: bool
         
     | 
| 54 | 
         | 
| 55 | 
         
            +
                batch_size: int = None
         
     | 
| 56 | 
         
            +
                eos: str = None
         
     | 
| 57 | 
         
            +
                infer_seed: int = None
         
     | 
| 58 | 
         
            +
                use_decoder: bool = None
         
     | 
| 59 | 
         
            +
                prompt1: str = None
         
     | 
| 60 | 
         
            +
                prompt2: str = None
         
     | 
| 61 | 
         
            +
                prefix: str = None
         
     | 
| 62 | 
         
            +
                spliter_threshold: int = None
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         | 
| 65 | 
         
             
            class SynthesisRequest(BaseModel):
         
     | 
| 66 | 
         
             
                text: str
         
     | 
| 
         | 
|
| 99 | 
         | 
| 100 | 
         
             
                    text = text_normalize(text, is_end=True)
         
     | 
| 101 | 
         
             
                    sample_rate, audio_data = synthesize_audio(
         
     | 
| 102 | 
         
            +
                        # TODO: 这两个参数现在用不着...但是其实gpt是可以用的
         
     | 
| 
         | 
|
| 103 | 
         
             
                        # length_penalty=XTTSV2.length_penalty,
         
     | 
| 104 | 
         
             
                        # repetition_penalty=XTTSV2.repetition_penalty,
         
     | 
| 105 | 
         
            +
                        text=text,
         
     | 
| 106 | 
         
            +
                        temperature=XTTSV2.temperature,
         
     | 
| 107 | 
         
             
                        top_P=XTTSV2.top_p,
         
     | 
| 108 | 
         
             
                        top_K=XTTSV2.top_k,
         
     | 
| 109 | 
         
             
                        spk=spk,
         
     | 
| 110 | 
         
            +
                        spliter_threshold=XTTSV2.spliter_threshold,
         
     | 
| 111 | 
         
            +
                        batch_size=XTTSV2.batch_size,
         
     | 
| 112 | 
         
            +
                        end_of_sentence=XTTSV2.eos,
         
     | 
| 113 | 
         
            +
                        infer_seed=XTTSV2.infer_seed,
         
     | 
| 114 | 
         
            +
                        use_decoder=XTTSV2.use_decoder,
         
     | 
| 115 | 
         
            +
                        prompt1=XTTSV2.prompt1,
         
     | 
| 116 | 
         
            +
                        prompt2=XTTSV2.prompt2,
         
     | 
| 117 | 
         
            +
                        prefix=XTTSV2.prefix,
         
     | 
| 118 | 
         
             
                    )
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                    if XTTSV2.speed:
         
     | 
| 
         | 
|
| 170 | 
         
             
                            )
         
     | 
| 171 | 
         | 
| 172 | 
         
             
                        XTTSV2.stream_chunk_size = request.stream_chunk_size
         
     | 
| 173 | 
         
            +
                        XTTSV2.spliter_threshold = request.stream_chunk_size
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
             
                        XTTSV2.temperature = request.temperature
         
     | 
| 176 | 
         
             
                        XTTSV2.speed = request.speed
         
     | 
| 177 | 
         
             
                        XTTSV2.length_penalty = request.length_penalty
         
     | 
| 
         | 
|
| 179 | 
         
             
                        XTTSV2.top_p = request.top_p
         
     | 
| 180 | 
         
             
                        XTTSV2.top_k = request.top_k
         
     | 
| 181 | 
         
             
                        XTTSV2.enable_text_splitting = request.enable_text_splitting
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                        # TODO: checker
         
     | 
| 184 | 
         
            +
                        if request.batch_size:
         
     | 
| 185 | 
         
            +
                            XTTSV2.batch_size = request.batch_size
         
     | 
| 186 | 
         
            +
                        if request.eos:
         
     | 
| 187 | 
         
            +
                            XTTSV2.eos = request.eos
         
     | 
| 188 | 
         
            +
                        if request.infer_seed:
         
     | 
| 189 | 
         
            +
                            XTTSV2.infer_seed = request.infer_seed
         
     | 
| 190 | 
         
            +
                        if request.use_decoder:
         
     | 
| 191 | 
         
            +
                            XTTSV2.use_decoder = request.use_decoder
         
     | 
| 192 | 
         
            +
                        if request.prompt1:
         
     | 
| 193 | 
         
            +
                            XTTSV2.prompt1 = request.prompt1
         
     | 
| 194 | 
         
            +
                        if request.prompt2:
         
     | 
| 195 | 
         
            +
                            XTTSV2.prompt2 = request.prompt2
         
     | 
| 196 | 
         
            +
                        if request.prefix:
         
     | 
| 197 | 
         
            +
                            XTTSV2.prefix = request.prefix
         
     | 
| 198 | 
         
            +
                        if request.spliter_threshold:
         
     | 
| 199 | 
         
            +
                            XTTSV2.spliter_threshold = request.spliter_threshold
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
             
                        return {"message": "Settings successfully applied"}
         
     | 
| 202 | 
         
             
                    except Exception as e:
         
     | 
| 203 | 
         
             
                        if isinstance(e, HTTPException):
         
     | 
    	
        modules/api/utils.py
    CHANGED
    
    | 
         @@ -1,9 +1,8 @@ 
     | 
|
| 1 | 
         
             
            from pydantic import BaseModel
         
     | 
| 2 | 
         
             
            from typing import Any, Union
         
     | 
| 3 | 
         | 
| 4 | 
         
            -
            import torch
         
     | 
| 5 | 
         | 
| 6 | 
         
            -
            from modules.speaker import  
     | 
| 7 | 
         | 
| 8 | 
         | 
| 9 | 
         
             
            from modules.data import styles_mgr
         
     | 
| 
         @@ -13,18 +12,10 @@ from pydub import AudioSegment 
     | 
|
| 13 | 
         
             
            from modules.ssml import merge_prompt
         
     | 
| 14 | 
         | 
| 15 | 
         | 
| 16 | 
         
            -
            from enum import Enum
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
             
            class ParamsTypeError(Exception):
         
     | 
| 20 | 
         
             
                pass
         
     | 
| 21 | 
         | 
| 22 | 
         | 
| 23 | 
         
            -
            class AudioFormat(str, Enum):
         
     | 
| 24 | 
         
            -
                mp3 = "mp3"
         
     | 
| 25 | 
         
            -
                wav = "wav"
         
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
             
            class BaseResponse(BaseModel):
         
     | 
| 29 | 
         
             
                message: str
         
     | 
| 30 | 
         
             
                data: Any
         
     | 
| 
         @@ -35,7 +26,7 @@ def success_response(data: Any, message: str = "ok") -> BaseResponse: 
     | 
|
| 35 | 
         | 
| 36 | 
         | 
| 37 | 
         
             
            def wav_to_mp3(wav_data, bitrate="48k"):
         
     | 
| 38 | 
         
            -
                audio = AudioSegment.from_wav(
         
     | 
| 39 | 
         
             
                    wav_data,
         
     | 
| 40 | 
         
             
                )
         
     | 
| 41 | 
         
             
                return audio.export(format="mp3", bitrate=bitrate)
         
     | 
| 
         | 
|
| 1 | 
         
             
            from pydantic import BaseModel
         
     | 
| 2 | 
         
             
            from typing import Any, Union
         
     | 
| 3 | 
         | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
            +
            from modules.speaker import speaker_mgr
         
     | 
| 6 | 
         | 
| 7 | 
         | 
| 8 | 
         
             
            from modules.data import styles_mgr
         
     | 
| 
         | 
|
| 12 | 
         
             
            from modules.ssml import merge_prompt
         
     | 
| 13 | 
         | 
| 14 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 15 | 
         
             
            class ParamsTypeError(Exception):
         
     | 
| 16 | 
         
             
                pass
         
     | 
| 17 | 
         | 
| 18 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 19 | 
         
             
            class BaseResponse(BaseModel):
         
     | 
| 20 | 
         
             
                message: str
         
     | 
| 21 | 
         
             
                data: Any
         
     | 
| 
         | 
|
| 26 | 
         | 
| 27 | 
         | 
| 28 | 
         
             
            def wav_to_mp3(wav_data, bitrate="48k"):
         
     | 
| 29 | 
         
            +
                audio: AudioSegment = AudioSegment.from_wav(
         
     | 
| 30 | 
         
             
                    wav_data,
         
     | 
| 31 | 
         
             
                )
         
     | 
| 32 | 
         
             
                return audio.export(format="mp3", bitrate=bitrate)
         
     | 
    	
        modules/devices/devices.py
    CHANGED
    
    | 
         @@ -127,6 +127,12 @@ def reset_device(): 
     | 
|
| 127 | 
         
             
                global dtype_gpt
         
     | 
| 128 | 
         
             
                global dtype_decoder
         
     | 
| 129 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         
             
                if not config.runtime_env_vars.no_half:
         
     | 
| 131 | 
         
             
                    dtype = torch.float16
         
     | 
| 132 | 
         
             
                    dtype_dvae = torch.float16
         
     | 
| 
         @@ -144,7 +150,7 @@ def reset_device(): 
     | 
|
| 144 | 
         | 
| 145 | 
         
             
                    logger.info("Using full precision: torch.float32")
         
     | 
| 146 | 
         | 
| 147 | 
         
            -
                if config.runtime_env_vars.use_cpu 
     | 
| 148 | 
         
             
                    device = cpu
         
     | 
| 149 | 
         
             
                else:
         
     | 
| 150 | 
         
             
                    device = get_optimal_device()
         
     | 
| 
         | 
|
| 127 | 
         
             
                global dtype_gpt
         
     | 
| 128 | 
         
             
                global dtype_decoder
         
     | 
| 129 | 
         | 
| 130 | 
         
            +
                if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
         
     | 
| 131 | 
         
            +
                    logger.warning(
         
     | 
| 132 | 
         
            +
                        "Cannot use half precision with CPU, using full precision instead"
         
     | 
| 133 | 
         
            +
                    )
         
     | 
| 134 | 
         
            +
                    config.runtime_env_vars.no_half = True
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
             
                if not config.runtime_env_vars.no_half:
         
     | 
| 137 | 
         
             
                    dtype = torch.float16
         
     | 
| 138 | 
         
             
                    dtype_dvae = torch.float16
         
     | 
| 
         | 
|
| 150 | 
         | 
| 151 | 
         
             
                    logger.info("Using full precision: torch.float32")
         
     | 
| 152 | 
         | 
| 153 | 
         
            +
                if "all" in config.runtime_env_vars.use_cpu:
         
     | 
| 154 | 
         
             
                    device = cpu
         
     | 
| 155 | 
         
             
                else:
         
     | 
| 156 | 
         
             
                    device = get_optimal_device()
         
     | 
    	
        modules/finetune/train_speaker.py
    CHANGED
    
    | 
         @@ -45,9 +45,10 @@ def train_speaker_embeddings( 
     | 
|
| 45 | 
         
             
                        )
         
     | 
| 46 | 
         
             
                        for speaker in dataset.speakers
         
     | 
| 47 | 
         
             
                    }
         
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 
         | 
|
| 51 | 
         | 
| 52 | 
         
             
                SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
         
     | 
| 53 | 
         
             
                AUDIO_EOS_TOKEN_ID = 0
         
     | 
| 
         @@ -166,13 +167,13 @@ def train_speaker_embeddings( 
     | 
|
| 166 | 
         
             
                            audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
         
     | 
| 167 | 
         
             
                        )
         
     | 
| 168 | 
         
             
                        loss = audio_loss
         
     | 
| 169 | 
         
            -
             
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
             
     | 
| 173 | 
         
            -
             
     | 
| 174 | 
         
            -
             
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         | 
| 177 | 
         
             
                        gpt_gen_mel_specs = decoder_decoder(
         
     | 
| 178 | 
         
             
                            audio_hidden_states[:, :-1].transpose(1, 2)
         
     | 
| 
         @@ -181,7 +182,12 @@ def train_speaker_embeddings( 
     | 
|
| 181 | 
         
             
                        loss += 0.01 * mse_loss
         
     | 
| 182 | 
         | 
| 183 | 
         
             
                        optimizer.zero_grad()
         
     | 
| 184 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 185 | 
         
             
                        torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
         
     | 
| 186 | 
         
             
                        optimizer.step()
         
     | 
| 187 | 
         
             
                        logger.meters["loss"].update(loss.item(), n=batch_size)
         
     | 
| 
         @@ -203,6 +209,7 @@ if __name__ == "__main__": 
     | 
|
| 203 | 
         
             
                from modules.speaker import Speaker
         
     | 
| 204 | 
         | 
| 205 | 
         
             
                config.runtime_env_vars.no_half = True
         
     | 
| 
         | 
|
| 206 | 
         
             
                devices.reset_device()
         
     | 
| 207 | 
         | 
| 208 | 
         
             
                parser = argparse.ArgumentParser()
         
     | 
| 
         | 
|
| 45 | 
         
             
                        )
         
     | 
| 46 | 
         
             
                        for speaker in dataset.speakers
         
     | 
| 47 | 
         
             
                    }
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                for speaker_embed in speaker_embeds.values():
         
     | 
| 50 | 
         
            +
                    std, mean = chat.pretrain_models["spk_stat"].chunk(2)
         
     | 
| 51 | 
         
            +
                    speaker_embed.data = speaker_embed.data * std + mean
         
     | 
| 52 | 
         | 
| 53 | 
         
             
                SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
         
     | 
| 54 | 
         
             
                AUDIO_EOS_TOKEN_ID = 0
         
     | 
| 
         | 
|
| 167 | 
         
             
                            audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
         
     | 
| 168 | 
         
             
                        )
         
     | 
| 169 | 
         
             
                        loss = audio_loss
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                        text_logits = gpt.head_text(text_hidden_states)
         
     | 
| 172 | 
         
            +
                        text_loss = loss_fn(
         
     | 
| 173 | 
         
            +
                            text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
         
     | 
| 174 | 
         
            +
                        )
         
     | 
| 175 | 
         
            +
                        loss += text_loss
         
     | 
| 176 | 
         
            +
                        logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
         
     | 
| 177 | 
         | 
| 178 | 
         
             
                        gpt_gen_mel_specs = decoder_decoder(
         
     | 
| 179 | 
         
             
                            audio_hidden_states[:, :-1].transpose(1, 2)
         
     | 
| 
         | 
|
| 182 | 
         
             
                        loss += 0.01 * mse_loss
         
     | 
| 183 | 
         | 
| 184 | 
         
             
                        optimizer.zero_grad()
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                        if train_text:
         
     | 
| 187 | 
         
            +
                            # just for test
         
     | 
| 188 | 
         
            +
                            text_loss.backward()
         
     | 
| 189 | 
         
            +
                        else:
         
     | 
| 190 | 
         
            +
                            loss.backward()
         
     | 
| 191 | 
         
             
                        torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
         
     | 
| 192 | 
         
             
                        optimizer.step()
         
     | 
| 193 | 
         
             
                        logger.meters["loss"].update(loss.item(), n=batch_size)
         
     | 
| 
         | 
|
| 209 | 
         
             
                from modules.speaker import Speaker
         
     | 
| 210 | 
         | 
| 211 | 
         
             
                config.runtime_env_vars.no_half = True
         
     | 
| 212 | 
         
            +
                config.runtime_env_vars.use_cpu = []
         
     | 
| 213 | 
         
             
                devices.reset_device()
         
     | 
| 214 | 
         | 
| 215 | 
         
             
                parser = argparse.ArgumentParser()
         
     | 
    	
        modules/prompts/news_oral_prompt.txt
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # 任务要求
         
     | 
| 2 | 
         
            +
            任务: 新闻稿口播化
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            你需要将一个新闻稿改写为口语化的口播文本
         
     | 
| 5 | 
         
            +
            同时,适当的添加一些 附语言 标签为文本增加多样性
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            目前可以使用的附语言标签如下:
         
     | 
| 8 | 
         
            +
            - `[laugh]`: 表示笑声
         
     | 
| 9 | 
         
            +
            - `[uv_break]`: 表示无声停顿
         
     | 
| 10 | 
         
            +
            - `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
         
     | 
| 11 | 
         
            +
            - `[lbreak]`: 表示一个长停顿一般表示段落结束
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # 输入
         
     | 
| 14 | 
         
            +
            {{USER_INPUT}}
         
     | 
    	
        modules/prompts/podcast_prompt.txt
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            TODO
         
     | 
    	
        modules/ssml_parser/SSMLParser.py
    CHANGED
    
    | 
         @@ -1,13 +1,10 @@ 
     | 
|
| 1 | 
         
             
            from lxml import etree
         
     | 
| 2 | 
         | 
| 3 | 
         | 
| 4 | 
         
            -
            from typing import  
     | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
            from modules.data import styles_mgr
         
     | 
| 8 | 
         
            -
            from modules.speaker import speaker_mgr
         
     | 
| 9 | 
         
             
            from box import Box
         
     | 
| 10 | 
         
            -
            import copy
         
     | 
| 11 | 
         | 
| 12 | 
         | 
| 13 | 
         
             
            class SSMLContext(Box):
         
     | 
| 
         | 
|
| 1 | 
         
             
            from lxml import etree
         
     | 
| 2 | 
         | 
| 3 | 
         | 
| 4 | 
         
            +
            from typing import List, Union
         
     | 
| 5 | 
         
             
            import logging
         
     | 
| 6 | 
         | 
| 
         | 
|
| 
         | 
|
| 7 | 
         
             
            from box import Box
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         | 
| 10 | 
         
             
            class SSMLContext(Box):
         
     | 
    	
        modules/webui/speaker/speaker_editor.py
    CHANGED
    
    | 
         @@ -25,7 +25,7 @@ def speaker_editor_ui(): 
     | 
|
| 25 | 
         
             
                    spk: Speaker = Speaker.from_file(spk_file)
         
     | 
| 26 | 
         
             
                    spk.name = name
         
     | 
| 27 | 
         
             
                    spk.gender = gender
         
     | 
| 28 | 
         
            -
                    spk. 
     | 
| 29 | 
         | 
| 30 | 
         
             
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
         
     | 
| 31 | 
         
             
                        torch.save(spk, tmp_file)
         
     | 
| 
         | 
|
| 25 | 
         
             
                    spk: Speaker = Speaker.from_file(spk_file)
         
     | 
| 26 | 
         
             
                    spk.name = name
         
     | 
| 27 | 
         
             
                    spk.gender = gender
         
     | 
| 28 | 
         
            +
                    spk.describe = desc
         
     | 
| 29 | 
         | 
| 30 | 
         
             
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
         
     | 
| 31 | 
         
             
                        torch.save(spk, tmp_file)
         
     | 
    	
        modules/webui/speaker/speaker_merger.py
    CHANGED
    
    | 
         @@ -38,12 +38,8 @@ def merge_spk( 
     | 
|
| 38 | 
         
             
                tensor_c = spk_to_tensor(spk_c)
         
     | 
| 39 | 
         
             
                tensor_d = spk_to_tensor(spk_d)
         
     | 
| 40 | 
         | 
| 41 | 
         
            -
                 
     | 
| 42 | 
         
            -
                     
     | 
| 43 | 
         
            -
                    or tensor_b is not None
         
     | 
| 44 | 
         
            -
                    or tensor_c is not None
         
     | 
| 45 | 
         
            -
                    or tensor_d is not None
         
     | 
| 46 | 
         
            -
                ), "At least one speaker should be selected"
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                merge_tensor = torch.zeros_like(
         
     | 
| 49 | 
         
             
                    tensor_a
         
     | 
| 
         | 
|
| 38 | 
         
             
                tensor_c = spk_to_tensor(spk_c)
         
     | 
| 39 | 
         
             
                tensor_d = spk_to_tensor(spk_d)
         
     | 
| 40 | 
         | 
| 41 | 
         
            +
                if tensor_a is None and tensor_b is None and tensor_c is None and tensor_d is None:
         
     | 
| 42 | 
         
            +
                    raise gr.Error("At least one speaker should be selected")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 43 | 
         | 
| 44 | 
         
             
                merge_tensor = torch.zeros_like(
         
     | 
| 45 | 
         
             
                    tensor_a
         
     |