1inkusFace commited on
Commit
b113647
·
verified ·
1 Parent(s): e6ed20e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -12
app.py CHANGED
@@ -8,12 +8,7 @@ import subprocess
8
  from PIL import Image
9
  import numpy as np
10
 
11
- # subprocess.run(['sh', './sky.sh']) # Removed as it's likely environment-specific
12
- # sys.path.append("./SkyReels-V1") # Removed as it's likely environment-specific
13
-
14
- # from skyreelsinfer import TaskType # Dummy classes cover this
15
- # from skyreelsinfer.offload import OffloadConfig # Dummy classes cover this
16
- # from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer # Dummy classes cover this
17
  from diffusers.utils import export_to_video
18
 
19
  import torch
@@ -31,6 +26,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
 
32
  logger = logging.getLogger(__name__)
33
 
 
34
  # --- Dummy Classes (Keep for standalone execution) ---
35
  class OffloadConfig:
36
  def __init__(
@@ -45,24 +41,30 @@ class OffloadConfig:
45
  self.compiler_transformer = compiler_transformer
46
  self.compiler_cache = compiler_cache
47
 
 
48
  class TaskType: # Keep here for infer
49
  T2V = 0
50
  I2V = 1
51
 
 
52
  class LlamaModel:
53
  @staticmethod
54
  def from_pretrained(*args, **kwargs):
55
  return LlamaModel()
 
56
  def to(self, device):
57
  return self
58
 
 
59
  class HunyuanVideoTransformer3DModel:
60
  @staticmethod
61
  def from_pretrained(*args, **kwargs):
62
  return HunyuanVideoTransformer3DModel()
 
63
  def to(self, device):
64
  return self
65
 
 
66
  class SkyreelsVideoPipeline:
67
  @staticmethod
68
  def from_pretrained(*args, **kwargs):
@@ -75,17 +77,21 @@ class SkyreelsVideoPipeline:
75
  num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
76
  height = kwargs.get("height", 512)
77
  width = kwargs.get("width", 512)
 
78
  if "image" in kwargs: # I2V
79
  image = kwargs["image"]
80
  # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
81
  image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
82
  image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
 
83
  # Create video by repeating the image
84
  frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
85
  frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
86
- # frames = frames.permute(0, 2, 1, 3, 4) # NO PERMUTE HERE
 
87
  else: # T2V
88
- frames = torch.randn(1, 3, num_frames, height, width) # Use correct dims: (1, C, T, H, W)
 
89
  return type("obj", (object,), {"frames": frames})() # No longer a list!
90
 
91
  def __init__(self):
@@ -101,12 +107,18 @@ class SkyreelsVideoPipeline:
101
  def enable_tiling(self):
102
  pass
103
 
 
104
  def quantize_(*args, **kwargs):
105
  return
106
 
 
107
  def float8_weight_only():
108
  return
109
 
 
 
 
 
110
  class SkyReelsVideoSingleGpuInfer:
111
  def _load_model(
112
  self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
@@ -118,6 +130,7 @@ class SkyReelsVideoSingleGpuInfer:
118
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
119
  model_id, torch_dtype=torch.bfloat16, device="cpu"
120
  ).to("cpu")
 
121
  if quant_model:
122
  quantize_(text_encoder, float8_weight_only())
123
  text_encoder.to("cpu")
@@ -125,6 +138,7 @@ class SkyReelsVideoSingleGpuInfer:
125
  quantize_(transformer, float8_weight_only())
126
  transformer.to("cpu")
127
  torch.cuda.empty_cache()
 
128
  pipe = SkyreelsVideoPipeline.from_pretrained(
129
  base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
130
  ).to("cpu")
@@ -155,14 +169,18 @@ class SkyReelsVideoSingleGpuInfer:
155
  """Initializes the model and moves it to the GPU."""
156
  if self.is_initialized:
157
  return
 
158
  if not torch.cuda.is_available():
159
  raise RuntimeError("CUDA is not available. Cannot initialize model.")
 
160
  self.gpu_device = "cuda:0"
161
  self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
 
162
  if self.is_offload:
163
- pass # Offloading logic (if any) would go here
164
  else:
165
  self.pipe.to(self.gpu_device)
 
166
  if self.offload_config.compiler_transformer:
167
  torch._dynamo.config.suppress_errors = True
168
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
@@ -177,6 +195,7 @@ class SkyReelsVideoSingleGpuInfer:
177
  def warm_up(self):
178
  if not self.is_initialized:
179
  raise RuntimeError("Model must be initialized before warm-up.")
 
180
  init_kwargs = {
181
  "prompt": "A woman is dancing in a room",
182
  "height": 544,
@@ -204,8 +223,10 @@ class SkyReelsVideoSingleGpuInfer:
204
  result = self.pipe(**kwargs).frames # Return the tensor directly
205
  return result
206
 
 
207
  _predictor = None
208
 
 
209
  @spaces.GPU(duration=90)
210
  def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
211
  """Generates a video based on the given prompt and seed.
@@ -219,9 +240,11 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
219
  A tuple containing the path to the generated video and the parameters used.
220
  """
221
  global _predictor
 
222
  if seed == -1:
223
  random.seed()
224
  seed = int(random.randrange(4294967294))
 
225
  if image is None:
226
  task_type = TaskType.T2V
227
  model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
@@ -249,8 +272,9 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
249
  "guidance_scale": 6.0,
250
  "embedded_guidance_scale": 1.0,
251
  "negative_prompt": "Aerial view, low quality, bad hands",
252
- "cfg_for": False, #Keep if present in the original
253
  }
 
254
  if _predictor is None:
255
  _predictor = SkyReelsVideoSingleGpuInfer(
256
  task_type=task_type,
@@ -265,12 +289,16 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
265
  )
266
  _predictor.initialize()
267
  logger.info("Predictor initialized")
 
268
  with torch.no_grad():
269
- output = _predictor.infer(**kwargs) #Removed [0]
 
270
  output = (output.numpy() * 255).astype(np.uint8)
271
  # Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
272
  output = output.transpose(0, 2, 3, 4, 1)
273
- #output = output[0] # Remove batch dimension: (T, H, W, C)
 
 
274
  save_dir = f"./result"
275
  os.makedirs(save_dir, exist_ok=True)
276
  video_out_file = f"{save_dir}/{seed}.mp4"
@@ -278,6 +306,7 @@ def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict
278
  export_to_video(output, video_out_file, fps=24)
279
  return video_out_file, kwargs
280
 
 
281
  def create_gradio_interface():
282
  with gr.Blocks() as demo:
283
  with gr.Row():
@@ -297,6 +326,7 @@ def create_gradio_interface():
297
  )
