Fabrice-TIERCELIN commited on
Commit
7a285be
·
verified ·
1 Parent(s): 67b8f7c

' instead of "

Browse files
Files changed (1) hide show
  1. hyvideo/modules/fp8_optimization.py +102 -102
hyvideo/modules/fp8_optimization.py CHANGED
@@ -1,102 +1,102 @@
1
- import os
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
-
7
- def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8
- _bits = torch.tensor(bits)
9
- _mantissa_bit = torch.tensor(mantissa_bit)
10
- _sign_bits = torch.tensor(sign_bits)
11
- M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12
- E = _bits - _sign_bits - M
13
- bias = 2 ** (E - 1) - 1
14
- mantissa = 1
15
- for i in range(mantissa_bit - 1):
16
- mantissa += 1 / (2 ** (i+1))
17
- maxval = mantissa * 2 ** (2**E - 1 - bias)
18
- return maxval
19
-
20
- def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21
- """
22
- Default is E4M3.
23
- """
24
- bits = torch.tensor(bits)
25
- mantissa_bit = torch.tensor(mantissa_bit)
26
- sign_bits = torch.tensor(sign_bits)
27
- M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28
- E = bits - sign_bits - M
29
- bias = 2 ** (E - 1) - 1
30
- mantissa = 1
31
- for i in range(mantissa_bit - 1):
32
- mantissa += 1 / (2 ** (i+1))
33
- maxval = mantissa * 2 ** (2**E - 1 - bias)
34
- minval = - maxval
35
- minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36
- input_clamp = torch.min(torch.max(x, minval), maxval)
37
- log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38
- log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39
- # dequant
40
- qdq_out = torch.round(input_clamp / log_scales) * log_scales
41
- return qdq_out, log_scales
42
-
43
- def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44
- for i in range(len(x.shape) - 1):
45
- scale = scale.unsqueeze(-1)
46
- new_x = x / scale
47
- quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48
- return quant_dequant_x, scale, log_scales
49
-
50
- def fp8_activation_dequant(qdq_out, scale, dtype):
51
- qdq_out = qdq_out.type(dtype)
52
- quant_dequant_x = qdq_out * scale.to(dtype)
53
- return quant_dequant_x
54
-
55
- def fp8_linear_forward(cls, original_dtype, input):
56
- weight_dtype = cls.weight.dtype
57
- #####
58
- if cls.weight.dtype != torch.float8_e4m3fn:
59
- maxval = get_fp_maxval()
60
- scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61
- linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62
- linear_weight = linear_weight.to(torch.float8_e4m3fn)
63
- weight_dtype = linear_weight.dtype
64
- else:
65
- scale = cls.fp8_scale.to(cls.weight.device)
66
- linear_weight = cls.weight
67
- #####
68
-
69
- if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70
- if True or len(input.shape) == 3:
71
- cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72
- if cls.bias != None:
73
- output = F.linear(input, cls_dequant, cls.bias)
74
- else:
75
- output = F.linear(input, cls_dequant)
76
- return output
77
- else:
78
- return cls.original_forward(input.to(original_dtype))
79
- else:
80
- return cls.original_forward(input)
81
-
82
- def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83
- setattr(module, "fp8_matmul_enabled", True)
84
-
85
- # loading fp8 mapping file
86
- fp8_map_path = dit_weight_path.replace('.pt', '_map.pt')
87
- if os.path.exists(fp8_map_path):
88
- fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89
- else:
90
- raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91
-
92
- fp8_layers = []
93
- for key, layer in module.named_modules():
94
- if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95
- fp8_layers.append(key)
96
- original_forward = layer.forward
97
- layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98
- setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99
- setattr(layer, "original_forward", original_forward)
100
- setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
101
-
102
-
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
8
+ _bits = torch.tensor(bits)
9
+ _mantissa_bit = torch.tensor(mantissa_bit)
10
+ _sign_bits = torch.tensor(sign_bits)
11
+ M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
12
+ E = _bits - _sign_bits - M
13
+ bias = 2 ** (E - 1) - 1
14
+ mantissa = 1
15
+ for i in range(mantissa_bit - 1):
16
+ mantissa += 1 / (2 ** (i+1))
17
+ maxval = mantissa * 2 ** (2**E - 1 - bias)
18
+ return maxval
19
+
20
+ def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
21
+ """
22
+ Default is E4M3.
23
+ """
24
+ bits = torch.tensor(bits)
25
+ mantissa_bit = torch.tensor(mantissa_bit)
26
+ sign_bits = torch.tensor(sign_bits)
27
+ M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
28
+ E = bits - sign_bits - M
29
+ bias = 2 ** (E - 1) - 1
30
+ mantissa = 1
31
+ for i in range(mantissa_bit - 1):
32
+ mantissa += 1 / (2 ** (i+1))
33
+ maxval = mantissa * 2 ** (2**E - 1 - bias)
34
+ minval = - maxval
35
+ minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
36
+ input_clamp = torch.min(torch.max(x, minval), maxval)
37
+ log_scales = torch.clamp((torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
38
+ log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
39
+ # dequant
40
+ qdq_out = torch.round(input_clamp / log_scales) * log_scales
41
+ return qdq_out, log_scales
42
+
43
+ def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
44
+ for i in range(len(x.shape) - 1):
45
+ scale = scale.unsqueeze(-1)
46
+ new_x = x / scale
47
+ quant_dequant_x, log_scales = quantize_to_fp8(new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
48
+ return quant_dequant_x, scale, log_scales
49
+
50
+ def fp8_activation_dequant(qdq_out, scale, dtype):
51
+ qdq_out = qdq_out.type(dtype)
52
+ quant_dequant_x = qdq_out * scale.to(dtype)
53
+ return quant_dequant_x
54
+
55
+ def fp8_linear_forward(cls, original_dtype, input):
56
+ weight_dtype = cls.weight.dtype
57
+ #####
58
+ if cls.weight.dtype != torch.float8_e4m3fn:
59
+ maxval = get_fp_maxval()
60
+ scale = torch.max(torch.abs(cls.weight.flatten())) / maxval
61
+ linear_weight, scale, log_scales = fp8_tensor_quant(cls.weight, scale)
62
+ linear_weight = linear_weight.to(torch.float8_e4m3fn)
63
+ weight_dtype = linear_weight.dtype
64
+ else:
65
+ scale = cls.fp8_scale.to(cls.weight.device)
66
+ linear_weight = cls.weight
67
+ #####
68
+
69
+ if weight_dtype == torch.float8_e4m3fn and cls.weight.sum() != 0:
70
+ if True or len(input.shape) == 3:
71
+ cls_dequant = fp8_activation_dequant(linear_weight, scale, original_dtype)
72
+ if cls.bias != None:
73
+ output = F.linear(input, cls_dequant, cls.bias)
74
+ else:
75
+ output = F.linear(input, cls_dequant)
76
+ return output
77
+ else:
78
+ return cls.original_forward(input.to(original_dtype))
79
+ else:
80
+ return cls.original_forward(input)
81
+
82
+ def convert_fp8_linear(module, dit_weight_path, original_dtype, params_to_keep={}):
83
+ setattr(module, "fp8_matmul_enabled", True)
84
+
85
+ # loading fp8 mapping file
86
+ fp8_map_path = dit_weight_path.replace(".pt", "_map.pt")
87
+ if os.path.exists(fp8_map_path):
88
+ fp8_map = torch.load(fp8_map_path, map_location=lambda storage, loc: storage)
89
+ else:
90
+ raise ValueError(f"Invalid fp8_map path: {fp8_map_path}.")
91
+
92
+ fp8_layers = []
93
+ for key, layer in module.named_modules():
94
+ if isinstance(layer, nn.Linear) and ('double_blocks' in key or 'single_blocks' in key):
95
+ fp8_layers.append(key)
96
+ original_forward = layer.forward
97
+ layer.weight = torch.nn.Parameter(layer.weight.to(torch.float8_e4m3fn))
98
+ setattr(layer, "fp8_scale", fp8_map[key].to(dtype=original_dtype))
99
+ setattr(layer, "original_forward", original_forward)
100
+ setattr(layer, "forward", lambda input, m=layer: fp8_linear_forward(m, original_dtype, input))
101
+
102
+