File size: 10,401 Bytes
63f899c
 
43a4f7e
 
 
 
 
63f899c
 
446a654
 
 
 
5b8134f
43a4f7e
 
59d9186
63f899c
43a4f7e
bd786ec
43a4f7e
 
 
ec5560f
 
43a4f7e
446a654
 
43a4f7e
446a654
43a4f7e
446a654
 
 
 
 
 
 
dcda854
446a654
43a4f7e
 
 
446a654
 
 
 
43a4f7e
 
446a654
 
 
 
62e5071
 
446a654
 
62e5071
446a654
 
dcda854
446a654
dcda854
 
 
 
 
 
 
 
 
 
 
 
e7e9e9b
43a4f7e
e7e9e9b
 
43a4f7e
e7e9e9b
 
 
dcda854
 
446a654
5b8134f
 
43a4f7e
5b8134f
 
 
 
43a4f7e
446a654
43a4f7e
 
dcda854
 
 
43a4f7e
 
dcda854
43a4f7e
 
dcda854
 
43a4f7e
 
 
 
 
 
dcda854
 
43a4f7e
 
dcda854
43a4f7e
dcda854
446a654
 
dcda854
 
43a4f7e
 
dcda854
43a4f7e
dcda854
 
f634467
 
 
43a4f7e
 
f634467
43a4f7e
f634467
 
43a4f7e
446a654
59d9186
63f899c
43a4f7e
 
 
 
 
 
 
 
 
63f899c
 
43a4f7e
63f899c
dcda854
43a4f7e
446a654
43a4f7e
446a654
 
43a4f7e
62e5071
 
43a4f7e
 
446a654
 
 
 
5b8134f
686542a
dcda854
5b8134f
446a654
43a4f7e
a6075c0
43a4f7e
a6075c0
bd786ec
a6075c0
446a654
43a4f7e
dcda854
 
 
 
446a654
 
 
43a4f7e
 
446a654
 
43a4f7e
dcda854
446a654
43a4f7e
 
446a654
 
43a4f7e
 
446a654
f634467
43a4f7e
 
f634467
43a4f7e
 
dcda854
 
43a4f7e
 
 
 
 
 
 
 
 
 
 
446a654
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import os
import shutil
import uuid
import argparse
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

from huggingface_hub import snapshot_download
import gradio as gr
from gradio_client import Client, handle_file
from mutagen.mp3 import MP3
from pydub import AudioSegment
from PIL import Image
import ffmpeg

# Set working directory
os.chdir(os.path.dirname(os.path.abspath(__file__)))

from scripts.inference import inference_process

# Constants
AUDIO_MAX_DURATION = 4000
is_shared_ui = "fffiloni/tts-hallo-talking-portrait" in os.environ.get('SPACE_ID', '')
hallo_dir = snapshot_download(repo_id="fudan-generative-ai/hallo", local_dir="pretrained_models")

# Utility Functions
def is_mp3(file_path):
    try:
        MP3(file_path)
        return True
    except Exception:
        return False

def convert_mp3_to_wav(mp3_file_path, wav_file_path):
    audio = AudioSegment.from_mp3(mp3_file_path)
    audio.export(wav_file_path, format="wav")
    return wav_file_path

def trim_audio(file_path, output_path, max_duration):
    audio = AudioSegment.from_wav(file_path)
    if len(audio) > max_duration:
        audio = audio[:max_duration]
    audio.export(output_path, format="wav")
    return output_path

def add_silence_to_wav(wav_file_path, duration_s=1):
    audio = AudioSegment.from_wav(wav_file_path)
    silence = AudioSegment.silent(duration=duration_s * 1000)
    (audio + silence).export(wav_file_path, format="wav")
    return wav_file_path

def check_mp3(file_path):
    if is_mp3(file_path):
        unique_id = uuid.uuid4()
        wav_file_path = f"{os.path.splitext(file_path)[0]}-{unique_id}.wav"
        converted_audio = convert_mp3_to_wav(file_path, wav_file_path)
        print(f"File converted to {wav_file_path}")
        return converted_audio, gr.update(value=converted_audio, visible=True)
    else:
        print("The file is not an MP3 file.")
        return file_path, gr.update(value=file_path, visible=True)

