Xenobd commited on
Commit
ca5f9a7
·
verified ·
1 Parent(s): b58dd63

Update simple_app.py

Browse files
Files changed (1) hide show
  1. simple_app.py +25 -51
simple_app.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  from tqdm import tqdm
6
  from huggingface_hub import snapshot_download
7
  import torch
 
8
 
9
  # Force the device to CPU
10
  device = torch.device("cpu")
@@ -17,29 +18,24 @@ snapshot_download(
17
  print("Model downloaded successfully.")
18
 
19
  def infer(prompt, progress=gr.Progress(track_tqdm=True)):
20
- # Configuration:
21
- total_process_steps = 11 # Total INFO messages expected
22
- irrelevant_steps = 4 # First 4 INFO messages are ignored
23
- relevant_steps = total_process_steps - irrelevant_steps # 7 overall steps
24
 
25
- # Create overall progress bar (Level 1)
26
  overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
27
  ncols=120, dynamic_ncols=False, leave=True)
28
  processed_steps = 0
29
 
30
- # Regex for video generation progress (Level 3)
31
  progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
32
  video_progress_bar = None
33
 
34
- # Variables for sub-step progress bar (Level 2)
35
  sub_bar = None
36
  sub_ticks = 0
37
  sub_tick_total = 1500
38
  video_phase = False
39
 
40
- # Command to run the video generation
41
  command = [
42
- "python", "-u", "-m", "generate", # using -u for unbuffered output
43
  "--task", "t2v-1.3B",
44
  "--size", "480*480",
45
  "--ckpt_dir", "./Wan2.1-T2V-1.3B",
@@ -47,40 +43,32 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
47
  "--sample_guide_scale", "6",
48
  "--prompt", prompt,
49
  "--t5_cpu",
50
- "--offload_model", "True", # Change from True (bool) to "True" (str)
51
  "--save_file", "generated_video.mp4"
52
  ]
 
53
  print("Starting video generation process...")
 
 
 
 
 
 
 
54
 
55
- process = subprocess.Popen(command,
56
- stdout=subprocess.PIPE,
57
- stderr=subprocess.STDOUT,
58
- text=True,
59
- bufsize=1)
60
-
61
- # Print logs
62
-
63
  stdout = process.stdout
64
- stderr = process.stderr
65
- print(stdout)
66
- while True:
67
- line = stdout.readline()
68
- if not line:
69
- break
70
  stripped_line = line.strip()
71
  if not stripped_line:
72
  continue
73
 
74
- # Check for video generation progress (Level 3)
75
  progress_match = progress_pattern.search(stripped_line)
76
-
77
  if progress_match:
78
- if sub_bar is not None:
79
- if sub_ticks < sub_tick_total:
80
- sub_bar.update(sub_tick_total - sub_ticks)
81
  sub_bar.close()
82
  overall_bar.update(1)
83
- overall_bar.refresh()
84
  sub_bar = None
85
  sub_ticks = 0
86
  video_phase = True
@@ -90,36 +78,28 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
90
  video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
91
  ncols=120, dynamic_ncols=True, leave=True)
92
  video_progress_bar.update(current - video_progress_bar.n)
93
- video_progress_bar.refresh()
94
  if video_progress_bar.n >= video_progress_bar.total:
95
  video_phase = False
96
  overall_bar.update(1)
97
- overall_bar.refresh()
98
  video_progress_bar.close()
99
  video_progress_bar = None
100
  continue
101
 
102
- # Process INFO messages (Level 2 sub-step)
103
  if "INFO:" in stripped_line:
104
  parts = stripped_line.split("INFO:", 1)
105
  msg = parts[1].strip() if len(parts) > 1 else ""
106
- print(f"[INFO]: {msg}") # Log the message
107
 
108
- # For the first 4 INFO messages, simply count them.
109
  if processed_steps < irrelevant_steps:
110
  processed_steps += 1
111
  continue
112
  else:
113
- # A new relevant INFO message has arrived.
114
- if sub_bar is not None:
115
- if sub_ticks < sub_tick_total:
116
- sub_bar.update(sub_tick_total - sub_ticks)
117
  sub_bar.close()
118
  overall_bar.update(1)
119
- overall_bar.refresh()
120
  sub_bar = None
121
  sub_ticks = 0
122
- # Start a new sub-step bar for the current INFO message.
123
  sub_bar = tqdm(total=sub_tick_total, desc=msg, position=2,
124
  ncols=120, dynamic_ncols=False, leave=True)
125
  sub_ticks = 0
@@ -127,28 +107,22 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
127
  else:
128
  print(stripped_line)
129
 
130
- # Drain any remaining output
131
- for line in process.stdout:
132
- print(line.strip())
133
-
134
  process.wait()
135
 
136
- # Finalize progress bars
137
  if video_progress_bar is not None:
