ford442 commited on
Commit
ac7d295
·
verified ·
1 Parent(s): 83771ff

Create musicgen_colab.py

Browse files
Files changed (1) hide show
  1. demos/musicgen_colab.py +494 -0
demos/musicgen_colab.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # <--- IMPORTANT: Add this import
2
+ import argparse
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ import subprocess as sp
7
+ import sys
8
+ import time
9
+ import typing as tp
10
+ from tempfile import NamedTemporaryFile, gettempdir
11
+ from einops import rearrange
12
+ import torch
13
+ import gradio as gr
14
+ from audiocraft.data.audio_utils import convert_audio
15
+ from audiocraft.data.audio import audio_write
16
+ from audiocraft.models.encodec import InterleaveStereoCompressionModel
17
+ from audiocraft.models import MusicGen, MultiBandDiffusion
18
+ import multiprocessing as mp
19
+ import warnings
20
+
21
+ os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
22
+ os.environ["SAFETENSORS_FAST_GPU"] = "1"
23
+
24
+ torch.backends.cuda.matmul.allow_tf32 = False
25
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
26
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
27
+ torch.backends.cudnn.allow_tf32 = False
28
+ torch.backends.cudnn.deterministic = False
29
+ torch.backends.cudnn.benchmark = False
30
+ # torch.backends.cuda.preferred_blas_library="cublas"
31
+ # torch.backends.cuda.preferred_linalg_library="cusolver"
32
+ torch.set_float32_matmul_precision("highest")
33
+
34
+ class FileCleaner:
35
+ def __init__(self, file_lifetime: float = 3600):
36
+ self.file_lifetime = file_lifetime
37
+ self.files = []
38
+ def add(self, path: tp.Union[str, Path]):
39
+ self._cleanup()
40
+ self.files.append((time.time(), Path(path)))
41
+ def _cleanup(self):
42
+ now = time.time()
43
+ for time_added, path in list(self.files):
44
+ if now - time_added > self.file_lifetime:
45
+ if path.exists():
46
+ path.unlink()
47
+ self.files.pop(0)
48
+ else:
49
+ break
50
+
51
+ file_cleaner = FileCleaner()
52
+
53
+ def convert_wav_to_mp4(wav_path, output_path=None):
54
+ """Converts a WAV file to a waveform MP4 video using ffmpeg."""
55
+ if output_path is None:
56
+ # Create output path in the same directory as the input
57
+ output_path = Path(wav_path).with_suffix(".mp4")
58
+ try:
59
+ command = [
60
+ "ffmpeg",
61
+ "-y", # Overwrite output file if it exists
62
+ "-i", str(wav_path),
63
+ "-filter_complex",
64
+ "[0:a]showwaves=s=1280x202:mode=line,format=yuv420p[v]", # Waveform filter
65
+ "-map", "[v]",
66
+ "-map", "0:a",
67
+ "-c:v", "libx264", # Video codec
68
+ "-c:a", "aac", # Audio codec
69
+ "-preset", "fast", # Important, don't do veryslow.
70
+ str(output_path),
71
+ ]
72
+ process = sp.run(command, capture_output=True, text=True, check=True)
73
+ return str(output_path)
74
+ except sp.CalledProcessError as e:
75
+ print(f"Error in ffmpeg conversion: {e}")
76
+ print(f"ffmpeg stdout: {e.stdout}")
77
+ print(f"ffmpeg stderr: {e.stderr}")
78
+ raise # Re-raise the exception to be caught by Gradio
79
+
80
+ def model_worker(model_name: str, task_queue: mp.Queue, result_queue: mp.Queue):
81
+ """
82
+ Persistent worker process (used when NOT running as a daemon).
83
+ """
84
+ try:
85
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
86
+ model = MusicGen.get_pretrained(model_name, device=device)
87
+ mbd = MultiBandDiffusion.get_mbd_musicgen(device=device)
88
+ while True:
89
+ task = task_queue.get()
90
+ if task is None:
91
+ break
92
+ task_id, text, melody, duration, use_diffusion, gen_params = task
93
+ try:
94
+ model.set_generation_params(duration=duration, **gen_params)
95
+ target_sr = model.sample_rate
96
+ target_ac = 1
97
+ processed_melody = None
98
+ if melody:
99
+ sr, melody_data = melody
100
+ melody_tensor = torch.from_numpy(melody_data).to(device).float().t()
101
+ if melody_tensor.ndim == 1:
102
+ melody_tensor = melody_tensor.unsqueeze(0)
103
+ melody_tensor = melody_tensor[..., :int(sr * duration)]
104
+ processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
105
+ if processed_melody is not None:
106
+ output, tokens = model.generate_with_chroma(
107
+ descriptions=[text],
108
+ melody_wavs=[processed_melody],
109
+ melody_sample_rate=target_sr,
110
+ progress=True,
111
+ return_tokens=True
112
+ )
113
+ else:
114
+ output, tokens = model.generate([text], progress=True, return_tokens=True)
115
+ output = output.detach().cpu()
116
+ if use_diffusion:
117
+ if isinstance(model.compression_model, InterleaveStereoCompressionModel):
118
+ left, right = model.compression_model.get_left_right_codes(tokens)
119
+ tokens = torch.cat([left, right])
120
+ outputs_diffusion = mbd.tokens_to_wav(tokens)
121
+ if isinstance(model.compression_model, InterleaveStereoCompressionModel):
122
+ assert outputs_diffusion.shape[1] == 1 # output is mono
123
+ outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
124
+ outputs_diffusion = outputs_diffusion.detach().cpu()
125
+ result_queue.put((task_id, (output, outputs_diffusion)))
126
+ else:
127
+ result_queue.put((task_id, (output, None)))
128
+ except Exception as e:
129
+ result_queue.put((task_id, e))
130
+ except Exception as e:
131
+ result_queue.put((-1, e))
132
+
133
+ class Predictor:
134
+ def __init__(self, model_name: str, depth: str):
135
+ self.model_name = model_name
136
+ self.is_daemon = mp.current_process().daemon
137
+ if self.is_daemon:
138
+ # Running in a daemonic process (e.g., on Spaces)
139
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
140
+ self.model = MusicGen.get_pretrained(self.model_name, device=self.device, depth=depth)
141
+ self.mbd = MultiBandDiffusion.get_mbd_musicgen(device=self.device) # Load MBD here too
142
+ self.current_task_id = 0 # Initialize task ID
143
+ else:
144
+ # Running in a non-daemonic process (e.g., locally)
145
+ self.task_queue = mp.Queue()
146
+ self.result_queue = mp.Queue()
147
+ self.process = mp.Process(
148
+ target=model_worker, args=(self.model_name, self.task_queue, self.result_queue)
149
+ )
150
+ self.process.start()
151
+ self.current_task_id = 0
152
+ self._check_initialization()
153
+
154
+ def _check_initialization(self):
155
+ """Check if the worker process initialized successfully (only in non-daemon mode)."""
156
+ if not self.is_daemon:
157
+ time.sleep(2)
158
+ try:
159
+ task_id, result = self.result_queue.get(timeout=3)
160
+ if isinstance(result, Exception):
161
+ if task_id == -1:
162
+ raise RuntimeError("Model loading failed in worker process.") from result
163
+ except:
164
+ pass
165
+
166
+ def predict(self, text, melody, duration, use_diffusion, **gen_params):
167
+ """Submits a prediction task."""
168
+ if self.is_daemon:
169
+ # Directly perform the prediction (single-process mode)
170
+ self.current_task_id +=1
171
+ task_id = self.current_task_id
172
+ try:
173
+ self.model.set_generation_params(duration=duration, **gen_params)
174
+ target_sr = self.model.sample_rate
175
+ target_ac = 1
176
+ processed_melody = None
177
+ if melody:
178
+ sr, melody_data = melody
179
+ melody_tensor = torch.from_numpy(melody_data).to(self.device).float().t()
180
+ if melody_tensor.ndim == 1:
181
+ melody_tensor = melody_tensor.unsqueeze(0)
182
+ melody_tensor = melody_tensor[..., :int(sr * duration)]
183
+ processed_melody = convert_audio(melody_tensor, sr, target_sr, target_ac)
184
+ if processed_melody is not None:
185
+ output, tokens = self.model.generate_with_chroma(
186
+ descriptions=[text],
187
+ melody_wavs=[processed_melody],
188
+ melody_sample_rate=target_sr,
189
+ progress=True,
190
+ return_tokens=True
191
+ )
192
+ else:
193
+ output, tokens = self.model.generate([text], progress=True, return_tokens=True)
194
+ output = output.detach().cpu()
195
+ if use_diffusion:
196
+ if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
197
+ left, right = self.model.compression_model.get_left_right_codes(tokens)
198
+ tokens = torch.cat([left, right])
199
+ outputs_diffusion = self.mbd.tokens_to_wav(tokens)
200
+ if isinstance(self.model.compression_model, InterleaveStereoCompressionModel):
201
+ assert outputs_diffusion.shape[1] == 1 # output is mono
202
+ outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
203
+ outputs_diffusion = outputs_diffusion.detach().cpu()
204
+ return task_id, (output, outputs_diffusion) #Return the task id.
205
+ else:
206
+ return task_id, (output, None)
207
+ except Exception as e:
208
+ return task_id, e
209
+ else:
210
+ # Use the multiprocessing queue (multi-process mode)
211
+ self.current_task_id += 1
212
+ task = (self.current_task_id, text, melody, duration, use_diffusion, gen_params)
213
+ self.task_queue.put(task)
214
+ return self.current_task_id
215
+
216
+ def get_result(self, task_id):
217
+ """Retrieves the result of a prediction task."""
218
+ if self.is_daemon:
219
+ # Results are returned directly by 'predict' in daemon mode
220
+ result_id, result = task_id, task_id #predictor return (task_id, results)
221
+ else:
222
+ # Get result from the queue (multi-process mode)
223
+ while True:
224
+ result_task_id, result = self.result_queue.get()
225
+ if result_task_id == task_id:
226
+ break # Found the correct result
227
+ if isinstance(result, Exception):
228
+ raise result
229
+ return result
230
+
231
+ def shutdown(self):
232
+ """Shuts down the worker process (if running)."""
233
+ if not self.is_daemon and self.process.is_alive():
234
+ self.task_queue.put(None)
235
+ self.process.join()
236
+
237
+ _default_model_name = "facebook/musicgen-melody"
238
+
239
+ @spaces.GPU(duration=90) # Use the decorator for Spaces
240
+ def predict_full(model, model_path, depth, use_mbd, text, melody, duration, topk, topp, temperature, cfg_coef):
241
+ # Initialize Predictor *INSIDE* the function
242
+ predictor = Predictor(model, depth)
243
+ task_id, (wav, diffusion_wav) = predictor.predict( # Unpack directly!
244
+ text=text,
245
+ melody=melody,
246
+ duration=duration,
247
+ use_diffusion=use_mbd,
248
+ top_k=topk,
249
+ top_p=topp,
250
+ temperature=temperature,
251
+ cfg_coef=cfg_coef,
252
+ )
253
+ # Save and return audio files
254
+ wav_paths = []
255
+ video_paths = []
256
+ # Save standard output
257
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
258
+ audio_write(
259
+ file.name, wav[0], 44100, strategy="loudness", #hardcoded sample rate
260
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
261
+ )
262
+ wav_paths.append(file.name)
263
+ # Make and clean up video:
264
+ video_path = convert_wav_to_mp4(file.name)
265
+ video_paths.append(video_path)
266
+ file_cleaner.add(file.name)
267
+ file_cleaner.add(video_path)
268
+ # Save MBD output if used
269
+ if diffusion_wav is not None:
270
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
271
+ audio_write(
272
+ file.name, diffusion_wav[0], 44100, strategy="loudness", #hardcoded sample rate
273
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False
274
+ )
275
+ wav_paths.append(file.name)
276
+ # Make and clean up video:
277
+ video_path = convert_wav_to_mp4(file.name)
278
+ video_paths.append(video_path)
279
+ file_cleaner.add(file.name)
280
+ file_cleaner.add(video_path)
281
+ # Shutdown predictor to prevent hanging processes!
282
+ if not predictor.is_daemon: # Important!
283
+ predictor.shutdown()
284
+ if use_mbd:
285
+ return video_paths[0], wav_paths[0], video_paths[1], wav_paths[1]
286
+ return video_paths[0], wav_paths[0], None, None
287
+
288
+ def toggle_audio_src(choice):
289
+ if choice == "mic":
290
+ return gr.update(sources="microphone", value=None, label="Microphone")
291
+ else:
292
+ return gr.update(sources="upload", value=None, label="File")
293
+
294
+ def toggle_diffusion(choice):
295
+ if choice == "MultiBand_Diffusion":
296
+ return [gr.update(visible=True)] * 2
297
+ else:
298
+ return [gr.update(visible=False)] * 2
299
+
300
+ def ui_full(launch_kwargs):
301
+ with gr.Blocks() as interface:
302
+ gr.Markdown(
303
+ """
304
+ # MusicGen
305
+ This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
306
+ a simple and controllable model for music generation
307
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
308
+ """
309
+ )
310
+ with gr.Row():
311
+ with gr.Column():
312
+ with gr.Row():
313
+ text = gr.Text(label="Input Text", interactive=True)
314
+ with gr.Column():
315
+ radio = gr.Radio(["file", "mic"], value="file",
316
+ label="Condition on a melody (optional) File or Mic")
317
+ melody = gr.Audio(sources="upload", type="numpy", label="File",
318
+ interactive=True, elem_id="melody-input")
319
+ with gr.Row():
320
+ submit = gr.Button("Submit")
321
+ # _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) # Interrupt is now handled implicitly
322
+ with gr.Row():
323
+ model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
324
+ "facebook/musicgen-large", "facebook/musicgen-melody-large",
325
+ "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
326
+ "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
327
+ "facebook/musicgen-stereo-melody-large"],
328
+ label="Model", value="facebook/musicgen-melody", interactive=True)
329
+ model_path = gr.Text(label="Model Path (custom models)", interactive=False, visible=False) # Keep, but hide
330
+ depth = gr.Radio(["float32", "bfloat16", "float16"],
331
+ label="Model Precision", value="float32", interactive=True)
332
+ with gr.Row():
333
+ decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
334
+ label="Decoder", value="Default", interactive=True)
335
+ with gr.Row():
336
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
337
+ with gr.Row():
338
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
339
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
340
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
341
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
342
+ with gr.Column():
343
+ output = gr.Video(label="Generated Music")
344
+ audio_output = gr.Audio(label="Generated Music (wav)", type='filepath')
345
+ diffusion_output = gr.Video(label="MultiBand Diffusion Decoder", visible=False)
346
+ audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath', visible=False)
347
+
348
+ submit.click(
349
+ toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False
350
+ ).then(
351
+ predict_full,
352
+ inputs=[model, model_path, depth, decoder, text, melody, duration, topk, topp, temperature, cfg_coef],
353
+ outputs=[output, audio_output, diffusion_output, audio_diffusion]
354
+ )
355
+ radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
356
+
357
+ gr.Examples(
358
+ fn=predict_full,
359
+ examples=[
360
+ [
361
+ "An 80s driving pop song with heavy drums and synth pads in the background",
362
+ "./assets/bach.mp3",
363
+ "facebook/musicgen-melody",
364
+ "Default"
365
+ ],
366
+ [
367
+ "A cheerful country song with acoustic guitars",
368
+ "./assets/bolero_ravel.mp3",
369
+ "facebook/musicgen-melody",
370
+ "Default"
371
+ ],
372
+ [
373
+ "90s rock song with electric guitar and heavy drums",
374
+ None,
375
+ "facebook/musicgen-medium",
376
+ "Default"
377
+ ],
378
+ [
379
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
380
+ "./assets/bach.mp3",
381
+ "facebook/musicgen-melody",
382
+ "Default"
383
+ ],
384
+ [
385
+ "lofi slow bpm electro chill with organic samples",
386
+ None,
387
+ "facebook/musicgen-medium",
388
+ "Default"
389
+ ],
390
+ [
391
+ "Punk rock with loud drum and power guitar",
392
+ None,
393
+ "facebook/musicgen-medium",
394
+ "MultiBand_Diffusion"
395
+ ],
396
+ ],
397
+ inputs=[text, melody, model, decoder],
398
+ outputs=[output]
399
+ )
400
+ gr.Markdown(
401
+ """
402
+ ### More details
403
+
404
+ The model will generate a short music extract based on the description you provided.
405
+ The model can generate up to 30 seconds of audio in one pass.
406
+
407
+ The model was trained with description from a stock music catalog, descriptions that will work best
408
+ should include some level of details on the instruments present, along with some intended use case
409
+ (e.g. adding "perfect for a commercial" can somehow help).
410
+
411
+ Using one of the `melody` model (e.g. `musicgen-melody-*`), you can optionally provide a reference audio
412
+ from which a broad melody will be extracted.
413
+ The model will then try to follow both the description and melody provided.
414
+ For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
415
+
416
+ It is now possible to extend the generation by feeding back the end of the previous chunk of audio.
417
+ This can take a long time, and the model might lose consistency. The model might also
418
+ decide at arbitrary positions that the song ends.
419
+
420
+ **WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
421
+ An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
422
+ are generated each time.
423
+
424
+ We present 10 model variations:
425
+ 1. facebook/musicgen-melody -- a music generation model capable of generating music condition
426
+ on text and melody inputs. **Note**, you can also use text only.
427
+ 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only.
428
+ 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only.
429
+ 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only.
430
+ 5. facebook/musicgen-melody-large -- a 3.3B transformer decoder conditioned on text and melody.
431
+ 6. facebook/musicgen-stereo-small -- a 300M transformer decoder conditioned on text only, fine tuned for stereo output.
432
+ 7. facebook/musicgen-stereo-medium -- a 1.5B transformer decoder conditioned on text only, fine tuned for stereo output.
433
+ 8. facebook/musicgen-stereo-melody -- a 1.5B transformer decoder conditioned on text and melody, fine tuned for stereo output.
434
+ 9. facebook/musicgen-stereo-large -- a 3.3B transformer decoder conditioned on text only, fine tuned for stereo output.
435
+ 10. facebook/musicgen-stereo-melody-large -- a 3.3B transformer decoder conditioned on text and melody, fine tuned for stereo output.
436
+
437
+ We also present two way of decoding the audio tokens:
438
+ 1. Use the default GAN based compression model. It can suffer from artifacts especially
439
+ for crashes, snares etc.
440
+ 2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
441
+ at an extra computational cost. When this is selected, we provide both the GAN based decoded
442
+ audio, and the one obtained with MBD.
443
+
444
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
445
+ for more details.
446
+ """
447
+ )
448
+
449
+ interface.queue().launch(**launch_kwargs)
450
+
451
+ if __name__ == '__main__':
452
+ parser = argparse.ArgumentParser()
453
+ parser.add_argument(
454
+ '--listen',
455
+ type=str,
456
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '127.0.0.1',
457
+ help='IP to listen on for connections to Gradio',
458
+ )
459
+ parser.add_argument(
460
+ '--username', type=str, default='', help='Username for authentication'
461
+ )
462
+ parser.add_argument(
463
+ '--password', type=str, default='', help='Password for authentication'
464
+ )
465
+ parser.add_argument(
466
+ '--server_port',
467
+ type=int,
468
+ default=0,
469
+ help='Port to run the server listener on',
470
+ )
471
+ parser.add_argument(
472
+ '--inbrowser', action='store_true', help='Open in browser'
473
+ )
474
+ parser.add_argument(
475
+ '--share', action='store_true', help='Share the gradio UI'
476
+ )
477
+ args = parser.parse_args()
478
+ launch_kwargs = {}
479
+ launch_kwargs['server_name'] = args.listen
480
+ if args.username and args.password:
481
+ launch_kwargs['auth'] = (args.username, args.password)
482
+ if args.server_port:
483
+ launch_kwargs['server_port'] = args.server_port
484
+ if args.inbrowser:
485
+ launch_kwargs['inbrowser'] = args.inbrowser
486
+ if args.share:
487
+ launch_kwargs['share'] = True
488
+ logging.basicConfig(level=logging.INFO, stream=sys.stderr)
489
+ # Added predictor shutdown
490
+ try:
491
+ ui_full(launch_kwargs)
492
+ finally:
493
+ if _predictor is not None:
494
+ _predictor.shutdown()