eggarsway commited on
Commit
8b2537e
Β·
1 Parent(s): 02f0f02

Better gradio keyframing interface

Browse files
Files changed (1) hide show
  1. app.py +213 -82
app.py CHANGED
@@ -8,7 +8,7 @@ import numpy as np
8
  from PIL import Image, ImageOps, ImageDraw, ImageFont, ImageColor
9
  from urllib.request import urlopen
10
 
11
- #os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
12
 
13
  root = os.path.dirname(os.path.abspath(__file__))
14
  static = os.path.join(root, "static")
@@ -33,12 +33,13 @@ unet3d_condition_model_forward_copy = UNet3DConditionModel.forward
33
  UNet3DConditionModel.forward = unet3d_condition_model_forward
34
 
35
 
36
-
37
  model_id = "cerspense/zeroscope_v2_576w"
38
  model_path = model_id
39
  pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
40
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
41
- pipe.to('cuda')
 
 
42
 
43
  @spaces.GPU(duration=120)
44
  def core(bundle):
@@ -55,25 +56,41 @@ def core(bundle):
55
 
56
 
57
  def clear_btn_fn():
58
- return "", "", "", ""
 
59
 
60
 
61
  def gen_btn_fn(
62
- prompts,
63
- bboxes,
64
- frames,
65
- word_prompt_indices,
66
- trailing_length,
67
- n_spatial_steps,
68
- n_temporal_steps,
69
- spatial_strengthen_scale,
70
- spatial_weaken_scale,
71
- temporal_strengthen_scale,
72
- temporal_weaken_scale,
73
- rand_seed,
74
- progress = gr.Progress(),
 
75
  ):
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  bundle = {}
78
  bundle["trailing_length"] = trailing_length
79
  bundle["num_dd_spatial_steps"] = n_spatial_steps
