1inkusFace commited on
Commit
3587696
·
verified ·
1 Parent(s): 5aa86e1

Update skyreelsinfer/skyreels_video_infer.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/skyreels_video_infer.py +24 -148
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -1,23 +1,19 @@
1
  import logging
2
  import os
3
- import threading
4
  import time
5
  from datetime import timedelta
6
  from typing import Any
7
  from typing import Dict
8
 
9
  import torch
10
- import torch.distributed as dist
11
- import torch.multiprocessing as mp
12
  from diffusers import HunyuanVideoTransformer3DModel
13
  from PIL import Image
14
  from torchao.quantization import float8_weight_only
15
  from torchao.quantization import quantize_
16
  from transformers import LlamaModel
17
 
18
- from . import TaskType
19
- from .offload import Offload
20
- from .offload import OffloadConfig
21
  from .pipelines import SkyreelsVideoPipeline
22
 
23
  logger = logging.getLogger("SkyreelsVideoInfer")
@@ -30,7 +26,6 @@ formatter = logging.Formatter(
30
  console_handler.setFormatter(formatter)
31
  logger.addHandler(console_handler)
32
 
33
-
34
  class SkyReelsVideoSingleGpuInfer:
35
  def _load_model(
36
  self,
@@ -44,28 +39,26 @@ class SkyReelsVideoSingleGpuInfer:
44
  base_model_id,
45
  subfolder="text_encoder",
46
  torch_dtype=torch.bfloat16,
47
- ).to("cpu")
48
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
49
  model_id,
50
  # subfolder="transformer",
51
  torch_dtype=torch.bfloat16,
52
- device="cpu",
53
- ).to("cpu")
54
  if quant_model:
55
- quantize_(text_encoder, float8_weight_only(), device=gpu_device)
56
- text_encoder.to("cpu")
57
- torch.cuda.empty_cache()
58
- quantize_(transformer, float8_weight_only(), device=gpu_device)
59
- transformer.to("cpu")
60
- torch.cuda.empty_cache()
61
  pipe = SkyreelsVideoPipeline.from_pretrained(
62
  base_model_id,
63
  transformer=transformer,
64
  text_encoder=text_encoder,
65
  torch_dtype=torch.bfloat16,
66
- ).to("cpu")
67
  pipe.vae.enable_tiling()
68
- torch.cuda.empty_cache()
69
  return pipe
70
 
71
  def __init__(
@@ -73,39 +66,19 @@ class SkyReelsVideoSingleGpuInfer:
73
  task_type: TaskType,
74
  model_id: str,
75
  quant_model: bool = True,
76
- local_rank: int = 0,
77
- world_size: int = 1,
78
  is_offload: bool = True,
79
  offload_config: OffloadConfig = OffloadConfig(),
80
- enable_cfg_parallel: bool = True,
81
  ):
82
  self.task_type = task_type
83
- self.gpu_rank = local_rank
84
- os.environ["LOCAL_RANK"] = str(local_rank)
85
- torch.cuda.set_device(0)
86
- torch.backends.cuda.enable_cudnn_sdp(False)
87
  gpu_device = "cuda:0"
88
 
89
  self.pipe: SkyreelsVideoPipeline = self._load_model(
90
  model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
91
  )
92
 
93
- from para_attn.context_parallel import init_context_parallel_mesh
94
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
95
- from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
96
-
97
- max_batch_dim_size = 2 if enable_cfg_parallel and world_size > 1 else 1
98
- max_ulysses_dim_size = int(world_size / max_batch_dim_size)
99
- logger.info(f"max_batch_dim_size: {max_batch_dim_size}, max_ulysses_dim_size:{max_ulysses_dim_size}")
100
-
101
- mesh = init_context_parallel_mesh(
102
- self.pipe.device.type,
103
- max_ring_dim_size=1,
104
- max_batch_dim_size=max_batch_dim_size,
105
- )
106
- parallelize_pipe(self.pipe, mesh=mesh)
107
- parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
108
-
109
  if is_offload:
