1inkusFace commited on
Commit
9191d3a
·
verified ·
1 Parent(s): 5568046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -37
app.py CHANGED
@@ -31,7 +31,6 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
 
32
  logger = logging.getLogger(__name__)
33
 
34
-
35
  # --- Dummy Classes (Keep for standalone execution) ---
36
  class OffloadConfig:
37
  def __init__(
@@ -46,30 +45,24 @@ class OffloadConfig:
46
  self.compiler_transformer = compiler_transformer
47
  self.compiler_cache = compiler_cache
48
 
49
-
50
  class TaskType: # Keep here for infer
51
  T2V = 0
52
  I2V = 1
53
 
54
-
55
  class LlamaModel:
56
  @staticmethod
57
  def from_pretrained(*args, **kwargs):
58
  return LlamaModel()
59
-
60
  def to(self, device):
61
  return self
62
 
63
-
64
  class HunyuanVideoTransformer3DModel:
65
  @staticmethod
66
  def from_pretrained(*args, **kwargs):
67
  return HunyuanVideoTransformer3DModel()
68
-
69
  def to(self, device):
70
  return self
71
 
72
-
73
  class SkyreelsVideoPipeline:
74
  @staticmethod
75
  def from_pretrained(*args, **kwargs):
@@ -82,21 +75,17 @@ class SkyreelsVideoPipeline:
82
  num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
83
  height = kwargs.get("height", 512)
84
  width = kwargs.get("width", 512)
85
-
86
  if "image" in kwargs: # I2V
87
  image = kwargs["image"]
88
  # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
89
  image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
90
  image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
91
-
92
  # Create video by repeating the image
93
  frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
94
  frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
95
  # frames = frames.permute(0, 2, 1, 3, 4) # NO PERMUTE HERE
96
-
97
  else: # T2V
98
  frames = torch.randn(1, 3, num_frames, height, width) # Use correct dims: (1, C, T, H, W)
99
-
100
  return type("obj", (object,), {"frames": frames})() # No longer a list!
101
 
102
  def __init__(self):
@@ -112,18 +101,12 @@ class SkyreelsVideoPipeline:
112
  def enable_tiling(self):
113
  pass
114
 
115
-
116
  def quantize_(*args, **kwargs):
117
  return
118
 
119
-
120
  def float8_weight_only():
121
  return
122
 
123
-
124
- # --- End Dummy Classes ---
125
-
126
-
127
  class SkyReelsVideoSingleGpuInfer:
128
  def _load_model(
129
  self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
@@ -135,7 +118,6 @@ class SkyReelsVideoSingleGpuInfer:
135
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
136
  model_id, torch_dtype=torch.bfloat16, device="cpu"
137
  ).to("cpu")
138
-
139
  if quant_model:
140
  quantize_(text_encoder, float8_weight_only())
141
  text_encoder.to("cpu")
@@ -143,7 +125,6 @@ class SkyReelsVideoSingleGpuInfer:
143
  quantize_(transformer, float8_weight_only())
144
  transformer.to("cpu")
145
  torch.cuda.empty_cache()
146
-
147
  pipe = SkyreelsVideoPipeline.from_pretrained(
148
  base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
149
  ).to("cpu")
@@ -174,18 +155,14 @@ class SkyReelsVideoSingleGpuInfer:
174
  """Initializes the model and moves it to the GPU."""
175
  if self.is_initialized:
176
  return
177
-
178
  if not torch.cuda.is_available():
179
  raise RuntimeError("CUDA is not available. Cannot initialize model.")
180
-
181
  self.gpu_device = "cuda:0"
182
  self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
183
-
184
  if self.is_offload:
185
  pass # Offloading logic (if any) would go here
186
  else:
187
  self.pipe.to(self.gpu_device)
188
-
189
  if self.offload_config.compiler_transformer:
190
  torch._dynamo.config.suppress_errors = True
191
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
@@ -200,7 +177,6 @@ class SkyReelsVideoSingleGpuInfer:
200
  def warm_up(self):
201
  if not self.is_initialized:
202
  raise RuntimeError("Model must be initialized before warm-up.")
203
-
204
  init_kwargs = {
205
  "prompt": "A woman is dancing in a room",
206
  "height": 544,
@@ -228,10 +204,8 @@ class SkyReelsVideoSingleGpuInfer:
228
  result = self.pipe(**kwargs).frames # Return the tensor directly
229
  return result
230
 
231
-
232
  _predictor = None
233
 
234
-
235
  @spaces.GPU(duration=90)
236
  def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
237
  """Generates a video based on the given prompt and seed.
@@ -245,11 +219,9 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
245
  A tuple containing the path to the generated video and the parameters used.
246
  """
247
  global _predictor
248
-
249
  if seed == -1:
250
  random.seed()
251
  seed = int(random.randrange(4294967294))
252
-
253
  if image is None:
254
  task_type = TaskType.T2V
255
  model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
@@ -279,7 +251,6 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
279
  "negative_prompt": "Aerial view, low quality, bad hands",
280
  "cfg_for": False, #Keep if present in the original
281
  }
282
-
283
  if _predictor is None:
284
  _predictor = SkyReelsVideoSingleGpuInfer(
285
  task_type=task_type,
@@ -294,15 +265,12 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
294
  )
295
  _predictor.initialize()
296
  logger.info("Predictor initialized")
297
-
298
  with torch.no_grad():
299
  output = _predictor.infer(**kwargs) #Removed [0]
300
-
301
  output = (output.numpy() * 255).astype(np.uint8)
302
- # CRITICAL CHANGE: Transpose *after* converting to numpy and taking output[0]
303
- #output = output.transpose(1, 2, 0, 3) # (T, H, W, C)
304
- print(output.shape)
305
- print(output[0].shape)
306
  save_dir = f"./result"
307
  os.makedirs(save_dir, exist_ok=True)
308
  video_out_file = f"{save_dir}/{seed}.mp4"
@@ -310,7 +278,6 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
310
  export_to_video(output, video_out_file, fps=24)
311
  return video_out_file, kwargs
312
 
313
-
314
  def create_gradio_interface():
315
  with gr.Blocks() as demo:
316
  with gr.Row():
@@ -330,7 +297,6 @@ def create_gradio_interface():
330
  )