@@ -87,26 +104,14 @@ def gen_btn_fn(
87
  bundle["token_inds"] = [int(v) for v in word_prompt_indices.split(",")]
88
 
89
  bundle["keyframe"] = []
90
- frames = frames.split(";")
91
- bboxes = bboxes.split(";")
92
- if ";" in prompts:
93
- prompts = prompts.split(";")
94
- else:
95
- prompts = [prompts for i in range(len(frames))]
96
-
97
- assert (
98
- len(frames) == len(bboxes) == len(prompts)
99
- ), "Inconsistent number of keyframes in the given inputs."
100
-
101
- frames.pop()
102
- bboxes.pop()
103
- prompts.pop()
104
-
105
- for i in range(len(frames)):
106
  keyframe = {}
107
- keyframe["bbox_ratios"] = [float(v) for v in bboxes[i].split(",")]
108
- keyframe["frame"] = int(frames[i])
109
- keyframe["prompt"] = prompts[i]
 
 
110
  bundle["keyframe"].append(keyframe)
111
  print(bundle)
112
  result = core(bundle)
@@ -114,6 +119,15 @@ def gen_btn_fn(
114
  return path
115
 
116
 
 
 
 
 
 
 
 
 
 
117
  def save_mask(inputs):
118
  layers = inputs["layers"]
119
  if not layers:
@@ -191,7 +205,7 @@ with gr.Blocks(
191
  ) as main:
192
 
193
  description = """
194
- <h1 align="center" style="font-size: 48px">TrailBlazer: Trajectory Control for Diffusion-Based Video Generation</h1>
195
  <h4 align="center" style="margin: 0;">If you like our project, please give us a star ✨ at our Huggingface space, and our Github repository.</h4>
196
  <br>
197
  <span align="center" style="font-size: 18px">
@@ -205,8 +219,9 @@ with gr.Blocks(
205
  <p>
206
  <strong>Usage:</strong> Our Gradio app is implemented based on our executable script CmdTrailBlazer in our github repository. Please see our general information below for a quick guidance, as well as the hints within the app widgets.
207
  <ul>
208
- <li>Basic: The bounding box (bbox) is the tuple of four floats for the rectangular corners: left, top, right, bottom in the normalized ratio. The Word prompt indices is a list of 1-indexed numbers determining the prompt word.</li>
209
- <li>Advanced Options: We also offer some key parameters to adjust the synthesis result. Please see our paper for more information about the ablations.</li>
 
210
  </ul>
211
 
212
  For your initial use, it is advisable to select one of the examples provided below and attempt to swap the subject first (e.g., cat -> lion). Subsequently, define the keyframe with the associated bbox/frame/prompt. Please note that our current work is based on the ZeroScope (cerspense/zeroscope_v2_576w) model. Using prompts that are commonly recognized in the ZeroScope model context is recommended.
@@ -214,25 +229,66 @@ with gr.Blocks(
214
  """
215
 
216
  gr.HTML(description)
217
- dummy_note = gr.Textbox(
218
- interactive=True, label="Note", visible=False
219
- )
 
 
 
 
220
  with gr.Row():
221
  with gr.Column(scale=2):
222
  with gr.Row():
223
  with gr.Tab("Main"):
224
- text_prompt_tb = gr.Textbox(
225
- interactive=True, label="Keyframe: Prompt"
226
- )
227
- bboxes_tb = gr.Textbox(interactive=True, label="Keyframe: Bboxes")
228
- frame_tb = gr.Textbox(
229
- interactive=True, label="Keyframe: frame indices"
 
 
 
 
 
 
230
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  with gr.Row():
232
  word_prompt_indices_tb = gr.Textbox(
233
  interactive=True, label="Word prompt indices:"
234
  )
235
- text = "<strong>Hint</strong>: Each keyframe ends with <strong>SEMICOLON</strong>, and <strong>COMMA</strong> for separating each value in the keyframe. The prompt field can be a single prompt without semicolon, or multiple prompts ended semicolon. One can use the SketchPadHelper tab to help to design the bboxes field."
236
  gr.HTML(text)
237
  with gr.Row():
238
  clear_btn = gr.Button(value="Clear")
@@ -343,57 +399,132 @@ with gr.Blocks(
343
  with gr.Row():
344
  out_gen_1 = gr.Video(visible=True, show_label=False)
345
 
 
 
346
  with gr.Row():
347
  gr.Examples(
348
  examples=[
349
  [
 
 
 
350
  "A clownfish swimming in a coral reef",
351
- "0.5,0.35,1.0,0.65; 0.0,0.35,0.5,0.65;",
352
- "0; 24;",
353
- "1, 2",
 
 
 
354
  "123451232531",
355
- "It generates clownfish at right, then move to left",
356
- "assets/gradio/fish-RL.mp4",
357
- ],
358
- [
359
- "A cat is running on the grass",
360
- "0.0,0.35,0.4,0.65; 0.6,0.35,1.0,0.65; 0.0,0.35,0.4,0.65;"
361
- "0.6,0.35,1.0,0.65; 0.0,0.35,0.4,0.65;",
362
- "0; 6; 12; 18; 24;",
363
- "1, 2",
364
- "123451232530",
365
- "The cat will run Left/Right/Left/Right",
366
- "assets/gradio/cat-LRLR.mp4",
367
  ],
368
  [
 
 
 
369
  "A fish swimming in the ocean",
370
- "0.0,0.0,0.1,0.1; 0.5,0.5,1.0,1.0;",
371
- "0; 24;",
 
 
 
372
  "1, 2",
373
  "0",
374
- "The fish moves from top left to bottom right, from far to near.",
375
- "assets/gradio/fish-TL2BR.mp4"
376
  ],
377
  [
 
 
 
 
378
  "A tiger walking alone down the street",
379
- "0.0,0.0,0.1,0.1; 0.5,0.5,1.0,1.0;",
380
- "0; 24;",
 
 
381
  "1, 2",
382
  "0",
383
- "Same with the above but now the prompt associates with tiger",
384
- "assets/gradio/tiger-TL2BR.mp4"
385
  ],
386
  [
387
- "A white cat walking on the grass; A yellow dog walking on the grass;",
388
- "0.7,0.4,1.0,0.65; 0.0,0.4,0.3,0.65;",
389
- "0; 24;",
 
 
 
 
 
 
390
  "1,2,3",
391
  "123451232531",
392
- "The subject will deformed from cat to dog.",
393
- "assets/gradio/Cat2Dog.mp4",
394
  ],
395
  ],
396
- inputs=[text_prompt_tb, bboxes_tb, frame_tb, word_prompt_indices_tb, rand_seed, dummy_note, out_gen_1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  outputs=None,
398
  fn=None,
399
  cache_examples=False,
@@ -402,16 +533,16 @@ with gr.Blocks(
402
  clear_btn.click(
403
  clear_btn_fn,
404
  inputs=[],
405
- outputs=[text_prompt_tb, bboxes_tb, frame_tb, word_prompt_indices_tb],
406
  queue=False,
407
  )
408
 
409
  gen_btn.click(
410
  gen_btn_fn,
411
  inputs=[
412
- text_prompt_tb,
413
- bboxes_tb,
414
- frame_tb,
415
  word_prompt_indices_tb,
416
  trailing_length,
417
  n_spatial_steps,
 
8
  from PIL import Image, ImageOps, ImageDraw, ImageFont, ImageColor
9
  from urllib.request import urlopen
10
 
11
+ # os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
12
 
13
  root = os.path.dirname(os.path.abspath(__file__))
14
  static = os.path.join(root, "static")
 
33
  UNet3DConditionModel.forward = unet3d_condition_model_forward
34
 
35
 
 
36
  model_id = "cerspense/zeroscope_v2_576w"
37
  model_path = model_id
38
  pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
39
  pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
40
+ pipe.to("cuda")
41
+
42
+ MAX_KEYS = 10
43
 
44
  @spaces.GPU(duration=120)
45
  def core(bundle):
 
56
 
57
 
58
  def clear_btn_fn():
59
+
60
+ return 0, *[""] * MAX_KEYS, *[""] * MAX_KEYS, *[""] * MAX_KEYS, ""
61
 
62
 
63
  def gen_btn_fn(
64
+ *args,
65
+ # *prompts,
66
+ # *bboxes,
67
+ # *frame_indices,
68
+ # word_prompt_indices_tb,
69
+ # trailing_length,
70
+ # n_spatial_steps,
71
+ # n_temporal_steps,
72
+ # spatial_strengthen_scale,
73
+ # spatial_weaken_scale,
74
+ # temporal_strengthen_scale,
75
+ # temporal_weaken_scale,
76
+ # rand_seed,
77
+ progress=gr.Progress(),
78
  ):
79
 
80
+ # no prompt at all
81
+ if not args[0]:
82
+ return
83
+
84
+ rand_seed = args[-1]
85
+ temporal_weaken_scale = args[-2]
86
+ temporal_strengthen_scale = args[-3]
87
+ spatial_weaken_scale = args[-4]
88
+ spatial_strengthen_scale = args[-5]
89
+ n_temporal_steps = args[-6]
90
+ n_spatial_steps = args[-7]
91
+ trailing_length = args[-8]
92
+ word_prompt_indices = args[-9]
93
+
94
  bundle = {}
95
  bundle["trailing_length"] = trailing_length
96
  bundle["num_dd_spatial_steps"] = n_spatial_steps
 
104
  bundle["token_inds"] = [int(v) for v in word_prompt_indices.split(",")]
105
 
106
  bundle["keyframe"] = []
107
+
108
+ for i in range(MAX_KEYS):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  keyframe = {}
110
+ if not args[i]:
111
+ break
112
+ keyframe["prompt"] = args[i]
113
+ keyframe["bbox_ratios"] = [float(v) for v in args[i + MAX_KEYS].split(",")]
114
+ keyframe["frame"] = int(args[i + 2 * MAX_KEYS])
115
  bundle["keyframe"].append(keyframe)
116
  print(bundle)
117
  result = core(bundle)
 
119
  return path
120
 
121
 
122
+ def keyframe_update(num):
123
+ keyframes = []
124
+ for i in range(num):
125
+ keyframes.append(gr.Row(visible=True))
126
+ for i in range(MAX_KEYS - num):
127
+ keyframes.append(gr.Row(visible=False))
128
+ return keyframes
129
+
130
+
131
  def save_mask(inputs):
132
  layers = inputs["layers"]
133
  if not layers:
 
205
  ) as main:
206
 
207
  description = """
208
+ <h1 align="center" style="font-size: 48px">TrailBlazer: Trajectory Control for Diffusion-Based Video Generation (v0.0.3)</h1>
209
  <h4 align="center" style="margin: 0;">If you like our project, please give us a star ✨ at our Huggingface space, and our Github repository.</h4>
210
  <br>
211
  <span align="center" style="font-size: 18px">
 
219
  <p>
220
  <strong>Usage:</strong> Our Gradio app is implemented based on our executable script CmdTrailBlazer in our github repository. Please see our general information below for a quick guidance, as well as the hints within the app widgets.
221
  <ul>
222
+ <li><strong>Basic</strong>: Our app workflow is straightforward. First, select the Number of keyframes, then fill all the values in the appeared prompt/frame indice/bounding box(bbox) for each keyframe, as well as the word prompt indices. Finally, hit the Generate button to run the TrailBlazer. It is roughly 90secs to get the result. We also provide the <strong>SketchPadHelper</strong> to visually design our bbox format</li>
223
+
224
+ <li><strong>Advanced Options</strong>: We also offer some key parameters to adjust the synthesis result. Please see our paper for more information about the ablations.</li>
225
  </ul>
226
 
227
  For your initial use, it is advisable to select one of the examples provided below and attempt to swap the subject first (e.g., cat -> lion). Subsequently, define the keyframe with the associated bbox/frame/prompt. Please note that our current work is based on the ZeroScope (cerspense/zeroscope_v2_576w) model. Using prompts that are commonly recognized in the ZeroScope model context is recommended.
 
229
  """
230
 
231
  gr.HTML(description)
232
+ dummy_note = gr.Textbox(interactive=True, label="Note", visible=False)
233
+
234
+ keyframes = []
235
+ prompts = []
236
+ bboxes = []
237
+ frame_indices = []
238
+
239
  with gr.Row():
240
  with gr.Column(scale=2):
241
  with gr.Row():
242
  with gr.Tab("Main"):
243
+
244
+ # text_prompt_tb = gr.Textbox(
245
+ # interactive=True, label="Keyframe: Prompt"
246
+ # )
247
+ # bboxes_tb = gr.Textbox(interactive=True, label="Keyframe: Bboxes")
248
+ # frame_tb = gr.Textbox(
249
+ # interactive=True, label="Keyframe: frame indices"
250
+ # )
251
+
252
+ dropdown = gr.Dropdown(
253
+ label="Number of keyframes",
254
+ choices=range(2, MAX_KEYS),
255
  )
256
+ for i in range(MAX_KEYS):
257
+ with gr.Row(visible=False) as row:
258
+ text = f"Keyframe #{i}"
259
+ text = gr.HTML(text, visible=True)
260
+ prompt = gr.Textbox(
261
+ None,
262
+ label=f"Prompt #{i}",
263
+ visible=True,
264
+ interactive=True,
265
+ scale=4,
266
+ )
267
+ frame_ids = gr.Textbox(
268
+ None,
269
+ label=f"Frame indice #{i}",
270
+ visible=True,
271
+ interactive=True,
272
+ scale=1,
273
+ )
274
+ bbox = gr.Textbox(
275
+ None,
276
+ label=f"BBox #{i}",
277
+ visible=True,
278
+ interactive=True,
279
+ scale=3,
280
+ )
281
+ prompts.append(prompt)
282
+ bboxes.append(bbox)
283
+ frame_indices.append(frame_ids)
284
+ keyframes.append(row)
285
+ dropdown.change(keyframe_update, dropdown, keyframes)
286
+
287
  with gr.Row():
288
  word_prompt_indices_tb = gr.Textbox(
289
  interactive=True, label="Word prompt indices:"
290
  )
291
+ text = "<strong>Hint</strong>: Each keyframe is associated with a prompt, frame indice, and the corresponding bbox. The bbox is the tuple of the four floats determining the four bbox corners (left, top, right, bottom) in normalized ratio. The word prompt indices is 1-indexed value to indicate the word in prompt. Note that we use <strong>COMMA</strong> to separate the multiple values."
292
  gr.HTML(text)
293
  with gr.Row():
294
  clear_btn = gr.Button(value="Clear")
 
399
  with gr.Row():
400
  out_gen_1 = gr.Video(visible=True, show_label=False)
401
 
402
+ with gr.Row():
403
+ gr.Markdown("## Two keyframes example")
404
  with gr.Row():
405
  gr.Examples(
406
  examples=[
407
  [
408
+ "assets/gradio/fish-RL.mp4",
409
+ "It generates clownfish at right, then move to left",
410
+ 2,
411
  "A clownfish swimming in a coral reef",
412
+ "A clownfish swimming in a coral reef",
413
+ "0",
414
+ "24",
415
+ "0.5,0.35,1.0,0.65",
416
+ "0.0,0.35,0.5,0.65",
417
+ "1,2",
418
  "123451232531",
 
 
 
 
 
 
 
 
 
 
 
 
419
  ],
420
  [
421
+ "assets/gradio/fish-TL2BR.mp4",
422
+ "The fish moves from top left to bottom right, from far to near.",
423
+ 2,
424
  "A fish swimming in the ocean",
425
+ "A fish swimming in the ocean",
426
+ "0",
427
+ "24",
428
+ "0.0,0.0,0.1,0.1",
429
+ "0.5,0.5,1.0,1.0",
430
  "1, 2",
431
  "0",
 
 
432
  ],
433
  [
434
+ "assets/gradio/tiger-TL2BR.mp4",
435
+ "Same with the above but now the prompt associates with tiger",
436
+ 2,
437
+ "A tiger walking alone down the street",
438
  "A tiger walking alone down the street",
439
+ "0",
440
+ "24",
441
+ "0.0,0.0,0.1,0.1",
442
+ "0.5,0.5,1.0,1.0",
443
  "1, 2",
444
  "0",
 
 
445
  ],
446
  [
447
+ "assets/gradio/Cat2Dog.mp4",
448
+ "The subject will deformed from cat to dog.",
449
+ 2,
450
+ "A white cat walking on the grass",
451
+ "A yellow dog walking on the grass",
452
+ "0",
453
+ "24",
454
+ "0.7,0.4,1.0,0.65",
455
+ "0.0,0.4,0.3,0.65",
456
  "1,2,3",
457
  "123451232531",
 
 
458
  ],
459
  ],
460
+ inputs=[
461
+ out_gen_1,
462
+ dummy_note,
463
+ dropdown,
464
+ prompts[0],
465
+ prompts[1],
466
+ frame_indices[0],
467
+ frame_indices[1],
468
+ bboxes[0],
469
+ bboxes[1],
470
+ word_prompt_indices_tb,
471
+ rand_seed,
472
+ ],
473
+ outputs=None,
474
+ fn=None,
475
+ cache_examples=False,
476
+ )
477
+
478
+ with gr.Row():
479
+ gr.Markdown("## Five keyframes example")
480
+ with gr.Row():
481
+ gr.Examples(
482
+ examples=[
483
+ [
484
+ "assets/gradio/cat-LRLR.mp4",
485
+ "The poor cat will run Left/Right/Left/Right :(",
486
+ 5,
487
+ "A cat is running on the grass",
488
+ "A cat is running on the grass",
489
+ "A cat is running on the grass",
490
+ "A cat is running on the grass",
491
+ "A cat is running on the grass",
492
+ "0",
493
+ "6",
494
+ "12",
495
+ "18",
496
+ "24",
497
+ "0.0,0.35,0.4,0.65",
498
+ "0.6,0.35,1.0,0.65",
499
+ "0.0,0.35,0.4,0.65",
500
+ "0.6,0.35,1.0,0.65",
501
+ "0.0,0.35,0.4,0.65",
502
+ "1, 2",
503
+ "123451232530",
504
+ ],
505
+ ],
506
+ inputs=[
507
+ out_gen_1,
508
+ dummy_note,
509
+ dropdown,
510
+ prompts[0],
511
+ prompts[1],
512
+ prompts[2],
513
+ prompts[3],
514
+ prompts[4],
515
+ frame_indices[0],
516
+ frame_indices[1],
517
+ frame_indices[2],
518
+ frame_indices[3],
519
+ frame_indices[4],
520
+ bboxes[0],
521
+ bboxes[1],
522
+ bboxes[2],
523
+ bboxes[3],
524
+ bboxes[4],
525
+ word_prompt_indices_tb,
526
+ rand_seed,
527
+ ],
528
  outputs=None,
529
  fn=None,
530
  cache_examples=False,
 
533
  clear_btn.click(
534
  clear_btn_fn,
535
  inputs=[],
536
+ outputs=[dropdown, *prompts, *bboxes, *frame_indices, word_prompt_indices_tb],
537
  queue=False,
538
  )
539
 
540
  gen_btn.click(
541
  gen_btn_fn,
542
  inputs=[
543
+ *prompts,
544
+ *bboxes,
545
+ *frame_indices,
546
  word_prompt_indices_tb,
547
  trailing_length,
548
  n_spatial_steps,