110
  Offload.offload(
111
  pipeline=self.pipe,
@@ -117,7 +90,7 @@ class SkyReelsVideoSingleGpuInfer:
117
  if offload_config.compiler_transformer:
118
  torch._dynamo.config.suppress_errors = True
119
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
120
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_{world_size}"
121
  self.pipe.transformer = torch.compile(
122
  self.pipe.transformer,
123
  mode="max-autotune-no-cudagraphs",
@@ -141,110 +114,13 @@ class SkyReelsVideoSingleGpuInfer:
141
  init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
142
  self.pipe(**init_kwargs)
143
 
144
- def damon_inference(self, request_queue: mp.Queue, response_queue: mp.Queue):
145
- response_queue.put(f"rank:{self.gpu_rank} ready")
146
- logger.info(f"rank:{self.gpu_rank} finish init pipe")
147
- while True:
148
- logger.info(f"rank:{self.gpu_rank} waiting for request")
149
- kwargs = request_queue.get()
150
- logger.info(f"rank:{self.gpu_rank} kwargs: {kwargs}")
151
- if "seed" in kwargs:
152
- kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
153
- del kwargs["seed"]
154
- start_time = time.time()
155
- assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
156
- out = self.pipe(**kwargs).frames[0]
157
- logger.info(f"rank:{dist.get_rank()} inference time: {time.time() - start_time}")
158
- if dist.get_rank() == 0:
159
- response_queue.put(out)
160
-
161
-
162
- def single_gpu_run(
163
- rank,
164
- task_type: TaskType,
165
- model_id: str,
166
- request_queue: mp.Queue,
167
- response_queue: mp.Queue,
168
- quant_model: bool = True,
169
- world_size: int = 1,
170
- is_offload: bool = True,
171
- offload_config: OffloadConfig = OffloadConfig(),
172
- enable_cfg_parallel: bool = True,
173
- ):
174
- pipe = SkyReelsVideoSingleGpuInfer(
175
- task_type=task_type,
176
- model_id=model_id,
177
- quant_model=quant_model,
178
- local_rank=rank,
179
- world_size=world_size,
180
- is_offload=is_offload,
181
- offload_config=offload_config,
182
- enable_cfg_parallel=enable_cfg_parallel,
183
- )
184
- pipe.damon_inference(request_queue, response_queue)
185
-
186
-
187
- class SkyReelsVideoInfer:
188
- def __init__(
189
- self,
190
- task_type: TaskType,
191
- model_id: str,
192
- quant_model: bool = True,
193
- world_size: int = 1,
194
- is_offload: bool = True,
195
- offload_config: OffloadConfig = OffloadConfig(),
196
- enable_cfg_parallel: bool = True,
197
- ):
198
- self.world_size = world_size
199
- smp = mp.get_context("spawn")
200
- self.REQ_QUEUES: mp.Queue = smp.Queue()
201
- self.RESP_QUEUE: mp.Queue = smp.Queue()
202
- assert self.world_size > 0, "gpu_num must be greater than 0"
203
- spawn_thread = threading.Thread(
204
- target=self.lauch_single_gpu_infer,
205
- args=(task_type, model_id, quant_model, world_size, is_offload, offload_config, enable_cfg_parallel),
206
- daemon=True,
207
- )
208
- spawn_thread.start()
209
- logger.info(f"Started multi-GPU thread with GPU_NUM: {world_size}")
210
- print(f"Started multi-GPU thread with GPU_NUM: {world_size}")
211
- # Block and wait for the prediction process to start
212
- for _ in range(world_size):
213
- msg = self.RESP_QUEUE.get()
214
- logger.info(f"launch_multi_gpu get init msg: {msg}")
215
- print(f"launch_multi_gpu get init msg: {msg}")
216
-
217
- def lauch_single_gpu_infer(
218
- self,
219
- task_type: TaskType,
220
- model_id: str,
221
- quant_model: bool = True,
222
- world_size: int = 1,
223
- is_offload: bool = True,
224
- offload_config: OffloadConfig = OffloadConfig(),
225
- enable_cfg_parallel: bool = True,
226
- ):
227
- mp.spawn(
228
- single_gpu_run,
229
- nprocs=world_size,
230
- join=True,
231
- daemon=False,
232
- args=(
233
- task_type,
234
- model_id,
235
- self.REQ_QUEUES,
236
- self.RESP_QUEUE,
237
- quant_model,
238
- world_size,
239
- is_offload,
240
- offload_config,
241
- enable_cfg_parallel,
242
- ),
243
- )
244
- logger.info(f"finish lanch multi gpu infer, world_size:{world_size}")
245
-
246
  def inference(self, kwargs: Dict[str, Any]):
247
- # put request to singlegpuinfer
248
- for _ in range(self.world_size):
249
- self.REQ_QUEUES.put(kwargs)
250
- return self.RESP_QUEUE.get()
 
 
 
 
 
 
1
  import logging
2
  import os
 
3
  import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
7
 
8
  import torch
 
 
9
  from diffusers import HunyuanVideoTransformer3DModel
10
  from PIL import Image
11
  from torchao.quantization import float8_weight_only
12
  from torchao.quantization import quantize_
13
  from transformers import LlamaModel
14
 
15
+ from . import TaskType # Assuming these are still needed
16
+ from .offload import Offload, OffloadConfig
 
17
  from .pipelines import SkyreelsVideoPipeline
18
 
19
  logger = logging.getLogger("SkyreelsVideoInfer")
 
26
  console_handler.setFormatter(formatter)
27
  logger.addHandler(console_handler)
28
 
 
29
  class SkyReelsVideoSingleGpuInfer:
30
  def _load_model(
31
  self,
 
39
  base_model_id,
40
  subfolder="text_encoder",
41
  torch_dtype=torch.bfloat16,
42
+ ).to(gpu_device) # Directly to GPU
43
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
44
  model_id,
45
  # subfolder="transformer",
46
  torch_dtype=torch.bfloat16,
47
+ # device="cpu", # No longer needed, use gpu_device directly
48
+ ).to(gpu_device) # Directly to GPU
49
  if quant_model:
50
+ quantize_(text_encoder, float8_weight_only(), device=gpu_device) # Quantize in place
51
+ quantize_(transformer, float8_weight_only(), device=gpu_device) # Quantize in place
52
+ # No need for text_encoder.to("cpu") and transformer.to("cpu") with torch.cuda.empty_cache().
53
+ # We put models to gpu_device in advance.
 
 
54
  pipe = SkyreelsVideoPipeline.from_pretrained(
55
  base_model_id,
56
  transformer=transformer,
57
  text_encoder=text_encoder,
58
  torch_dtype=torch.bfloat16,
59
+ ).to(gpu_device) # Directly to GPU
60
  pipe.vae.enable_tiling()
61
+ # torch.cuda.empty_cache() # Generally good practice, but placement matters.
62
  return pipe
63
 
64
  def __init__(
 
66
  task_type: TaskType,
67
  model_id: str,
68
  quant_model: bool = True,
 
 
69
  is_offload: bool = True,
70
  offload_config: OffloadConfig = OffloadConfig(),
 
71
  ):
72
  self.task_type = task_type
73
+ # os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU
74
+ torch.cuda.set_device(0) # Still a good idea to be explicit.
75
+ torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
 
76
  gpu_device = "cuda:0"
77
 
78
  self.pipe: SkyreelsVideoPipeline = self._load_model(
79
  model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
80
  )
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if is_offload:
83
  Offload.offload(
84
  pipeline=self.pipe,
 
90
  if offload_config.compiler_transformer:
91
  torch._dynamo.config.suppress_errors = True
92
  os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
93
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
94
  self.pipe.transformer = torch.compile(
95
  self.pipe.transformer,
96
  mode="max-autotune-no-cudagraphs",
 
114
  init_kwargs["image"] = Image.new("RGB", (544, 960), color="black")
115
  self.pipe(**init_kwargs)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def inference(self, kwargs: Dict[str, Any]):
118
+ logger.info(f"kwargs: {kwargs}")
119
+ if "seed" in kwargs:
120
+ kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
121
+ del kwargs["seed"]
122
+ start_time = time.time()
123
+ assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
124
+ out = self.pipe(**kwargs).frames[0]
125
+ logger.info(f"inference time: {time.time() - start_time}")
126
+ return out