File size: 6,485 Bytes
9f4b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af6ea0
 
9f4b9c7
 
 
1af6ea0
 
 
 
9f4b9c7
1af6ea0
 
 
 
 
9f4b9c7
 
 
 
1af6ea0
9f4b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1af6ea0
9f4b9c7
 
 
1af6ea0
 
 
 
 
 
 
 
 
 
9f4b9c7
1af6ea0
 
 
 
 
 
9f4b9c7
1af6ea0
 
 
 
 
9f4b9c7
 
 
1af6ea0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import subprocess
from pathlib import Path

import gradio as gr

from config import hparams as hp
from config import hparams_gradio as hp_gradio
from nota_wav2lip import Wav2LipModelComparisonGradio

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = hp_gradio.device
print(f'Using {device} for inference.')
video_label_dict = hp_gradio.sample.video
audio_label_dict = hp_gradio.sample.audio

LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None)

if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
    subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
    subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)

path_inference_sample = "sample.tar.gz"
if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
    subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)


if __name__ == "__main__":

    servicer = Wav2LipModelComparisonGradio(
        device=device,
        video_label_dict=video_label_dict,
        audio_label_list=audio_label_dict,
        default_video='v1',
        default_audio='a1'
    )

    for video_name in sorted(video_label_dict):
        video_stem = Path(video_label_dict[video_name])
        servicer.update_video(video_stem, video_stem.with_suffix('.json'),
                              name=video_name)

    for audio_name in sorted(audio_label_dict):
        audio_path = Path(audio_label_dict[audio_name])
        servicer.update_audio(audio_path, name=audio_name)

    with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
        gr.Markdown(Path('docs/header.md').read_text())
        gr.Markdown(Path('docs/description.md').read_text())
        with gr.Row():
            with gr.Column(variant='panel'):

                gr.Markdown('## Select or Upload input video and audio', sanitize_html=False)
                # Define preview slots
                sample_video = gr.Video(interactive=False, label="Input Video")
                sample_audio = gr.Audio(interactive=False, label="Input Audio")

                # New upload inputs
                video_upload = gr.Video(source="upload", type="filepath", label="Upload Video")
                audio_upload = gr.Audio(source="upload", type="filepath", label="Upload Audio")

                # Define radio inputs
                video_selection = gr.Radio(video_label_dict,
                                           type='value', label="Select an input video:")
                audio_selection = gr.Radio(audio_label_dict,
                                           type='value', label="Select an input audio:")

                # Define button inputs
                with gr.Row(equal_height=True):
                    generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
                    generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")

            with gr.Column(variant='panel'):
                # Define original model output components
                gr.Markdown('## Original Wav2Lip')
                original_model_output = gr.Video(label="Original Model", interactive=False)
                with gr.Column():
                    with gr.Row(equal_height=True):
                        original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
                        original_model_fps = gr.Textbox(value="", label="FPS")
                    original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters")
            with gr.Column(variant='panel'):
                # Define compressed model output components
                gr.Markdown('## Compressed Wav2Lip (Ours)')
                compressed_model_output = gr.Video(label="Compressed Model", interactive=False)
                with gr.Column():
                    with gr.Row(equal_height=True):
                        compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
                        compressed_model_fps = gr.Textbox(value="", label="FPS")
                    compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters")

        # Switch video and audio samples when selecting the radio button
        video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video)
        audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio)

        # Update preview when uploading
        video_upload.change(fn=lambda x: x, inputs=video_upload, outputs=sample_video)
        audio_upload.change(fn=lambda x: x, inputs=audio_upload, outputs=sample_audio)

        # Helper: decide whether to use uploaded or selected
        def resolve_inputs(video_choice, audio_choice, video_file, audio_file):
            video_path = video_file if video_file else video_label_dict.get(video_choice)
            audio_path = audio_file if audio_file else audio_label_dict.get(audio_choice)
            return video_path, audio_path

        # Click the generate button for original model
        generate_original_button.click(
            fn=lambda v, a, vu, au: servicer.generate_original_model(*resolve_inputs(v, a, vu, au)),
            inputs=[video_selection, audio_selection, video_upload, audio_upload],
            outputs=[original_model_output, original_model_inference_time, original_model_fps]
        )

        # Click the generate button for compressed model
        generate_compressed_button.click(
            fn=lambda v, a, vu, au: servicer.generate_compressed_model(*resolve_inputs(v, a, vu, au)),
            inputs=[video_selection, audio_selection, video_upload, audio_upload],
            outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps]
        )

        gr.Markdown(Path('docs/footer.md').read_text())

    demo.queue().launch()