Upload 31 files
Browse files- hyvideo/config.py +6 -0
- hyvideo/inference.py +5 -6
- hyvideo/modules/fp8_optimization.py +102 -0
hyvideo/config.py
CHANGED
|
@@ -346,6 +346,12 @@ def add_inference_args(parser: argparse.ArgumentParser):
|
|
| 346 |
help="Embeded classifier free guidance scale.",
|
| 347 |
)
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
group.add_argument(
|
| 350 |
"--reproduce",
|
| 351 |
action="store_true",
|
|
|
|
| 346 |
help="Embeded classifier free guidance scale.",
|
| 347 |
)
|
| 348 |
|
| 349 |
+
group.add_argument(
|
| 350 |
+
"--use-fp8",
|
| 351 |
+
action="store_true",
|
| 352 |
+
help="Enable use fp8 for inference acceleration."
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
group.add_argument(
|
| 356 |
"--reproduce",
|
| 357 |
action="store_true",
|
hyvideo/inference.py
CHANGED
|
@@ -15,6 +15,7 @@ from hyvideo.modules import load_model
|
|
| 15 |
from hyvideo.text_encoder import TextEncoder
|
| 16 |
from hyvideo.utils.data_utils import align_to
|
| 17 |
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
|
|
|
|
| 18 |
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
| 19 |
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
| 20 |
|
|
@@ -196,6 +197,8 @@ class Inference(object):
|
|
| 196 |
out_channels=out_channels,
|
| 197 |
factor_kwargs=factor_kwargs,
|
| 198 |
)
|
|
|
|
|
|
|
| 199 |
model = model.to(device)
|
| 200 |
model = Inference.load_state_dict(args, model, pretrained_model_path)
|
| 201 |
model.eval()
|
|
@@ -402,6 +405,8 @@ class HunyuanVideoSampler(Inference):
|
|
| 402 |
)
|
| 403 |
|
| 404 |
self.default_negative_prompt = NEGATIVE_PROMPT
|
|
|
|
|
|
|
| 405 |
|
| 406 |
def load_diffusion_pipeline(
|
| 407 |
self,
|
|
@@ -521,12 +526,6 @@ class HunyuanVideoSampler(Inference):
|
|
| 521 |
num_images_per_prompt (int): The number of images per prompt. Default is 1.
|
| 522 |
infer_steps (int): The number of inference steps. Default is 100.
|
| 523 |
"""
|
| 524 |
-
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
|
| 525 |
-
assert seed is not None, \
|
| 526 |
-
"You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."
|
| 527 |
-
|
| 528 |
-
parallelize_transformer(self.pipeline)
|
| 529 |
-
|
| 530 |
out_dict = dict()
|
| 531 |
|
| 532 |
# ========================================================================
|
|
|
|
| 15 |
from hyvideo.text_encoder import TextEncoder
|
| 16 |
from hyvideo.utils.data_utils import align_to
|
| 17 |
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed
|
| 18 |
+
from hyvideo.modules.fp8_optimization import convert_fp8_linear
|
| 19 |
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
| 20 |
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
| 21 |
|
|
|
|
| 197 |
out_channels=out_channels,
|
| 198 |
factor_kwargs=factor_kwargs,
|
| 199 |
)
|
| 200 |
+
if args.use_fp8:
|
| 201 |
+
convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
|
| 202 |
model = model.to(device)
|
| 203 |
model = Inference.load_state_dict(args, model, pretrained_model_path)
|
| 204 |
model.eval()
|
|
|
|
| 405 |
)
|
| 406 |
|
| 407 |
self.default_negative_prompt = NEGATIVE_PROMPT
|
| 408 |
+
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
|
| 409 |
+
parallelize_transformer(self.pipeline)
|
| 410 |
|
| 411 |
def load_diffusion_pipeline(
|
| 412 |
self,
|
|
|
|
| 526 |
num_images_per_prompt (int): The number of images per prompt. Default is 1.
|
| 527 |
infer_steps (int): The number of inference steps. Default is 100.
|
| 528 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
out_dict = dict()
|
| 530 |
|
| 531 |
# ========================================================================
|
hyvideo/modules/fp8_optimization.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
|
| 8 |
+
_bits = torch.tensor(bits)
|
| 9 |
+
_mantissa_bit = torch.tensor(mantissa_bit)
|
| 10 |
+
_sign_bits = torch.tensor(sign_bits)
|
| 11 |
+
M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
|
| 12 |
+
E = _bits - _sign_bits - M
|
| 13 |
+
bias = 2 ** (E - 1) - 1
|
| 14 |
+
mantissa = 1
|
| 15 |
+
for i in range(mantissa_bit - 1):
|
| 16 |
+
mantissa += 1 / (2 ** (i+1))
|
| 17 |
+
maxval = mantissa * 2 ** (2**E - 1 - bias)
|
| 18 |
+
return maxval
|
| 19 |
+
|
| 20 |
+
def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
|
| 21 |
+
"""
|
| 22 |
+
Default is E4M3.
|
| 23 |
+
"""
|
| 24 |
+
bits = torch.tensor(bits)
|
| 25 |
+
mantissa_bit = torch.tensor(mantissa_bit)
|
| 26 |
+
sign_bits = torch.tensor(sign_bits)
|
| 27 |
+
M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
|
| 28 |
+
E = bits - sign_bits - M
|
| 29 |
+
bias = 2 ** (E - 1) - 1
|
| 30 |
+
mantissa = 1
|
| 31 |
+
for i in range(mantissa_bit - 1):
|
| 32 |
+
mantissa += 1 / (2 ** (i+1))
|
| 33 |
+
maxval = mantissa * 2 ** (2**E - 1 - bias)
|
| 34 |
+
minval = - maxval
|
| 35 |
+
minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
|
| 36 |
+
input_clamp = torch.min(torch.max(x, minval), maxval)
|
| 37 |
+
log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
|
| 38 |
+
log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
|
| 39 |
+
# dequant
|
| 40 |
+
qdq_out = torch.round(input_clamp / log_scales) * log_scales
|
| 41 |
+
return qdq_out, log_scales
|
| 42 |
+
|
| 43 |
+
def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
|
| 44 |
+
for i in range(len(x.shape) - 1):
|
| 45 |
+
scale = scale.unsqueeze(-1)
|
| 46 |
+
new_x = x / scale
|
| 47 |
+
quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
|
| 48 |
+
return quant_dequant_x, scale, log_scales
|
| 49 |
+
|
| 50 |
+
def fp8_activation_dequant(qdq_out, scale, dtype):
|
| 51 |
+
qdq_out = qdq_out.type(dtype)
|
| 52 |
+
quant_dequant_x = qdq_out * scale.to(dtype)
|
| 53 |
+
return quant_dequant_x
|
| 54 |
+
|
| 55 |
+
def fp8_linear_forward(cls, original_dtype, input):
|
| 56 |
+
weight_dtype = cls.weight.dtype
|
| 57 |
+
#####
|
| 58 |
+
if cls.weight.dtype != torch.float8_e4m3fn:
|
| 59 |
+
maxval = get_fp_maxval()
|
| 60 |
+
scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
|
| 61 |
+
linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
|
| 62 |
+
linear_weight = linear_weight.to(torch.float8_e4m3fn)
|
| 63 |
+
weight_dtype = linear_weight.dtype
|
| 64 |
+
else:
|
| 65 |
+
scale = cls.fp8_scale.to(cls.weight.device)
|
| 66 |
+
linear_weight = cls.weight
|
| 67 |
+
#####
|
| 68 |
+
|
| 69 |
+
if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
|
| 70 |
+
if True or len(input.shape) == 3:
|
| 71 |
+
cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
|
| 72 |
+
if cls.bias != None:
|
| 73 |
+
output = F.linear(input, cls_dequant, cls.bias)
|
| 74 |
+
else:
|
| 75 |
+
output = F.linear(input, cls_dequant)
|
| 76 |
+
return output
|
| 77 |
+
else:
|
| 78 |
+
return cls.original_forward(input.to(original_dtype))
|
| 79 |
+
else:
|
| 80 |
+
return cls.original_forward(input)
|
| 81 |
+
|
| 82 |
+
def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
|
| 83 |
+
setattr(module, "fp8_matmul_enabled", True)
|
| 84 |
+
|
| 85 |
+
# loading fp8 mapping file
|
| 86 |
+
fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
|
| 87 |
+
if os.path.exists(fp8_map_path):
|
| 88 |
+
fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
|
| 91 |
+
|
| 92 |
+
fp8_layers = []
|
| 93 |
+
for key, layer in module.named_modules():
|
| 94 |
+
if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
|
| 95 |
+
fp8_layers.append(key)
|
| 96 |
+
original_forward = layer.forward
|
| 97 |
+
layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
|
| 98 |
+
setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
|
| 99 |
+
setattr(layer, "original_forward", original_forward)
|
| 100 |
+
setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
|
| 101 |
+
|
| 102 |
+
|