def check_and_convert_webp_to_png(input_path, output_path):
    try:
        with Image.open(input_path) as img:
            if img.format == 'WEBP':
                img.save(output_path, 'PNG')
                print(f"Converted {input_path} to {output_path}")
                return output_path
            else:
                print(f"The file {input_path} is not in WebP format.")
                return input_path
    except IOError:
        print(f"Cannot open {input_path}. The file might not exist or is not an image.")

def convert_user_uploaded_webp(input_path):
    unique_id = uuid.uuid4()
    output_file = f"converted_to_png_portrait-{unique_id}.png"
    ready_png = check_and_convert_webp_to_png(input_path, output_file)
    print(f"PORTRAIT PNG FILE: {ready_png}")
    return ready_png

def clear_audio_elms():
    return gr.update(value=None, visible=False)

def change_video_codec(input_file, output_file, codec='libx264', audio_codec='aac'):
    try:
        ffmpeg.input(input_file).output(output_file, vcodec=codec, acodec=audio_codec).run(overwrite_output=True)
        print(f'Successfully changed codec of {input_file} and saved as {output_file}')
    except ffmpeg.Error as e:
        print(f'Error occurred: {e.stderr.decode()}')

# Gradio APIs
def generate_portrait(prompt_image):
    if not prompt_image:
        raise gr.Error("Can't generate a portrait without a prompt!")
    
    try:
        client = Client("ByteDance/SDXL-Lightning")
    except Exception:
        raise gr.Error("ByteDance/SDXL-Lightning space's API might not be ready, please wait, or upload an image instead.")
    
    result = client.predict(prompt=prompt_image, ckpt="4-Step", api_name="/generate_image")
    return convert_user_uploaded_webp(result)

def generate_voice_with_parler(prompt_audio, voice_description):
    if not prompt_audio:
        raise gr.Error("Can't generate a voice without text to synthesize!")
    
    if not voice_description:
        gr.Info("For better control, you may want to provide a voice character description next time.", duration=10, visible=True)
    
    try:
        client = Client("parler-tts/parler_tts_mini")
    except Exception:
        raise gr.Error("parler-tts/parler_tts_mini space's API might not be ready, please wait, or upload an audio instead.")
    
    result = client.predict(text=prompt_audio, description=voice_description, api_name="/gen_tts")
    return result, gr.update(value=result, visible=True)

def get_whisperspeech(prompt_audio_whisperspeech, audio_to_clone):
    try:
        client = Client("collabora/WhisperSpeech")
    except Exception:
        raise gr.Error("collabora/WhisperSpeech space's API might not be ready, please wait, or upload an audio instead.")
    
    result = client.predict(multilingual_text=prompt_audio_whisperspeech, speaker_audio=handle_file(audio_to_clone), speaker_url="", cps=14, api_name="/whisper_speech_demo")
    return result, gr.update(value=result, visible=True)

def get_maskGCT_TTS(prompt_audio_maskGCT, audio_to_clone):
    try:
        client = Client("amphion/maskgct")
    except Exception:
        raise gr.Error("amphion/maskgct space's API might not be ready, please wait, or upload an audio instead.")
    
    result = client.predict(prompt_wav=handle_file(audio_to_clone), target_text=prompt_audio_maskGCT, target_len=-1, n_timesteps=25, api_name="/predict")
    return result, gr.update(value=result, visible=True)

# Talking Portrait Generation
def run_hallo(source_image, driving_audio, progress=gr.Progress(track_tqdm=True)):
    unique_id = uuid.uuid4()
    args = argparse.Namespace(
        config='configs/inference/default.yaml',
        source_image=source_image,
        driving_audio=driving_audio,
        output=f'output-{unique_id}.mp4',
        pose_weight=1.0,
        face_weight=1.0,
        lip_weight=1.0,
        face_expand_ratio=1.2,
        checkpoint=None
    )
    inference_process(args)
    return f'output-{unique_id}.mp4'