331
  return demo
332
 
333
-
334
  if __name__ == "__main__":
335
  demo = create_gradio_interface()
336
  demo.queue().launch()
 
31
 
32
  logger = logging.getLogger(__name__)
33
 
 
34
  # --- Dummy Classes (Keep for standalone execution) ---
35
  class OffloadConfig:
36
  def __init__(
 
45
  self.compiler_transformer = compiler_transformer
46
  self.compiler_cache = compiler_cache
47
 
 
48
  class TaskType: # Keep here for infer
49
  T2V = 0
50
  I2V = 1
51
 
 
52
  class LlamaModel:
53
  @staticmethod
54
  def from_pretrained(*args, **kwargs):
55
  return LlamaModel()
 
56
  def to(self, device):
57
  return self
58
 
 
59
  class HunyuanVideoTransformer3DModel:
60
  @staticmethod
61
  def from_pretrained(*args, **kwargs):
62
  return HunyuanVideoTransformer3DModel()
 
63
  def to(self, device):
64
  return self
65
 
 
66
  class SkyreelsVideoPipeline:
67
  @staticmethod
68
  def from_pretrained(*args, **kwargs):
 
75
  num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
76
  height = kwargs.get("height", 512)
77
  width = kwargs.get("width", 512)
 
78
  if "image" in kwargs: # I2V
79
  image = kwargs["image"]
80
  # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
81
  image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
82
  image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
 
83
  # Create video by repeating the image
84
  frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
85
  frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
86
  # frames = frames.permute(0, 2, 1, 3, 4) # NO PERMUTE HERE
 
87
  else: # T2V
88
  frames = torch.randn(1, 3, num_frames, height, width) # Use correct dims: (1, C, T, H, W)
 
89
  return type("obj", (object,), {"frames": frames})() # No longer a list!
90
 
91
  def __init__(self):
 
101
  def enable_tiling(self):
102
  pass
103
 
 
104
  def quantize_(*args, **kwargs):
105
  return
106
 
 
107
  def float8_weight_only():
108
  return
109
 
 
 
 
 
110
  class SkyReelsVideoSingleGpuInfer:
111
  def _load_model(
112
  self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
 
118
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
119
  model_id, torch_dtype=torch.bfloat16, device="cpu"
120
  ).to("cpu")
 
121
  if quant_model:
122
  quantize_(text_encoder, float8_weight_only())
123
  text_encoder.to("cpu")
 
125
  quantize_(transformer, float8_weight_only())
126
  transformer.to("cpu")
127
  torch.cuda.empty_cache()
 
128
  pipe = SkyreelsVideoPipeline.from_pretrained(
129
  base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
130
  ).to("cpu")
 
155
  """Initializes the model and moves it to the GPU."""
156
  if self.is_initialized:
157
  return
 
158
  if not torch.cuda.is_available():
159
  raise RuntimeError("CUDA is not available. Cannot initialize model.")
 
160
  self.gpu_device = "cuda:0"
161
  self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
 
162
  if self.is_offload:
163
  pass # Offloading logic (if any) would go here
164
  else:
165
  self.pipe.to(self.gpu_device)
 
166
  if self.offload_config.compiler_transformer:
167
  torch._dynamo.config.suppress_errors = True
168
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
 
177
  def warm_up(self):
178
  if not self.is_initialized:
179
  raise RuntimeError("Model must be initialized before warm-up.")
 
180
  init_kwargs = {
181
  "prompt": "A woman is dancing in a room",
182
  "height": 544,
 
204
  result = self.pipe(**kwargs).frames # Return the tensor directly
205
  return result
206
 
 
207
  _predictor = None
208
 
 
209
  @spaces.GPU(duration=90)
210
  def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
211
  """Generates a video based on the given prompt and seed.
 
219
  A tuple containing the path to the generated video and the parameters used.
220
  """
221
  global _predictor
 
222
  if seed == -1:
223
  random.seed()
224
  seed = int(random.randrange(4294967294))
 
225
  if image is None:
226
  task_type = TaskType.T2V
227
  model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
 
251
  "negative_prompt": "Aerial view, low quality, bad hands",
252
  "cfg_for": False, #Keep if present in the original
253
  }
 
254
  if _predictor is None:
255
  _predictor = SkyReelsVideoSingleGpuInfer(
256
  task_type=task_type,
 
265
  )
266
  _predictor.initialize()
267
  logger.info("Predictor initialized")
 
268
  with torch.no_grad():
269
  output = _predictor.infer(**kwargs) #Removed [0]
 
270
  output = (output.numpy() * 255).astype(np.uint8)
271
+ # Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
272
+ output = output.transpose(0, 2, 3, 4, 1)
273
+ output = output[0] # Remove batch dimension: (T, H, W, C)
 
274
  save_dir = f"./result"
275
  os.makedirs(save_dir, exist_ok=True)
276
  video_out_file = f"{save_dir}/{seed}.mp4"
 
278
  export_to_video(output, video_out_file, fps=24)
279
  return video_out_file, kwargs
280
 
 
281
  def create_gradio_interface():
282
  with gr.Blocks() as demo:
283
  with gr.Row():
 
297
  )
298
  return demo
299
 
 
300
  if __name__ == "__main__":
301
  demo = create_gradio_interface()
302
  demo.queue().launch()