298
  return demo
299
 
 
300
  if __name__ == "__main__":
301
  demo = create_gradio_interface()
302
  demo.queue().launch()
 
8
  from PIL import Image
9
  import numpy as np
10
 
11
+ # Removed environment-specific lines
 
 
 
 
 
12
  from diffusers.utils import export_to_video
13
 
14
  import torch
 
26
 
27
  logger = logging.getLogger(__name__)
28
 
29
+
30
  # --- Dummy Classes (Keep for standalone execution) ---
31
  class OffloadConfig:
32
  def __init__(
 
41
  self.compiler_transformer = compiler_transformer
42
  self.compiler_cache = compiler_cache
43
 
44
+
45
  class TaskType: # Keep here for infer
46
  T2V = 0
47
  I2V = 1
48
 
49
+
50
  class LlamaModel:
51
  @staticmethod
52
  def from_pretrained(*args, **kwargs):
53
  return LlamaModel()
54
+
55
  def to(self, device):
56
  return self
57
 
58
+
59
  class HunyuanVideoTransformer3DModel:
60
  @staticmethod
61
  def from_pretrained(*args, **kwargs):
62
  return HunyuanVideoTransformer3DModel()
63
+
64
  def to(self, device):
65
  return self
66
 
67
+
68
  class SkyreelsVideoPipeline:
69
  @staticmethod
70
  def from_pretrained(*args, **kwargs):
 
77
  num_frames = kwargs.get("num_frames", 16) # Default to 16 frames
78
  height = kwargs.get("height", 512)
79
  width = kwargs.get("width", 512)
80
+
81
  if "image" in kwargs: # I2V
82
  image = kwargs["image"]
83
  # Convert PIL Image to PyTorch tensor (and normalize to [0, 1])
84
  image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
85
  image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W)
86
+
87
  # Create video by repeating the image
88
  frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W)