def generate_talking_portrait(portrait, voice, progress=gr.Progress(track_tqdm=True)):
    if not portrait:
        raise gr.Error("Please provide a portrait to animate.")
    if not voice:
        raise gr.Error("Please provide audio (4 seconds max).")
    
    if is_shared_ui:
        unique_id = uuid.uuid4()
        trimmed_output_file = f"-{unique_id}.wav"
        voice = trim_audio(voice, trimmed_output_file, AUDIO_MAX_DURATION)
    
    ready_audio = add_silence_to_wav(voice)
    print(f"1 second of silence added to {voice}")

    talking_portrait_vid = run_hallo(portrait, ready_audio)
    final_output_file = f"converted_{talking_portrait_vid}"
    change_video_codec(talking_portrait_vid, final_output_file)
    
    return final_output_file

# Gradio Interface
css = '''
/* Your CSS here */
'''

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# TTS x Hallo Talking Portrait Generator")
        with gr.Row(elem_id="column-names"):
            gr.Markdown("## 1. Load Portrait")
            gr.Markdown("## 2. Load Voice")
            gr.Markdown("## 3. Result")
        with gr.Group(elem_id="main-group"):
            with gr.Row():
                with gr.Column():
                    portrait = gr.Image(sources=["upload"], type="filepath", format="png", elem_id="image-block")
                    prompt_image = gr.Textbox(label="Generate image", lines=2, max_lines=2)
                    gen_image_btn = gr.Button("Generate portrait (optional)")
                with gr.Column(elem_id="audio-column"):
                    voice = gr.Audio(type="filepath", elem_id="audio-block")
                    preprocess_audio_file = gr.File(visible=False)
                    with gr.Tab("Parler TTS", elem_id="parler-tab"):
                        prompt_audio = gr.Textbox(label="Text to synthesize", lines=3, max_lines=3, elem_id="text-synth")
                        voice_description = gr.Textbox(label="Voice description", lines=3, max_lines=3, elem_id="voice-desc")
                        gen_voice_btn = gr.Button("Generate voice (optional)")
                    with gr.Tab("WhisperSpeech", elem_id="whisperspeech-tab"):
                        prompt_audio_whisperspeech = gr.Textbox(label="Text to synthesize", lines=2, max_lines=2, elem_id="text-synth-wsp")
                        audio_to_clone = gr.Audio(label="Voice to clone", type="filepath", elem_id="audio-clone-elm")
                        gen_wsp_voice_btn = gr.Button("Generate voice clone (optional)")
                    with gr.Tab("MaskGCT TTS", elem_id="maskGCT-tab"):
                        prompt_audio_maskGCT = gr.Textbox(label="Text to synthesize", lines=2, max_lines=2, elem_id="text-synth-maskGCT")
                        audio_to_clone_maskGCT = gr.Audio(label="Voice to clone", type="filepath", elem_id="audio-clone-elm-maskGCT")
                        gen_maskGCT_voice_btn = gr.Button("Generate voice clone (optional)")
                with gr.Column(elem_id="result-column"):
                    result = gr.Video(elem_id="video-block")
                    submit_btn = gr.Button("Go talking Portrait !", elem_id="main-submit")
        with gr.Row(elem_id="pro-tips"):
            gr.Markdown("# Hallo Pro Tips:")
            gr.Markdown("# TTS Pro Tips:")

    portrait.upload(convert_user_uploaded_webp, inputs=[portrait], outputs=[portrait], queue=False, show_api=False)
    voice.upload(check_mp3, inputs=[voice], outputs=[voice, preprocess_audio_file], queue=False, show_api=False)
    voice.clear(clear_audio_elms, inputs=None, outputs=[preprocess_audio_file], queue=False, show_api=False)
    gen_image_btn.click(generate_portrait, inputs=[prompt_image], outputs=[portrait], queue=False, show_api=False)
    gen_voice_btn.click(generate_voice_with_parler, inputs=[prompt_audio, voice_description], outputs=[voice, preprocess_audio_file], queue=False, show_api=False)
    gen_wsp_voice_btn.click(get_whisperspeech, inputs=[prompt_audio_whisperspeech, audio_to_clone], outputs=[voice, preprocess_audio_file], queue=False, show_api=False)
    gen_maskGCT_voice_btn.click(get_maskGCT_TTS, inputs=[prompt_audio_maskGCT, audio_to_clone_maskGCT], outputs=[voice, preprocess_audio_file], queue=False, show_api=False)
    submit_btn.click(generate_talking_portrait, inputs=[portrait, voice], outputs=[result], show_api=False)

demo.queue(max_size=2).launch(show_error=True, show_api=False)