138
  video_progress_bar.close()
139
  if sub_bar is not None:
140
  sub_bar.close()
141
  overall_bar.close()
142
 
143
- # Add log for successful video generation
144
- if process.returncode == 0:
145
  print("Video generation completed successfully.")
146
  return "generated_video.mp4"
147
  else:
148
- print("Error executing command.")
149
- raise Exception("Error executing command")
150
 
151
- # Gradio UI to trigger inference
152
  with gr.Blocks() as demo:
153
  with gr.Column():
154
  gr.Markdown("# Wan 2.1 1.3B")
 
5
  from tqdm import tqdm
6
  from huggingface_hub import snapshot_download
7
  import torch
8
+ import os
9
 
10
  # Force the device to CPU
11
  device = torch.device("cpu")
 
18
  print("Model downloaded successfully.")
19
 
20
  def infer(prompt, progress=gr.Progress(track_tqdm=True)):
21
+ total_process_steps = 11
22
+ irrelevant_steps = 4
23
+ relevant_steps = total_process_steps - irrelevant_steps
 
24
 
 
25
  overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
26
  ncols=120, dynamic_ncols=False, leave=True)
27
  processed_steps = 0
28
 
 
29
  progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
30
  video_progress_bar = None
31
 
 
32
  sub_bar = None
33
  sub_ticks = 0
34
  sub_tick_total = 1500
35
  video_phase = False
36
 
 
37
  command = [
38
+ "python", "-u", "-m", "generate",
39
  "--task", "t2v-1.3B",
40
  "--size", "480*480",
41
  "--ckpt_dir", "./Wan2.1-T2V-1.3B",
 
43
  "--sample_guide_scale", "6",
44
  "--prompt", prompt,
45
  "--t5_cpu",
46
+ "--offload_model", "True",
47
  "--save_file", "generated_video.mp4"
48
  ]
49
+
50
  print("Starting video generation process...")
51
+ process = subprocess.Popen(
52
+ command,
53
+ stdout=subprocess.PIPE,
54
+ stderr=subprocess.STDOUT,
55
+ text=True,
56
+ bufsize=1
57
+ )
58
 
 
 
 
 
 
 
 
 
59
  stdout = process.stdout
60
+
61
+ for line in iter(stdout.readline, ''):
 
 
 
 
62
  stripped_line = line.strip()
63
  if not stripped_line:
64
  continue
65
 
 
66
  progress_match = progress_pattern.search(stripped_line)
 
67
  if progress_match:
68
+ if sub_bar is not None and sub_ticks < sub_tick_total:
69
+ sub_bar.update(sub_tick_total - sub_ticks)
 
70
  sub_bar.close()
71
  overall_bar.update(1)
 
72
  sub_bar = None
73
  sub_ticks = 0
74
  video_phase = True
 
78
  video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
79
  ncols=120, dynamic_ncols=True, leave=True)
80
  video_progress_bar.update(current - video_progress_bar.n)
 
81
  if video_progress_bar.n >= video_progress_bar.total:
82
  video_phase = False
83
  overall_bar.update(1)
 
84
  video_progress_bar.close()
85
  video_progress_bar = None
86
  continue
87
 
 
88
  if "INFO:" in stripped_line:
89
  parts = stripped_line.split("INFO:", 1)
90
  msg = parts[1].strip() if len(parts) > 1 else ""
91
+ print(f"[INFO]: {msg}")
92
 
 
93
  if processed_steps < irrelevant_steps:
94
  processed_steps += 1
95
  continue
96
  else:
97
+ if sub_bar is not None and sub_ticks < sub_tick_total:
98
+ sub_bar.update(sub_tick_total - sub_ticks)
 
 
99
  sub_bar.close()
100
  overall_bar.update(1)
 
101
  sub_bar = None
102
  sub_ticks = 0
 
103
  sub_bar = tqdm(total=sub_tick_total, desc=msg, position=2,
104
  ncols=120, dynamic_ncols=False, leave=True)
105
  sub_ticks = 0
 
107
  else:
108
  print(stripped_line)
109
 
 
 
 
 
110
  process.wait()
111
 
 
112
  if video_progress_bar is not None:
113
  video_progress_bar.close()
114
  if sub_bar is not None:
115
  sub_bar.close()
116
  overall_bar.close()
117
 
118
+ if process.returncode == 0 and os.path.exists("generated_video.mp4"):
 
119
  print("Video generation completed successfully.")
120
  return "generated_video.mp4"
121
  else:
122
+ print("Error: Video generation failed.")
123
+ raise gr.Error("Video generation failed. Check logs for details.")
124
 
125
+ # Gradio UI
126
  with gr.Blocks() as demo:
127
  with gr.Column():
128
  gr.Markdown("# Wan 2.1 1.3B")