89
  frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise
90
+ # Correct shape: (1, C, T, H, W) - NO PERMUTE HERE
91
+
92
  else: # T2V
93
+ frames = torch.randn(1, 3, num_frames, height, width) # (1, C, T, H, W) - Correct!
94
+
95
  return type("obj", (object,), {"frames": frames})() # No longer a list!
96
 
97
  def __init__(self):
 
107
  def enable_tiling(self):
108
  pass
109
 
110
+
111
  def quantize_(*args, **kwargs):
112
  return
113
 
114
+
115
  def float8_weight_only():
116
  return
117
 
118
+
119
+ # --- End Dummy Classes ---
120
+
121
+
122
  class SkyReelsVideoSingleGpuInfer:
123
  def _load_model(
124
  self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True
 
130
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
131
  model_id, torch_dtype=torch.bfloat16, device="cpu"
132
  ).to("cpu")
133
+
134
  if quant_model:
135
  quantize_(text_encoder, float8_weight_only())
136
  text_encoder.to("cpu")
 
138
  quantize_(transformer, float8_weight_only())
139
  transformer.to("cpu")
140
  torch.cuda.empty_cache()
141
+
142
  pipe = SkyreelsVideoPipeline.from_pretrained(
143
  base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16
144
  ).to("cpu")
 
169
  """Initializes the model and moves it to the GPU."""
170
  if self.is_initialized:
171
  return
172
+
173
  if not torch.cuda.is_available():
174
  raise RuntimeError("CUDA is not available. Cannot initialize model.")
175
+
176
  self.gpu_device = "cuda:0"
177
  self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model)
178
+
179
  if self.is_offload:
180
+ pass
181
  else:
182
  self.pipe.to(self.gpu_device)
183
+
184
  if self.offload_config.compiler_transformer:
185
  torch._dynamo.config.suppress_errors = True
186
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
 
195
  def warm_up(self):
196
  if not self.is_initialized:
197
  raise RuntimeError("Model must be initialized before warm-up.")
198
+
199
  init_kwargs = {
200
  "prompt": "A woman is dancing in a room",
201
  "height": 544,
 
223
  result = self.pipe(**kwargs).frames # Return the tensor directly
224
  return result
225
 
226
+
227
  _predictor = None
228
 
229
+
230
  @spaces.GPU(duration=90)
231
  def generate_video(prompt: str, seed: int, image: str = None) -> tuple[str, dict]:
232
  """Generates a video based on the given prompt and seed.
 
240
  A tuple containing the path to the generated video and the parameters used.
241
  """
242
  global _predictor
243
+
244
  if seed == -1:
245
  random.seed()
246
  seed = int(random.randrange(4294967294))
247
+
248
  if image is None:
249
  task_type = TaskType.T2V
250
  model_id = "Skywork/SkyReels-V1-Hunyuan-T2V"
 
272
  "guidance_scale": 6.0,
273
  "embedded_guidance_scale": 1.0,
274
  "negative_prompt": "Aerial view, low quality, bad hands",
275
+ "cfg_for": False,
276
  }
277
+
278
  if _predictor is None:
279
  _predictor = SkyReelsVideoSingleGpuInfer(
280
  task_type=task_type,
 
289
  )
290
  _predictor.initialize()
291
  logger.info("Predictor initialized")
292
+
293
  with torch.no_grad():
294
+ output = _predictor.infer(**kwargs)
295
+
296
  output = (output.numpy() * 255).astype(np.uint8)
297
  # Correct Transpose: (1, C, T, H, W) -> (1, T, H, W, C)
298
  output = output.transpose(0, 2, 3, 4, 1)
299
+ output = output[0] # Remove batch dimension: (T, H, W, C)
300
+
301
+
302
  save_dir = f"./result"
303
  os.makedirs(save_dir, exist_ok=True)
304
  video_out_file = f"{save_dir}/{seed}.mp4"
 
306
  export_to_video(output, video_out_file, fps=24)
307
  return video_out_file, kwargs
308
 
309
+
310
  def create_gradio_interface():
311
  with gr.Blocks() as demo:
312
  with gr.Row():
 
326
  )
327
  return demo
328
 
329
+
330
  if __name__ == "__main__":
331
  demo = create_gradio_interface()
332
  demo.queue().launch()