yiren98 commited on
Commit
abd09b6
·
verified ·
1 Parent(s): 12ae7b3

Upload 98 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +25 -1
  2. images/datasets/claysculpture.png +3 -0
  3. images/datasets/claytoys.png +3 -0
  4. images/datasets/cook.png +3 -0
  5. images/datasets/emoji.png +3 -0
  6. images/datasets/fabrictoys.png +3 -0
  7. images/datasets/icon.png +3 -0
  8. images/datasets/illustration.png +3 -0
  9. images/datasets/inkpainting.png +3 -0
  10. images/datasets/jadecarving.png +3 -0
  11. images/datasets/landscape.png +3 -0
  12. images/datasets/lego.png +3 -0
  13. images/datasets/linedraw.png +3 -0
  14. images/datasets/oilpainting.png +3 -0
  15. images/datasets/painting.png +3 -0
  16. images/datasets/pencilsketch.png +3 -0
  17. images/datasets/portrait.png +3 -0
  18. images/datasets/sandart.png +3 -0
  19. images/datasets/sketch.png +3 -0
  20. images/datasets/transformer.png +3 -0
  21. images/datasets/woodsculpture.png +3 -0
  22. images/datasets/zbrush.png +3 -0
  23. images/i2i.png +3 -0
  24. images/oneshot.png +3 -0
  25. images/t2i.png +3 -0
  26. images/teaser.png +3 -0
  27. library/__init__.py +0 -0
  28. library/adafactor_fused.py +138 -0
  29. library/attention_processors.py +227 -0
  30. library/config_util.py +716 -0
  31. library/custom_offloading_utils.py +227 -0
  32. library/custom_train_functions.py +559 -0
  33. library/deepspeed_utils.py +139 -0
  34. library/device_utils.py +84 -0
  35. library/flux_models.py +1237 -0
  36. library/flux_train_utils.py +582 -0
  37. library/flux_train_utils_recraft.py +659 -0
  38. library/flux_utils.py +472 -0
  39. library/huggingface_util.py +84 -0
  40. library/hypernetwork.py +223 -0
  41. library/ipex/__init__.py +180 -0
  42. library/ipex/attention.py +177 -0
  43. library/ipex/diffusers.py +312 -0
  44. library/ipex/gradscaler.py +183 -0
  45. library/ipex/hijacks.py +313 -0
  46. library/lpw_stable_diffusion.py +1233 -0
  47. library/model_util.py +1356 -0
  48. library/original_unet.py +1919 -0
  49. library/sai_model_spec.py +334 -0
  50. library/sd3_models.py +1413 -0
.gitattributes CHANGED
@@ -43,4 +43,28 @@ asy_results
43
  recraft_results
44
  drop
45
  SplitAsy
46
- example*
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  recraft_results
44
  drop
45
  SplitAsy
46
+ example*images/datasets/claysculpture.png filter=lfs diff=lfs merge=lfs -text
47
+ images/datasets/claytoys.png filter=lfs diff=lfs merge=lfs -text
48
+ images/datasets/cook.png filter=lfs diff=lfs merge=lfs -text
49
+ images/datasets/emoji.png filter=lfs diff=lfs merge=lfs -text
50
+ images/datasets/fabrictoys.png filter=lfs diff=lfs merge=lfs -text
51
+ images/datasets/icon.png filter=lfs diff=lfs merge=lfs -text
52
+ images/datasets/illustration.png filter=lfs diff=lfs merge=lfs -text
53
+ images/datasets/inkpainting.png filter=lfs diff=lfs merge=lfs -text
54
+ images/datasets/jadecarving.png filter=lfs diff=lfs merge=lfs -text
55
+ images/datasets/landscape.png filter=lfs diff=lfs merge=lfs -text
56
+ images/datasets/lego.png filter=lfs diff=lfs merge=lfs -text
57
+ images/datasets/linedraw.png filter=lfs diff=lfs merge=lfs -text
58
+ images/datasets/oilpainting.png filter=lfs diff=lfs merge=lfs -text
59
+ images/datasets/painting.png filter=lfs diff=lfs merge=lfs -text
60
+ images/datasets/pencilsketch.png filter=lfs diff=lfs merge=lfs -text
61
+ images/datasets/portrait.png filter=lfs diff=lfs merge=lfs -text
62
+ images/datasets/sandart.png filter=lfs diff=lfs merge=lfs -text
63
+ images/datasets/sketch.png filter=lfs diff=lfs merge=lfs -text
64
+ images/datasets/transformer.png filter=lfs diff=lfs merge=lfs -text
65
+ images/datasets/woodsculpture.png filter=lfs diff=lfs merge=lfs -text
66
+ images/datasets/zbrush.png filter=lfs diff=lfs merge=lfs -text
67
+ images/i2i.png filter=lfs diff=lfs merge=lfs -text
68
+ images/oneshot.png filter=lfs diff=lfs merge=lfs -text
69
+ images/t2i.png filter=lfs diff=lfs merge=lfs -text
70
+ images/teaser.png filter=lfs diff=lfs merge=lfs -text
images/datasets/claysculpture.png ADDED

Git LFS Details

  • SHA256: f566abc3b8197f1bdb491947352fa7984d1ffa99bcc91e43fade76116d3232eb
  • Pointer size: 131 Bytes
  • Size of remote file: 392 kB
images/datasets/claytoys.png ADDED

Git LFS Details

  • SHA256: 86de33d83a349ff4319ec99f4edf7121453e93175fd43e477cdc7a71e03ad5ec
  • Pointer size: 131 Bytes
  • Size of remote file: 805 kB
images/datasets/cook.png ADDED

Git LFS Details

  • SHA256: 4fd01153e9ae01670bb57f48729adf46afcb4463a12f7f685747c73892373eaa
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
images/datasets/emoji.png ADDED

Git LFS Details

  • SHA256: bc4d1abe5cad6f9f61838473a5449fcc65c361792479a50434b2a77b7d177b76
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
images/datasets/fabrictoys.png ADDED

Git LFS Details

  • SHA256: 72274d9a4fdae9b66081e4508618f7c88eaf749a24881574af4190df5c6c3159
  • Pointer size: 131 Bytes
  • Size of remote file: 619 kB
images/datasets/icon.png ADDED

Git LFS Details

  • SHA256: b9886684d28e956d041bdf835f0d6365ea2dbd6298140f10faa8c076ce7215f3
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
images/datasets/illustration.png ADDED

Git LFS Details

  • SHA256: 26f14d9ff97810012a809cbd4163c0662768c819b19f9f77622c233cdaaef4e8
  • Pointer size: 131 Bytes
  • Size of remote file: 238 kB
images/datasets/inkpainting.png ADDED

Git LFS Details

  • SHA256: b513e1b03984a127d424ef8e89ce687f2497e030f71efbdfd5501a71509ef961
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
images/datasets/jadecarving.png ADDED

Git LFS Details

  • SHA256: 53d8976040d3329bed1811568f453ae6095683699d6dd962369220079ef738f9
  • Pointer size: 131 Bytes
  • Size of remote file: 905 kB
images/datasets/landscape.png ADDED

Git LFS Details

  • SHA256: 0972b0332c1e1dc9a7a01681d682625b2ab42c09579550bbaa43271dc5833880
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
images/datasets/lego.png ADDED

Git LFS Details

  • SHA256: 63a87beac48dff4e1110e8c40c6bbe8899253bcf784725ea9183940c427a379a
  • Pointer size: 131 Bytes
  • Size of remote file: 557 kB
images/datasets/linedraw.png ADDED

Git LFS Details

  • SHA256: 70aada932b271a04b2a9e2adf78cacc2a5541e19ebeb01a33fdb225a3957fa00
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
images/datasets/oilpainting.png ADDED

Git LFS Details

  • SHA256: 6316f86c9b8e28d1210574f3b8e041d15548c91ff0838e8ad3502c35fb60db9f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
images/datasets/painting.png ADDED

Git LFS Details

  • SHA256: b1a10a776f880c6fd217cec659374d81fc3a7bbccf0a11dd5a2155b7113d83d8
  • Pointer size: 131 Bytes
  • Size of remote file: 680 kB
images/datasets/pencilsketch.png ADDED

Git LFS Details

  • SHA256: 884ac2f80d63aef5b8fa54e8e37d428eb137fe5b37c3695b5d57eb1af79a0b7e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
images/datasets/portrait.png ADDED

Git LFS Details

  • SHA256: 6153919f4ed6ca2b18c88f75ffc44d8af37490cf615636931e710c73e01de015
  • Pointer size: 131 Bytes
  • Size of remote file: 606 kB
images/datasets/sandart.png ADDED

Git LFS Details

  • SHA256: 41d064a27b3ad7759784de84d7f2a5a6284b697d63a2e7e6f6aa54fc34c748dc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
images/datasets/sketch.png ADDED

Git LFS Details

  • SHA256: 84810c3716f1174ae994544ffc5fa166e0a14aee7d45e23db3fd9a4f31772728
  • Pointer size: 131 Bytes
  • Size of remote file: 490 kB
images/datasets/transformer.png ADDED

Git LFS Details

  • SHA256: af20605f443e31c9ad0ada9ee41a6f06364f77fef21c11756fc64878e217e63f
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB
images/datasets/woodsculpture.png ADDED

Git LFS Details

  • SHA256: a9b43a434bcc423ca659327665957e38d4855c0701839f0d2222b3f7621fa1d2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
images/datasets/zbrush.png ADDED

Git LFS Details

  • SHA256: 43547722eb67c24356a1f6c28485a0ec2cd2513a3b99579b697d2966c908892e
  • Pointer size: 131 Bytes
  • Size of remote file: 316 kB
images/i2i.png ADDED

Git LFS Details

  • SHA256: 78a94a6ba5fceea6e2f6e5a85dece91ce700b73c15884d1a7e8d435c71a1d5cd
  • Pointer size: 131 Bytes
  • Size of remote file: 555 kB
images/oneshot.png ADDED

Git LFS Details

  • SHA256: a14fb05ad2b648876021f2fa87044bc5693b99e8977920022c38e137a3122e61
  • Pointer size: 131 Bytes
  • Size of remote file: 572 kB
images/t2i.png ADDED

Git LFS Details

  • SHA256: 91445e3d371ec49fc47f828f97662597a065163b6893f457d1700404d3272e0c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
images/teaser.png ADDED

Git LFS Details

  • SHA256: 5bc7be63232429c28d5633a1654344de04dba2e13e163574f56bea6887a6402f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.98 MB
library/__init__.py ADDED
File without changes
library/adafactor_fused.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from transformers import Adafactor
4
+
5
+ # stochastic rounding for bfloat16
6
+ # The implementation was provided by 2kpr. Thank you very much!
7
+
8
+ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
9
+ """
10
+ copies source into target using stochastic rounding
11
+
12
+ Args:
13
+ target: the target tensor with dtype=bfloat16
14
+ source: the target tensor with dtype=float32
15
+ """
16
+ # create a random 16 bit integer
17
+ result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
18
+
19
+ # add the random number to the lower 16 bit of the mantissa
20
+ result.add_(source.view(dtype=torch.int32))
21
+
22
+ # mask off the lower 16 bit of the mantissa
23
+ result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
24
+
25
+ # copy the higher 16 bit into the target tensor
26
+ target.copy_(result.view(dtype=torch.float32))
27
+
28
+ del result
29
+
30
+
31
+ @torch.no_grad()
32
+ def adafactor_step_param(self, p, group):
33
+ if p.grad is None:
34
+ return
35
+ grad = p.grad
36
+ if grad.dtype in {torch.float16, torch.bfloat16}:
37
+ grad = grad.float()
38
+ if grad.is_sparse:
39
+ raise RuntimeError("Adafactor does not support sparse gradients.")
40
+
41
+ state = self.state[p]
42
+ grad_shape = grad.shape
43
+
44
+ factored, use_first_moment = Adafactor._get_options(group, grad_shape)
45
+ # State Initialization
46
+ if len(state) == 0:
47
+ state["step"] = 0
48
+
49
+ if use_first_moment:
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(grad)
52
+ if factored:
53
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
54
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
55
+ else:
56
+ state["exp_avg_sq"] = torch.zeros_like(grad)
57
+
58
+ state["RMS"] = 0
59
+ else:
60
+ if use_first_moment:
61
+ state["exp_avg"] = state["exp_avg"].to(grad)
62
+ if factored:
63
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
64
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
65
+ else:
66
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
67
+
68
+ p_data_fp32 = p
69
+ if p.dtype in {torch.float16, torch.bfloat16}:
70
+ p_data_fp32 = p_data_fp32.float()
71
+
72
+ state["step"] += 1
73
+ state["RMS"] = Adafactor._rms(p_data_fp32)
74
+ lr = Adafactor._get_lr(group, state)
75
+
76
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
77
+ update = (grad**2) + group["eps"][0]
78
+ if factored:
79
+ exp_avg_sq_row = state["exp_avg_sq_row"]
80
+ exp_avg_sq_col = state["exp_avg_sq_col"]
81
+
82
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
83
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
84
+
85
+ # Approximation of exponential moving average of square of gradient
86
+ update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
87
+ update.mul_(grad)
88
+ else:
89
+ exp_avg_sq = state["exp_avg_sq"]
90
+
91
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
92
+ update = exp_avg_sq.rsqrt().mul_(grad)
93
+
94
+ update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
95
+ update.mul_(lr)
96
+
97
+ if use_first_moment:
98
+ exp_avg = state["exp_avg"]
99
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
100
+ update = exp_avg
101
+
102
+ if group["weight_decay"] != 0:
103
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
104
+
105
+ p_data_fp32.add_(-update)
106
+
107
+ # if p.dtype in {torch.float16, torch.bfloat16}:
108
+ # p.copy_(p_data_fp32)
109
+
110
+ if p.dtype == torch.bfloat16:
111
+ copy_stochastic_(p, p_data_fp32)
112
+ elif p.dtype == torch.float16:
113
+ p.copy_(p_data_fp32)
114
+
115
+
116
+ @torch.no_grad()
117
+ def adafactor_step(self, closure=None):
118
+ """
119
+ Performs a single optimization step
120
+
121
+ Arguments:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ loss = closure()
128
+
129
+ for group in self.param_groups:
130
+ for p in group["params"]:
131
+ adafactor_step_param(self, p, group)
132
+
133
+ return loss
134
+
135
+
136
+ def patch_adafactor_fused(optimizer: Adafactor):
137
+ optimizer.step_param = adafactor_step_param.__get__(optimizer)
138
+ optimizer.step = adafactor_step.__get__(optimizer)
library/attention_processors.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ from einops import rearrange
4
+ import torch
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+
8
+ # flash attention forwards and backwards
9
+
10
+ # https://arxiv.org/abs/2205.14135
11
+
12
+ EPSILON = 1e-6
13
+
14
+
15
+ class FlashAttentionFunction(torch.autograd.function.Function):
16
+ @staticmethod
17
+ @torch.no_grad()
18
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
19
+ """Algorithm 2 in the paper"""
20
+
21
+ device = q.device
22
+ dtype = q.dtype
23
+ max_neg_value = -torch.finfo(q.dtype).max
24
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
25
+
26
+ o = torch.zeros_like(q)
27
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
28
+ all_row_maxes = torch.full(
29
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
30
+ )
31
+
32
+ scale = q.shape[-1] ** -0.5
33
+
34
+ if mask is None:
35
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
36
+ else:
37
+ mask = rearrange(mask, "b n -> b 1 1 n")
38
+ mask = mask.split(q_bucket_size, dim=-1)
39
+
40
+ row_splits = zip(
41
+ q.split(q_bucket_size, dim=-2),
42
+ o.split(q_bucket_size, dim=-2),
43
+ mask,
44
+ all_row_sums.split(q_bucket_size, dim=-2),
45
+ all_row_maxes.split(q_bucket_size, dim=-2),
46
+ )
47
+
48
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
49
+ q_start_index = ind * q_bucket_size - qk_len_diff
50
+
51
+ col_splits = zip(
52
+ k.split(k_bucket_size, dim=-2),
53
+ v.split(k_bucket_size, dim=-2),
54
+ )
55
+
56
+ for k_ind, (kc, vc) in enumerate(col_splits):
57
+ k_start_index = k_ind * k_bucket_size
58
+
59
+ attn_weights = (
60
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
61
+ )
62
+
63
+ if row_mask is not None:
64
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
65
+
66
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
67
+ causal_mask = torch.ones(
68
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
69
+ ).triu(q_start_index - k_start_index + 1)
70
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
71
+
72
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
73
+ attn_weights -= block_row_maxes
74
+ exp_weights = torch.exp(attn_weights)
75
+
76
+ if row_mask is not None:
77
+ exp_weights.masked_fill_(~row_mask, 0.0)
78
+
79
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
80
+ min=EPSILON
81
+ )
82
+
83
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
84
+
85
+ exp_values = torch.einsum(
86
+ "... i j, ... j d -> ... i d", exp_weights, vc
87
+ )
88
+
89
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
90
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
91
+
92
+ new_row_sums = (
93
+ exp_row_max_diff * row_sums
94
+ + exp_block_row_max_diff * block_row_sums
95
+ )
96
+
97
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
98
+ (exp_block_row_max_diff / new_row_sums) * exp_values
99
+ )
100
+
101
+ row_maxes.copy_(new_row_maxes)
102
+ row_sums.copy_(new_row_sums)
103
+
104
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
105
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
106
+
107
+ return o
108
+
109
+ @staticmethod
110
+ @torch.no_grad()
111
+ def backward(ctx, do):
112
+ """Algorithm 4 in the paper"""
113
+
114
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
115
+ q, k, v, o, l, m = ctx.saved_tensors
116
+
117
+ device = q.device
118
+
119
+ max_neg_value = -torch.finfo(q.dtype).max
120
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
121
+
122
+ dq = torch.zeros_like(q)
123
+ dk = torch.zeros_like(k)
124
+ dv = torch.zeros_like(v)
125
+
126
+ row_splits = zip(
127
+ q.split(q_bucket_size, dim=-2),
128
+ o.split(q_bucket_size, dim=-2),
129
+ do.split(q_bucket_size, dim=-2),
130
+ mask,
131
+ l.split(q_bucket_size, dim=-2),
132
+ m.split(q_bucket_size, dim=-2),
133
+ dq.split(q_bucket_size, dim=-2),
134
+ )
135
+
136
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
137
+ q_start_index = ind * q_bucket_size - qk_len_diff
138
+
139
+ col_splits = zip(
140
+ k.split(k_bucket_size, dim=-2),
141
+ v.split(k_bucket_size, dim=-2),
142
+ dk.split(k_bucket_size, dim=-2),
143
+ dv.split(k_bucket_size, dim=-2),
144
+ )
145
+
146
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
147
+ k_start_index = k_ind * k_bucket_size
148
+
149
+ attn_weights = (
150
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
151
+ )
152
+
153
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
154
+ causal_mask = torch.ones(
155
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
156
+ ).triu(q_start_index - k_start_index + 1)
157
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
158
+
159
+ exp_attn_weights = torch.exp(attn_weights - mc)
160
+
161
+ if row_mask is not None:
162
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
163
+
164
+ p = exp_attn_weights / lc
165
+
166
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
167
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
168
+
169
+ D = (doc * oc).sum(dim=-1, keepdims=True)
170
+ ds = p * scale * (dp - D)
171
+
172
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
173
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
174
+
175
+ dqc.add_(dq_chunk)
176
+ dkc.add_(dk_chunk)
177
+ dvc.add_(dv_chunk)
178
+
179
+ return dq, dk, dv, None, None, None, None
180
+
181
+
182
+ class FlashAttnProcessor:
183
+ def __call__(
184
+ self,
185
+ attn: Attention,
186
+ hidden_states,
187
+ encoder_hidden_states=None,
188
+ attention_mask=None,
189
+ ) -> Any:
190
+ q_bucket_size = 512
191
+ k_bucket_size = 1024
192
+
193
+ h = attn.heads
194
+ q = attn.to_q(hidden_states)
195
+
196
+ encoder_hidden_states = (
197
+ encoder_hidden_states
198
+ if encoder_hidden_states is not None
199
+ else hidden_states
200
+ )
201
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
202
+
203
+ if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
204
+ context_k, context_v = attn.hypernetwork.forward(
205
+ hidden_states, encoder_hidden_states
206
+ )
207
+ context_k = context_k.to(hidden_states.dtype)
208
+ context_v = context_v.to(hidden_states.dtype)
209
+ else:
210
+ context_k = encoder_hidden_states
211
+ context_v = encoder_hidden_states
212
+
213
+ k = attn.to_k(context_k)
214
+ v = attn.to_v(context_v)
215
+ del encoder_hidden_states, hidden_states
216
+
217
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
218
+
219
+ out = FlashAttentionFunction.apply(
220
+ q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
221
+ )
222
+
223
+ out = rearrange(out, "b h n d -> b n (h d)")
224
+
225
+ out = attn.to_out[0](out)
226
+ out = attn.to_out[1](out)
227
+ return out
library/config_util.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import (
18
+ Any,
19
+ ExactSequence,
20
+ MultipleInvalid,
21
+ Object,
22
+ Required,
23
+ Schema,
24
+ )
25
+ from transformers import CLIPTokenizer
26
+
27
+ from . import train_util
28
+ from .train_util import (
29
+ DreamBoothSubset,
30
+ FineTuningSubset,
31
+ ControlNetSubset,
32
+ DreamBoothDataset,
33
+ FineTuningDataset,
34
+ ControlNetDataset,
35
+ DatasetGroup,
36
+ )
37
+ from .utils import setup_logging
38
+
39
+ setup_logging()
40
+ import logging
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ def add_config_arguments(parser: argparse.ArgumentParser):
46
+ parser.add_argument(
47
+ "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
48
+ )
49
+
50
+
51
+ # TODO: inherit Params class in Subset, Dataset
52
+
53
+
54
+ @dataclass
55
+ class BaseSubsetParams:
56
+ image_dir: Optional[str] = None
57
+ num_repeats: int = 1
58
+ shuffle_caption: bool = False
59
+ caption_separator: str = (",",)
60
+ keep_tokens: int = 0
61
+ keep_tokens_separator: str = (None,)
62
+ secondary_separator: Optional[str] = None
63
+ enable_wildcard: bool = False
64
+ color_aug: bool = False
65
+ flip_aug: bool = False
66
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
67
+ random_crop: bool = False
68
+ caption_prefix: Optional[str] = None
69
+ caption_suffix: Optional[str] = None
70
+ caption_dropout_rate: float = 0.0
71
+ caption_dropout_every_n_epochs: int = 0
72
+ caption_tag_dropout_rate: float = 0.0
73
+ token_warmup_min: int = 1
74
+ token_warmup_step: float = 0
75
+ custom_attributes: Optional[Dict[str, Any]] = None
76
+
77
+
78
+ @dataclass
79
+ class DreamBoothSubsetParams(BaseSubsetParams):
80
+ is_reg: bool = False
81
+ class_tokens: Optional[str] = None
82
+ caption_extension: str = ".caption"
83
+ cache_info: bool = False
84
+ alpha_mask: bool = False
85
+
86
+
87
+ @dataclass
88
+ class FineTuningSubsetParams(BaseSubsetParams):
89
+ metadata_file: Optional[str] = None
90
+ alpha_mask: bool = False
91
+
92
+
93
+ @dataclass
94
+ class ControlNetSubsetParams(BaseSubsetParams):
95
+ conditioning_data_dir: str = None
96
+ caption_extension: str = ".caption"
97
+ cache_info: bool = False
98
+
99
+
100
+ @dataclass
101
+ class BaseDatasetParams:
102
+ resolution: Optional[Tuple[int, int]] = None
103
+ network_multiplier: float = 1.0
104
+ debug_dataset: bool = False
105
+
106
+
107
+ @dataclass
108
+ class DreamBoothDatasetParams(BaseDatasetParams):
109
+ batch_size: int = 1
110
+ enable_bucket: bool = False
111
+ min_bucket_reso: int = 256
112
+ max_bucket_reso: int = 1024
113
+ bucket_reso_steps: int = 64
114
+ bucket_no_upscale: bool = False
115
+ prior_loss_weight: float = 1.0
116
+
117
+
118
+ @dataclass
119
+ class FineTuningDatasetParams(BaseDatasetParams):
120
+ batch_size: int = 1
121
+ enable_bucket: bool = False
122
+ min_bucket_reso: int = 256
123
+ max_bucket_reso: int = 1024
124
+ bucket_reso_steps: int = 64
125
+ bucket_no_upscale: bool = False
126
+
127
+
128
+ @dataclass
129
+ class ControlNetDatasetParams(BaseDatasetParams):
130
+ batch_size: int = 1
131
+ enable_bucket: bool = False
132
+ min_bucket_reso: int = 256
133
+ max_bucket_reso: int = 1024
134
+ bucket_reso_steps: int = 64
135
+ bucket_no_upscale: bool = False
136
+
137
+
138
+ @dataclass
139
+ class SubsetBlueprint:
140
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
141
+
142
+
143
+ @dataclass
144
+ class DatasetBlueprint:
145
+ is_dreambooth: bool
146
+ is_controlnet: bool
147
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
148
+ subsets: Sequence[SubsetBlueprint]
149
+
150
+
151
+ @dataclass
152
+ class DatasetGroupBlueprint:
153
+ datasets: Sequence[DatasetBlueprint]
154
+
155
+
156
+ @dataclass
157
+ class Blueprint:
158
+ dataset_group: DatasetGroupBlueprint
159
+
160
+
161
+ class ConfigSanitizer:
162
+ # @curry
163
+ @staticmethod
164
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
165
+ Schema(ExactSequence([klass, klass]))(value)
166
+ return tuple(value)
167
+
168
+ # @curry
169
+ @staticmethod
170
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
171
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
172
+ try:
173
+ Schema(klass)(value)
174
+ return (value, value)
175
+ except:
176
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
177
+
178
+ # subset schema
179
+ SUBSET_ASCENDABLE_SCHEMA = {
180
+ "color_aug": bool,
181
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
182
+ "flip_aug": bool,
183
+ "num_repeats": int,
184
+ "random_crop": bool,
185
+ "shuffle_caption": bool,
186
+ "keep_tokens": int,
187
+ "keep_tokens_separator": str,
188
+ "secondary_separator": str,
189
+ "caption_separator": str,
190
+ "enable_wildcard": bool,
191
+ "token_warmup_min": int,
192
+ "token_warmup_step": Any(float, int),
193
+ "caption_prefix": str,
194
+ "caption_suffix": str,
195
+ "custom_attributes": dict,
196
+ }
197
+ # DO means DropOut
198
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
199
+ "caption_dropout_every_n_epochs": int,
200
+ "caption_dropout_rate": Any(float, int),
201
+ "caption_tag_dropout_rate": Any(float, int),
202
+ }
203
+ # DB means DreamBooth
204
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
205
+ "caption_extension": str,
206
+ "class_tokens": str,
207
+ "cache_info": bool,
208
+ }
209
+ DB_SUBSET_DISTINCT_SCHEMA = {
210
+ Required("image_dir"): str,
211
+ "is_reg": bool,
212
+ "alpha_mask": bool,
213
+ }
214
+ # FT means FineTuning
215
+ FT_SUBSET_DISTINCT_SCHEMA = {
216
+ Required("metadata_file"): str,
217
+ "image_dir": str,
218
+ "alpha_mask": bool,
219
+ }
220
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
221
+ "caption_extension": str,
222
+ "cache_info": bool,
223
+ }
224
+ CN_SUBSET_DISTINCT_SCHEMA = {
225
+ Required("image_dir"): str,
226
+ Required("conditioning_data_dir"): str,
227
+ }
228
+
229
+ # datasets schema
230
+ DATASET_ASCENDABLE_SCHEMA = {
231
+ "batch_size": int,
232
+ "bucket_no_upscale": bool,
233
+ "bucket_reso_steps": int,
234
+ "enable_bucket": bool,
235
+ "max_bucket_reso": int,
236
+ "min_bucket_reso": int,
237
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
238
+ "network_multiplier": float,
239
+ }
240
+
241
+ # options handled by argparse but not handled by user config
242
+ ARGPARSE_SPECIFIC_SCHEMA = {
243
+ "debug_dataset": bool,
244
+ "max_token_length": Any(None, int),
245
+ "prior_loss_weight": Any(float, int),
246
+ }
247
+ # for handling default None value of argparse
248
+ ARGPARSE_NULLABLE_OPTNAMES = [
249
+ "face_crop_aug_range",
250
+ "resolution",
251
+ ]
252
+ # prepare map because option name may differ among argparse and user config
253
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
254
+ "train_batch_size": "batch_size",
255
+ "dataset_repeats": "num_repeats",
256
+ }
257
+
258
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
259
+ assert support_dreambooth or support_finetuning or support_controlnet, (
260
+ "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
261
+ + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
262
+ )
263
+
264
+ self.db_subset_schema = self.__merge_dict(
265
+ self.SUBSET_ASCENDABLE_SCHEMA,
266
+ self.DB_SUBSET_DISTINCT_SCHEMA,
267
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
268
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
269
+ )
270
+
271
+ self.ft_subset_schema = self.__merge_dict(
272
+ self.SUBSET_ASCENDABLE_SCHEMA,
273
+ self.FT_SUBSET_DISTINCT_SCHEMA,
274
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
275
+ )
276
+
277
+ self.cn_subset_schema = self.__merge_dict(
278
+ self.SUBSET_ASCENDABLE_SCHEMA,
279
+ self.CN_SUBSET_DISTINCT_SCHEMA,
280
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
281
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
282
+ )
283
+
284
+ self.db_dataset_schema = self.__merge_dict(
285
+ self.DATASET_ASCENDABLE_SCHEMA,
286
+ self.SUBSET_ASCENDABLE_SCHEMA,
287
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
288
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
289
+ {"subsets": [self.db_subset_schema]},
290
+ )
291
+
292
+ self.ft_dataset_schema = self.__merge_dict(
293
+ self.DATASET_ASCENDABLE_SCHEMA,
294
+ self.SUBSET_ASCENDABLE_SCHEMA,
295
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
296
+ {"subsets": [self.ft_subset_schema]},
297
+ )
298
+
299
+ self.cn_dataset_schema = self.__merge_dict(
300
+ self.DATASET_ASCENDABLE_SCHEMA,
301
+ self.SUBSET_ASCENDABLE_SCHEMA,
302
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
303
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
304
+ {"subsets": [self.cn_subset_schema]},
305
+ )
306
+
307
+ if support_dreambooth and support_finetuning:
308
+
309
+ def validate_flex_dataset(dataset_config: dict):
310
+ subsets_config = dataset_config.get("subsets", [])
311
+
312
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
313
+ return Schema(self.cn_dataset_schema)(dataset_config)
314
+ # check dataset meets FT style
315
+ # NOTE: all FT subsets should have "metadata_file"
316
+ elif all(["metadata_file" in subset for subset in subsets_config]):
317
+ return Schema(self.ft_dataset_schema)(dataset_config)
318
+ # check dataset meets DB style
319
+ # NOTE: all DB subsets should have no "metadata_file"
320
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
321
+ return Schema(self.db_dataset_schema)(dataset_config)
322
+ else:
323
+ raise voluptuous.Invalid(
324
+ "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
325
+ )
326
+
327
+ self.dataset_schema = validate_flex_dataset
328
+ elif support_dreambooth:
329
+ if support_controlnet:
330
+ self.dataset_schema = self.cn_dataset_schema
331
+ else:
332
+ self.dataset_schema = self.db_dataset_schema
333
+ elif support_finetuning:
334
+ self.dataset_schema = self.ft_dataset_schema
335
+ elif support_controlnet:
336
+ self.dataset_schema = self.cn_dataset_schema
337
+
338
+ self.general_schema = self.__merge_dict(
339
+ self.DATASET_ASCENDABLE_SCHEMA,
340
+ self.SUBSET_ASCENDABLE_SCHEMA,
341
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
342
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
343
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
344
+ )
345
+
346
+ self.user_config_validator = Schema(
347
+ {
348
+ "general": self.general_schema,
349
+ "datasets": [self.dataset_schema],
350
+ }
351
+ )
352
+
353
+ self.argparse_schema = self.__merge_dict(
354
+ self.general_schema,
355
+ self.ARGPARSE_SPECIFIC_SCHEMA,
356
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
357
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
358
+ )
359
+
360
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
361
+
362
+ def sanitize_user_config(self, user_config: dict) -> dict:
363
+ try:
364
+ return self.user_config_validator(user_config)
365
+ except MultipleInvalid:
366
+ # TODO: エラー発生時のメッセージをわかりやすくする
367
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
368
+ raise
369
+
370
+ # NOTE: In nature, argument parser result is not needed to be sanitize
371
+ # However this will help us to detect program bug
372
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
373
+ try:
374
+ return self.argparse_config_validator(argparse_namespace)
375
+ except MultipleInvalid:
376
+ # XXX: this should be a bug
377
+ logger.error(
378
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
379
+ )
380
+ raise
381
+
382
+ # NOTE: value would be overwritten by latter dict if there is already the same key
383
+ @staticmethod
384
+ def __merge_dict(*dict_list: dict) -> dict:
385
+ merged = {}
386
+ for schema in dict_list:
387
+ # merged |= schema
388
+ for k, v in schema.items():
389
+ merged[k] = v
390
+ return merged
391
+
392
+
393
+ class BlueprintGenerator:
394
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
395
+
396
+ def __init__(self, sanitizer: ConfigSanitizer):
397
+ self.sanitizer = sanitizer
398
+
399
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
400
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
401
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
402
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
403
+
404
+ # convert argparse namespace to dict like config
405
+ # NOTE: it is ok to have extra entries in dict
406
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
407
+ argparse_config = {
408
+ optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
409
+ }
410
+
411
+ general_config = sanitized_user_config.get("general", {})
412
+
413
+ dataset_blueprints = []
414
+ for dataset_config in sanitized_user_config.get("datasets", []):
415
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
416
+ subsets = dataset_config.get("subsets", [])
417
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
418
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
419
+ if is_controlnet:
420
+ subset_params_klass = ControlNetSubsetParams
421
+ dataset_params_klass = ControlNetDatasetParams
422
+ elif is_dreambooth:
423
+ subset_params_klass = DreamBoothSubsetParams
424
+ dataset_params_klass = DreamBoothDatasetParams
425
+ else:
426
+ subset_params_klass = FineTuningSubsetParams
427
+ dataset_params_klass = FineTuningDatasetParams
428
+
429
+ subset_blueprints = []
430
+ for subset_config in subsets:
431
+ params = self.generate_params_by_fallbacks(
432
+ subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
433
+ )
434
+ subset_blueprints.append(SubsetBlueprint(params))
435
+
436
+ params = self.generate_params_by_fallbacks(
437
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
438
+ )
439
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
440
+
441
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
442
+
443
+ return Blueprint(dataset_group_blueprint)
444
+
445
+ @staticmethod
446
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
447
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
448
+ search_value = BlueprintGenerator.search_value
449
+ default_params = asdict(param_klass())
450
+ param_names = default_params.keys()
451
+
452
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
453
+
454
+ return param_klass(**params)
455
+
456
+ @staticmethod
457
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
458
+ for cand in fallbacks:
459
+ value = cand.get(key)
460
+ if value is not None:
461
+ return value
462
+
463
+ return default_value
464
+
465
+
466
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
467
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
468
+
469
+ for dataset_blueprint in dataset_group_blueprint.datasets:
470
+ if dataset_blueprint.is_controlnet:
471
+ subset_klass = ControlNetSubset
472
+ dataset_klass = ControlNetDataset
473
+ elif dataset_blueprint.is_dreambooth:
474
+ subset_klass = DreamBoothSubset
475
+ dataset_klass = DreamBoothDataset
476
+ else:
477
+ subset_klass = FineTuningSubset
478
+ dataset_klass = FineTuningDataset
479
+
480
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
481
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
482
+ datasets.append(dataset)
483
+
484
+ # print info
485
+ info = ""
486
+ for i, dataset in enumerate(datasets):
487
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
488
+ is_controlnet = isinstance(dataset, ControlNetDataset)
489
+ info += dedent(
490
+ f"""\
491
+ [Dataset {i}]
492
+ batch_size: {dataset.batch_size}
493
+ resolution: {(dataset.width, dataset.height)}
494
+ enable_bucket: {dataset.enable_bucket}
495
+ network_multiplier: {dataset.network_multiplier}
496
+ """
497
+ )
498
+
499
+ if dataset.enable_bucket:
500
+ info += indent(
501
+ dedent(
502
+ f"""\
503
+ min_bucket_reso: {dataset.min_bucket_reso}
504
+ max_bucket_reso: {dataset.max_bucket_reso}
505
+ bucket_reso_steps: {dataset.bucket_reso_steps}
506
+ bucket_no_upscale: {dataset.bucket_no_upscale}
507
+ \n"""
508
+ ),
509
+ " ",
510
+ )
511
+ else:
512
+ info += "\n"
513
+
514
+ for j, subset in enumerate(dataset.subsets):
515
+ info += indent(
516
+ dedent(
517
+ f"""\
518
+ [Subset {j} of Dataset {i}]
519
+ image_dir: "{subset.image_dir}"
520
+ image_count: {subset.img_count}
521
+ num_repeats: {subset.num_repeats}
522
+ shuffle_caption: {subset.shuffle_caption}
523
+ keep_tokens: {subset.keep_tokens}
524
+ keep_tokens_separator: {subset.keep_tokens_separator}
525
+ caption_separator: {subset.caption_separator}
526
+ secondary_separator: {subset.secondary_separator}
527
+ enable_wildcard: {subset.enable_wildcard}
528
+ caption_dropout_rate: {subset.caption_dropout_rate}
529
+ caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
530
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
531
+ caption_prefix: {subset.caption_prefix}
532
+ caption_suffix: {subset.caption_suffix}
533
+ color_aug: {subset.color_aug}
534
+ flip_aug: {subset.flip_aug}
535
+ face_crop_aug_range: {subset.face_crop_aug_range}
536
+ random_crop: {subset.random_crop}
537
+ token_warmup_min: {subset.token_warmup_min}
538
+ token_warmup_step: {subset.token_warmup_step}
539
+ alpha_mask: {subset.alpha_mask}
540
+ custom_attributes: {subset.custom_attributes}
541
+ """
542
+ ),
543
+ " ",
544
+ )
545
+
546
+ if is_dreambooth:
547
+ info += indent(
548
+ dedent(
549
+ f"""\
550
+ is_reg: {subset.is_reg}
551
+ class_tokens: {subset.class_tokens}
552
+ caption_extension: {subset.caption_extension}
553
+ \n"""
554
+ ),
555
+ " ",
556
+ )
557
+ elif not is_controlnet:
558
+ info += indent(
559
+ dedent(
560
+ f"""\
561
+ metadata_file: {subset.metadata_file}
562
+ \n"""
563
+ ),
564
+ " ",
565
+ )
566
+
567
+ logger.info(f"{info}")
568
+
569
+ # make buckets first because it determines the length of dataset
570
+ # and set the same seed for all datasets
571
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
572
+ for i, dataset in enumerate(datasets):
573
+ logger.info(f"[Dataset {i}]")
574
+ dataset.make_buckets()
575
+ dataset.set_seed(seed)
576
+
577
+ return DatasetGroup(datasets)
578
+
579
+
580
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
581
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
582
+ tokens = name.split("_")
583
+ try:
584
+ n_repeats = int(tokens[0])
585
+ except ValueError as e:
586
+ logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
587
+ return 0, ""
588
+ caption_by_folder = "_".join(tokens[1:])
589
+ return n_repeats, caption_by_folder
590
+
591
+ def generate(base_dir: Optional[str], is_reg: bool):
592
+ if base_dir is None:
593
+ return []
594
+
595
+ base_dir: Path = Path(base_dir)
596
+ if not base_dir.is_dir():
597
+ return []
598
+
599
+ subsets_config = []
600
+ for subdir in base_dir.iterdir():
601
+ if not subdir.is_dir():
602
+ continue
603
+
604
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
605
+ if num_repeats < 1:
606
+ continue
607
+
608
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
609
+ subsets_config.append(subset_config)
610
+
611
+ return subsets_config
612
+
613
+ subsets_config = []
614
+ subsets_config += generate(train_data_dir, False)
615
+ subsets_config += generate(reg_data_dir, True)
616
+
617
+ return subsets_config
618
+
619
+
620
+ def generate_controlnet_subsets_config_by_subdirs(
621
+ train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
622
+ ):
623
+ def generate(base_dir: Optional[str]):
624
+ if base_dir is None:
625
+ return []
626
+
627
+ base_dir: Path = Path(base_dir)
628
+ if not base_dir.is_dir():
629
+ return []
630
+
631
+ subsets_config = []
632
+ subset_config = {
633
+ "image_dir": train_data_dir,
634
+ "conditioning_data_dir": conditioning_data_dir,
635
+ "caption_extension": caption_extension,
636
+ "num_repeats": 1,
637
+ }
638
+ subsets_config.append(subset_config)
639
+
640
+ return subsets_config
641
+
642
+ subsets_config = []
643
+ subsets_config += generate(train_data_dir)
644
+
645
+ return subsets_config
646
+
647
+
648
+ def load_user_config(file: str) -> dict:
649
+ file: Path = Path(file)
650
+ if not file.is_file():
651
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
652
+
653
+ if file.name.lower().endswith(".json"):
654
+ try:
655
+ with open(file, "r") as f:
656
+ config = json.load(f)
657
+ except Exception:
658
+ logger.error(
659
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
660
+ )
661
+ raise
662
+ elif file.name.lower().endswith(".toml"):
663
+ try:
664
+ config = toml.load(file)
665
+ except Exception:
666
+ logger.error(
667
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
668
+ )
669
+ raise
670
+ else:
671
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
672
+
673
+ return config
674
+
675
+
676
+ # for config test
677
+ if __name__ == "__main__":
678
+ parser = argparse.ArgumentParser()
679
+ parser.add_argument("--support_dreambooth", action="store_true")
680
+ parser.add_argument("--support_finetuning", action="store_true")
681
+ parser.add_argument("--support_controlnet", action="store_true")
682
+ parser.add_argument("--support_dropout", action="store_true")
683
+ parser.add_argument("dataset_config")
684
+ config_args, remain = parser.parse_known_args()
685
+
686
+ parser = argparse.ArgumentParser()
687
+ train_util.add_dataset_arguments(
688
+ parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
689
+ )
690
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
691
+ argparse_namespace = parser.parse_args(remain)
692
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
693
+
694
+ logger.info("[argparse_namespace]")
695
+ logger.info(f"{vars(argparse_namespace)}")
696
+
697
+ user_config = load_user_config(config_args.dataset_config)
698
+
699
+ logger.info("")
700
+ logger.info("[user_config]")
701
+ logger.info(f"{user_config}")
702
+
703
+ sanitizer = ConfigSanitizer(
704
+ config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
705
+ )
706
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
707
+
708
+ logger.info("")
709
+ logger.info("[sanitized_user_config]")
710
+ logger.info(f"{sanitized_user_config}")
711
+
712
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
713
+
714
+ logger.info("")
715
+ logger.info("[blueprint]")
716
+ logger.info(f"{blueprint}")
library/custom_offloading_utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import time
3
+ from typing import Optional
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from library.device_utils import clean_memory_on_device
8
+
9
+
10
+ def synchronize_device(device: torch.device):
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize()
13
+ elif device.type == "xpu":
14
+ torch.xpu.synchronize()
15
+ elif device.type == "mps":
16
+ torch.mps.synchronize()
17
+
18
+
19
+ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
20
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
21
+
22
+ weight_swap_jobs = []
23
+
24
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
25
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
26
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
27
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
28
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
29
+
30
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
31
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
32
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
33
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
34
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
35
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
36
+ else:
37
+ if module_to_cuda.weight.data.device.type != device.type:
38
+ # print(
39
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
40
+ # )
41
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
42
+
43
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
44
+
45
+ stream = torch.cuda.Stream()
46
+ with torch.cuda.stream(stream):
47
+ # cuda to cpu
48
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
49
+ cuda_data_view.record_stream(stream)
50
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
51
+
52
+ stream.synchronize()
53
+
54
+ # cpu to cuda
55
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
56
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
57
+ module_to_cuda.weight.data = cuda_data_view
58
+
59
+ stream.synchronize()
60
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
61
+
62
+
63
+ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
64
+ """
65
+ not tested
66
+ """
67
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
68
+
69
+ weight_swap_jobs = []
70
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
71
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
72
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
73
+
74
+ # device to cpu
75
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
76
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
77
+
78
+ synchronize_device()
79
+
80
+ # cpu to device
81
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
82
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
83
+ module_to_cuda.weight.data = cuda_data_view
84
+
85
+ synchronize_device()
86
+
87
+
88
+ def weighs_to_device(layer: nn.Module, device: torch.device):
89
+ for module in layer.modules():
90
+ if hasattr(module, "weight") and module.weight is not None:
91
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
92
+
93
+
94
+ class Offloader:
95
+ """
96
+ common offloading class
97
+ """
98
+
99
+ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
100
+ self.num_blocks = num_blocks
101
+ self.blocks_to_swap = blocks_to_swap
102
+ self.device = device
103
+ self.debug = debug
104
+
105
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
106
+ self.futures = {}
107
+ self.cuda_available = device.type == "cuda"
108
+
109
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
110
+ if self.cuda_available:
111
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
112
+ else:
113
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
114
+
115
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
116
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
117
+ if self.debug:
118
+ start_time = time.perf_counter()
119
+ print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
120
+
121
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
122
+
123
+ if self.debug:
124
+ print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
125
+ return bidx_to_cpu, bidx_to_cuda # , event
126
+
127
+ block_to_cpu = blocks[block_idx_to_cpu]
128
+ block_to_cuda = blocks[block_idx_to_cuda]
129
+
130
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
131
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
132
+ )
133
+
134
+ def _wait_blocks_move(self, block_idx):
135
+ if block_idx not in self.futures:
136
+ return
137
+
138
+ if self.debug:
139
+ print(f"Wait for block {block_idx}")
140
+ start_time = time.perf_counter()
141
+
142
+ future = self.futures.pop(block_idx)
143
+ _, bidx_to_cuda = future.result()
144
+
145
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
146
+
147
+ if self.debug:
148
+ print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
149
+
150
+
151
+ class ModelOffloader(Offloader):
152
+ """
153
+ supports forward offloading
154
+ """
155
+
156
+ def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
157
+ super().__init__(num_blocks, blocks_to_swap, device, debug)
158
+
159
+ # register backward hooks
160
+ self.remove_handles = []
161
+ for i, block in enumerate(blocks):
162
+ hook = self.create_backward_hook(blocks, i)
163
+ if hook is not None:
164
+ handle = block.register_full_backward_hook(hook)
165
+ self.remove_handles.append(handle)
166
+
167
+ def __del__(self):
168
+ for handle in self.remove_handles:
169
+ handle.remove()
170
+
171
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
172
+ # -1 for 0-based index
173
+ num_blocks_propagated = self.num_blocks - block_index - 1
174
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
175
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
176
+
177
+ if not swapping and not waiting:
178
+ return None
179
+
180
+ # create hook
181
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
182
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
183
+ block_idx_to_wait = block_index - 1
184
+
185
+ def backward_hook(module, grad_input, grad_output):
186
+ if self.debug:
187
+ print(f"Backward hook for block {block_index}")
188
+
189
+ if swapping:
190
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
191
+ if waiting:
192
+ self._wait_blocks_move(block_idx_to_wait)
193
+ return None
194
+
195
+ return backward_hook
196
+
197
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
198
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
199
+ return
200
+
201
+ if self.debug:
202
+ print("Prepare block devices before forward")
203
+
204
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
205
+ b.to(self.device)
206
+ weighs_to_device(b, self.device) # make sure weights are on device
207
+
208
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
209
+ b.to(self.device) # move block to device first
210
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
211
+
212
+ synchronize_device(self.device)
213
+ clean_memory_on_device(self.device)
214
+
215
+ def wait_for_block(self, block_idx: int):
216
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
217
+ return
218
+ self._wait_blocks_move(block_idx)
219
+
220
+ def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
221
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
222
+ return
223
+ if block_idx >= self.blocks_to_swap:
224
+ return
225
+ block_idx_to_cpu = block_idx
226
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
227
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
library/custom_train_functions.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
15
+ if hasattr(noise_scheduler, "all_snr"):
16
+ return
17
+
18
+ alphas_cumprod = noise_scheduler.alphas_cumprod
19
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
20
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
21
+ alpha = sqrt_alphas_cumprod
22
+ sigma = sqrt_one_minus_alphas_cumprod
23
+ all_snr = (alpha / sigma) ** 2
24
+
25
+ noise_scheduler.all_snr = all_snr.to(device)
26
+
27
+
28
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
29
+ # fix beta: zero terminal SNR
30
+ logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
31
+
32
+ def enforce_zero_terminal_snr(betas):
33
+ # Convert betas to alphas_bar_sqrt
34
+ alphas = 1 - betas
35
+ alphas_bar = alphas.cumprod(0)
36
+ alphas_bar_sqrt = alphas_bar.sqrt()
37
+
38
+ # Store old values.
39
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
40
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
41
+ # Shift so last timestep is zero.
42
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
43
+ # Scale so first timestep is back to old value.
44
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
45
+
46
+ # Convert alphas_bar_sqrt to betas
47
+ alphas_bar = alphas_bar_sqrt**2
48
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
49
+ alphas = torch.cat([alphas_bar[0:1], alphas])
50
+ betas = 1 - alphas
51
+ return betas
52
+
53
+ betas = noise_scheduler.betas
54
+ betas = enforce_zero_terminal_snr(betas)
55
+ alphas = 1.0 - betas
56
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
57
+
58
+ # logger.info(f"original: {noise_scheduler.betas}")
59
+ # logger.info(f"fixed: {betas}")
60
+
61
+ noise_scheduler.betas = betas
62
+ noise_scheduler.alphas = alphas
63
+ noise_scheduler.alphas_cumprod = alphas_cumprod
64
+
65
+
66
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
67
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
68
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
69
+ if v_prediction:
70
+ snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
71
+ else:
72
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
73
+ loss = loss * snr_weight
74
+ return loss
75
+
76
+
77
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
78
+ scale = get_snr_scale(timesteps, noise_scheduler)
79
+ loss = loss * scale
80
+ return loss
81
+
82
+
83
+ def get_snr_scale(timesteps, noise_scheduler):
84
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
85
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
86
+ scale = snr_t / (snr_t + 1)
87
+ # # show debug info
88
+ # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
89
+ return scale
90
+
91
+
92
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
93
+ scale = get_snr_scale(timesteps, noise_scheduler)
94
+ # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
95
+ loss = loss + loss / scale * v_pred_like_loss
96
+ return loss
97
+
98
+
99
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
100
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
101
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
102
+ if v_prediction:
103
+ weight = 1 / (snr_t + 1)
104
+ else:
105
+ weight = 1 / torch.sqrt(snr_t)
106
+ loss = weight * loss
107
+ return loss
108
+
109
+
110
+ # TODO train_utilと分散しているのでどちらかに寄せる
111
+
112
+
113
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
114
+ parser.add_argument(
115
+ "--min_snr_gamma",
116
+ type=float,
117
+ default=None,
118
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
119
+ )
120
+ parser.add_argument(
121
+ "--scale_v_pred_loss_like_noise_pred",
122
+ action="store_true",
123
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
124
+ )
125
+ parser.add_argument(
126
+ "--v_pred_like_loss",
127
+ type=float,
128
+ default=None,
129
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
130
+ )
131
+ parser.add_argument(
132
+ "--debiased_estimation_loss",
133
+ action="store_true",
134
+ help="debiased estimation loss / debiased estimation loss",
135
+ )
136
+ if support_weighted_captions:
137
+ parser.add_argument(
138
+ "--weighted_captions",
139
+ action="store_true",
140
+ default=False,
141
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
142
+ )
143
+
144
+
145
+ re_attention = re.compile(
146
+ r"""
147
+ \\\(|
148
+ \\\)|
149
+ \\\[|
150
+ \\]|
151
+ \\\\|
152
+ \\|
153
+ \(|
154
+ \[|
155
+ :([+-]?[.\d]+)\)|
156
+ \)|
157
+ ]|
158
+ [^\\()\[\]:]+|
159
+ :
160
+ """,
161
+ re.X,
162
+ )
163
+
164
+
165
+ def parse_prompt_attention(text):
166
+ """
167
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
168
+ Accepted tokens are:
169
+ (abc) - increases attention to abc by a multiplier of 1.1
170
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
171
+ [abc] - decreases attention to abc by a multiplier of 1.1
172
+ \( - literal character '('
173
+ \[ - literal character '['
174
+ \) - literal character ')'
175
+ \] - literal character ']'
176
+ \\ - literal character '\'
177
+ anything else - just text
178
+ >>> parse_prompt_attention('normal text')
179
+ [['normal text', 1.0]]
180
+ >>> parse_prompt_attention('an (important) word')
181
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
182
+ >>> parse_prompt_attention('(unbalanced')
183
+ [['unbalanced', 1.1]]
184
+ >>> parse_prompt_attention('\(literal\]')
185
+ [['(literal]', 1.0]]
186
+ >>> parse_prompt_attention('(unnecessary)(parens)')
187
+ [['unnecessaryparens', 1.1]]
188
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
189
+ [['a ', 1.0],
190
+ ['house', 1.5730000000000004],
191
+ [' ', 1.1],
192
+ ['on', 1.0],
193
+ [' a ', 1.1],
194
+ ['hill', 0.55],
195
+ [', sun, ', 1.1],
196
+ ['sky', 1.4641000000000006],
197
+ ['.', 1.1]]
198
+ """
199
+
200
+ res = []
201
+ round_brackets = []
202
+ square_brackets = []
203
+
204
+ round_bracket_multiplier = 1.1
205
+ square_bracket_multiplier = 1 / 1.1
206
+
207
+ def multiply_range(start_position, multiplier):
208
+ for p in range(start_position, len(res)):
209
+ res[p][1] *= multiplier
210
+
211
+ for m in re_attention.finditer(text):
212
+ text = m.group(0)
213
+ weight = m.group(1)
214
+
215
+ if text.startswith("\\"):
216
+ res.append([text[1:], 1.0])
217
+ elif text == "(":
218
+ round_brackets.append(len(res))
219
+ elif text == "[":
220
+ square_brackets.append(len(res))
221
+ elif weight is not None and len(round_brackets) > 0:
222
+ multiply_range(round_brackets.pop(), float(weight))
223
+ elif text == ")" and len(round_brackets) > 0:
224
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
225
+ elif text == "]" and len(square_brackets) > 0:
226
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
227
+ else:
228
+ res.append([text, 1.0])
229
+
230
+ for pos in round_brackets:
231
+ multiply_range(pos, round_bracket_multiplier)
232
+
233
+ for pos in square_brackets:
234
+ multiply_range(pos, square_bracket_multiplier)
235
+
236
+ if len(res) == 0:
237
+ res = [["", 1.0]]
238
+
239
+ # merge runs of identical weights
240
+ i = 0
241
+ while i + 1 < len(res):
242
+ if res[i][1] == res[i + 1][1]:
243
+ res[i][0] += res[i + 1][0]
244
+ res.pop(i + 1)
245
+ else:
246
+ i += 1
247
+
248
+ return res
249
+
250
+
251
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
252
+ r"""
253
+ Tokenize a list of prompts and return its tokens with weights of each token.
254
+
255
+ No padding, starting or ending token is included.
256
+ """
257
+ tokens = []
258
+ weights = []
259
+ truncated = False
260
+ for text in prompt:
261
+ texts_and_weights = parse_prompt_attention(text)
262
+ text_token = []
263
+ text_weight = []
264
+ for word, weight in texts_and_weights:
265
+ # tokenize and discard the starting and the ending token
266
+ token = tokenizer(word).input_ids[1:-1]
267
+ text_token += token
268
+ # copy the weight by length of token
269
+ text_weight += [weight] * len(token)
270
+ # stop if the text is too long (longer than truncation limit)
271
+ if len(text_token) > max_length:
272
+ truncated = True
273
+ break
274
+ # truncate
275
+ if len(text_token) > max_length:
276
+ truncated = True
277
+ text_token = text_token[:max_length]
278
+ text_weight = text_weight[:max_length]
279
+ tokens.append(text_token)
280
+ weights.append(text_weight)
281
+ if truncated:
282
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
283
+ return tokens, weights
284
+
285
+
286
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
287
+ r"""
288
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
289
+ """
290
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
291
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
292
+ for i in range(len(tokens)):
293
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
294
+ if no_boseos_middle:
295
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
296
+ else:
297
+ w = []
298
+ if len(weights[i]) == 0:
299
+ w = [1.0] * weights_length
300
+ else:
301
+ for j in range(max_embeddings_multiples):
302
+ w.append(1.0) # weight for starting token in this chunk
303
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
304
+ w.append(1.0) # weight for ending token in this chunk
305
+ w += [1.0] * (weights_length - len(w))
306
+ weights[i] = w[:]
307
+
308
+ return tokens, weights
309
+
310
+
311
+ def get_unweighted_text_embeddings(
312
+ tokenizer,
313
+ text_encoder,
314
+ text_input: torch.Tensor,
315
+ chunk_length: int,
316
+ clip_skip: int,
317
+ eos: int,
318
+ pad: int,
319
+ no_boseos_middle: Optional[bool] = True,
320
+ ):
321
+ """
322
+ When the length of tokens is a multiple of the capacity of the text encoder,
323
+ it should be split into chunks and sent to the text encoder individually.
324
+ """
325
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
326
+ if max_embeddings_multiples > 1:
327
+ text_embeddings = []
328
+ for i in range(max_embeddings_multiples):
329
+ # extract the i-th chunk
330
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
331
+
332
+ # cover the head and the tail by the starting and the ending tokens
333
+ text_input_chunk[:, 0] = text_input[0, 0]
334
+ if pad == eos: # v1
335
+ text_input_chunk[:, -1] = text_input[0, -1]
336
+ else: # v2
337
+ for j in range(len(text_input_chunk)):
338
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
339
+ text_input_chunk[j, -1] = eos
340
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
341
+ text_input_chunk[j, 1] = eos
342
+
343
+ if clip_skip is None or clip_skip == 1:
344
+ text_embedding = text_encoder(text_input_chunk)[0]
345
+ else:
346
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
347
+ text_embedding = enc_out["hidden_states"][-clip_skip]
348
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
349
+
350
+ if no_boseos_middle:
351
+ if i == 0:
352
+ # discard the ending token
353
+ text_embedding = text_embedding[:, :-1]
354
+ elif i == max_embeddings_multiples - 1:
355
+ # discard the starting token
356
+ text_embedding = text_embedding[:, 1:]
357
+ else:
358
+ # discard both starting and ending tokens
359
+ text_embedding = text_embedding[:, 1:-1]
360
+
361
+ text_embeddings.append(text_embedding)
362
+ text_embeddings = torch.concat(text_embeddings, axis=1)
363
+ else:
364
+ if clip_skip is None or clip_skip == 1:
365
+ text_embeddings = text_encoder(text_input)[0]
366
+ else:
367
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
368
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
369
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
370
+ return text_embeddings
371
+
372
+
373
+ def get_weighted_text_embeddings(
374
+ tokenizer,
375
+ text_encoder,
376
+ prompt: Union[str, List[str]],
377
+ device,
378
+ max_embeddings_multiples: Optional[int] = 3,
379
+ no_boseos_middle: Optional[bool] = False,
380
+ clip_skip=None,
381
+ ):
382
+ r"""
383
+ Prompts can be assigned with local weights using brackets. For example,
384
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
385
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
386
+
387
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
388
+
389
+ Args:
390
+ prompt (`str` or `List[str]`):
391
+ The prompt or prompts to guide the image generation.
392
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
393
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
394
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
395
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
396
+ ending token in each of the chunk in the middle.
397
+ skip_parsing (`bool`, *optional*, defaults to `False`):
398
+ Skip the parsing of brackets.
399
+ skip_weighting (`bool`, *optional*, defaults to `False`):
400
+ Skip the weighting. When the parsing is skipped, it is forced True.
401
+ """
402
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
403
+ if isinstance(prompt, str):
404
+ prompt = [prompt]
405
+
406
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
407
+
408
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
409
+ max_length = max([len(token) for token in prompt_tokens])
410
+
411
+ max_embeddings_multiples = min(
412
+ max_embeddings_multiples,
413
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
414
+ )
415
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
416
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
417
+
418
+ # pad the length of tokens and weights
419
+ bos = tokenizer.bos_token_id
420
+ eos = tokenizer.eos_token_id
421
+ pad = tokenizer.pad_token_id
422
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
423
+ prompt_tokens,
424
+ prompt_weights,
425
+ max_length,
426
+ bos,
427
+ eos,
428
+ no_boseos_middle=no_boseos_middle,
429
+ chunk_length=tokenizer.model_max_length,
430
+ )
431
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
432
+
433
+ # get the embeddings
434
+ text_embeddings = get_unweighted_text_embeddings(
435
+ tokenizer,
436
+ text_encoder,
437
+ prompt_tokens,
438
+ tokenizer.model_max_length,
439
+ clip_skip,
440
+ eos,
441
+ pad,
442
+ no_boseos_middle=no_boseos_middle,
443
+ )
444
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
445
+
446
+ # assign weights to the prompts and normalize in the sense of mean
447
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
448
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
449
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
450
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
451
+
452
+ return text_embeddings
453
+
454
+
455
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
456
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
457
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
458
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
459
+ for i in range(iterations):
460
+ r = random.random() * 2 + 2 # Rather than always going 2x,
461
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
462
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
463
+ if wn == 1 or hn == 1:
464
+ break # Lowest resolution is 1x1
465
+ return noise / noise.std() # Scaled back to roughly unit variance
466
+
467
+
468
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
469
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
470
+ if noise_offset is None:
471
+ return noise
472
+ if adaptive_noise_scale is not None:
473
+ # latent shape: (batch_size, channels, height, width)
474
+ # abs mean value for each channel
475
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
476
+
477
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
478
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
479
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
480
+
481
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
482
+ return noise
483
+
484
+
485
+ def apply_masked_loss(loss, batch):
486
+ if "conditioning_images" in batch:
487
+ # conditioning image is -1 to 1. we need to convert it to 0 to 1
488
+ mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
489
+ mask_image = mask_image / 2 + 0.5
490
+ # print(f"conditioning_image: {mask_image.shape}")
491
+ elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
492
+ # alpha mask is 0 to 1
493
+ mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
494
+ # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
495
+ else:
496
+ return loss
497
+
498
+ # resize to the same size as the loss
499
+ mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
500
+ loss = loss * mask_image
501
+ return loss
502
+
503
+
504
+ """
505
+ ##########################################
506
+ # Perlin Noise
507
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
508
+ delta = (res[0] / shape[0], res[1] / shape[1])
509
+ d = (shape[0] // res[0], shape[1] // res[1])
510
+
511
+ grid = (
512
+ torch.stack(
513
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
514
+ dim=-1,
515
+ )
516
+ % 1
517
+ )
518
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
519
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
520
+
521
+ tile_grads = (
522
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
523
+ .repeat_interleave(d[0], 0)
524
+ .repeat_interleave(d[1], 1)
525
+ )
526
+ dot = lambda grad, shift: (
527
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
528
+ * grad[: shape[0], : shape[1]]
529
+ ).sum(dim=-1)
530
+
531
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
532
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
533
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
534
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
535
+ t = fade(grid[: shape[0], : shape[1]])
536
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
537
+
538
+
539
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
540
+ noise = torch.zeros(shape, device=device)
541
+ frequency = 1
542
+ amplitude = 1
543
+ for _ in range(octaves):
544
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
545
+ frequency *= 2
546
+ amplitude *= persistence
547
+ return noise
548
+
549
+
550
+ def perlin_noise(noise, device, octaves):
551
+ _, c, w, h = noise.shape
552
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
553
+ noise_perlin = []
554
+ for _ in range(c):
555
+ noise_perlin.append(perlin())
556
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
557
+ noise += noise_perlin # broadcast for each batch
558
+ return noise / noise.std() # Scaled back to roughly unit variance
559
+ """
library/deepspeed_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from accelerate import DeepSpeedPlugin, Accelerator
5
+
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15
+ # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16
+ parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17
+ parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18
+ parser.add_argument(
19
+ "--offload_optimizer_device",
20
+ type=str,
21
+ default=None,
22
+ choices=[None, "cpu", "nvme"],
23
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24
+ )
25
+ parser.add_argument(
26
+ "--offload_optimizer_nvme_path",
27
+ type=str,
28
+ default=None,
29
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30
+ )
31
+ parser.add_argument(
32
+ "--offload_param_device",
33
+ type=str,
34
+ default=None,
35
+ choices=[None, "cpu", "nvme"],
36
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37
+ )
38
+ parser.add_argument(
39
+ "--offload_param_nvme_path",
40
+ type=str,
41
+ default=None,
42
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43
+ )
44
+ parser.add_argument(
45
+ "--zero3_init_flag",
46
+ action="store_true",
47
+ help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48
+ "Only applicable with ZeRO Stage-3.",
49
+ )
50
+ parser.add_argument(
51
+ "--zero3_save_16bit_model",
52
+ action="store_true",
53
+ help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54
+ )
55
+ parser.add_argument(
56
+ "--fp16_master_weights_and_gradients",
57
+ action="store_true",
58
+ help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59
+ )
60
+
61
+
62
+ def prepare_deepspeed_args(args: argparse.Namespace):
63
+ if not args.deepspeed:
64
+ return
65
+
66
+ # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67
+ args.max_data_loader_n_workers = 1
68
+
69
+
70
+ def prepare_deepspeed_plugin(args: argparse.Namespace):
71
+ if not args.deepspeed:
72
+ return None
73
+
74
+ try:
75
+ import deepspeed
76
+ except ImportError as e:
77
+ logger.error(
78
+ "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79
+ )
80
+ exit(1)
81
+
82
+ deepspeed_plugin = DeepSpeedPlugin(
83
+ zero_stage=args.zero_stage,
84
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
85
+ gradient_clipping=args.max_grad_norm,
86
+ offload_optimizer_device=args.offload_optimizer_device,
87
+ offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88
+ offload_param_device=args.offload_param_device,
89
+ offload_param_nvme_path=args.offload_param_nvme_path,
90
+ zero3_init_flag=args.zero3_init_flag,
91
+ zero3_save_16bit_model=args.zero3_save_16bit_model,
92
+ )
93
+ deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94
+ deepspeed_plugin.deepspeed_config["train_batch_size"] = (
95
+ args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96
+ )
97
+ deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98
+ if args.mixed_precision.lower() == "fp16":
99
+ deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100
+ if args.full_fp16 or args.fp16_master_weights_and_gradients:
101
+ if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102
+ deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103
+ logger.info("[DeepSpeed] full fp16 enable.")
104
+ else:
105
+ logger.info(
106
+ "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107
+ )
108
+
109
+ if args.offload_optimizer_device is not None:
110
+ logger.info("[DeepSpeed] start to manually build cpu_adam.")
111
+ deepspeed.ops.op_builder.CPUAdamBuilder().load()
112
+ logger.info("[DeepSpeed] building cpu_adam done.")
113
+
114
+ return deepspeed_plugin
115
+
116
+
117
+ # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118
+ def prepare_deepspeed_model(args: argparse.Namespace, **models):
119
+ # remove None from models
120
+ models = {k: v for k, v in models.items() if v is not None}
121
+
122
+ class DeepSpeedWrapper(torch.nn.Module):
123
+ def __init__(self, **kw_models) -> None:
124
+ super().__init__()
125
+ self.models = torch.nn.ModuleDict()
126
+
127
+ for key, model in kw_models.items():
128
+ if isinstance(model, list):
129
+ model = torch.nn.ModuleList(model)
130
+ assert isinstance(
131
+ model, torch.nn.Module
132
+ ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133
+ self.models.update(torch.nn.ModuleDict({key: model}))
134
+
135
+ def get_models(self):
136
+ return self.models
137
+
138
+ ds_model = DeepSpeedWrapper(**models)
139
+ return ds_model
library/device_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+
4
+ import torch
5
+
6
+ try:
7
+ HAS_CUDA = torch.cuda.is_available()
8
+ except Exception:
9
+ HAS_CUDA = False
10
+
11
+ try:
12
+ HAS_MPS = torch.backends.mps.is_available()
13
+ except Exception:
14
+ HAS_MPS = False
15
+
16
+ try:
17
+ import intel_extension_for_pytorch as ipex # noqa
18
+
19
+ HAS_XPU = torch.xpu.is_available()
20
+ except Exception:
21
+ HAS_XPU = False
22
+
23
+
24
+ def clean_memory():
25
+ gc.collect()
26
+ if HAS_CUDA:
27
+ torch.cuda.empty_cache()
28
+ if HAS_XPU:
29
+ torch.xpu.empty_cache()
30
+ if HAS_MPS:
31
+ torch.mps.empty_cache()
32
+
33
+
34
+ def clean_memory_on_device(device: torch.device):
35
+ r"""
36
+ Clean memory on the specified device, will be called from training scripts.
37
+ """
38
+ gc.collect()
39
+
40
+ # device may "cuda" or "cuda:0", so we need to check the type of device
41
+ if device.type == "cuda":
42
+ torch.cuda.empty_cache()
43
+ if device.type == "xpu":
44
+ torch.xpu.empty_cache()
45
+ if device.type == "mps":
46
+ torch.mps.empty_cache()
47
+
48
+
49
+ @functools.lru_cache(maxsize=None)
50
+ def get_preferred_device() -> torch.device:
51
+ r"""
52
+ Do not call this function from training scripts. Use accelerator.device instead.
53
+ """
54
+ if HAS_CUDA:
55
+ device = torch.device("cuda")
56
+ elif HAS_XPU:
57
+ device = torch.device("xpu")
58
+ elif HAS_MPS:
59
+ device = torch.device("mps")
60
+ else:
61
+ device = torch.device("cpu")
62
+ print(f"get_preferred_device() -> {device}")
63
+ return device
64
+
65
+
66
+ def init_ipex():
67
+ """
68
+ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
69
+
70
+ This function should run right after importing torch and before doing anything else.
71
+
72
+ If IPEX is not available, this function does nothing.
73
+ """
74
+ try:
75
+ if HAS_XPU:
76
+ from library.ipex import ipex_init
77
+
78
+ is_initialized, error_message = ipex_init()
79
+ if not is_initialized:
80
+ print("failed to initialize ipex:", error_message)
81
+ else:
82
+ return
83
+ except Exception as e:
84
+ print("failed to initialize ipex:", e)
library/flux_models.py ADDED
@@ -0,0 +1,1237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from FLUX repo: https://github.com/black-forest-labs/flux
2
+ # license: Apache-2.0 License
3
+
4
+
5
+ from concurrent.futures import Future, ThreadPoolExecutor
6
+ from dataclasses import dataclass
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from library import utils
13
+ from library.device_utils import init_ipex, clean_memory_on_device
14
+
15
+ init_ipex()
16
+
17
+ import torch
18
+ from einops import rearrange
19
+ from torch import Tensor, nn
20
+ from torch.utils.checkpoint import checkpoint
21
+ from library import custom_offloading_utils
22
+
23
+ # USE_REENTRANT = True
24
+
25
+
26
+ @dataclass
27
+ class FluxParams:
28
+ in_channels: int
29
+ vec_in_dim: int
30
+ context_in_dim: int
31
+ hidden_size: int
32
+ mlp_ratio: float
33
+ num_heads: int
34
+ depth: int
35
+ depth_single_blocks: int
36
+ axes_dim: list[int]
37
+ theta: int
38
+ qkv_bias: bool
39
+ guidance_embed: bool
40
+
41
+
42
+ # region autoencoder
43
+
44
+
45
+ @dataclass
46
+ class AutoEncoderParams:
47
+ resolution: int
48
+ in_channels: int
49
+ ch: int
50
+ out_ch: int
51
+ ch_mult: list[int]
52
+ num_res_blocks: int
53
+ z_channels: int
54
+ scale_factor: float
55
+ shift_factor: float
56
+
57
+
58
+ def swish(x: Tensor) -> Tensor:
59
+ return x * torch.sigmoid(x)
60
+
61
+
62
+ class AttnBlock(nn.Module):
63
+ def __init__(self, in_channels: int):
64
+ super().__init__()
65
+ self.in_channels = in_channels
66
+
67
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
68
+
69
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
70
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
71
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
72
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
73
+
74
+ def attention(self, h_: Tensor) -> Tensor:
75
+ h_ = self.norm(h_)
76
+ q = self.q(h_)
77
+ k = self.k(h_)
78
+ v = self.v(h_)
79
+
80
+ b, c, h, w = q.shape
81
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
82
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
83
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
84
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
85
+
86
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ return x + self.proj_out(self.attention(x))
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, in_channels: int, out_channels: int):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+ out_channels = in_channels if out_channels is None else out_channels
97
+ self.out_channels = out_channels
98
+
99
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
100
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
101
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
102
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
105
+
106
+ def forward(self, x):
107
+ h = x
108
+ h = self.norm1(h)
109
+ h = swish(h)
110
+ h = self.conv1(h)
111
+
112
+ h = self.norm2(h)
113
+ h = swish(h)
114
+ h = self.conv2(h)
115
+
116
+ if self.in_channels != self.out_channels:
117
+ x = self.nin_shortcut(x)
118
+
119
+ return x + h
120
+
121
+
122
+ class Downsample(nn.Module):
123
+ def __init__(self, in_channels: int):
124
+ super().__init__()
125
+ # no asymmetric padding in torch conv, must do it ourselves
126
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
127
+
128
+ def forward(self, x: Tensor):
129
+ pad = (0, 1, 0, 1)
130
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
131
+ x = self.conv(x)
132
+ return x
133
+
134
+
135
+ class Upsample(nn.Module):
136
+ def __init__(self, in_channels: int):
137
+ super().__init__()
138
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
139
+
140
+ def forward(self, x: Tensor):
141
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
142
+ x = self.conv(x)
143
+ return x
144
+
145
+
146
+ class Encoder(nn.Module):
147
+ def __init__(
148
+ self,
149
+ resolution: int,
150
+ in_channels: int,
151
+ ch: int,
152
+ ch_mult: list[int],
153
+ num_res_blocks: int,
154
+ z_channels: int,
155
+ ):
156
+ super().__init__()
157
+ self.ch = ch
158
+ self.num_resolutions = len(ch_mult)
159
+ self.num_res_blocks = num_res_blocks
160
+ self.resolution = resolution
161
+ self.in_channels = in_channels
162
+ # downsampling
163
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
164
+
165
+ curr_res = resolution
166
+ in_ch_mult = (1,) + tuple(ch_mult)
167
+ self.in_ch_mult = in_ch_mult
168
+ self.down = nn.ModuleList()
169
+ block_in = self.ch
170
+ for i_level in range(self.num_resolutions):
171
+ block = nn.ModuleList()
172
+ attn = nn.ModuleList()
173
+ block_in = ch * in_ch_mult[i_level]
174
+ block_out = ch * ch_mult[i_level]
175
+ for _ in range(self.num_res_blocks):
176
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
177
+ block_in = block_out
178
+ down = nn.Module()
179
+ down.block = block
180
+ down.attn = attn
181
+ if i_level != self.num_resolutions - 1:
182
+ down.downsample = Downsample(block_in)
183
+ curr_res = curr_res // 2
184
+ self.down.append(down)
185
+
186
+ # middle
187
+ self.mid = nn.Module()
188
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
189
+ self.mid.attn_1 = AttnBlock(block_in)
190
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
191
+
192
+ # end
193
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
194
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
195
+
196
+ def forward(self, x: Tensor) -> Tensor:
197
+ # downsampling
198
+ hs = [self.conv_in(x)]
199
+ for i_level in range(self.num_resolutions):
200
+ for i_block in range(self.num_res_blocks):
201
+ h = self.down[i_level].block[i_block](hs[-1])
202
+ if len(self.down[i_level].attn) > 0:
203
+ h = self.down[i_level].attn[i_block](h)
204
+ hs.append(h)
205
+ if i_level != self.num_resolutions - 1:
206
+ hs.append(self.down[i_level].downsample(hs[-1]))
207
+
208
+ # middle
209
+ h = hs[-1]
210
+ h = self.mid.block_1(h)
211
+ h = self.mid.attn_1(h)
212
+ h = self.mid.block_2(h)
213
+ # end
214
+ h = self.norm_out(h)
215
+ h = swish(h)
216
+ h = self.conv_out(h)
217
+ return h
218
+
219
+
220
+ class Decoder(nn.Module):
221
+ def __init__(
222
+ self,
223
+ ch: int,
224
+ out_ch: int,
225
+ ch_mult: list[int],
226
+ num_res_blocks: int,
227
+ in_channels: int,
228
+ resolution: int,
229
+ z_channels: int,
230
+ ):
231
+ super().__init__()
232
+ self.ch = ch
233
+ self.num_resolutions = len(ch_mult)
234
+ self.num_res_blocks = num_res_blocks
235
+ self.resolution = resolution
236
+ self.in_channels = in_channels
237
+ self.ffactor = 2 ** (self.num_resolutions - 1)
238
+
239
+ # compute in_ch_mult, block_in and curr_res at lowest res
240
+ block_in = ch * ch_mult[self.num_resolutions - 1]
241
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
242
+ self.z_shape = (1, z_channels, curr_res, curr_res)
243
+
244
+ # z to block_in
245
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
246
+
247
+ # middle
248
+ self.mid = nn.Module()
249
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
250
+ self.mid.attn_1 = AttnBlock(block_in)
251
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
252
+
253
+ # upsampling
254
+ self.up = nn.ModuleList()
255
+ for i_level in reversed(range(self.num_resolutions)):
256
+ block = nn.ModuleList()
257
+ attn = nn.ModuleList()
258
+ block_out = ch * ch_mult[i_level]
259
+ for _ in range(self.num_res_blocks + 1):
260
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
261
+ block_in = block_out
262
+ up = nn.Module()
263
+ up.block = block
264
+ up.attn = attn
265
+ if i_level != 0:
266
+ up.upsample = Upsample(block_in)
267
+ curr_res = curr_res * 2
268
+ self.up.insert(0, up) # prepend to get consistent order
269
+
270
+ # end
271
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
272
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
273
+
274
+ def forward(self, z: Tensor) -> Tensor:
275
+ # z to block_in
276
+ h = self.conv_in(z)
277
+
278
+ # middle
279
+ h = self.mid.block_1(h)
280
+ h = self.mid.attn_1(h)
281
+ h = self.mid.block_2(h)
282
+
283
+ # upsampling
284
+ for i_level in reversed(range(self.num_resolutions)):
285
+ for i_block in range(self.num_res_blocks + 1):
286
+ h = self.up[i_level].block[i_block](h)
287
+ if len(self.up[i_level].attn) > 0:
288
+ h = self.up[i_level].attn[i_block](h)
289
+ if i_level != 0:
290
+ h = self.up[i_level].upsample(h)
291
+
292
+ # end
293
+ h = self.norm_out(h)
294
+ h = swish(h)
295
+ h = self.conv_out(h)
296
+ return h
297
+
298
+
299
+ class DiagonalGaussian(nn.Module):
300
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
301
+ super().__init__()
302
+ self.sample = sample
303
+ self.chunk_dim = chunk_dim
304
+
305
+ def forward(self, z: Tensor) -> Tensor:
306
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
307
+ if self.sample:
308
+ std = torch.exp(0.5 * logvar)
309
+ return mean + std * torch.randn_like(mean)
310
+ else:
311
+ return mean
312
+
313
+
314
+ class AutoEncoder(nn.Module):
315
+ def __init__(self, params: AutoEncoderParams):
316
+ super().__init__()
317
+ self.encoder = Encoder(
318
+ resolution=params.resolution,
319
+ in_channels=params.in_channels,
320
+ ch=params.ch,
321
+ ch_mult=params.ch_mult,
322
+ num_res_blocks=params.num_res_blocks,
323
+ z_channels=params.z_channels,
324
+ )
325
+ self.decoder = Decoder(
326
+ resolution=params.resolution,
327
+ in_channels=params.in_channels,
328
+ ch=params.ch,
329
+ out_ch=params.out_ch,
330
+ ch_mult=params.ch_mult,
331
+ num_res_blocks=params.num_res_blocks,
332
+ z_channels=params.z_channels,
333
+ )
334
+ self.reg = DiagonalGaussian()
335
+
336
+ self.scale_factor = params.scale_factor
337
+ self.shift_factor = params.shift_factor
338
+
339
+ @property
340
+ def device(self) -> torch.device:
341
+ return next(self.parameters()).device
342
+
343
+ @property
344
+ def dtype(self) -> torch.dtype:
345
+ return next(self.parameters()).dtype
346
+
347
+ def encode(self, x: Tensor) -> Tensor:
348
+ z = self.reg(self.encoder(x))
349
+ z = self.scale_factor * (z - self.shift_factor)
350
+ return z
351
+
352
+ def decode(self, z: Tensor) -> Tensor:
353
+ z = z / self.scale_factor + self.shift_factor
354
+ return self.decoder(z)
355
+
356
+ def forward(self, x: Tensor) -> Tensor:
357
+ return self.decode(self.encode(x))
358
+
359
+
360
+ # endregion
361
+ # region config
362
+
363
+
364
+ @dataclass
365
+ class ModelSpec:
366
+ params: FluxParams
367
+ ae_params: AutoEncoderParams
368
+ ckpt_path: str | None
369
+ ae_path: str | None
370
+ # repo_id: str | None
371
+ # repo_flow: str | None
372
+ # repo_ae: str | None
373
+
374
+
375
+ configs = {
376
+ "dev": ModelSpec(
377
+ # repo_id="black-forest-labs/FLUX.1-dev",
378
+ # repo_flow="flux1-dev.sft",
379
+ # repo_ae="ae.sft",
380
+ ckpt_path=None, # os.getenv("FLUX_DEV"),
381
+ params=FluxParams(
382
+ in_channels=64,
383
+ vec_in_dim=768,
384
+ context_in_dim=4096,
385
+ hidden_size=3072,
386
+ mlp_ratio=4.0,
387
+ num_heads=24,
388
+ depth=19,
389
+ depth_single_blocks=38,
390
+ axes_dim=[16, 56, 56],
391
+ theta=10_000,
392
+ qkv_bias=True,
393
+ guidance_embed=True,
394
+ ),
395
+ ae_path=None, # os.getenv("AE"),
396
+ ae_params=AutoEncoderParams(
397
+ resolution=256,
398
+ in_channels=3,
399
+ ch=128,
400
+ out_ch=3,
401
+ ch_mult=[1, 2, 4, 4],
402
+ num_res_blocks=2,
403
+ z_channels=16,
404
+ scale_factor=0.3611,
405
+ shift_factor=0.1159,
406
+ ),
407
+ ),
408
+ "schnell": ModelSpec(
409
+ # repo_id="black-forest-labs/FLUX.1-schnell",
410
+ # repo_flow="flux1-schnell.sft",
411
+ # repo_ae="ae.sft",
412
+ ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
413
+ params=FluxParams(
414
+ in_channels=64,
415
+ vec_in_dim=768,
416
+ context_in_dim=4096,
417
+ hidden_size=3072,
418
+ mlp_ratio=4.0,
419
+ num_heads=24,
420
+ depth=19,
421
+ depth_single_blocks=38,
422
+ axes_dim=[16, 56, 56],
423
+ theta=10_000,
424
+ qkv_bias=True,
425
+ guidance_embed=False,
426
+ ),
427
+ ae_path=None, # os.getenv("AE"),
428
+ ae_params=AutoEncoderParams(
429
+ resolution=256,
430
+ in_channels=3,
431
+ ch=128,
432
+ out_ch=3,
433
+ ch_mult=[1, 2, 4, 4],
434
+ num_res_blocks=2,
435
+ z_channels=16,
436
+ scale_factor=0.3611,
437
+ shift_factor=0.1159,
438
+ ),
439
+ ),
440
+ }
441
+
442
+
443
+ # endregion
444
+
445
+ # region math
446
+
447
+
448
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
449
+ q, k = apply_rope(q, k, pe)
450
+
451
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
452
+ x = rearrange(x, "B H L D -> B L (H D)")
453
+
454
+ return x
455
+
456
+
457
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
458
+ assert dim % 2 == 0
459
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
460
+ omega = 1.0 / (theta**scale)
461
+ out = torch.einsum("...n,d->...nd", pos, omega)
462
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
463
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
464
+ return out.float()
465
+
466
+
467
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
468
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
469
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
470
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
471
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
472
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
473
+
474
+
475
+ # endregion
476
+
477
+
478
+ # region layers
479
+
480
+
481
+ # for cpu_offload_checkpointing
482
+
483
+
484
+ def to_cuda(x):
485
+ if isinstance(x, torch.Tensor):
486
+ return x.cuda()
487
+ elif isinstance(x, (list, tuple)):
488
+ return [to_cuda(elem) for elem in x]
489
+ elif isinstance(x, dict):
490
+ return {k: to_cuda(v) for k, v in x.items()}
491
+ else:
492
+ return x
493
+
494
+
495
+ def to_cpu(x):
496
+ if isinstance(x, torch.Tensor):
497
+ return x.cpu()
498
+ elif isinstance(x, (list, tuple)):
499
+ return [to_cpu(elem) for elem in x]
500
+ elif isinstance(x, dict):
501
+ return {k: to_cpu(v) for k, v in x.items()}
502
+ else:
503
+ return x
504
+
505
+
506
+ class EmbedND(nn.Module):
507
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
508
+ super().__init__()
509
+ self.dim = dim
510
+ self.theta = theta
511
+ self.axes_dim = axes_dim
512
+
513
+ def forward(self, ids: Tensor) -> Tensor:
514
+ n_axes = ids.shape[-1]
515
+ emb = torch.cat(
516
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
517
+ dim=-3,
518
+ )
519
+
520
+ return emb.unsqueeze(1)
521
+
522
+
523
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
524
+ """
525
+ Create sinusoidal timestep embeddings.
526
+ :param t: a 1-D Tensor of N indices, one per batch element.
527
+ These may be fractional.
528
+ :param dim: the dimension of the output.
529
+ :param max_period: controls the minimum frequency of the embeddings.
530
+ :return: an (N, D) Tensor of positional embeddings.
531
+ """
532
+ t = time_factor * t
533
+ half = dim // 2
534
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
535
+
536
+ args = t[:, None].float() * freqs[None]
537
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
538
+ if dim % 2:
539
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
540
+ if torch.is_floating_point(t):
541
+ embedding = embedding.to(t)
542
+ return embedding
543
+
544
+
545
+ class MLPEmbedder(nn.Module):
546
+ def __init__(self, in_dim: int, hidden_dim: int):
547
+ super().__init__()
548
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
549
+ self.silu = nn.SiLU()
550
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
551
+
552
+ self.gradient_checkpointing = False
553
+
554
+ def enable_gradient_checkpointing(self):
555
+ self.gradient_checkpointing = True
556
+
557
+ def disable_gradient_checkpointing(self):
558
+ self.gradient_checkpointing = False
559
+
560
+ def _forward(self, x: Tensor) -> Tensor:
561
+ return self.out_layer(self.silu(self.in_layer(x)))
562
+
563
+ def forward(self, *args, **kwargs):
564
+ if self.training and self.gradient_checkpointing:
565
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
566
+ else:
567
+ return self._forward(*args, **kwargs)
568
+
569
+ # def forward(self, x):
570
+ # if self.training and self.gradient_checkpointing:
571
+ # def create_custom_forward(func):
572
+ # def custom_forward(*inputs):
573
+ # return func(*inputs)
574
+ # return custom_forward
575
+ # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
576
+ # else:
577
+ # return self._forward(x)
578
+
579
+
580
+ class RMSNorm(torch.nn.Module):
581
+ def __init__(self, dim: int):
582
+ super().__init__()
583
+ self.scale = nn.Parameter(torch.ones(dim))
584
+
585
+ def forward(self, x: Tensor):
586
+ x_dtype = x.dtype
587
+ x = x.float()
588
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
589
+ # return (x * rrms).to(dtype=x_dtype) * self.scale
590
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
591
+
592
+
593
+ class QKNorm(torch.nn.Module):
594
+ def __init__(self, dim: int):
595
+ super().__init__()
596
+ self.query_norm = RMSNorm(dim)
597
+ self.key_norm = RMSNorm(dim)
598
+
599
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
600
+ q = self.query_norm(q)
601
+ k = self.key_norm(k)
602
+ return q.to(v), k.to(v)
603
+
604
+
605
+ class SelfAttention(nn.Module):
606
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
607
+ super().__init__()
608
+ self.num_heads = num_heads
609
+ head_dim = dim // num_heads
610
+
611
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
612
+ self.norm = QKNorm(head_dim)
613
+ self.proj = nn.Linear(dim, dim)
614
+
615
+ # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
616
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
617
+ qkv = self.qkv(x)
618
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
619
+ q, k = self.norm(q, k, v)
620
+ x = attention(q, k, v, pe=pe)
621
+ x = self.proj(x)
622
+ return x
623
+
624
+
625
+ @dataclass
626
+ class ModulationOut:
627
+ shift: Tensor
628
+ scale: Tensor
629
+ gate: Tensor
630
+
631
+
632
+ class Modulation(nn.Module):
633
+ def __init__(self, dim: int, double: bool):
634
+ super().__init__()
635
+ self.is_double = double
636
+ self.multiplier = 6 if double else 3
637
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
638
+
639
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
640
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
641
+
642
+ return (
643
+ ModulationOut(*out[:3]),
644
+ ModulationOut(*out[3:]) if self.is_double else None,
645
+ )
646
+
647
+
648
+ class DoubleStreamBlock(nn.Module):
649
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
650
+ super().__init__()
651
+
652
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
653
+ self.num_heads = num_heads
654
+ self.hidden_size = hidden_size
655
+ self.img_mod = Modulation(hidden_size, double=True)
656
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
657
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
658
+
659
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
660
+ self.img_mlp = nn.Sequential(
661
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
662
+ nn.GELU(approximate="tanh"),
663
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
664
+ )
665
+
666
+ self.txt_mod = Modulation(hidden_size, double=True)
667
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
668
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
669
+
670
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
671
+ self.txt_mlp = nn.Sequential(
672
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
673
+ nn.GELU(approximate="tanh"),
674
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
675
+ )
676
+
677
+ self.gradient_checkpointing = False
678
+ self.cpu_offload_checkpointing = False
679
+
680
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
681
+ self.gradient_checkpointing = True
682
+ self.cpu_offload_checkpointing = cpu_offload
683
+
684
+ def disable_gradient_checkpointing(self):
685
+ self.gradient_checkpointing = False
686
+ self.cpu_offload_checkpointing = False
687
+
688
+ def _forward(
689
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
690
+ ) -> tuple[Tensor, Tensor]:
691
+ img_mod1, img_mod2 = self.img_mod(vec)
692
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
693
+
694
+ # prepare image for attention
695
+ img_modulated = self.img_norm1(img)
696
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
697
+ img_qkv = self.img_attn.qkv(img_modulated)
698
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
699
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
700
+
701
+ # prepare txt for attention
702
+ txt_modulated = self.txt_norm1(txt)
703
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
704
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
705
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
706
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
707
+
708
+ # run actual attention
709
+ q = torch.cat((txt_q, img_q), dim=2)
710
+ k = torch.cat((txt_k, img_k), dim=2)
711
+ v = torch.cat((txt_v, img_v), dim=2)
712
+
713
+ # make attention mask if not None
714
+ attn_mask = None
715
+ if txt_attention_mask is not None:
716
+ # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
717
+ attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
718
+ attn_mask = torch.cat(
719
+ (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
720
+ ) # b, seq_len + img_len
721
+
722
+ # broadcast attn_mask to all heads
723
+ attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
724
+
725
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
726
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
727
+
728
+ # calculate the img blocks
729
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
730
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
731
+
732
+ # calculate the txt blocks
733
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
734
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
735
+ return img, txt
736
+
737
+ def forward(
738
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
739
+ ) -> tuple[Tensor, Tensor]:
740
+ if self.training and self.gradient_checkpointing:
741
+ if not self.cpu_offload_checkpointing:
742
+ return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
743
+ # cpu offload checkpointing
744
+
745
+ def create_custom_forward(func):
746
+ def custom_forward(*inputs):
747
+ cuda_inputs = to_cuda(inputs)
748
+ outputs = func(*cuda_inputs)
749
+ return to_cpu(outputs)
750
+
751
+ return custom_forward
752
+
753
+ return torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
755
+ )
756
+
757
+ else:
758
+ return self._forward(img, txt, vec, pe, txt_attention_mask)
759
+
760
+
761
+ class SingleStreamBlock(nn.Module):
762
+ """
763
+ A DiT block with parallel linear layers as described in
764
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
765
+ """
766
+
767
+ def __init__(
768
+ self,
769
+ hidden_size: int,
770
+ num_heads: int,
771
+ mlp_ratio: float = 4.0,
772
+ qk_scale: float | None = None,
773
+ ):
774
+ super().__init__()
775
+ self.hidden_dim = hidden_size
776
+ self.num_heads = num_heads
777
+ head_dim = hidden_size // num_heads
778
+ self.scale = qk_scale or head_dim**-0.5
779
+
780
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
781
+ # qkv and mlp_in
782
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
783
+ # proj and mlp_out
784
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
785
+
786
+ self.norm = QKNorm(head_dim)
787
+
788
+ self.hidden_size = hidden_size
789
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
790
+
791
+ self.mlp_act = nn.GELU(approximate="tanh")
792
+ self.modulation = Modulation(hidden_size, double=False)
793
+
794
+ self.gradient_checkpointing = False
795
+ self.cpu_offload_checkpointing = False
796
+
797
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
798
+ self.gradient_checkpointing = True
799
+ self.cpu_offload_checkpointing = cpu_offload
800
+
801
+ def disable_gradient_checkpointing(self):
802
+ self.gradient_checkpointing = False
803
+ self.cpu_offload_checkpointing = False
804
+
805
+ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
806
+ mod, _ = self.modulation(vec)
807
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
808
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
809
+
810
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
811
+ q, k = self.norm(q, k, v)
812
+
813
+ # make attention mask if not None
814
+ attn_mask = None
815
+ if txt_attention_mask is not None:
816
+ # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
817
+ attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
818
+ attn_mask = torch.cat(
819
+ (
820
+ attn_mask,
821
+ torch.ones(
822
+ attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
823
+ ),
824
+ ),
825
+ dim=1,
826
+ ) # b, seq_len + img_len = x_len
827
+
828
+ # broadcast attn_mask to all heads
829
+ attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
830
+
831
+ # compute attention
832
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
833
+
834
+ # compute activation in mlp stream, cat again and run second linear layer
835
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
836
+ return x + mod.gate * output
837
+
838
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
839
+ if self.training and self.gradient_checkpointing:
840
+ if not self.cpu_offload_checkpointing:
841
+ return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
842
+
843
+ # cpu offload checkpointing
844
+
845
+ def create_custom_forward(func):
846
+ def custom_forward(*inputs):
847
+ cuda_inputs = to_cuda(inputs)
848
+ outputs = func(*cuda_inputs)
849
+ return to_cpu(outputs)
850
+
851
+ return custom_forward
852
+
853
+ return torch.utils.checkpoint.checkpoint(
854
+ create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
855
+ )
856
+ else:
857
+ return self._forward(x, vec, pe, txt_attention_mask)
858
+
859
+
860
+ class LastLayer(nn.Module):
861
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
862
+ super().__init__()
863
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
864
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
865
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
866
+
867
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
868
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
869
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
870
+ x = self.linear(x)
871
+ return x
872
+
873
+
874
+ # endregion
875
+
876
+
877
+ class Flux(nn.Module):
878
+ """
879
+ Transformer model for flow matching on sequences.
880
+ """
881
+
882
+ def __init__(self, params: FluxParams):
883
+ super().__init__()
884
+
885
+ self.params = params
886
+ self.in_channels = params.in_channels
887
+ self.out_channels = self.in_channels
888
+ if params.hidden_size % params.num_heads != 0:
889
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
890
+ pe_dim = params.hidden_size // params.num_heads
891
+ if sum(params.axes_dim) != pe_dim:
892
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
893
+ self.hidden_size = params.hidden_size
894
+ self.num_heads = params.num_heads
895
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
896
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
897
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
898
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
899
+ self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
900
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
901
+
902
+ self.double_blocks = nn.ModuleList(
903
+ [
904
+ DoubleStreamBlock(
905
+ self.hidden_size,
906
+ self.num_heads,
907
+ mlp_ratio=params.mlp_ratio,
908
+ qkv_bias=params.qkv_bias,
909
+ )
910
+ for _ in range(params.depth)
911
+ ]
912
+ )
913
+
914
+ self.single_blocks = nn.ModuleList(
915
+ [
916
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
917
+ for _ in range(params.depth_single_blocks)
918
+ ]
919
+ )
920
+
921
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
922
+
923
+ self.gradient_checkpointing = False
924
+ self.cpu_offload_checkpointing = False
925
+ self.blocks_to_swap = None
926
+
927
+ self.offloader_double = None
928
+ self.offloader_single = None
929
+ self.num_double_blocks = len(self.double_blocks)
930
+ self.num_single_blocks = len(self.single_blocks)
931
+
932
+ @property
933
+ def device(self):
934
+ return next(self.parameters()).device
935
+
936
+ @property
937
+ def dtype(self):
938
+ return next(self.parameters()).dtype
939
+
940
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
941
+ self.gradient_checkpointing = True
942
+ self.cpu_offload_checkpointing = cpu_offload
943
+
944
+ self.time_in.enable_gradient_checkpointing()
945
+ self.vector_in.enable_gradient_checkpointing()
946
+ if self.guidance_in.__class__ != nn.Identity:
947
+ self.guidance_in.enable_gradient_checkpointing()
948
+
949
+ for block in self.double_blocks + self.single_blocks:
950
+ block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
951
+
952
+ print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
953
+
954
+ def disable_gradient_checkpointing(self):
955
+ self.gradient_checkpointing = False
956
+ self.cpu_offload_checkpointing = False
957
+
958
+ self.time_in.disable_gradient_checkpointing()
959
+ self.vector_in.disable_gradient_checkpointing()
960
+ if self.guidance_in.__class__ != nn.Identity:
961
+ self.guidance_in.disable_gradient_checkpointing()
962
+
963
+ for block in self.double_blocks + self.single_blocks:
964
+ block.disable_gradient_checkpointing()
965
+
966
+ print("FLUX: Gradient checkpointing disabled.")
967
+
968
+ def enable_block_swap(self, num_blocks: int, device: torch.device):
969
+ self.blocks_to_swap = num_blocks
970
+ double_blocks_to_swap = num_blocks // 2
971
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
972
+
973
+ assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
974
+ f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
975
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
976
+ )
977
+
978
+ self.offloader_double = custom_offloading_utils.ModelOffloader(
979
+ self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
980
+ )
981
+ self.offloader_single = custom_offloading_utils.ModelOffloader(
982
+ self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
983
+ )
984
+ print(
985
+ f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
986
+ )
987
+
988
+ def move_to_device_except_swap_blocks(self, device: torch.device):
989
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
990
+ if self.blocks_to_swap:
991
+ save_double_blocks = self.double_blocks
992
+ save_single_blocks = self.single_blocks
993
+ self.double_blocks = None
994
+ self.single_blocks = None
995
+
996
+ self.to(device)
997
+
998
+ if self.blocks_to_swap:
999
+ self.double_blocks = save_double_blocks
1000
+ self.single_blocks = save_single_blocks
1001
+
1002
+ def prepare_block_swap_before_forward(self):
1003
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
1004
+ return
1005
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
1006
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
1007
+
1008
+ def forward(
1009
+ self,
1010
+ img: Tensor,
1011
+ img_ids: Tensor,
1012
+ txt: Tensor,
1013
+ txt_ids: Tensor,
1014
+ timesteps: Tensor,
1015
+ y: Tensor,
1016
+ guidance: Tensor | None = None,
1017
+ txt_attention_mask: Tensor | None = None,
1018
+ ) -> Tensor:
1019
+ if img.ndim != 3 or txt.ndim != 3:
1020
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
1021
+
1022
+ # running on sequences img
1023
+ img = self.img_in(img)
1024
+ vec = self.time_in(timestep_embedding(timesteps, 256))
1025
+ if self.params.guidance_embed:
1026
+ if guidance is None:
1027
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
1028
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
1029
+ vec = vec + self.vector_in(y)
1030
+ txt = self.txt_in(txt)
1031
+
1032
+ ids = torch.cat((txt_ids, img_ids), dim=1)
1033
+ pe = self.pe_embedder(ids)
1034
+
1035
+ if not self.blocks_to_swap:
1036
+ for block in self.double_blocks:
1037
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1038
+ img = torch.cat((txt, img), 1)
1039
+ for block in self.single_blocks:
1040
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1041
+ else:
1042
+ for block_idx, block in enumerate(self.double_blocks):
1043
+ self.offloader_double.wait_for_block(block_idx)
1044
+
1045
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1046
+
1047
+ self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
1048
+
1049
+ img = torch.cat((txt, img), 1)
1050
+
1051
+ for block_idx, block in enumerate(self.single_blocks):
1052
+ self.offloader_single.wait_for_block(block_idx)
1053
+
1054
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1055
+
1056
+ self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
1057
+
1058
+ img = img[:, txt.shape[1] :, ...]
1059
+
1060
+ if self.training and self.cpu_offload_checkpointing:
1061
+ img = img.to(self.device)
1062
+ vec = vec.to(self.device)
1063
+
1064
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
1065
+
1066
+ return img
1067
+
1068
+
1069
+ """
1070
+ class FluxUpper(nn.Module):
1071
+ ""
1072
+ Transformer model for flow matching on sequences.
1073
+ ""
1074
+
1075
+ def __init__(self, params: FluxParams):
1076
+ super().__init__()
1077
+
1078
+ self.params = params
1079
+ self.in_channels = params.in_channels
1080
+ self.out_channels = self.in_channels
1081
+ if params.hidden_size % params.num_heads != 0:
1082
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
1083
+ pe_dim = params.hidden_size // params.num_heads
1084
+ if sum(params.axes_dim) != pe_dim:
1085
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
1086
+ self.hidden_size = params.hidden_size
1087
+ self.num_heads = params.num_heads
1088
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
1089
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
1090
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
1091
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
1092
+ self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
1093
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
1094
+
1095
+ self.double_blocks = nn.ModuleList(
1096
+ [
1097
+ DoubleStreamBlock(
1098
+ self.hidden_size,
1099
+ self.num_heads,
1100
+ mlp_ratio=params.mlp_ratio,
1101
+ qkv_bias=params.qkv_bias,
1102
+ )
1103
+ for _ in range(params.depth)
1104
+ ]
1105
+ )
1106
+
1107
+ self.gradient_checkpointing = False
1108
+
1109
+ @property
1110
+ def device(self):
1111
+ return next(self.parameters()).device
1112
+
1113
+ @property
1114
+ def dtype(self):
1115
+ return next(self.parameters()).dtype
1116
+
1117
+ def enable_gradient_checkpointing(self):
1118
+ self.gradient_checkpointing = True
1119
+
1120
+ self.time_in.enable_gradient_checkpointing()
1121
+ self.vector_in.enable_gradient_checkpointing()
1122
+ if self.guidance_in.__class__ != nn.Identity:
1123
+ self.guidance_in.enable_gradient_checkpointing()
1124
+
1125
+ for block in self.double_blocks:
1126
+ block.enable_gradient_checkpointing()
1127
+
1128
+ print("FLUX: Gradient checkpointing enabled.")
1129
+
1130
+ def disable_gradient_checkpointing(self):
1131
+ self.gradient_checkpointing = False
1132
+
1133
+ self.time_in.disable_gradient_checkpointing()
1134
+ self.vector_in.disable_gradient_checkpointing()
1135
+ if self.guidance_in.__class__ != nn.Identity:
1136
+ self.guidance_in.disable_gradient_checkpointing()
1137
+
1138
+ for block in self.double_blocks:
1139
+ block.disable_gradient_checkpointing()
1140
+
1141
+ print("FLUX: Gradient checkpointing disabled.")
1142
+
1143
+ def forward(
1144
+ self,
1145
+ img: Tensor,
1146
+ img_ids: Tensor,
1147
+ txt: Tensor,
1148
+ txt_ids: Tensor,
1149
+ timesteps: Tensor,
1150
+ y: Tensor,
1151
+ guidance: Tensor | None = None,
1152
+ txt_attention_mask: Tensor | None = None,
1153
+ ) -> Tensor:
1154
+ if img.ndim != 3 or txt.ndim != 3:
1155
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
1156
+
1157
+ # running on sequences img
1158
+ img = self.img_in(img)
1159
+ vec = self.time_in(timestep_embedding(timesteps, 256))
1160
+ if self.params.guidance_embed:
1161
+ if guidance is None:
1162
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
1163
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
1164
+ vec = vec + self.vector_in(y)
1165
+ txt = self.txt_in(txt)
1166
+
1167
+ ids = torch.cat((txt_ids, img_ids), dim=1)
1168
+ pe = self.pe_embedder(ids)
1169
+
1170
+ for block in self.double_blocks:
1171
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1172
+
1173
+ return img, txt, vec, pe
1174
+
1175
+
1176
+ class FluxLower(nn.Module):
1177
+ ""
1178
+ Transformer model for flow matching on sequences.
1179
+ ""
1180
+
1181
+ def __init__(self, params: FluxParams):
1182
+ super().__init__()
1183
+ self.hidden_size = params.hidden_size
1184
+ self.num_heads = params.num_heads
1185
+ self.out_channels = params.in_channels
1186
+
1187
+ self.single_blocks = nn.ModuleList(
1188
+ [
1189
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
1190
+ for _ in range(params.depth_single_blocks)
1191
+ ]
1192
+ )
1193
+
1194
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
1195
+
1196
+ self.gradient_checkpointing = False
1197
+
1198
+ @property
1199
+ def device(self):
1200
+ return next(self.parameters()).device
1201
+
1202
+ @property
1203
+ def dtype(self):
1204
+ return next(self.parameters()).dtype
1205
+
1206
+ def enable_gradient_checkpointing(self):
1207
+ self.gradient_checkpointing = True
1208
+
1209
+ for block in self.single_blocks:
1210
+ block.enable_gradient_checkpointing()
1211
+
1212
+ print("FLUX: Gradient checkpointing enabled.")
1213
+
1214
+ def disable_gradient_checkpointing(self):
1215
+ self.gradient_checkpointing = False
1216
+
1217
+ for block in self.single_blocks:
1218
+ block.disable_gradient_checkpointing()
1219
+
1220
+ print("FLUX: Gradient checkpointing disabled.")
1221
+
1222
+ def forward(
1223
+ self,
1224
+ img: Tensor,
1225
+ txt: Tensor,
1226
+ vec: Tensor | None = None,
1227
+ pe: Tensor | None = None,
1228
+ txt_attention_mask: Tensor | None = None,
1229
+ ) -> Tensor:
1230
+ img = torch.cat((txt, img), 1)
1231
+ for block in self.single_blocks:
1232
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1233
+ img = img[:, txt.shape[1] :, ...]
1234
+
1235
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
1236
+ return img
1237
+ """
library/flux_train_utils.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import numpy as np
5
+ import toml
6
+ import json
7
+ import time
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from accelerate import Accelerator, PartialState
12
+ from transformers import CLIPTextModel
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+ from safetensors.torch import save_file
16
+
17
+ from library import flux_models, flux_utils, strategy_base, train_util
18
+ from library.device_utils import init_ipex, clean_memory_on_device
19
+
20
+ init_ipex()
21
+
22
+ from .utils import setup_logging, mem_eff_save_file
23
+
24
+ setup_logging()
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ # region sample images
31
+
32
+
33
+ def sample_images(
34
+ accelerator: Accelerator,
35
+ args: argparse.Namespace,
36
+ epoch,
37
+ steps,
38
+ flux,
39
+ ae,
40
+ text_encoders,
41
+ sample_prompts_te_outputs,
42
+ prompt_replacement=None,
43
+ ):
44
+ if steps == 0:
45
+ if not args.sample_at_first:
46
+ return
47
+ else:
48
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
49
+ return
50
+ if args.sample_every_n_epochs is not None:
51
+ # sample_every_n_steps は無視する
52
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
53
+ return
54
+ else:
55
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
56
+ return
57
+
58
+ logger.info("")
59
+ logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
60
+ if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
61
+ logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
62
+ return
63
+
64
+ distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
65
+
66
+ # unwrap unet and text_encoder(s)
67
+ flux = accelerator.unwrap_model(flux)
68
+ if text_encoders is not None:
69
+ text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
70
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
71
+
72
+ prompts = train_util.load_prompts(args.sample_prompts)
73
+
74
+ save_dir = args.output_dir + "/sample"
75
+ os.makedirs(save_dir, exist_ok=True)
76
+
77
+ # save random state to restore later
78
+ rng_state = torch.get_rng_state()
79
+ cuda_rng_state = None
80
+ try:
81
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
82
+ except Exception:
83
+ pass
84
+
85
+ if distributed_state.num_processes <= 1:
86
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
87
+ with torch.no_grad(), accelerator.autocast():
88
+ for prompt_dict in prompts:
89
+ sample_image_inference(
90
+ accelerator,
91
+ args,
92
+ flux,
93
+ text_encoders,
94
+ ae,
95
+ save_dir,
96
+ prompt_dict,
97
+ epoch,
98
+ steps,
99
+ sample_prompts_te_outputs,
100
+ prompt_replacement,
101
+ )
102
+ else:
103
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
104
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
105
+ per_process_prompts = [] # list of lists
106
+ for i in range(distributed_state.num_processes):
107
+ per_process_prompts.append(prompts[i :: distributed_state.num_processes])
108
+
109
+ with torch.no_grad():
110
+ with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
111
+ for prompt_dict in prompt_dict_lists[0]:
112
+ sample_image_inference(
113
+ accelerator,
114
+ args,
115
+ flux,
116
+ text_encoders,
117
+ ae,
118
+ save_dir,
119
+ prompt_dict,
120
+ epoch,
121
+ steps,
122
+ sample_prompts_te_outputs,
123
+ prompt_replacement,
124
+ )
125
+
126
+ torch.set_rng_state(rng_state)
127
+ if cuda_rng_state is not None:
128
+ torch.cuda.set_rng_state(cuda_rng_state)
129
+
130
+ clean_memory_on_device(accelerator.device)
131
+
132
+
133
+ def sample_image_inference(
134
+ accelerator: Accelerator,
135
+ args: argparse.Namespace,
136
+ flux: flux_models.Flux,
137
+ text_encoders: Optional[List[CLIPTextModel]],
138
+ ae: flux_models.AutoEncoder,
139
+ save_dir,
140
+ prompt_dict,
141
+ epoch,
142
+ steps,
143
+ sample_prompts_te_outputs,
144
+ prompt_replacement,
145
+ ):
146
+ assert isinstance(prompt_dict, dict)
147
+ # negative_prompt = prompt_dict.get("negative_prompt")
148
+ sample_steps = prompt_dict.get("sample_steps", 20)
149
+ width = prompt_dict.get("width", 512)
150
+ height = prompt_dict.get("height", 512)
151
+ scale = prompt_dict.get("scale", 3.5)
152
+ seed = prompt_dict.get("seed")
153
+ # controlnet_image = prompt_dict.get("controlnet_image")
154
+ prompt: str = prompt_dict.get("prompt", "")
155
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
156
+
157
+ if prompt_replacement is not None:
158
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
159
+ # if negative_prompt is not None:
160
+ # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
161
+
162
+ if seed is not None:
163
+ torch.manual_seed(seed)
164
+ torch.cuda.manual_seed(seed)
165
+ else:
166
+ # True random sample image generation
167
+ torch.seed()
168
+ torch.cuda.seed()
169
+
170
+ # if negative_prompt is None:
171
+ # negative_prompt = ""
172
+
173
+ height = max(64, height - height % 16) # round to divisible by 16
174
+ width = max(64, width - width % 16) # round to divisible by 16
175
+ logger.info(f"prompt: {prompt}")
176
+ # logger.info(f"negative_prompt: {negative_prompt}")
177
+ logger.info(f"height: {height}")
178
+ logger.info(f"width: {width}")
179
+ logger.info(f"sample_steps: {sample_steps}")
180
+ logger.info(f"scale: {scale}")
181
+ # logger.info(f"sample_sampler: {sampler_name}")
182
+ if seed is not None:
183
+ logger.info(f"seed: {seed}")
184
+
185
+ # encode prompts
186
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
187
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
188
+
189
+ text_encoder_conds = []
190
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
191
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
192
+ print(f"Using cached text encoder outputs for prompt: {prompt}")
193
+ if text_encoders is not None:
194
+ print(f"Encoding prompt: {prompt}")
195
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
196
+ # strategy has apply_t5_attn_mask option
197
+ encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
198
+
199
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
200
+ if len(text_encoder_conds) == 0:
201
+ text_encoder_conds = encoded_text_encoder_conds
202
+ else:
203
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
204
+ for i in range(len(encoded_text_encoder_conds)):
205
+ if encoded_text_encoder_conds[i] is not None:
206
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
207
+
208
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
209
+
210
+ # sample image
211
+ weight_dtype = ae.dtype # TOFO give dtype as argument
212
+ packed_latent_height = height // 16
213
+ packed_latent_width = width // 16
214
+ noise = torch.randn(
215
+ 1,
216
+ packed_latent_height * packed_latent_width,
217
+ 16 * 2 * 2,
218
+ device=accelerator.device,
219
+ dtype=weight_dtype,
220
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
221
+ )
222
+ timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
223
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
224
+ t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
225
+
226
+ with accelerator.autocast(), torch.no_grad():
227
+ x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
228
+
229
+ x = x.float()
230
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
231
+
232
+ # latent to image
233
+ clean_memory_on_device(accelerator.device)
234
+ org_vae_device = ae.device # will be on cpu
235
+ ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
236
+ with accelerator.autocast(), torch.no_grad():
237
+ x = ae.decode(x)
238
+ ae.to(org_vae_device)
239
+ clean_memory_on_device(accelerator.device)
240
+
241
+ x = x.clamp(-1, 1)
242
+ x = x.permute(0, 2, 3, 1)
243
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
244
+
245
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
246
+ # but adding 'enum' to the filename should be enough
247
+
248
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
249
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
250
+ seed_suffix = "" if seed is None else f"_{seed}"
251
+ i: int = prompt_dict["enum"]
252
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
253
+ image.save(os.path.join(save_dir, img_filename))
254
+
255
+ # send images to wandb if enabled
256
+ if "wandb" in [tracker.name for tracker in accelerator.trackers]:
257
+ wandb_tracker = accelerator.get_tracker("wandb")
258
+
259
+ import wandb
260
+
261
+ # not to commit images to avoid inconsistency between training and logging steps
262
+ wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
263
+
264
+
265
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
266
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
267
+
268
+
269
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
270
+ m = (y2 - y1) / (x2 - x1)
271
+ b = y1 - m * x1
272
+ return lambda x: m * x + b
273
+
274
+
275
+ def get_schedule(
276
+ num_steps: int,
277
+ image_seq_len: int,
278
+ base_shift: float = 0.5,
279
+ max_shift: float = 1.15,
280
+ shift: bool = True,
281
+ ) -> list[float]:
282
+ # extra step for zero
283
+ timesteps = torch.linspace(1, 0, num_steps + 1)
284
+
285
+ # shifting the schedule to favor high timesteps for higher signal images
286
+ if shift:
287
+ # eastimate mu based on linear estimation between two points
288
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
289
+ timesteps = time_shift(mu, 1.0, timesteps)
290
+
291
+ return timesteps.tolist()
292
+
293
+
294
+ def denoise(
295
+ model: flux_models.Flux,
296
+ img: torch.Tensor,
297
+ img_ids: torch.Tensor,
298
+ txt: torch.Tensor,
299
+ txt_ids: torch.Tensor,
300
+ vec: torch.Tensor,
301
+ timesteps: list[float],
302
+ guidance: float = 4.0,
303
+ t5_attn_mask: Optional[torch.Tensor] = None,
304
+ ):
305
+ # this is ignored for schnell
306
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
307
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
308
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
309
+ model.prepare_block_swap_before_forward()
310
+ pred = model(
311
+ img=img,
312
+ img_ids=img_ids,
313
+ txt=txt,
314
+ txt_ids=txt_ids,
315
+ y=vec,
316
+ timesteps=t_vec,
317
+ guidance=guidance_vec,
318
+ txt_attention_mask=t5_attn_mask,
319
+ )
320
+
321
+ img = img + (t_prev - t_curr) * pred
322
+
323
+ model.prepare_block_swap_before_forward()
324
+ return img
325
+
326
+
327
+ # endregion
328
+
329
+
330
+ # region train
331
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
332
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
333
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
334
+ timesteps = timesteps.to(device)
335
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
336
+
337
+ sigma = sigmas[step_indices].flatten()
338
+ while len(sigma.shape) < n_dim:
339
+ sigma = sigma.unsqueeze(-1)
340
+ return sigma
341
+
342
+
343
+ def compute_density_for_timestep_sampling(
344
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
345
+ ):
346
+ """Compute the density for sampling the timesteps when doing SD3 training.
347
+
348
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
349
+
350
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
351
+ """
352
+ if weighting_scheme == "logit_normal":
353
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
354
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
355
+ u = torch.nn.functional.sigmoid(u)
356
+ elif weighting_scheme == "mode":
357
+ u = torch.rand(size=(batch_size,), device="cpu")
358
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
359
+ else:
360
+ u = torch.rand(size=(batch_size,), device="cpu")
361
+ return u
362
+
363
+
364
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
365
+ """Computes loss weighting scheme for SD3 training.
366
+
367
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
368
+
369
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
370
+ """
371
+ if weighting_scheme == "sigma_sqrt":
372
+ weighting = (sigmas**-2.0).float()
373
+ elif weighting_scheme == "cosmap":
374
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
375
+ weighting = 2 / (math.pi * bot)
376
+ else:
377
+ weighting = torch.ones_like(sigmas)
378
+ return weighting
379
+
380
+
381
+ def get_noisy_model_input_and_timesteps(
382
+ args, noise_scheduler, latents, noise, device, dtype
383
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
384
+ bsz, _, h, w = latents.shape
385
+ sigmas = None
386
+
387
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
388
+ # Simple random t-based noise sampling
389
+ if args.timestep_sampling == "sigmoid":
390
+ # https://github.com/XLabs-AI/x-flux/tree/main
391
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
392
+ else:
393
+ t = torch.rand((bsz,), device=device)
394
+
395
+ timesteps = t * 1000.0
396
+ t = t.view(-1, 1, 1, 1)
397
+ noisy_model_input = (1 - t) * latents + t * noise
398
+ elif args.timestep_sampling == "shift":
399
+ shift = args.discrete_flow_shift
400
+ logits_norm = torch.randn(bsz, device=device)
401
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
402
+ timesteps = logits_norm.sigmoid()
403
+ timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
404
+
405
+ t = timesteps.view(-1, 1, 1, 1)
406
+ timesteps = timesteps * 1000.0
407
+ noisy_model_input = (1 - t) * latents + t * noise
408
+ elif args.timestep_sampling == "flux_shift":
409
+ logits_norm = torch.randn(bsz, device=device)
410
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
411
+ timesteps = logits_norm.sigmoid()
412
+ mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
413
+ timesteps = time_shift(mu, 1.0, timesteps)
414
+
415
+ t = timesteps.view(-1, 1, 1, 1)
416
+ timesteps = timesteps * 1000.0
417
+ noisy_model_input = (1 - t) * latents + t * noise
418
+ else:
419
+ # Sample a random timestep for each image
420
+ # for weighting schemes where we sample timesteps non-uniformly
421
+ u = compute_density_for_timestep_sampling(
422
+ weighting_scheme=args.weighting_scheme,
423
+ batch_size=bsz,
424
+ logit_mean=args.logit_mean,
425
+ logit_std=args.logit_std,
426
+ mode_scale=args.mode_scale,
427
+ )
428
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
429
+ timesteps = noise_scheduler.timesteps[indices].to(device=device)
430
+
431
+ # Add noise according to flow matching.
432
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
433
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
434
+
435
+ return noisy_model_input, timesteps, sigmas
436
+
437
+
438
+ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
439
+ weighting = None
440
+ if args.model_prediction_type == "raw":
441
+ pass
442
+ elif args.model_prediction_type == "additive":
443
+ # add the model_pred to the noisy_model_input
444
+ model_pred = model_pred + noisy_model_input
445
+ elif args.model_prediction_type == "sigma_scaled":
446
+ # apply sigma scaling
447
+ model_pred = model_pred * (-sigmas) + noisy_model_input
448
+
449
+ # these weighting schemes use a uniform timestep sampling
450
+ # and instead post-weight the loss
451
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
452
+
453
+ return model_pred, weighting
454
+
455
+
456
+ def save_models(
457
+ ckpt_path: str,
458
+ flux: flux_models.Flux,
459
+ sai_metadata: Optional[dict],
460
+ save_dtype: Optional[torch.dtype] = None,
461
+ use_mem_eff_save: bool = False,
462
+ ):
463
+ state_dict = {}
464
+
465
+ def update_sd(prefix, sd):
466
+ for k, v in sd.items():
467
+ key = prefix + k
468
+ if save_dtype is not None and v.dtype != save_dtype:
469
+ v = v.detach().clone().to("cpu").to(save_dtype)
470
+ state_dict[key] = v
471
+
472
+ update_sd("", flux.state_dict())
473
+
474
+ if not use_mem_eff_save:
475
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
476
+ else:
477
+ mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
478
+
479
+
480
+ def save_flux_model_on_train_end(
481
+ args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
482
+ ):
483
+ def sd_saver(ckpt_file, epoch_no, global_step):
484
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
485
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
486
+
487
+ train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
488
+
489
+
490
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
491
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
492
+ def save_flux_model_on_epoch_end_or_stepwise(
493
+ args: argparse.Namespace,
494
+ on_epoch_end: bool,
495
+ accelerator,
496
+ save_dtype: torch.dtype,
497
+ epoch: int,
498
+ num_train_epochs: int,
499
+ global_step: int,
500
+ flux: flux_models.Flux,
501
+ ):
502
+ def sd_saver(ckpt_file, epoch_no, global_step):
503
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
504
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
505
+
506
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
507
+ args,
508
+ on_epoch_end,
509
+ accelerator,
510
+ True,
511
+ True,
512
+ epoch,
513
+ num_train_epochs,
514
+ global_step,
515
+ sd_saver,
516
+ None,
517
+ )
518
+
519
+
520
+ # endregion
521
+
522
+
523
+ def add_flux_train_arguments(parser: argparse.ArgumentParser):
524
+ parser.add_argument(
525
+ "--clip_l",
526
+ type=str,
527
+ help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
528
+ )
529
+ parser.add_argument(
530
+ "--t5xxl",
531
+ type=str,
532
+ help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
533
+ )
534
+ parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
535
+ parser.add_argument(
536
+ "--t5xxl_max_token_length",
537
+ type=int,
538
+ default=None,
539
+ help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
540
+ " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
541
+ )
542
+ parser.add_argument(
543
+ "--apply_t5_attn_mask",
544
+ action="store_true",
545
+ help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
546
+ )
547
+
548
+ parser.add_argument(
549
+ "--guidance_scale",
550
+ type=float,
551
+ default=3.5,
552
+ help="the FLUX.1 dev variant is a guidance distilled model",
553
+ )
554
+
555
+ parser.add_argument(
556
+ "--timestep_sampling",
557
+ choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
558
+ default="sigma",
559
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
560
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
561
+ )
562
+ parser.add_argument(
563
+ "--sigmoid_scale",
564
+ type=float,
565
+ default=1.0,
566
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
567
+ )
568
+ parser.add_argument(
569
+ "--model_prediction_type",
570
+ choices=["raw", "additive", "sigma_scaled"],
571
+ default="sigma_scaled",
572
+ help="How to interpret and process the model prediction: "
573
+ "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
574
+ " / モデル予測の解釈と処理方法:"
575
+ "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
576
+ )
577
+ parser.add_argument(
578
+ "--discrete_flow_shift",
579
+ type=float,
580
+ default=3.0,
581
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
582
+ )
library/flux_train_utils_recraft.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import numpy as np
5
+ import toml
6
+ import json
7
+ import time
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+ import pdb
10
+
11
+ import torch
12
+ from accelerate import Accelerator, PartialState
13
+ from transformers import CLIPTextModel
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from safetensors.torch import save_file
17
+
18
+ from library import flux_models, flux_utils, strategy_base, train_util
19
+ from library.device_utils import init_ipex, clean_memory_on_device
20
+
21
+ init_ipex()
22
+
23
+ from .utils import setup_logging, mem_eff_save_file
24
+
25
+ setup_logging()
26
+ import logging
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # region sample images
32
+
33
+ def sample_images(
34
+ accelerator: Accelerator,
35
+ args: argparse.Namespace,
36
+ epoch,
37
+ steps,
38
+ flux,
39
+ ae,
40
+ text_encoders,
41
+ sample_prompts_te_outputs,
42
+ prompt_replacement=None,
43
+ sample_images_ae_outputs=None
44
+ ):
45
+ if steps == 0:
46
+ if not args.sample_at_first:
47
+ return
48
+ else:
49
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
50
+ return
51
+ if args.sample_every_n_epochs is not None:
52
+ # sample_every_n_steps は無視する
53
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
54
+ return
55
+ else:
56
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
57
+ return
58
+
59
+ logger.info("")
60
+ logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
61
+ if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
62
+ logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
63
+ return
64
+
65
+ distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
66
+
67
+ # unwrap unet and text_encoder(s)
68
+ flux = accelerator.unwrap_model(flux)
69
+ if text_encoders is not None:
70
+ text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
71
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
72
+
73
+ prompts = train_util.load_prompts(args.sample_prompts)
74
+
75
+ save_dir = args.output_dir + "/sample"
76
+ os.makedirs(save_dir, exist_ok=True)
77
+
78
+ # save random state to restore later
79
+ rng_state = torch.get_rng_state()
80
+ cuda_rng_state = None
81
+ try:
82
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
83
+ except Exception:
84
+ pass
85
+
86
+ if distributed_state.num_processes <= 1:
87
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
88
+ with torch.no_grad(), accelerator.autocast():
89
+ for prompt_dict in prompts:
90
+ sample_image_inference(
91
+ accelerator,
92
+ args,
93
+ flux,
94
+ text_encoders,
95
+ ae,
96
+ save_dir,
97
+ prompt_dict,
98
+ epoch,
99
+ steps,
100
+ sample_prompts_te_outputs,
101
+ prompt_replacement,
102
+ sample_images_ae_outputs
103
+ )
104
+ else:
105
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
106
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
107
+ per_process_prompts = [] # list of lists
108
+ for i in range(distributed_state.num_processes):
109
+ per_process_prompts.append(prompts[i :: distributed_state.num_processes])
110
+
111
+ with torch.no_grad():
112
+ with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
113
+ for prompt_dict in prompt_dict_lists[0]:
114
+ sample_image_inference(
115
+ accelerator,
116
+ args,
117
+ flux,
118
+ text_encoders,
119
+ ae,
120
+ save_dir,
121
+ prompt_dict,
122
+ epoch,
123
+ steps,
124
+ sample_prompts_te_outputs,
125
+ prompt_replacement,
126
+ sample_images_ae_outputs
127
+ )
128
+
129
+ torch.set_rng_state(rng_state)
130
+ if cuda_rng_state is not None:
131
+ torch.cuda.set_rng_state(cuda_rng_state)
132
+
133
+ clean_memory_on_device(accelerator.device)
134
+
135
+
136
+ def sample_image_inference(
137
+ accelerator: Accelerator,
138
+ args: argparse.Namespace,
139
+ flux: flux_models.Flux,
140
+ text_encoders: Optional[List[CLIPTextModel]],
141
+ ae: flux_models.AutoEncoder,
142
+ save_dir,
143
+ prompt_dict,
144
+ epoch,
145
+ steps,
146
+ sample_prompts_te_outputs,
147
+ prompt_replacement,
148
+ sample_images_ae_outputs
149
+ ):
150
+ assert isinstance(prompt_dict, dict)
151
+ # negative_prompt = prompt_dict.get("negative_prompt")
152
+ sample_steps = prompt_dict.get("sample_steps", 20)
153
+ width = prompt_dict.get("width", 1024) if args.frame_num==4 else prompt_dict.get("width", 1056)
154
+ height = prompt_dict.get("height", 1024) if args.frame_num==4 else prompt_dict.get("height", 1056)
155
+ scale = prompt_dict.get("scale", 1.0)
156
+ seed = prompt_dict.get("seed")
157
+ # controlnet_image = prompt_dict.get("controlnet_image")
158
+ prompt: str = prompt_dict.get("prompt", "")
159
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
160
+
161
+ if prompt_replacement is not None:
162
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
163
+ # if negative_prompt is not None:
164
+ # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
165
+
166
+ if seed is not None:
167
+ torch.manual_seed(seed)
168
+ torch.cuda.manual_seed(seed)
169
+ else:
170
+ # True random sample image generation
171
+ torch.seed()
172
+ torch.cuda.seed()
173
+
174
+ # if negative_prompt is None:
175
+ # negative_prompt = ""
176
+
177
+ height = max(64, height - height % 16) # round to divisible by 16
178
+ width = max(64, width - width % 16) # round to divisible by 16
179
+ logger.info(f"prompt: {prompt}")
180
+ # logger.info(f"negative_prompt: {negative_prompt}")
181
+ logger.info(f"height: {height}")
182
+ logger.info(f"width: {width}")
183
+ logger.info(f"sample_steps: {sample_steps}")
184
+ logger.info(f"scale: {scale}")
185
+ # logger.info(f"sample_sampler: {sampler_name}")
186
+ if seed is not None:
187
+ logger.info(f"seed: {seed}")
188
+
189
+ # encode prompts
190
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
191
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
192
+
193
+ text_encoder_conds = []
194
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
195
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
196
+ print(f"Using cached text encoder outputs for prompt: {prompt}")
197
+ if text_encoders is not None:
198
+ print(f"Encoding prompt: {prompt}")
199
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
200
+ # strategy has apply_t5_attn_mask option
201
+ encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
202
+
203
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
204
+ if len(text_encoder_conds) == 0:
205
+ text_encoder_conds = encoded_text_encoder_conds
206
+ else:
207
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
208
+ for i in range(len(encoded_text_encoder_conds)):
209
+ if encoded_text_encoder_conds[i] is not None:
210
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
211
+
212
+ if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
213
+ ae_outputs = sample_images_ae_outputs[prompt]
214
+ else:
215
+ ae_outputs = None
216
+
217
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
218
+
219
+ # sample image
220
+ weight_dtype = ae.dtype # TOFO give dtype as argument
221
+ packed_latent_height = height // 16
222
+ packed_latent_width = width // 16
223
+ noise = torch.randn(
224
+ 1,
225
+ packed_latent_height * packed_latent_width,
226
+ 16 * 2 * 2,
227
+ device=accelerator.device,
228
+ dtype=weight_dtype,
229
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
230
+ )
231
+ timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
232
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
233
+ t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
234
+
235
+ with accelerator.autocast(), torch.no_grad():
236
+ x = denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
237
+
238
+ x = x.float()
239
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
240
+
241
+ # latent to image
242
+ clean_memory_on_device(accelerator.device)
243
+ org_vae_device = ae.device # will be on cpu
244
+ ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
245
+ with accelerator.autocast(), torch.no_grad():
246
+ x = ae.decode(x)
247
+ ae.to(org_vae_device)
248
+ clean_memory_on_device(accelerator.device)
249
+
250
+ x = x.clamp(-1, 1)
251
+ x = x.permute(0, 2, 3, 1)
252
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
253
+
254
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
255
+ # but adding 'enum' to the filename should be enough
256
+
257
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
258
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
259
+ seed_suffix = "" if seed is None else f"_{seed}"
260
+ i: int = prompt_dict["enum"]
261
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
262
+ image.save(os.path.join(save_dir, img_filename))
263
+
264
+ # send images to wandb if enabled
265
+ if "wandb" in [tracker.name for tracker in accelerator.trackers]:
266
+ wandb_tracker = accelerator.get_tracker("wandb")
267
+
268
+ import wandb
269
+ # not to commit images to avoid inconsistency between training and logging steps
270
+ wandb_tracker.log(
271
+ {f"sample_{i}": wandb.Image(
272
+ image,
273
+ caption=prompt # positive prompt as a caption
274
+ )},
275
+ commit=False
276
+ )
277
+
278
+
279
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
280
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
281
+
282
+
283
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
284
+ m = (y2 - y1) / (x2 - x1)
285
+ b = y1 - m * x1
286
+ return lambda x: m * x + b
287
+
288
+
289
+ def get_schedule(
290
+ num_steps: int,
291
+ image_seq_len: int,
292
+ base_shift: float = 0.5,
293
+ max_shift: float = 1.15,
294
+ shift: bool = True,
295
+ ) -> list[float]:
296
+ # extra step for zero
297
+ timesteps = torch.linspace(1, 0, num_steps + 1)
298
+
299
+ # shifting the schedule to favor high timesteps for higher signal images
300
+ if shift:
301
+ # eastimate mu based on linear estimation between two points
302
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
303
+ timesteps = time_shift(mu, 1.0, timesteps)
304
+
305
+ return timesteps.tolist()
306
+
307
+
308
+ def denoise(
309
+ args: argparse.Namespace,
310
+ model: flux_models.Flux,
311
+ img: torch.Tensor,
312
+ img_ids: torch.Tensor,
313
+ txt: torch.Tensor,
314
+ txt_ids: torch.Tensor,
315
+ vec: torch.Tensor,
316
+ timesteps: list[float],
317
+ guidance: float = 4.0,
318
+ t5_attn_mask: Optional[torch.Tensor] = None,
319
+ ae_outputs: torch.Tensor = None,
320
+ ):
321
+ # this is ignored for schnell
322
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
323
+ img_ids = img_ids.to(img.device)
324
+ txt_ids = txt_ids.to(img.device)
325
+ vec = vec.to(img.device)
326
+ txt = txt.to(img.device)
327
+
328
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
329
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
330
+ model.prepare_block_swap_before_forward()
331
+ if args.frame_num == 4:
332
+ packed_latent_height, packed_latent_width = ae_outputs.shape[2]*2 // 2, ae_outputs.shape[3]*2 // 2
333
+ img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
334
+ img[:,:, img.shape[2] // 2: img.shape[2], :img.shape[3] // 2] = ae_outputs
335
+ else:
336
+ packed_latent_height, packed_latent_width = ae_outputs.shape[2]*3 // 2, ae_outputs.shape[3]*3 // 2
337
+ img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
338
+ img[:,:, 2*img.shape[2] // 3: img.shape[2], 2*img.shape[3] // 3:img.shape[3]] = ae_outputs
339
+
340
+ img = flux_utils.pack_latents(img)
341
+ pred = model(
342
+ img=img,
343
+ img_ids=img_ids,
344
+ txt=txt,
345
+ txt_ids=txt_ids,
346
+ y=vec,
347
+ timesteps=t_vec,
348
+ guidance=guidance_vec,
349
+ txt_attention_mask=t5_attn_mask,
350
+ )
351
+
352
+ img = img + (t_prev - t_curr) * pred
353
+
354
+ model.prepare_block_swap_before_forward()
355
+ return img
356
+
357
+
358
+ # endregion
359
+
360
+
361
+ # region train
362
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
363
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
364
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
365
+ timesteps = timesteps.to(device)
366
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
367
+
368
+ sigma = sigmas[step_indices].flatten()
369
+ while len(sigma.shape) < n_dim:
370
+ sigma = sigma.unsqueeze(-1)
371
+ return sigma
372
+
373
+
374
+ def compute_density_for_timestep_sampling(
375
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
376
+ ):
377
+ """Compute the density for sampling the timesteps when doing SD3 training.
378
+
379
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
380
+
381
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
382
+ """
383
+ if weighting_scheme == "logit_normal":
384
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
385
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
386
+ u = torch.nn.functional.sigmoid(u)
387
+ elif weighting_scheme == "mode":
388
+ u = torch.rand(size=(batch_size,), device="cpu")
389
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
390
+ else:
391
+ u = torch.rand(size=(batch_size,), device="cpu")
392
+ return u
393
+
394
+
395
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
396
+ """Computes loss weighting scheme for SD3 training.
397
+
398
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
399
+
400
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
401
+ """
402
+ if weighting_scheme == "sigma_sqrt":
403
+ weighting = (sigmas**-2.0).float()
404
+ elif weighting_scheme == "cosmap":
405
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
406
+ weighting = 2 / (math.pi * bot)
407
+ else:
408
+ weighting = torch.ones_like(sigmas)
409
+ return weighting
410
+
411
+
412
+ def get_noisy_model_input_and_timesteps(
413
+ args, noise_scheduler, latents, noise, device, dtype
414
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
415
+ bsz, _, h, w = latents.shape
416
+ sigmas = None
417
+
418
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
419
+ # Simple random t-based noise sampling
420
+ if args.timestep_sampling == "sigmoid":
421
+ # https://github.com/XLabs-AI/x-flux/tree/main
422
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
423
+ else:
424
+ t = torch.rand((bsz,), device=device)
425
+
426
+ timesteps = t * 1000.0
427
+ t = t.view(-1, 1, 1, 1)
428
+ noisy_model_input = (1 - t) * latents + t * noise
429
+ elif args.timestep_sampling == "shift":
430
+ shift = args.discrete_flow_shift
431
+ logits_norm = torch.randn(bsz, device=device)
432
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
433
+ timesteps = logits_norm.sigmoid()
434
+ timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
435
+
436
+ t = timesteps.view(-1, 1, 1, 1)
437
+ timesteps = timesteps * 1000.0
438
+ noisy_model_input = (1 - t) * latents + t * noise
439
+ elif args.timestep_sampling == "flux_shift":
440
+ logits_norm = torch.randn(bsz, device=device)
441
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
442
+ timesteps = logits_norm.sigmoid()
443
+ mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
444
+ timesteps = time_shift(mu, 1.0, timesteps)
445
+
446
+ t = timesteps.view(-1, 1, 1, 1)
447
+ timesteps = timesteps * 1000.0
448
+ noisy_model_input = (1 - t) * latents + t * noise
449
+ else:
450
+ # Sample a random timestep for each image
451
+ # for weighting schemes where we sample timesteps non-uniformly
452
+ u = compute_density_for_timestep_sampling(
453
+ weighting_scheme=args.weighting_scheme,
454
+ batch_size=bsz,
455
+ logit_mean=args.logit_mean,
456
+ logit_std=args.logit_std,
457
+ mode_scale=args.mode_scale,
458
+ )
459
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
460
+ timesteps = noise_scheduler.timesteps[indices].to(device=device)
461
+
462
+ # Add noise according to flow matching.
463
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
464
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
465
+
466
+ # 替换部分区域为原始latents
467
+ h, w = noisy_model_input.shape[2], noisy_model_input.shape[3]
468
+ # import pdb; pdb.set_trace()
469
+ if args.frame_num == 4:
470
+ noisy_model_input[:, :, h//2 : h, w//2 : w] = latents[:, :, h//2:h, w//2:w]
471
+ else:
472
+ noisy_model_input[:, :, 2*h//3 : h, 2*w//3 : w] = latents[:, :, 2*h//3:h, 2*w//3:w]
473
+
474
+
475
+ return noisy_model_input, timesteps, sigmas
476
+
477
+
478
+ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
479
+ weighting = None
480
+ if args.model_prediction_type == "raw":
481
+ pass
482
+ elif args.model_prediction_type == "additive":
483
+ # add the model_pred to the noisy_model_input
484
+ model_pred = model_pred + noisy_model_input
485
+ elif args.model_prediction_type == "sigma_scaled":
486
+ # apply sigma scaling
487
+ model_pred = model_pred * (-sigmas) + noisy_model_input
488
+
489
+ # these weighting schemes use a uniform timestep sampling
490
+ # and instead post-weight the loss
491
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
492
+
493
+ return model_pred, weighting
494
+
495
+
496
+ def save_models(
497
+ ckpt_path: str,
498
+ flux: flux_models.Flux,
499
+ sai_metadata: Optional[dict],
500
+ save_dtype: Optional[torch.dtype] = None,
501
+ use_mem_eff_save: bool = False,
502
+ ):
503
+ state_dict = {}
504
+
505
+ def update_sd(prefix, sd):
506
+ for k, v in sd.items():
507
+ key = prefix + k
508
+ if save_dtype is not None and v.dtype != save_dtype:
509
+ v = v.detach().clone().to("cpu").to(save_dtype)
510
+ state_dict[key] = v
511
+
512
+ update_sd("", flux.state_dict())
513
+
514
+ if not use_mem_eff_save:
515
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
516
+ else:
517
+ mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
518
+
519
+
520
+ def save_flux_model_on_train_end(
521
+ args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
522
+ ):
523
+ def sd_saver(ckpt_file, epoch_no, global_step):
524
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
525
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
526
+
527
+ train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
528
+
529
+
530
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
531
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
532
+ def save_flux_model_on_epoch_end_or_stepwise(
533
+ args: argparse.Namespace,
534
+ on_epoch_end: bool,
535
+ accelerator,
536
+ save_dtype: torch.dtype,
537
+ epoch: int,
538
+ num_train_epochs: int,
539
+ global_step: int,
540
+ flux: flux_models.Flux,
541
+ ):
542
+ def sd_saver(ckpt_file, epoch_no, global_step):
543
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
544
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
545
+
546
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
547
+ args,
548
+ on_epoch_end,
549
+ accelerator,
550
+ True,
551
+ True,
552
+ epoch,
553
+ num_train_epochs,
554
+ global_step,
555
+ sd_saver,
556
+ None,
557
+ )
558
+
559
+
560
+ # endregion
561
+
562
+
563
+ def add_flux_train_arguments(parser: argparse.ArgumentParser):
564
+ parser.add_argument(
565
+ "--clip_l",
566
+ type=str,
567
+ help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
568
+ )
569
+ parser.add_argument(
570
+ "--t5xxl",
571
+ type=str,
572
+ help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
573
+ )
574
+ parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
575
+ parser.add_argument(
576
+ "--t5xxl_max_token_length",
577
+ type=int,
578
+ default=None,
579
+ help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
580
+ " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
581
+ )
582
+ parser.add_argument(
583
+ "--apply_t5_attn_mask",
584
+ action="store_true",
585
+ help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
586
+ )
587
+ parser.add_argument(
588
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
589
+ )
590
+ parser.add_argument(
591
+ "--cache_text_encoder_outputs_to_disk",
592
+ action="store_true",
593
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
594
+ )
595
+ parser.add_argument(
596
+ "--text_encoder_batch_size",
597
+ type=int,
598
+ default=None,
599
+ help="text encoder batch size (default: None, use dataset's batch size)"
600
+ + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
601
+ )
602
+ parser.add_argument(
603
+ "--disable_mmap_load_safetensors",
604
+ action="store_true",
605
+ help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
606
+ )
607
+
608
+ # copy from Diffusers
609
+ parser.add_argument(
610
+ "--weighting_scheme",
611
+ type=str,
612
+ default="none",
613
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
614
+ )
615
+ parser.add_argument(
616
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
617
+ )
618
+ parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
619
+ parser.add_argument(
620
+ "--mode_scale",
621
+ type=float,
622
+ default=1.29,
623
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
624
+ )
625
+ parser.add_argument(
626
+ "--guidance_scale",
627
+ type=float,
628
+ default=3.5,
629
+ help="the FLUX.1 dev variant is a guidance distilled model",
630
+ )
631
+
632
+ parser.add_argument(
633
+ "--timestep_sampling",
634
+ choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
635
+ default="sigma",
636
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
637
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
638
+ )
639
+ parser.add_argument(
640
+ "--sigmoid_scale",
641
+ type=float,
642
+ default=1.0,
643
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
644
+ )
645
+ parser.add_argument(
646
+ "--model_prediction_type",
647
+ choices=["raw", "additive", "sigma_scaled"],
648
+ default="sigma_scaled",
649
+ help="How to interpret and process the model prediction: "
650
+ "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
651
+ " / モデル予測の解釈と処理方法:"
652
+ "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
653
+ )
654
+ parser.add_argument(
655
+ "--discrete_flow_shift",
656
+ type=float,
657
+ default=3.0,
658
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
659
+ )
library/flux_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ import json
3
+ import os
4
+ from typing import List, Optional, Tuple, Union
5
+ import einops
6
+ import torch
7
+
8
+ from safetensors.torch import load_file
9
+ from safetensors import safe_open
10
+ from accelerate import init_empty_weights
11
+ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
12
+
13
+ from library.utils import setup_logging
14
+
15
+ setup_logging()
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ from library import flux_models
21
+ from library.utils import load_safetensors
22
+
23
+ MODEL_VERSION_FLUX_V1 = "flux1"
24
+ MODEL_NAME_DEV = "dev"
25
+ MODEL_NAME_SCHNELL = "schnell"
26
+
27
+
28
+ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
29
+ """
30
+ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
31
+
32
+ Args:
33
+ ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
34
+
35
+ Returns:
36
+ Tuple[bool, bool, Tuple[int, int], List[str]]:
37
+ - bool: Diffusersかどうかを示すフラグ。
38
+ - bool: Schnellかどうかを示すフラグ。
39
+ - Tuple[int, int]: ダブルブロックとシングルブロックの数。
40
+ - List[str]: チェックポイントに含まれるキーのリスト。
41
+ """
42
+ # check the state dict: Diffusers or BFL, dev or schnell, number of blocks
43
+ logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
44
+
45
+ if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
46
+ ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
47
+ if "00001-of-00003" in ckpt_path:
48
+ ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
49
+ else:
50
+ ckpt_paths = [ckpt_path]
51
+
52
+ keys = []
53
+ for ckpt_path in ckpt_paths:
54
+ with safe_open(ckpt_path, framework="pt") as f:
55
+ keys.extend(f.keys())
56
+
57
+ # if the key has annoying prefix, remove it
58
+ if keys[0].startswith("model.diffusion_model."):
59
+ keys = [key.replace("model.diffusion_model.", "") for key in keys]
60
+
61
+ is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
62
+ is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
63
+
64
+ # check number of double and single blocks
65
+ if not is_diffusers:
66
+ max_double_block_index = max(
67
+ [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
68
+ )
69
+ max_single_block_index = max(
70
+ [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
71
+ )
72
+ else:
73
+ max_double_block_index = max(
74
+ [
75
+ int(key.split(".")[1])
76
+ for key in keys
77
+ if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
78
+ ]
79
+ )
80
+ max_single_block_index = max(
81
+ [
82
+ int(key.split(".")[1])
83
+ for key in keys
84
+ if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
85
+ ]
86
+ )
87
+
88
+ num_double_blocks = max_double_block_index + 1
89
+ num_single_blocks = max_single_block_index + 1
90
+
91
+ return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
92
+
93
+
94
+ def load_flow_model(
95
+ ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
96
+ ) -> Tuple[bool, flux_models.Flux]:
97
+ is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
98
+ name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
99
+
100
+ # build model
101
+ logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
102
+ with torch.device("meta"):
103
+ params = flux_models.configs[name].params
104
+
105
+ # set the number of blocks
106
+ if params.depth != num_double_blocks:
107
+ logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
108
+ params = replace(params, depth=num_double_blocks)
109
+ if params.depth_single_blocks != num_single_blocks:
110
+ logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
111
+ params = replace(params, depth_single_blocks=num_single_blocks)
112
+
113
+ model = flux_models.Flux(params)
114
+ if dtype is not None:
115
+ model = model.to(dtype)
116
+
117
+ # load_sft doesn't support torch.device
118
+ logger.info(f"Loading state dict from {ckpt_path}")
119
+ sd = {}
120
+ for ckpt_path in ckpt_paths:
121
+ sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
122
+
123
+ # convert Diffusers to BFL
124
+ if is_diffusers:
125
+ logger.info("Converting Diffusers to BFL")
126
+ sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
127
+ logger.info("Converted Diffusers to BFL")
128
+
129
+ # if the key has annoying prefix, remove it
130
+ for key in list(sd.keys()):
131
+ new_key = key.replace("model.diffusion_model.", "")
132
+ if new_key == key:
133
+ break # the model doesn't have annoying prefix
134
+ sd[new_key] = sd.pop(key)
135
+
136
+ info = model.load_state_dict(sd, strict=False, assign=True)
137
+ logger.info(f"Loaded Flux: {info}")
138
+ return is_schnell, model
139
+
140
+
141
+ def load_ae(
142
+ ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
143
+ ) -> flux_models.AutoEncoder:
144
+ logger.info("Building AutoEncoder")
145
+ with torch.device("meta"):
146
+ # dev and schnell have the same AE params
147
+ ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
148
+
149
+ logger.info(f"Loading state dict from {ckpt_path}")
150
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
151
+ info = ae.load_state_dict(sd, strict=False, assign=True)
152
+ logger.info(f"Loaded AE: {info}")
153
+ return ae
154
+
155
+
156
+ def load_clip_l(
157
+ ckpt_path: Optional[str],
158
+ dtype: torch.dtype,
159
+ device: Union[str, torch.device],
160
+ disable_mmap: bool = False,
161
+ state_dict: Optional[dict] = None,
162
+ ) -> CLIPTextModel:
163
+ logger.info("Building CLIP-L")
164
+ CLIPL_CONFIG = {
165
+ "_name_or_path": "clip-vit-large-patch14/",
166
+ "architectures": ["CLIPModel"],
167
+ "initializer_factor": 1.0,
168
+ "logit_scale_init_value": 2.6592,
169
+ "model_type": "clip",
170
+ "projection_dim": 768,
171
+ # "text_config": {
172
+ "_name_or_path": "",
173
+ "add_cross_attention": False,
174
+ "architectures": None,
175
+ "attention_dropout": 0.0,
176
+ "bad_words_ids": None,
177
+ "bos_token_id": 0,
178
+ "chunk_size_feed_forward": 0,
179
+ "cross_attention_hidden_size": None,
180
+ "decoder_start_token_id": None,
181
+ "diversity_penalty": 0.0,
182
+ "do_sample": False,
183
+ "dropout": 0.0,
184
+ "early_stopping": False,
185
+ "encoder_no_repeat_ngram_size": 0,
186
+ "eos_token_id": 2,
187
+ "finetuning_task": None,
188
+ "forced_bos_token_id": None,
189
+ "forced_eos_token_id": None,
190
+ "hidden_act": "quick_gelu",
191
+ "hidden_size": 768,
192
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
193
+ "initializer_factor": 1.0,
194
+ "initializer_range": 0.02,
195
+ "intermediate_size": 3072,
196
+ "is_decoder": False,
197
+ "is_encoder_decoder": False,
198
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
199
+ "layer_norm_eps": 1e-05,
200
+ "length_penalty": 1.0,
201
+ "max_length": 20,
202
+ "max_position_embeddings": 77,
203
+ "min_length": 0,
204
+ "model_type": "clip_text_model",
205
+ "no_repeat_ngram_size": 0,
206
+ "num_attention_heads": 12,
207
+ "num_beam_groups": 1,
208
+ "num_beams": 1,
209
+ "num_hidden_layers": 12,
210
+ "num_return_sequences": 1,
211
+ "output_attentions": False,
212
+ "output_hidden_states": False,
213
+ "output_scores": False,
214
+ "pad_token_id": 1,
215
+ "prefix": None,
216
+ "problem_type": None,
217
+ "projection_dim": 768,
218
+ "pruned_heads": {},
219
+ "remove_invalid_values": False,
220
+ "repetition_penalty": 1.0,
221
+ "return_dict": True,
222
+ "return_dict_in_generate": False,
223
+ "sep_token_id": None,
224
+ "task_specific_params": None,
225
+ "temperature": 1.0,
226
+ "tie_encoder_decoder": False,
227
+ "tie_word_embeddings": True,
228
+ "tokenizer_class": None,
229
+ "top_k": 50,
230
+ "top_p": 1.0,
231
+ "torch_dtype": None,
232
+ "torchscript": False,
233
+ "transformers_version": "4.16.0.dev0",
234
+ "use_bfloat16": False,
235
+ "vocab_size": 49408,
236
+ "hidden_act": "gelu",
237
+ "hidden_size": 1280,
238
+ "intermediate_size": 5120,
239
+ "num_attention_heads": 20,
240
+ "num_hidden_layers": 32,
241
+ # },
242
+ # "text_config_dict": {
243
+ "hidden_size": 768,
244
+ "intermediate_size": 3072,
245
+ "num_attention_heads": 12,
246
+ "num_hidden_layers": 12,
247
+ "projection_dim": 768,
248
+ # },
249
+ # "torch_dtype": "float32",
250
+ # "transformers_version": None,
251
+ }
252
+ config = CLIPConfig(**CLIPL_CONFIG)
253
+ with init_empty_weights():
254
+ clip = CLIPTextModel._from_config(config)
255
+
256
+ if state_dict is not None:
257
+ sd = state_dict
258
+ else:
259
+ logger.info(f"Loading state dict from {ckpt_path}")
260
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
261
+ info = clip.load_state_dict(sd, strict=False, assign=True)
262
+ logger.info(f"Loaded CLIP-L: {info}")
263
+ return clip
264
+
265
+
266
+ def load_t5xxl(
267
+ ckpt_path: str,
268
+ dtype: Optional[torch.dtype],
269
+ device: Union[str, torch.device],
270
+ disable_mmap: bool = False,
271
+ state_dict: Optional[dict] = None,
272
+ ) -> T5EncoderModel:
273
+ T5_CONFIG_JSON = """
274
+ {
275
+ "architectures": [
276
+ "T5EncoderModel"
277
+ ],
278
+ "classifier_dropout": 0.0,
279
+ "d_ff": 10240,
280
+ "d_kv": 64,
281
+ "d_model": 4096,
282
+ "decoder_start_token_id": 0,
283
+ "dense_act_fn": "gelu_new",
284
+ "dropout_rate": 0.1,
285
+ "eos_token_id": 1,
286
+ "feed_forward_proj": "gated-gelu",
287
+ "initializer_factor": 1.0,
288
+ "is_encoder_decoder": true,
289
+ "is_gated_act": true,
290
+ "layer_norm_epsilon": 1e-06,
291
+ "model_type": "t5",
292
+ "num_decoder_layers": 24,
293
+ "num_heads": 64,
294
+ "num_layers": 24,
295
+ "output_past": true,
296
+ "pad_token_id": 0,
297
+ "relative_attention_max_distance": 128,
298
+ "relative_attention_num_buckets": 32,
299
+ "tie_word_embeddings": false,
300
+ "torch_dtype": "float16",
301
+ "transformers_version": "4.41.2",
302
+ "use_cache": true,
303
+ "vocab_size": 32128
304
+ }
305
+ """
306
+ config = json.loads(T5_CONFIG_JSON)
307
+ config = T5Config(**config)
308
+ with init_empty_weights():
309
+ t5xxl = T5EncoderModel._from_config(config)
310
+
311
+ if state_dict is not None:
312
+ sd = state_dict
313
+ else:
314
+ logger.info(f"Loading state dict from {ckpt_path}")
315
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
316
+ info = t5xxl.load_state_dict(sd, strict=False, assign=True)
317
+ logger.info(f"Loaded T5xxl: {info}")
318
+ return t5xxl
319
+
320
+
321
+ def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
322
+ # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
323
+ return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
324
+
325
+
326
+ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
327
+ img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
328
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
329
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
330
+ img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
331
+ return img_ids
332
+
333
+
334
+ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
335
+ """
336
+ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
337
+ """
338
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
339
+ return x
340
+
341
+
342
+ def pack_latents(x: torch.Tensor) -> torch.Tensor:
343
+ """
344
+ x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
345
+ """
346
+ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
347
+ return x
348
+
349
+
350
+ # region Diffusers
351
+
352
+ NUM_DOUBLE_BLOCKS = 19
353
+ NUM_SINGLE_BLOCKS = 38
354
+
355
+ BFL_TO_DIFFUSERS_MAP = {
356
+ "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
357
+ "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
358
+ "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
359
+ "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
360
+ "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
361
+ "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
362
+ "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
363
+ "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
364
+ "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
365
+ "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
366
+ "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
367
+ "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
368
+ "txt_in.weight": ["context_embedder.weight"],
369
+ "txt_in.bias": ["context_embedder.bias"],
370
+ "img_in.weight": ["x_embedder.weight"],
371
+ "img_in.bias": ["x_embedder.bias"],
372
+ "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
373
+ "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
374
+ "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
375
+ "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
376
+ "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
377
+ "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
378
+ "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
379
+ "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
380
+ "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
381
+ "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
382
+ "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
383
+ "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
384
+ "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
385
+ "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
386
+ "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
387
+ "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
388
+ "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
389
+ "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
390
+ "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
391
+ "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
392
+ "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
393
+ "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
394
+ "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
395
+ "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
396
+ "single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
397
+ "single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
398
+ "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
399
+ "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
400
+ "single_blocks.().linear2.weight": ["proj_out.weight"],
401
+ "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
402
+ "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
403
+ "single_blocks.().linear2.weight": ["proj_out.weight"],
404
+ "single_blocks.().linear2.bias": ["proj_out.bias"],
405
+ "final_layer.linear.weight": ["proj_out.weight"],
406
+ "final_layer.linear.bias": ["proj_out.bias"],
407
+ "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
408
+ "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
409
+ }
410
+
411
+
412
+ def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
413
+ # make reverse map from diffusers map
414
+ diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
415
+ for b in range(num_double_blocks):
416
+ for key, weights in BFL_TO_DIFFUSERS_MAP.items():
417
+ if key.startswith("double_blocks."):
418
+ block_prefix = f"transformer_blocks.{b}."
419
+ for i, weight in enumerate(weights):
420
+ diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
421
+ for b in range(num_single_blocks):
422
+ for key, weights in BFL_TO_DIFFUSERS_MAP.items():
423
+ if key.startswith("single_blocks."):
424
+ block_prefix = f"single_transformer_blocks.{b}."
425
+ for i, weight in enumerate(weights):
426
+ diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
427
+ for key, weights in BFL_TO_DIFFUSERS_MAP.items():
428
+ if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
429
+ for i, weight in enumerate(weights):
430
+ diffusers_to_bfl_map[weight] = (i, key)
431
+ return diffusers_to_bfl_map
432
+
433
+
434
+ def convert_diffusers_sd_to_bfl(
435
+ diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
436
+ ) -> dict[str, torch.Tensor]:
437
+ diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
438
+
439
+ # iterate over three safetensors files to reduce memory usage
440
+ flux_sd = {}
441
+ for diffusers_key, tensor in diffusers_sd.items():
442
+ if diffusers_key in diffusers_to_bfl_map:
443
+ index, bfl_key = diffusers_to_bfl_map[diffusers_key]
444
+ if bfl_key not in flux_sd:
445
+ flux_sd[bfl_key] = []
446
+ flux_sd[bfl_key].append((index, tensor))
447
+ else:
448
+ logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
449
+ raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
450
+
451
+ # concat tensors if multiple tensors are mapped to a single key, sort by index
452
+ for key, values in flux_sd.items():
453
+ if len(values) == 1:
454
+ flux_sd[key] = values[0][1]
455
+ else:
456
+ flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
457
+
458
+ # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
459
+ def swap_scale_shift(weight):
460
+ shift, scale = weight.chunk(2, dim=0)
461
+ new_weight = torch.cat([scale, shift], dim=0)
462
+ return new_weight
463
+
464
+ if "final_layer.adaLN_modulation.1.weight" in flux_sd:
465
+ flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
466
+ if "final_layer.adaLN_modulation.1.bias" in flux_sd:
467
+ flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
468
+
469
+ return flux_sd
470
+
471
+
472
+ # endregion
library/huggingface_util.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, BinaryIO
2
+ from huggingface_hub import HfApi
3
+ from pathlib import Path
4
+ import argparse
5
+ import os
6
+ from library.utils import fire_in_thread
7
+ from library.utils import setup_logging
8
+ setup_logging()
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
13
+ api = HfApi(
14
+ token=token,
15
+ )
16
+ try:
17
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
18
+ return True
19
+ except:
20
+ return False
21
+
22
+
23
+ def upload(
24
+ args: argparse.Namespace,
25
+ src: Union[str, Path, bytes, BinaryIO],
26
+ dest_suffix: str = "",
27
+ force_sync_upload: bool = False,
28
+ ):
29
+ repo_id = args.huggingface_repo_id
30
+ repo_type = args.huggingface_repo_type
31
+ token = args.huggingface_token
32
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
33
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
34
+ api = HfApi(token=token)
35
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
36
+ try:
37
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
38
+ except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
39
+ logger.error("===========================================")
40
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
41
+ logger.error("===========================================")
42
+
43
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
44
+
45
+ def uploader():
46
+ try:
47
+ if is_folder:
48
+ api.upload_folder(
49
+ repo_id=repo_id,
50
+ repo_type=repo_type,
51
+ folder_path=src,
52
+ path_in_repo=path_in_repo,
53
+ )
54
+ else:
55
+ api.upload_file(
56
+ repo_id=repo_id,
57
+ repo_type=repo_type,
58
+ path_or_fileobj=src,
59
+ path_in_repo=path_in_repo,
60
+ )
61
+ except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
62
+ logger.error("===========================================")
63
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
64
+ logger.error("===========================================")
65
+
66
+ if args.async_upload and not force_sync_upload:
67
+ fire_in_thread(uploader)
68
+ else:
69
+ uploader()
70
+
71
+
72
+ def list_dir(
73
+ repo_id: str,
74
+ subfolder: str,
75
+ repo_type: str,
76
+ revision: str = "main",
77
+ token: str = None,
78
+ ):
79
+ api = HfApi(
80
+ token=token,
81
+ )
82
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
83
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
84
+ return file_list
library/hypernetwork.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from diffusers.models.attention_processor import (
4
+ Attention,
5
+ AttnProcessor2_0,
6
+ SlicedAttnProcessor,
7
+ XFormersAttnProcessor
8
+ )
9
+
10
+ try:
11
+ import xformers.ops
12
+ except:
13
+ xformers = None
14
+
15
+
16
+ loaded_networks = []
17
+
18
+
19
+ def apply_single_hypernetwork(
20
+ hypernetwork, hidden_states, encoder_hidden_states
21
+ ):
22
+ context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
23
+ return context_k, context_v
24
+
25
+
26
+ def apply_hypernetworks(context_k, context_v, layer=None):
27
+ if len(loaded_networks) == 0:
28
+ return context_v, context_v
29
+ for hypernetwork in loaded_networks:
30
+ context_k, context_v = hypernetwork.forward(context_k, context_v)
31
+
32
+ context_k = context_k.to(dtype=context_k.dtype)
33
+ context_v = context_v.to(dtype=context_k.dtype)
34
+
35
+ return context_k, context_v
36
+
37
+
38
+
39
+ def xformers_forward(
40
+ self: XFormersAttnProcessor,
41
+ attn: Attention,
42
+ hidden_states: torch.Tensor,
43
+ encoder_hidden_states: torch.Tensor = None,
44
+ attention_mask: torch.Tensor = None,
45
+ ):
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape
48
+ if encoder_hidden_states is None
49
+ else encoder_hidden_states.shape
50
+ )
51
+
52
+ attention_mask = attn.prepare_attention_mask(
53
+ attention_mask, sequence_length, batch_size
54
+ )
55
+
56
+ query = attn.to_q(hidden_states)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
64
+
65
+ key = attn.to_k(context_k)
66
+ value = attn.to_v(context_v)
67
+
68
+ query = attn.head_to_batch_dim(query).contiguous()
69
+ key = attn.head_to_batch_dim(key).contiguous()
70
+ value = attn.head_to_batch_dim(value).contiguous()
71
+
72
+ hidden_states = xformers.ops.memory_efficient_attention(
73
+ query,
74
+ key,
75
+ value,
76
+ attn_bias=attention_mask,
77
+ op=self.attention_op,
78
+ scale=attn.scale,
79
+ )
80
+ hidden_states = hidden_states.to(query.dtype)
81
+ hidden_states = attn.batch_to_head_dim(hidden_states)
82
+
83
+ # linear proj
84
+ hidden_states = attn.to_out[0](hidden_states)
85
+ # dropout
86
+ hidden_states = attn.to_out[1](hidden_states)
87
+ return hidden_states
88
+
89
+
90
+ def sliced_attn_forward(
91
+ self: SlicedAttnProcessor,
92
+ attn: Attention,
93
+ hidden_states: torch.Tensor,
94
+ encoder_hidden_states: torch.Tensor = None,
95
+ attention_mask: torch.Tensor = None,
96
+ ):
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape
99
+ if encoder_hidden_states is None
100
+ else encoder_hidden_states.shape
101
+ )
102
+ attention_mask = attn.prepare_attention_mask(
103
+ attention_mask, sequence_length, batch_size
104
+ )
105
+
106
+ query = attn.to_q(hidden_states)
107
+ dim = query.shape[-1]
108
+ query = attn.head_to_batch_dim(query)
109
+
110
+ if encoder_hidden_states is None:
111
+ encoder_hidden_states = hidden_states
112
+ elif attn.norm_cross:
113
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
114
+
115
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
116
+
117
+ key = attn.to_k(context_k)
118
+ value = attn.to_v(context_v)
119
+ key = attn.head_to_batch_dim(key)
120
+ value = attn.head_to_batch_dim(value)
121
+
122
+ batch_size_attention, query_tokens, _ = query.shape
123
+ hidden_states = torch.zeros(
124
+ (batch_size_attention, query_tokens, dim // attn.heads),
125
+ device=query.device,
126
+ dtype=query.dtype,
127
+ )
128
+
129
+ for i in range(batch_size_attention // self.slice_size):
130
+ start_idx = i * self.slice_size
131
+ end_idx = (i + 1) * self.slice_size
132
+
133
+ query_slice = query[start_idx:end_idx]
134
+ key_slice = key[start_idx:end_idx]
135
+ attn_mask_slice = (
136
+ attention_mask[start_idx:end_idx] if attention_mask is not None else None
137
+ )
138
+
139
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
140
+
141
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
142
+
143
+ hidden_states[start_idx:end_idx] = attn_slice
144
+
145
+ hidden_states = attn.batch_to_head_dim(hidden_states)
146
+
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+
152
+ return hidden_states
153
+
154
+
155
+ def v2_0_forward(
156
+ self: AttnProcessor2_0,
157
+ attn: Attention,
158
+ hidden_states,
159
+ encoder_hidden_states=None,
160
+ attention_mask=None,
161
+ ):
162
+ batch_size, sequence_length, _ = (
163
+ hidden_states.shape
164
+ if encoder_hidden_states is None
165
+ else encoder_hidden_states.shape
166
+ )
167
+ inner_dim = hidden_states.shape[-1]
168
+
169
+ if attention_mask is not None:
170
+ attention_mask = attn.prepare_attention_mask(
171
+ attention_mask, sequence_length, batch_size
172
+ )
173
+ # scaled_dot_product_attention expects attention_mask shape to be
174
+ # (batch, heads, source_length, target_length)
175
+ attention_mask = attention_mask.view(
176
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
177
+ )
178
+
179
+ query = attn.to_q(hidden_states)
180
+
181
+ if encoder_hidden_states is None:
182
+ encoder_hidden_states = hidden_states
183
+ elif attn.norm_cross:
184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
185
+
186
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
187
+
188
+ key = attn.to_k(context_k)
189
+ value = attn.to_v(context_v)
190
+
191
+ head_dim = inner_dim // attn.heads
192
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+
196
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
197
+ # TODO: add support for attn.scale when we move to Torch 2.1
198
+ hidden_states = F.scaled_dot_product_attention(
199
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
200
+ )
201
+
202
+ hidden_states = hidden_states.transpose(1, 2).reshape(
203
+ batch_size, -1, attn.heads * head_dim
204
+ )
205
+ hidden_states = hidden_states.to(query.dtype)
206
+
207
+ # linear proj
208
+ hidden_states = attn.to_out[0](hidden_states)
209
+ # dropout
210
+ hidden_states = attn.to_out[1](hidden_states)
211
+ return hidden_states
212
+
213
+
214
+ def replace_attentions_for_hypernetwork():
215
+ import diffusers.models.attention_processor
216
+
217
+ diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
218
+ xformers_forward
219
+ )
220
+ diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
221
+ sliced_attn_forward
222
+ )
223
+ diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
library/ipex/__init__.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import contextlib
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ from .hijacks import ipex_hijacks
7
+
8
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
9
+
10
+ def ipex_init(): # pylint: disable=too-many-statements
11
+ try:
12
+ if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
13
+ return True, "Skipping IPEX hijack"
14
+ else:
15
+ # Replace cuda with xpu:
16
+ torch.cuda.current_device = torch.xpu.current_device
17
+ torch.cuda.current_stream = torch.xpu.current_stream
18
+ torch.cuda.device = torch.xpu.device
19
+ torch.cuda.device_count = torch.xpu.device_count
20
+ torch.cuda.device_of = torch.xpu.device_of
21
+ torch.cuda.get_device_name = torch.xpu.get_device_name
22
+ torch.cuda.get_device_properties = torch.xpu.get_device_properties
23
+ torch.cuda.init = torch.xpu.init
24
+ torch.cuda.is_available = torch.xpu.is_available
25
+ torch.cuda.is_initialized = torch.xpu.is_initialized
26
+ torch.cuda.is_current_stream_capturing = lambda: False
27
+ torch.cuda.set_device = torch.xpu.set_device
28
+ torch.cuda.stream = torch.xpu.stream
29
+ torch.cuda.synchronize = torch.xpu.synchronize
30
+ torch.cuda.Event = torch.xpu.Event
31
+ torch.cuda.Stream = torch.xpu.Stream
32
+ torch.cuda.FloatTensor = torch.xpu.FloatTensor
33
+ torch.Tensor.cuda = torch.Tensor.xpu
34
+ torch.Tensor.is_cuda = torch.Tensor.is_xpu
35
+ torch.nn.Module.cuda = torch.nn.Module.xpu
36
+ torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
37
+ torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
38
+ torch.cuda._initialized = torch.xpu.lazy_init._initialized
39
+ torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
40
+ torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
41
+ torch.cuda._tls = torch.xpu.lazy_init._tls
42
+ torch.cuda.threading = torch.xpu.lazy_init.threading
43
+ torch.cuda.traceback = torch.xpu.lazy_init.traceback
44
+ torch.cuda.Optional = torch.xpu.Optional
45
+ torch.cuda.__cached__ = torch.xpu.__cached__
46
+ torch.cuda.__loader__ = torch.xpu.__loader__
47
+ torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
48
+ torch.cuda.Tuple = torch.xpu.Tuple
49
+ torch.cuda.streams = torch.xpu.streams
50
+ torch.cuda._lazy_new = torch.xpu._lazy_new
51
+ torch.cuda.FloatStorage = torch.xpu.FloatStorage
52
+ torch.cuda.Any = torch.xpu.Any
53
+ torch.cuda.__doc__ = torch.xpu.__doc__
54
+ torch.cuda.default_generators = torch.xpu.default_generators
55
+ torch.cuda.HalfTensor = torch.xpu.HalfTensor
56
+ torch.cuda._get_device_index = torch.xpu._get_device_index
57
+ torch.cuda.__path__ = torch.xpu.__path__
58
+ torch.cuda.Device = torch.xpu.Device
59
+ torch.cuda.IntTensor = torch.xpu.IntTensor
60
+ torch.cuda.ByteStorage = torch.xpu.ByteStorage
61
+ torch.cuda.set_stream = torch.xpu.set_stream
62
+ torch.cuda.BoolStorage = torch.xpu.BoolStorage
63
+ torch.cuda.os = torch.xpu.os
64
+ torch.cuda.torch = torch.xpu.torch
65
+ torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
66
+ torch.cuda.Union = torch.xpu.Union
67
+ torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
68
+ torch.cuda.ShortTensor = torch.xpu.ShortTensor
69
+ torch.cuda.LongTensor = torch.xpu.LongTensor
70
+ torch.cuda.IntStorage = torch.xpu.IntStorage
71
+ torch.cuda.LongStorage = torch.xpu.LongStorage
72
+ torch.cuda.__annotations__ = torch.xpu.__annotations__
73
+ torch.cuda.__package__ = torch.xpu.__package__
74
+ torch.cuda.__builtins__ = torch.xpu.__builtins__
75
+ torch.cuda.CharTensor = torch.xpu.CharTensor
76
+ torch.cuda.List = torch.xpu.List
77
+ torch.cuda._lazy_init = torch.xpu._lazy_init
78
+ torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
79
+ torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
80
+ torch.cuda.ByteTensor = torch.xpu.ByteTensor
81
+ torch.cuda.StreamContext = torch.xpu.StreamContext
82
+ torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
83
+ torch.cuda.ShortStorage = torch.xpu.ShortStorage
84
+ torch.cuda._lazy_call = torch.xpu._lazy_call
85
+ torch.cuda.HalfStorage = torch.xpu.HalfStorage
86
+ torch.cuda.random = torch.xpu.random
87
+ torch.cuda._device = torch.xpu._device
88
+ torch.cuda.classproperty = torch.xpu.classproperty
89
+ torch.cuda.__name__ = torch.xpu.__name__
90
+ torch.cuda._device_t = torch.xpu._device_t
91
+ torch.cuda.warnings = torch.xpu.warnings
92
+ torch.cuda.__spec__ = torch.xpu.__spec__
93
+ torch.cuda.BoolTensor = torch.xpu.BoolTensor
94
+ torch.cuda.CharStorage = torch.xpu.CharStorage
95
+ torch.cuda.__file__ = torch.xpu.__file__
96
+ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
97
+ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
98
+
99
+ # Memory:
100
+ torch.cuda.memory = torch.xpu.memory
101
+ if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
102
+ torch.xpu.empty_cache = lambda: None
103
+ torch.cuda.empty_cache = torch.xpu.empty_cache
104
+ torch.cuda.memory_stats = torch.xpu.memory_stats
105
+ torch.cuda.memory_summary = torch.xpu.memory_summary
106
+ torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
107
+ torch.cuda.memory_allocated = torch.xpu.memory_allocated
108
+ torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
109
+ torch.cuda.memory_reserved = torch.xpu.memory_reserved
110
+ torch.cuda.memory_cached = torch.xpu.memory_reserved
111
+ torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
112
+ torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
113
+ torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
114
+ torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
115
+ torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
116
+ torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
117
+ torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
118
+
119
+ # RNG:
120
+ torch.cuda.get_rng_state = torch.xpu.get_rng_state
121
+ torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
122
+ torch.cuda.set_rng_state = torch.xpu.set_rng_state
123
+ torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
124
+ torch.cuda.manual_seed = torch.xpu.manual_seed
125
+ torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
126
+ torch.cuda.seed = torch.xpu.seed
127
+ torch.cuda.seed_all = torch.xpu.seed_all
128
+ torch.cuda.initial_seed = torch.xpu.initial_seed
129
+
130
+ # AMP:
131
+ torch.cuda.amp = torch.xpu.amp
132
+ torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
133
+ torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
134
+
135
+ if not hasattr(torch.cuda.amp, "common"):
136
+ torch.cuda.amp.common = contextlib.nullcontext()
137
+ torch.cuda.amp.common.amp_definitely_not_available = lambda: False
138
+
139
+ try:
140
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
141
+ except Exception: # pylint: disable=broad-exception-caught
142
+ try:
143
+ from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
144
+ gradscaler_init()
145
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
146
+ except Exception: # pylint: disable=broad-exception-caught
147
+ torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
148
+
149
+ # C
150
+ torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
151
+ ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
152
+ ipex._C._DeviceProperties.major = 2024
153
+ ipex._C._DeviceProperties.minor = 0
154
+
155
+ # Fix functions with ipex:
156
+ torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
157
+ torch._utils._get_available_device_type = lambda: "xpu"
158
+ torch.has_cuda = True
159
+ torch.cuda.has_half = True
160
+ torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
161
+ torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
162
+ torch.backends.cuda.is_built = lambda *args, **kwargs: True
163
+ torch.version.cuda = "12.1"
164
+ torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
165
+ torch.cuda.get_device_properties.major = 12
166
+ torch.cuda.get_device_properties.minor = 1
167
+ torch.cuda.ipc_collect = lambda *args, **kwargs: None
168
+ torch.cuda.utilization = lambda *args, **kwargs: 0
169
+
170
+ ipex_hijacks()
171
+ if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
172
+ try:
173
+ from .diffusers import ipex_diffusers
174
+ ipex_diffusers()
175
+ except Exception: # pylint: disable=broad-exception-caught
176
+ pass
177
+ torch.cuda.is_xpu_hijacked = True
178
+ except Exception as e:
179
+ return False, e
180
+ return True, None
library/ipex/attention.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ from functools import cache
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
9
+
10
+ sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
11
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
12
+
13
+ # Find something divisible with the input_tokens
14
+ @cache
15
+ def find_slice_size(slice_size, slice_block_size):
16
+ while (slice_size * slice_block_size) > attention_slice_rate:
17
+ slice_size = slice_size // 2
18
+ if slice_size <= 1:
19
+ slice_size = 1
20
+ break
21
+ return slice_size
22
+
23
+ # Find slice sizes for SDPA
24
+ @cache
25
+ def find_sdpa_slice_sizes(query_shape, query_element_size):
26
+ if len(query_shape) == 3:
27
+ batch_size_attention, query_tokens, shape_three = query_shape
28
+ shape_four = 1
29
+ else:
30
+ batch_size_attention, query_tokens, shape_three, shape_four = query_shape
31
+
32
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
33
+ block_size = batch_size_attention * slice_block_size
34
+
35
+ split_slice_size = batch_size_attention
36
+ split_2_slice_size = query_tokens
37
+ split_3_slice_size = shape_three
38
+
39
+ do_split = False
40
+ do_split_2 = False
41
+ do_split_3 = False
42
+
43
+ if block_size > sdpa_slice_trigger_rate:
44
+ do_split = True
45
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
46
+ if split_slice_size * slice_block_size > attention_slice_rate:
47
+ slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
48
+ do_split_2 = True
49
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
50
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
51
+ slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
52
+ do_split_3 = True
53
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
54
+
55
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
56
+
57
+ # Find slice sizes for BMM
58
+ @cache
59
+ def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
60
+ batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
61
+ slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
62
+ block_size = batch_size_attention * slice_block_size
63
+
64
+ split_slice_size = batch_size_attention
65
+ split_2_slice_size = input_tokens
66
+ split_3_slice_size = mat2_atten_shape
67
+
68
+ do_split = False
69
+ do_split_2 = False
70
+ do_split_3 = False
71
+
72
+ if block_size > attention_slice_rate:
73
+ do_split = True
74
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
75
+ if split_slice_size * slice_block_size > attention_slice_rate:
76
+ slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
77
+ do_split_2 = True
78
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
79
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
80
+ slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
81
+ do_split_3 = True
82
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
83
+
84
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
85
+
86
+
87
+ original_torch_bmm = torch.bmm
88
+ def torch_bmm_32_bit(input, mat2, *, out=None):
89
+ if input.device.type != "xpu":
90
+ return original_torch_bmm(input, mat2, out=out)
91
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
92
+
93
+ # Slice BMM
94
+ if do_split:
95
+ batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
96
+ hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
97
+ for i in range(batch_size_attention // split_slice_size):
98
+ start_idx = i * split_slice_size
99
+ end_idx = (i + 1) * split_slice_size
100
+ if do_split_2:
101
+ for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
102
+ start_idx_2 = i2 * split_2_slice_size
103
+ end_idx_2 = (i2 + 1) * split_2_slice_size
104
+ if do_split_3:
105
+ for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
106
+ start_idx_3 = i3 * split_3_slice_size
107
+ end_idx_3 = (i3 + 1) * split_3_slice_size
108
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
109
+ input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
110
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
111
+ out=out
112
+ )
113
+ else:
114
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
115
+ input[start_idx:end_idx, start_idx_2:end_idx_2],
116
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2],
117
+ out=out
118
+ )
119
+ else:
120
+ hidden_states[start_idx:end_idx] = original_torch_bmm(
121
+ input[start_idx:end_idx],
122
+ mat2[start_idx:end_idx],
123
+ out=out
124
+ )
125
+ torch.xpu.synchronize(input.device)
126
+ else:
127
+ return original_torch_bmm(input, mat2, out=out)
128
+ return hidden_states
129
+
130
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
131
+ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
132
+ if query.device.type != "xpu":
133
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
134
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
135
+
136
+ # Slice SDPA
137
+ if do_split:
138
+ batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
139
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
140
+ for i in range(batch_size_attention // split_slice_size):
141
+ start_idx = i * split_slice_size
142
+ end_idx = (i + 1) * split_slice_size
143
+ if do_split_2:
144
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
145
+ start_idx_2 = i2 * split_2_slice_size
146
+ end_idx_2 = (i2 + 1) * split_2_slice_size
147
+ if do_split_3:
148
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
149
+ start_idx_3 = i3 * split_3_slice_size
150
+ end_idx_3 = (i3 + 1) * split_3_slice_size
151
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
152
+ query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
153
+ key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
154
+ value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
155
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
156
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
157
+ )
158
+ else:
159
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
160
+ query[start_idx:end_idx, start_idx_2:end_idx_2],
161
+ key[start_idx:end_idx, start_idx_2:end_idx_2],
162
+ value[start_idx:end_idx, start_idx_2:end_idx_2],
163
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
164
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
165
+ )
166
+ else:
167
+ hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
168
+ query[start_idx:end_idx],
169
+ key[start_idx:end_idx],
170
+ value[start_idx:end_idx],
171
+ attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
172
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
173
+ )
174
+ torch.xpu.synchronize(query.device)
175
+ else:
176
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
177
+ return hidden_states
library/ipex/diffusers.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import diffusers #0.24.0 # pylint: disable=import-error
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.utils import USE_PEFT_BACKEND
7
+ from functools import cache
8
+
9
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
10
+
11
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
12
+
13
+ @cache
14
+ def find_slice_size(slice_size, slice_block_size):
15
+ while (slice_size * slice_block_size) > attention_slice_rate:
16
+ slice_size = slice_size // 2
17
+ if slice_size <= 1:
18
+ slice_size = 1
19
+ break
20
+ return slice_size
21
+
22
+ @cache
23
+ def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
24
+ if len(query_shape) == 3:
25
+ batch_size_attention, query_tokens, shape_three = query_shape
26
+ shape_four = 1
27
+ else:
28
+ batch_size_attention, query_tokens, shape_three, shape_four = query_shape
29
+ if slice_size is not None:
30
+ batch_size_attention = slice_size
31
+
32
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
33
+ block_size = batch_size_attention * slice_block_size
34
+
35
+ split_slice_size = batch_size_attention
36
+ split_2_slice_size = query_tokens
37
+ split_3_slice_size = shape_three
38
+
39
+ do_split = False
40
+ do_split_2 = False
41
+ do_split_3 = False
42
+
43
+ if query_device_type != "xpu":
44
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
45
+
46
+ if block_size > attention_slice_rate:
47
+ do_split = True
48
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
49
+ if split_slice_size * slice_block_size > attention_slice_rate:
50
+ slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
51
+ do_split_2 = True
52
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
53
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
54
+ slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
55
+ do_split_3 = True
56
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
57
+
58
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
59
+
60
+ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
61
+ r"""
62
+ Processor for implementing sliced attention.
63
+
64
+ Args:
65
+ slice_size (`int`, *optional*):
66
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
67
+ `attention_head_dim` must be a multiple of the `slice_size`.
68
+ """
69
+
70
+ def __init__(self, slice_size):
71
+ self.slice_size = slice_size
72
+
73
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
74
+ encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
75
+
76
+ residual = hidden_states
77
+
78
+ input_ndim = hidden_states.ndim
79
+
80
+ if input_ndim == 4:
81
+ batch_size, channel, height, width = hidden_states.shape
82
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
83
+
84
+ batch_size, sequence_length, _ = (
85
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
86
+ )
87
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
88
+
89
+ if attn.group_norm is not None:
90
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
91
+
92
+ query = attn.to_q(hidden_states)
93
+ dim = query.shape[-1]
94
+ query = attn.head_to_batch_dim(query)
95
+
96
+ if encoder_hidden_states is None:
97
+ encoder_hidden_states = hidden_states
98
+ elif attn.norm_cross:
99
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
100
+
101
+ key = attn.to_k(encoder_hidden_states)
102
+ value = attn.to_v(encoder_hidden_states)
103
+ key = attn.head_to_batch_dim(key)
104
+ value = attn.head_to_batch_dim(value)
105
+
106
+ batch_size_attention, query_tokens, shape_three = query.shape
107
+ hidden_states = torch.zeros(
108
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
109
+ )
110
+
111
+ ####################################################################
112
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
113
+ _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
114
+
115
+ for i in range(batch_size_attention // split_slice_size):
116
+ start_idx = i * split_slice_size
117
+ end_idx = (i + 1) * split_slice_size
118
+ if do_split_2:
119
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
120
+ start_idx_2 = i2 * split_2_slice_size
121
+ end_idx_2 = (i2 + 1) * split_2_slice_size
122
+ if do_split_3:
123
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
124
+ start_idx_3 = i3 * split_3_slice_size
125
+ end_idx_3 = (i3 + 1) * split_3_slice_size
126
+
127
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
128
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
129
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
130
+
131
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
132
+ del query_slice
133
+ del key_slice
134
+ del attn_mask_slice
135
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
136
+
137
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
138
+ del attn_slice
139
+ else:
140
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
141
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
142
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
143
+
144
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
145
+ del query_slice
146
+ del key_slice
147
+ del attn_mask_slice
148
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
149
+
150
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
151
+ del attn_slice
152
+ torch.xpu.synchronize(query.device)
153
+ else:
154
+ query_slice = query[start_idx:end_idx]
155
+ key_slice = key[start_idx:end_idx]
156
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
157
+
158
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
159
+ del query_slice
160
+ del key_slice
161
+ del attn_mask_slice
162
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
163
+
164
+ hidden_states[start_idx:end_idx] = attn_slice
165
+ del attn_slice
166
+ ####################################################################
167
+
168
+ hidden_states = attn.batch_to_head_dim(hidden_states)
169
+
170
+ # linear proj
171
+ hidden_states = attn.to_out[0](hidden_states)
172
+ # dropout
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+
175
+ if input_ndim == 4:
176
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
177
+
178
+ if attn.residual_connection:
179
+ hidden_states = hidden_states + residual
180
+
181
+ hidden_states = hidden_states / attn.rescale_output_factor
182
+
183
+ return hidden_states
184
+
185
+
186
+ class AttnProcessor:
187
+ r"""
188
+ Default processor for performing attention-related computations.
189
+ """
190
+
191
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
192
+ encoder_hidden_states=None, attention_mask=None,
193
+ temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
194
+
195
+ residual = hidden_states
196
+
197
+ args = () if USE_PEFT_BACKEND else (scale,)
198
+
199
+ if attn.spatial_norm is not None:
200
+ hidden_states = attn.spatial_norm(hidden_states, temb)
201
+
202
+ input_ndim = hidden_states.ndim
203
+
204
+ if input_ndim == 4:
205
+ batch_size, channel, height, width = hidden_states.shape
206
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
207
+
208
+ batch_size, sequence_length, _ = (
209
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
210
+ )
211
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
212
+
213
+ if attn.group_norm is not None:
214
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
215
+
216
+ query = attn.to_q(hidden_states, *args)
217
+
218
+ if encoder_hidden_states is None:
219
+ encoder_hidden_states = hidden_states
220
+ elif attn.norm_cross:
221
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
222
+
223
+ key = attn.to_k(encoder_hidden_states, *args)
224
+ value = attn.to_v(encoder_hidden_states, *args)
225
+
226
+ query = attn.head_to_batch_dim(query)
227
+ key = attn.head_to_batch_dim(key)
228
+ value = attn.head_to_batch_dim(value)
229
+
230
+ ####################################################################
231
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
232
+ batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
233
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
234
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
235
+
236
+ if do_split:
237
+ for i in range(batch_size_attention // split_slice_size):
238
+ start_idx = i * split_slice_size
239
+ end_idx = (i + 1) * split_slice_size
240
+ if do_split_2:
241
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
242
+ start_idx_2 = i2 * split_2_slice_size
243
+ end_idx_2 = (i2 + 1) * split_2_slice_size
244
+ if do_split_3:
245
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
246
+ start_idx_3 = i3 * split_3_slice_size
247
+ end_idx_3 = (i3 + 1) * split_3_slice_size
248
+
249
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
250
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
251
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
252
+
253
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
254
+ del query_slice
255
+ del key_slice
256
+ del attn_mask_slice
257
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
258
+
259
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
260
+ del attn_slice
261
+ else:
262
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
263
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
264
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
265
+
266
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
267
+ del query_slice
268
+ del key_slice
269
+ del attn_mask_slice
270
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
271
+
272
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
273
+ del attn_slice
274
+ else:
275
+ query_slice = query[start_idx:end_idx]
276
+ key_slice = key[start_idx:end_idx]
277
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
278
+
279
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
280
+ del query_slice
281
+ del key_slice
282
+ del attn_mask_slice
283
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
284
+
285
+ hidden_states[start_idx:end_idx] = attn_slice
286
+ del attn_slice
287
+ torch.xpu.synchronize(query.device)
288
+ else:
289
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
290
+ hidden_states = torch.bmm(attention_probs, value)
291
+ ####################################################################
292
+ hidden_states = attn.batch_to_head_dim(hidden_states)
293
+
294
+ # linear proj
295
+ hidden_states = attn.to_out[0](hidden_states, *args)
296
+ # dropout
297
+ hidden_states = attn.to_out[1](hidden_states)
298
+
299
+ if input_ndim == 4:
300
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
301
+
302
+ if attn.residual_connection:
303
+ hidden_states = hidden_states + residual
304
+
305
+ hidden_states = hidden_states / attn.rescale_output_factor
306
+
307
+ return hidden_states
308
+
309
+ def ipex_diffusers():
310
+ #ARC GPUs can't allocate more than 4GB to a single block:
311
+ diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
312
+ diffusers.models.attention_processor.AttnProcessor = AttnProcessor
library/ipex/gradscaler.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+ OptState = ipex.cpu.autocast._grad_scaler.OptState
10
+ _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
11
+ _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
12
+
13
+ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
14
+ per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
15
+ per_device_found_inf = _MultiDeviceReplicator(found_inf)
16
+
17
+ # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
18
+ # There could be hundreds of grads, so we'd like to iterate through them just once.
19
+ # However, we don't know their devices or dtypes in advance.
20
+
21
+ # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
22
+ # Google says mypy struggles with defaultdicts type annotations.
23
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
24
+ # sync grad to master weight
25
+ if hasattr(optimizer, "sync_grad"):
26
+ optimizer.sync_grad()
27
+ with torch.no_grad():
28
+ for group in optimizer.param_groups:
29
+ for param in group["params"]:
30
+ if param.grad is None:
31
+ continue
32
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
33
+ raise ValueError("Attempting to unscale FP16 gradients.")
34
+ if param.grad.is_sparse:
35
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
36
+ # coalesce() deduplicates indices and adds all values that have the same index.
37
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
38
+ # so we should check the coalesced _values().
39
+ if param.grad.dtype is torch.float16:
40
+ param.grad = param.grad.coalesce()
41
+ to_unscale = param.grad._values()
42
+ else:
43
+ to_unscale = param.grad
44
+
45
+ # -: is there a way to split by device and dtype without appending in the inner loop?
46
+ to_unscale = to_unscale.to("cpu")
47
+ per_device_and_dtype_grads[to_unscale.device][
48
+ to_unscale.dtype
49
+ ].append(to_unscale)
50
+
51
+ for _, per_dtype_grads in per_device_and_dtype_grads.items():
52
+ for grads in per_dtype_grads.values():
53
+ core._amp_foreach_non_finite_check_and_unscale_(
54
+ grads,
55
+ per_device_found_inf.get("cpu"),
56
+ per_device_inv_scale.get("cpu"),
57
+ )
58
+
59
+ return per_device_found_inf._per_device_tensors
60
+
61
+ def unscale_(self, optimizer):
62
+ """
63
+ Divides ("unscales") the optimizer's gradient tensors by the scale factor.
64
+ :meth:`unscale_` is optional, serving cases where you need to
65
+ :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
66
+ between the backward pass(es) and :meth:`step`.
67
+ If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
68
+ Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
69
+ ...
70
+ scaler.scale(loss).backward()
71
+ scaler.unscale_(optimizer)
72
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
73
+ scaler.step(optimizer)
74
+ scaler.update()
75
+ Args:
76
+ optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
77
+ .. warning::
78
+ :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
79
+ and only after all gradients for that optimizer's assigned parameters have been accumulated.
80
+ Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
81
+ .. warning::
82
+ :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
83
+ """
84
+ if not self._enabled:
85
+ return
86
+
87
+ self._check_scale_growth_tracker("unscale_")
88
+
89
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
90
+
91
+ if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
92
+ raise RuntimeError(
93
+ "unscale_() has already been called on this optimizer since the last update()."
94
+ )
95
+ elif optimizer_state["stage"] is OptState.STEPPED:
96
+ raise RuntimeError("unscale_() is being called after step().")
97
+
98
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
99
+ assert self._scale is not None
100
+ if device_supports_fp64:
101
+ inv_scale = self._scale.double().reciprocal().float()
102
+ else:
103
+ inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
104
+ found_inf = torch.full(
105
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
106
+ )
107
+
108
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
109
+ optimizer, inv_scale, found_inf, False
110
+ )
111
+ optimizer_state["stage"] = OptState.UNSCALED
112
+
113
+ def update(self, new_scale=None):
114
+ """
115
+ Updates the scale factor.
116
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
117
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
118
+ the scale is multiplied by ``growth_factor`` to increase it.
119
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
120
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
121
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
122
+ affect the scale GradScaler uses internally.)
123
+ Args:
124
+ new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
125
+ .. warning::
126
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
127
+ been invoked for all optimizers used this iteration.
128
+ """
129
+ if not self._enabled:
130
+ return
131
+
132
+ _scale, _growth_tracker = self._check_scale_growth_tracker("update")
133
+
134
+ if new_scale is not None:
135
+ # Accept a new user-defined scale.
136
+ if isinstance(new_scale, float):
137
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
138
+ else:
139
+ reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
140
+ assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
141
+ assert new_scale.numel() == 1, reason
142
+ assert new_scale.requires_grad is False, reason
143
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
144
+ else:
145
+ # Consume shared inf/nan data collected from optimizers to update the scale.
146
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
147
+ found_infs = [
148
+ found_inf.to(device="cpu", non_blocking=True)
149
+ for state in self._per_optimizer_states.values()
150
+ for found_inf in state["found_inf_per_device"].values()
151
+ ]
152
+
153
+ assert len(found_infs) > 0, "No inf checks were recorded prior to update."
154
+
155
+ found_inf_combined = found_infs[0]
156
+ if len(found_infs) > 1:
157
+ for i in range(1, len(found_infs)):
158
+ found_inf_combined += found_infs[i]
159
+
160
+ to_device = _scale.device
161
+ _scale = _scale.to("cpu")
162
+ _growth_tracker = _growth_tracker.to("cpu")
163
+
164
+ core._amp_update_scale_(
165
+ _scale,
166
+ _growth_tracker,
167
+ found_inf_combined,
168
+ self._growth_factor,
169
+ self._backoff_factor,
170
+ self._growth_interval,
171
+ )
172
+
173
+ _scale = _scale.to(to_device)
174
+ _growth_tracker = _growth_tracker.to(to_device)
175
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
176
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
177
+
178
+ def gradscaler_init():
179
+ torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
180
+ torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
181
+ torch.xpu.amp.GradScaler.unscale_ = unscale_
182
+ torch.xpu.amp.GradScaler.update = update
183
+ return torch.xpu.amp.GradScaler
library/ipex/hijacks.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import wraps
3
+ from contextlib import nullcontext
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ import numpy as np
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+
10
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
11
+
12
+ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
13
+ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
14
+ if isinstance(device_ids, list) and len(device_ids) > 1:
15
+ print("IPEX backend doesn't support DataParallel on multiple XPU devices")
16
+ return module.to("xpu")
17
+
18
+ def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
19
+ return nullcontext()
20
+
21
+ @property
22
+ def is_cuda(self):
23
+ return self.device.type == 'xpu' or self.device.type == 'cuda'
24
+
25
+ def check_device(device):
26
+ return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
27
+
28
+ def return_xpu(device):
29
+ return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
30
+
31
+
32
+ # Autocast
33
+ original_autocast_init = torch.amp.autocast_mode.autocast.__init__
34
+ @wraps(torch.amp.autocast_mode.autocast.__init__)
35
+ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
36
+ if device_type == "cuda":
37
+ return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
38
+ else:
39
+ return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
40
+
41
+ # Latent Antialias CPU Offload:
42
+ original_interpolate = torch.nn.functional.interpolate
43
+ @wraps(torch.nn.functional.interpolate)
44
+ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
45
+ if antialias or align_corners is not None or mode == 'bicubic':
46
+ return_device = tensor.device
47
+ return_dtype = tensor.dtype
48
+ return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
49
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
50
+ else:
51
+ return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
52
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
53
+
54
+
55
+ # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
56
+ original_from_numpy = torch.from_numpy
57
+ @wraps(torch.from_numpy)
58
+ def from_numpy(ndarray):
59
+ if ndarray.dtype == float:
60
+ return original_from_numpy(ndarray.astype('float32'))
61
+ else:
62
+ return original_from_numpy(ndarray)
63
+
64
+ original_as_tensor = torch.as_tensor
65
+ @wraps(torch.as_tensor)
66
+ def as_tensor(data, dtype=None, device=None):
67
+ if check_device(device):
68
+ device = return_xpu(device)
69
+ if isinstance(data, np.ndarray) and data.dtype == float and not (
70
+ (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
71
+ return original_as_tensor(data, dtype=torch.float32, device=device)
72
+ else:
73
+ return original_as_tensor(data, dtype=dtype, device=device)
74
+
75
+
76
+ if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
77
+ original_torch_bmm = torch.bmm
78
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
79
+ else:
80
+ # 32 bit attention workarounds for Alchemist:
81
+ try:
82
+ from .attention import torch_bmm_32_bit as original_torch_bmm
83
+ from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
84
+ except Exception: # pylint: disable=broad-exception-caught
85
+ original_torch_bmm = torch.bmm
86
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
87
+
88
+
89
+ # Data Type Errors:
90
+ @wraps(torch.bmm)
91
+ def torch_bmm(input, mat2, *, out=None):
92
+ if input.dtype != mat2.dtype:
93
+ mat2 = mat2.to(input.dtype)
94
+ return original_torch_bmm(input, mat2, out=out)
95
+
96
+ @wraps(torch.nn.functional.scaled_dot_product_attention)
97
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
98
+ if query.dtype != key.dtype:
99
+ key = key.to(dtype=query.dtype)
100
+ if query.dtype != value.dtype:
101
+ value = value.to(dtype=query.dtype)
102
+ if attn_mask is not None and query.dtype != attn_mask.dtype:
103
+ attn_mask = attn_mask.to(dtype=query.dtype)
104
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
105
+
106
+ # A1111 FP16
107
+ original_functional_group_norm = torch.nn.functional.group_norm
108
+ @wraps(torch.nn.functional.group_norm)
109
+ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
110
+ if weight is not None and input.dtype != weight.data.dtype:
111
+ input = input.to(dtype=weight.data.dtype)
112
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
113
+ bias.data = bias.data.to(dtype=weight.data.dtype)
114
+ return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
115
+
116
+ # A1111 BF16
117
+ original_functional_layer_norm = torch.nn.functional.layer_norm
118
+ @wraps(torch.nn.functional.layer_norm)
119
+ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
120
+ if weight is not None and input.dtype != weight.data.dtype:
121
+ input = input.to(dtype=weight.data.dtype)
122
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
123
+ bias.data = bias.data.to(dtype=weight.data.dtype)
124
+ return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
125
+
126
+ # Training
127
+ original_functional_linear = torch.nn.functional.linear
128
+ @wraps(torch.nn.functional.linear)
129
+ def functional_linear(input, weight, bias=None):
130
+ if input.dtype != weight.data.dtype:
131
+ input = input.to(dtype=weight.data.dtype)
132
+ if bias is not None and bias.data.dtype != weight.data.dtype:
133
+ bias.data = bias.data.to(dtype=weight.data.dtype)
134
+ return original_functional_linear(input, weight, bias=bias)
135
+
136
+ original_functional_conv2d = torch.nn.functional.conv2d
137
+ @wraps(torch.nn.functional.conv2d)
138
+ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
139
+ if input.dtype != weight.data.dtype:
140
+ input = input.to(dtype=weight.data.dtype)
141
+ if bias is not None and bias.data.dtype != weight.data.dtype:
142
+ bias.data = bias.data.to(dtype=weight.data.dtype)
143
+ return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
144
+
145
+ # A1111 Embedding BF16
146
+ original_torch_cat = torch.cat
147
+ @wraps(torch.cat)
148
+ def torch_cat(tensor, *args, **kwargs):
149
+ if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
150
+ return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
151
+ else:
152
+ return original_torch_cat(tensor, *args, **kwargs)
153
+
154
+ # SwinIR BF16:
155
+ original_functional_pad = torch.nn.functional.pad
156
+ @wraps(torch.nn.functional.pad)
157
+ def functional_pad(input, pad, mode='constant', value=None):
158
+ if mode == 'reflect' and input.dtype == torch.bfloat16:
159
+ return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
160
+ else:
161
+ return original_functional_pad(input, pad, mode=mode, value=value)
162
+
163
+
164
+ original_torch_tensor = torch.tensor
165
+ @wraps(torch.tensor)
166
+ def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
167
+ if check_device(device):
168
+ device = return_xpu(device)
169
+ if not device_supports_fp64:
170
+ if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
171
+ if dtype == torch.float64:
172
+ dtype = torch.float32
173
+ elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
174
+ dtype = torch.float32
175
+ return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
176
+
177
+ original_Tensor_to = torch.Tensor.to
178
+ @wraps(torch.Tensor.to)
179
+ def Tensor_to(self, device=None, *args, **kwargs):
180
+ if check_device(device):
181
+ return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
182
+ else:
183
+ return original_Tensor_to(self, device, *args, **kwargs)
184
+
185
+ original_Tensor_cuda = torch.Tensor.cuda
186
+ @wraps(torch.Tensor.cuda)
187
+ def Tensor_cuda(self, device=None, *args, **kwargs):
188
+ if check_device(device):
189
+ return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
190
+ else:
191
+ return original_Tensor_cuda(self, device, *args, **kwargs)
192
+
193
+ original_Tensor_pin_memory = torch.Tensor.pin_memory
194
+ @wraps(torch.Tensor.pin_memory)
195
+ def Tensor_pin_memory(self, device=None, *args, **kwargs):
196
+ if device is None:
197
+ device = "xpu"
198
+ if check_device(device):
199
+ return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
200
+ else:
201
+ return original_Tensor_pin_memory(self, device, *args, **kwargs)
202
+
203
+ original_UntypedStorage_init = torch.UntypedStorage.__init__
204
+ @wraps(torch.UntypedStorage.__init__)
205
+ def UntypedStorage_init(*args, device=None, **kwargs):
206
+ if check_device(device):
207
+ return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
208
+ else:
209
+ return original_UntypedStorage_init(*args, device=device, **kwargs)
210
+
211
+ original_UntypedStorage_cuda = torch.UntypedStorage.cuda
212
+ @wraps(torch.UntypedStorage.cuda)
213
+ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
214
+ if check_device(device):
215
+ return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
216
+ else:
217
+ return original_UntypedStorage_cuda(self, device, *args, **kwargs)
218
+
219
+ original_torch_empty = torch.empty
220
+ @wraps(torch.empty)
221
+ def torch_empty(*args, device=None, **kwargs):
222
+ if check_device(device):
223
+ return original_torch_empty(*args, device=return_xpu(device), **kwargs)
224
+ else:
225
+ return original_torch_empty(*args, device=device, **kwargs)
226
+
227
+ original_torch_randn = torch.randn
228
+ @wraps(torch.randn)
229
+ def torch_randn(*args, device=None, dtype=None, **kwargs):
230
+ if dtype == bytes:
231
+ dtype = None
232
+ if check_device(device):
233
+ return original_torch_randn(*args, device=return_xpu(device), **kwargs)
234
+ else:
235
+ return original_torch_randn(*args, device=device, **kwargs)
236
+
237
+ original_torch_ones = torch.ones
238
+ @wraps(torch.ones)
239
+ def torch_ones(*args, device=None, **kwargs):
240
+ if check_device(device):
241
+ return original_torch_ones(*args, device=return_xpu(device), **kwargs)
242
+ else:
243
+ return original_torch_ones(*args, device=device, **kwargs)
244
+
245
+ original_torch_zeros = torch.zeros
246
+ @wraps(torch.zeros)
247
+ def torch_zeros(*args, device=None, **kwargs):
248
+ if check_device(device):
249
+ return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
250
+ else:
251
+ return original_torch_zeros(*args, device=device, **kwargs)
252
+
253
+ original_torch_linspace = torch.linspace
254
+ @wraps(torch.linspace)
255
+ def torch_linspace(*args, device=None, **kwargs):
256
+ if check_device(device):
257
+ return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
258
+ else:
259
+ return original_torch_linspace(*args, device=device, **kwargs)
260
+
261
+ original_torch_Generator = torch.Generator
262
+ @wraps(torch.Generator)
263
+ def torch_Generator(device=None):
264
+ if check_device(device):
265
+ return original_torch_Generator(return_xpu(device))
266
+ else:
267
+ return original_torch_Generator(device)
268
+
269
+ original_torch_load = torch.load
270
+ @wraps(torch.load)
271
+ def torch_load(f, map_location=None, *args, **kwargs):
272
+ if map_location is None:
273
+ map_location = "xpu"
274
+ if check_device(map_location):
275
+ return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
276
+ else:
277
+ return original_torch_load(f, *args, map_location=map_location, **kwargs)
278
+
279
+
280
+ # Hijack Functions:
281
+ def ipex_hijacks():
282
+ torch.tensor = torch_tensor
283
+ torch.Tensor.to = Tensor_to
284
+ torch.Tensor.cuda = Tensor_cuda
285
+ torch.Tensor.pin_memory = Tensor_pin_memory
286
+ torch.UntypedStorage.__init__ = UntypedStorage_init
287
+ torch.UntypedStorage.cuda = UntypedStorage_cuda
288
+ torch.empty = torch_empty
289
+ torch.randn = torch_randn
290
+ torch.ones = torch_ones
291
+ torch.zeros = torch_zeros
292
+ torch.linspace = torch_linspace
293
+ torch.Generator = torch_Generator
294
+ torch.load = torch_load
295
+
296
+ torch.backends.cuda.sdp_kernel = return_null_context
297
+ torch.nn.DataParallel = DummyDataParallel
298
+ torch.UntypedStorage.is_cuda = is_cuda
299
+ torch.amp.autocast_mode.autocast.__init__ = autocast_init
300
+
301
+ torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
302
+ torch.nn.functional.group_norm = functional_group_norm
303
+ torch.nn.functional.layer_norm = functional_layer_norm
304
+ torch.nn.functional.linear = functional_linear
305
+ torch.nn.functional.conv2d = functional_conv2d
306
+ torch.nn.functional.interpolate = interpolate
307
+ torch.nn.functional.pad = functional_pad
308
+
309
+ torch.bmm = torch_bmm
310
+ torch.cat = torch_cat
311
+ if not device_supports_fp64:
312
+ torch.from_numpy = from_numpy
313
+ torch.as_tensor = as_tensor
library/lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
13
+
14
+ import diffusers
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+
20
+ try:
21
+ from diffusers.utils import PIL_INTERPOLATION
22
+ except ImportError:
23
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
24
+ PIL_INTERPOLATION = {
25
+ "linear": PIL.Image.Resampling.BILINEAR,
26
+ "bilinear": PIL.Image.Resampling.BILINEAR,
27
+ "bicubic": PIL.Image.Resampling.BICUBIC,
28
+ "lanczos": PIL.Image.Resampling.LANCZOS,
29
+ "nearest": PIL.Image.Resampling.NEAREST,
30
+ }
31
+ else:
32
+ PIL_INTERPOLATION = {
33
+ "linear": PIL.Image.LINEAR,
34
+ "bilinear": PIL.Image.BILINEAR,
35
+ "bicubic": PIL.Image.BICUBIC,
36
+ "lanczos": PIL.Image.LANCZOS,
37
+ "nearest": PIL.Image.NEAREST,
38
+ }
39
+ # ------------------------------------------------------------------------------
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ re_attention = re.compile(
44
+ r"""
45
+ \\\(|
46
+ \\\)|
47
+ \\\[|
48
+ \\]|
49
+ \\\\|
50
+ \\|
51
+ \(|
52
+ \[|
53
+ :([+-]?[.\d]+)\)|
54
+ \)|
55
+ ]|
56
+ [^\\()\[\]:]+|
57
+ :
58
+ """,
59
+ re.X,
60
+ )
61
+
62
+
63
+ def parse_prompt_attention(text):
64
+ """
65
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
66
+ Accepted tokens are:
67
+ (abc) - increases attention to abc by a multiplier of 1.1
68
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
69
+ [abc] - decreases attention to abc by a multiplier of 1.1
70
+ \( - literal character '('
71
+ \[ - literal character '['
72
+ \) - literal character ')'
73
+ \] - literal character ']'
74
+ \\ - literal character '\'
75
+ anything else - just text
76
+ >>> parse_prompt_attention('normal text')
77
+ [['normal text', 1.0]]
78
+ >>> parse_prompt_attention('an (important) word')
79
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
80
+ >>> parse_prompt_attention('(unbalanced')
81
+ [['unbalanced', 1.1]]
82
+ >>> parse_prompt_attention('\(literal\]')
83
+ [['(literal]', 1.0]]
84
+ >>> parse_prompt_attention('(unnecessary)(parens)')
85
+ [['unnecessaryparens', 1.1]]
86
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
87
+ [['a ', 1.0],
88
+ ['house', 1.5730000000000004],
89
+ [' ', 1.1],
90
+ ['on', 1.0],
91
+ [' a ', 1.1],
92
+ ['hill', 0.55],
93
+ [', sun, ', 1.1],
94
+ ['sky', 1.4641000000000006],
95
+ ['.', 1.1]]
96
+ """
97
+
98
+ res = []
99
+ round_brackets = []
100
+ square_brackets = []
101
+
102
+ round_bracket_multiplier = 1.1
103
+ square_bracket_multiplier = 1 / 1.1
104
+
105
+ def multiply_range(start_position, multiplier):
106
+ for p in range(start_position, len(res)):
107
+ res[p][1] *= multiplier
108
+
109
+ for m in re_attention.finditer(text):
110
+ text = m.group(0)
111
+ weight = m.group(1)
112
+
113
+ if text.startswith("\\"):
114
+ res.append([text[1:], 1.0])
115
+ elif text == "(":
116
+ round_brackets.append(len(res))
117
+ elif text == "[":
118
+ square_brackets.append(len(res))
119
+ elif weight is not None and len(round_brackets) > 0:
120
+ multiply_range(round_brackets.pop(), float(weight))
121
+ elif text == ")" and len(round_brackets) > 0:
122
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
123
+ elif text == "]" and len(square_brackets) > 0:
124
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
125
+ else:
126
+ res.append([text, 1.0])
127
+
128
+ for pos in round_brackets:
129
+ multiply_range(pos, round_bracket_multiplier)
130
+
131
+ for pos in square_brackets:
132
+ multiply_range(pos, square_bracket_multiplier)
133
+
134
+ if len(res) == 0:
135
+ res = [["", 1.0]]
136
+
137
+ # merge runs of identical weights
138
+ i = 0
139
+ while i + 1 < len(res):
140
+ if res[i][1] == res[i + 1][1]:
141
+ res[i][0] += res[i + 1][0]
142
+ res.pop(i + 1)
143
+ else:
144
+ i += 1
145
+
146
+ return res
147
+
148
+
149
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
150
+ r"""
151
+ Tokenize a list of prompts and return its tokens with weights of each token.
152
+
153
+ No padding, starting or ending token is included.
154
+ """
155
+ tokens = []
156
+ weights = []
157
+ truncated = False
158
+ for text in prompt:
159
+ texts_and_weights = parse_prompt_attention(text)
160
+ text_token = []
161
+ text_weight = []
162
+ for word, weight in texts_and_weights:
163
+ # tokenize and discard the starting and the ending token
164
+ token = pipe.tokenizer(word).input_ids[1:-1]
165
+ text_token += token
166
+ # copy the weight by length of token
167
+ text_weight += [weight] * len(token)
168
+ # stop if the text is too long (longer than truncation limit)
169
+ if len(text_token) > max_length:
170
+ truncated = True
171
+ break
172
+ # truncate
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ text_token = text_token[:max_length]
176
+ text_weight = text_weight[:max_length]
177
+ tokens.append(text_token)
178
+ weights.append(text_weight)
179
+ if truncated:
180
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
181
+ return tokens, weights
182
+
183
+
184
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
185
+ r"""
186
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
187
+ """
188
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
189
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
190
+ for i in range(len(tokens)):
191
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
192
+ if no_boseos_middle:
193
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
194
+ else:
195
+ w = []
196
+ if len(weights[i]) == 0:
197
+ w = [1.0] * weights_length
198
+ else:
199
+ for j in range(max_embeddings_multiples):
200
+ w.append(1.0) # weight for starting token in this chunk
201
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
202
+ w.append(1.0) # weight for ending token in this chunk
203
+ w += [1.0] * (weights_length - len(w))
204
+ weights[i] = w[:]
205
+
206
+ return tokens, weights
207
+
208
+
209
+ def get_unweighted_text_embeddings(
210
+ pipe: StableDiffusionPipeline,
211
+ text_input: torch.Tensor,
212
+ chunk_length: int,
213
+ clip_skip: int,
214
+ eos: int,
215
+ pad: int,
216
+ no_boseos_middle: Optional[bool] = True,
217
+ ):
218
+ """
219
+ When the length of tokens is a multiple of the capacity of the text encoder,
220
+ it should be split into chunks and sent to the text encoder individually.
221
+ """
222
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
223
+ if max_embeddings_multiples > 1:
224
+ text_embeddings = []
225
+ for i in range(max_embeddings_multiples):
226
+ # extract the i-th chunk
227
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
228
+
229
+ # cover the head and the tail by the starting and the ending tokens
230
+ text_input_chunk[:, 0] = text_input[0, 0]
231
+ if pad == eos: # v1
232
+ text_input_chunk[:, -1] = text_input[0, -1]
233
+ else: # v2
234
+ for j in range(len(text_input_chunk)):
235
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
236
+ text_input_chunk[j, -1] = eos
237
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
238
+ text_input_chunk[j, 1] = eos
239
+
240
+ if clip_skip is None or clip_skip == 1:
241
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
242
+ else:
243
+ enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
244
+ text_embedding = enc_out["hidden_states"][-clip_skip]
245
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
246
+
247
+ if no_boseos_middle:
248
+ if i == 0:
249
+ # discard the ending token
250
+ text_embedding = text_embedding[:, :-1]
251
+ elif i == max_embeddings_multiples - 1:
252
+ # discard the starting token
253
+ text_embedding = text_embedding[:, 1:]
254
+ else:
255
+ # discard both starting and ending tokens
256
+ text_embedding = text_embedding[:, 1:-1]
257
+
258
+ text_embeddings.append(text_embedding)
259
+ text_embeddings = torch.concat(text_embeddings, axis=1)
260
+ else:
261
+ if clip_skip is None or clip_skip == 1:
262
+ text_embeddings = pipe.text_encoder(text_input)[0]
263
+ else:
264
+ enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
265
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
266
+ text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
267
+ return text_embeddings
268
+
269
+
270
+ def get_weighted_text_embeddings(
271
+ pipe: StableDiffusionPipeline,
272
+ prompt: Union[str, List[str]],
273
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
274
+ max_embeddings_multiples: Optional[int] = 3,
275
+ no_boseos_middle: Optional[bool] = False,
276
+ skip_parsing: Optional[bool] = False,
277
+ skip_weighting: Optional[bool] = False,
278
+ clip_skip=None,
279
+ ):
280
+ r"""
281
+ Prompts can be assigned with local weights using brackets. For example,
282
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
283
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
284
+
285
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
286
+
287
+ Args:
288
+ pipe (`StableDiffusionPipeline`):
289
+ Pipe to provide access to the tokenizer and the text encoder.
290
+ prompt (`str` or `List[str]`):
291
+ The prompt or prompts to guide the image generation.
292
+ uncond_prompt (`str` or `List[str]`):
293
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
294
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
295
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
296
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
297
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
298
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
299
+ ending token in each of the chunk in the middle.
300
+ skip_parsing (`bool`, *optional*, defaults to `False`):
301
+ Skip the parsing of brackets.
302
+ skip_weighting (`bool`, *optional*, defaults to `False`):
303
+ Skip the weighting. When the parsing is skipped, it is forced True.
304
+ """
305
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
306
+ if isinstance(prompt, str):
307
+ prompt = [prompt]
308
+
309
+ if not skip_parsing:
310
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
311
+ if uncond_prompt is not None:
312
+ if isinstance(uncond_prompt, str):
313
+ uncond_prompt = [uncond_prompt]
314
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
315
+ else:
316
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
317
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
318
+ if uncond_prompt is not None:
319
+ if isinstance(uncond_prompt, str):
320
+ uncond_prompt = [uncond_prompt]
321
+ uncond_tokens = [
322
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
323
+ ]
324
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
325
+
326
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
327
+ max_length = max([len(token) for token in prompt_tokens])
328
+ if uncond_prompt is not None:
329
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
330
+
331
+ max_embeddings_multiples = min(
332
+ max_embeddings_multiples,
333
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
334
+ )
335
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
336
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
337
+
338
+ # pad the length of tokens and weights
339
+ bos = pipe.tokenizer.bos_token_id
340
+ eos = pipe.tokenizer.eos_token_id
341
+ pad = pipe.tokenizer.pad_token_id
342
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
343
+ prompt_tokens,
344
+ prompt_weights,
345
+ max_length,
346
+ bos,
347
+ eos,
348
+ no_boseos_middle=no_boseos_middle,
349
+ chunk_length=pipe.tokenizer.model_max_length,
350
+ )
351
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
352
+ if uncond_prompt is not None:
353
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
354
+ uncond_tokens,
355
+ uncond_weights,
356
+ max_length,
357
+ bos,
358
+ eos,
359
+ no_boseos_middle=no_boseos_middle,
360
+ chunk_length=pipe.tokenizer.model_max_length,
361
+ )
362
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
363
+
364
+ # get the embeddings
365
+ text_embeddings = get_unweighted_text_embeddings(
366
+ pipe,
367
+ prompt_tokens,
368
+ pipe.tokenizer.model_max_length,
369
+ clip_skip,
370
+ eos,
371
+ pad,
372
+ no_boseos_middle=no_boseos_middle,
373
+ )
374
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
375
+ if uncond_prompt is not None:
376
+ uncond_embeddings = get_unweighted_text_embeddings(
377
+ pipe,
378
+ uncond_tokens,
379
+ pipe.tokenizer.model_max_length,
380
+ clip_skip,
381
+ eos,
382
+ pad,
383
+ no_boseos_middle=no_boseos_middle,
384
+ )
385
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
386
+
387
+ # assign weights to the prompts and normalize in the sense of mean
388
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
389
+ if (not skip_parsing) and (not skip_weighting):
390
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
391
+ text_embeddings *= prompt_weights.unsqueeze(-1)
392
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
393
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
394
+ if uncond_prompt is not None:
395
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
396
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
397
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
398
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
399
+
400
+ if uncond_prompt is not None:
401
+ return text_embeddings, uncond_embeddings
402
+ return text_embeddings, None
403
+
404
+
405
+ def preprocess_image(image):
406
+ w, h = image.size
407
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
408
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
409
+ image = np.array(image).astype(np.float32) / 255.0
410
+ image = image[None].transpose(0, 3, 1, 2)
411
+ image = torch.from_numpy(image)
412
+ return 2.0 * image - 1.0
413
+
414
+
415
+ def preprocess_mask(mask, scale_factor=8):
416
+ mask = mask.convert("L")
417
+ w, h = mask.size
418
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
419
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
420
+ mask = np.array(mask).astype(np.float32) / 255.0
421
+ mask = np.tile(mask, (4, 1, 1))
422
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
423
+ mask = 1 - mask # repaint white, keep black
424
+ mask = torch.from_numpy(mask)
425
+ return mask
426
+
427
+
428
+ def prepare_controlnet_image(
429
+ image: PIL.Image.Image,
430
+ width: int,
431
+ height: int,
432
+ batch_size: int,
433
+ num_images_per_prompt: int,
434
+ device: torch.device,
435
+ dtype: torch.dtype,
436
+ do_classifier_free_guidance: bool = False,
437
+ guess_mode: bool = False,
438
+ ):
439
+ if not isinstance(image, torch.Tensor):
440
+ if isinstance(image, PIL.Image.Image):
441
+ image = [image]
442
+
443
+ if isinstance(image[0], PIL.Image.Image):
444
+ images = []
445
+
446
+ for image_ in image:
447
+ image_ = image_.convert("RGB")
448
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
449
+ image_ = np.array(image_)
450
+ image_ = image_[None, :]
451
+ images.append(image_)
452
+
453
+ image = images
454
+
455
+ image = np.concatenate(image, axis=0)
456
+ image = np.array(image).astype(np.float32) / 255.0
457
+ image = image.transpose(0, 3, 1, 2)
458
+ image = torch.from_numpy(image)
459
+ elif isinstance(image[0], torch.Tensor):
460
+ image = torch.cat(image, dim=0)
461
+
462
+ image_batch_size = image.shape[0]
463
+
464
+ if image_batch_size == 1:
465
+ repeat_by = batch_size
466
+ else:
467
+ # image batch size is the same as prompt batch size
468
+ repeat_by = num_images_per_prompt
469
+
470
+ image = image.repeat_interleave(repeat_by, dim=0)
471
+
472
+ image = image.to(device=device, dtype=dtype)
473
+
474
+ if do_classifier_free_guidance and not guess_mode:
475
+ image = torch.cat([image] * 2)
476
+
477
+ return image
478
+
479
+
480
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
481
+ r"""
482
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
483
+ weighting in prompt.
484
+
485
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
486
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
487
+
488
+ Args:
489
+ vae ([`AutoencoderKL`]):
490
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
491
+ text_encoder ([`CLIPTextModel`]):
492
+ Frozen text-encoder. Stable Diffusion uses the text portion of
493
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
494
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
495
+ tokenizer (`CLIPTokenizer`):
496
+ Tokenizer of class
497
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
498
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
499
+ scheduler ([`SchedulerMixin`]):
500
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
501
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
502
+ safety_checker ([`StableDiffusionSafetyChecker`]):
503
+ Classification module that estimates whether generated images could be considered offensive or harmful.
504
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
505
+ feature_extractor ([`CLIPFeatureExtractor`]):
506
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
507
+ """
508
+
509
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
510
+
511
+ def __init__(
512
+ self,
513
+ vae: AutoencoderKL,
514
+ text_encoder: CLIPTextModel,
515
+ tokenizer: CLIPTokenizer,
516
+ unet: UNet2DConditionModel,
517
+ scheduler: SchedulerMixin,
518
+ # clip_skip: int,
519
+ safety_checker: StableDiffusionSafetyChecker,
520
+ feature_extractor: CLIPFeatureExtractor,
521
+ requires_safety_checker: bool = True,
522
+ image_encoder: CLIPVisionModelWithProjection = None,
523
+ clip_skip: int = 1,
524
+ ):
525
+ super().__init__(
526
+ vae=vae,
527
+ text_encoder=text_encoder,
528
+ tokenizer=tokenizer,
529
+ unet=unet,
530
+ scheduler=scheduler,
531
+ safety_checker=safety_checker,
532
+ feature_extractor=feature_extractor,
533
+ requires_safety_checker=requires_safety_checker,
534
+ image_encoder=image_encoder,
535
+ )
536
+ self.custom_clip_skip = clip_skip
537
+ self.__init__additional__()
538
+
539
+ def __init__additional__(self):
540
+ if not hasattr(self, "vae_scale_factor"):
541
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
542
+
543
+ @property
544
+ def _execution_device(self):
545
+ r"""
546
+ Returns the device on which the pipeline's models will be executed. After calling
547
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
548
+ hooks.
549
+ """
550
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
551
+ return self.device
552
+ for module in self.unet.modules():
553
+ if (
554
+ hasattr(module, "_hf_hook")
555
+ and hasattr(module._hf_hook, "execution_device")
556
+ and module._hf_hook.execution_device is not None
557
+ ):
558
+ return torch.device(module._hf_hook.execution_device)
559
+ return self.device
560
+
561
+ def _encode_prompt(
562
+ self,
563
+ prompt,
564
+ device,
565
+ num_images_per_prompt,
566
+ do_classifier_free_guidance,
567
+ negative_prompt,
568
+ max_embeddings_multiples,
569
+ ):
570
+ r"""
571
+ Encodes the prompt into text encoder hidden states.
572
+
573
+ Args:
574
+ prompt (`str` or `list(int)`):
575
+ prompt to be encoded
576
+ device: (`torch.device`):
577
+ torch device
578
+ num_images_per_prompt (`int`):
579
+ number of images that should be generated per prompt
580
+ do_classifier_free_guidance (`bool`):
581
+ whether to use classifier free guidance or not
582
+ negative_prompt (`str` or `List[str]`):
583
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
584
+ if `guidance_scale` is less than `1`).
585
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
586
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
587
+ """
588
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
589
+
590
+ if negative_prompt is None:
591
+ negative_prompt = [""] * batch_size
592
+ elif isinstance(negative_prompt, str):
593
+ negative_prompt = [negative_prompt] * batch_size
594
+ if batch_size != len(negative_prompt):
595
+ raise ValueError(
596
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
597
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
598
+ " the batch size of `prompt`."
599
+ )
600
+
601
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
602
+ pipe=self,
603
+ prompt=prompt,
604
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
605
+ max_embeddings_multiples=max_embeddings_multiples,
606
+ clip_skip=self.custom_clip_skip,
607
+ )
608
+ bs_embed, seq_len, _ = text_embeddings.shape
609
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
610
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
611
+
612
+ if do_classifier_free_guidance:
613
+ bs_embed, seq_len, _ = uncond_embeddings.shape
614
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
615
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
616
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
617
+
618
+ return text_embeddings
619
+
620
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
621
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
622
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
623
+
624
+ if strength < 0 or strength > 1:
625
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
626
+
627
+ if height % 8 != 0 or width % 8 != 0:
628
+ logger.info(f'{height} {width}')
629
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
630
+
631
+ if (callback_steps is None) or (
632
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
633
+ ):
634
+ raise ValueError(
635
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
636
+ )
637
+
638
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
639
+ if is_text2img:
640
+ return self.scheduler.timesteps.to(device), num_inference_steps
641
+ else:
642
+ # get the original timestep using init_timestep
643
+ offset = self.scheduler.config.get("steps_offset", 0)
644
+ init_timestep = int(num_inference_steps * strength) + offset
645
+ init_timestep = min(init_timestep, num_inference_steps)
646
+
647
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
648
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
649
+ return timesteps, num_inference_steps - t_start
650
+
651
+ def run_safety_checker(self, image, device, dtype):
652
+ if self.safety_checker is not None:
653
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
654
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
655
+ else:
656
+ has_nsfw_concept = None
657
+ return image, has_nsfw_concept
658
+
659
+ def decode_latents(self, latents):
660
+ latents = 1 / 0.18215 * latents
661
+ image = self.vae.decode(latents).sample
662
+ image = (image / 2 + 0.5).clamp(0, 1)
663
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
664
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
665
+ return image
666
+
667
+ def prepare_extra_step_kwargs(self, generator, eta):
668
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
669
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
670
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
671
+ # and should be between [0, 1]
672
+
673
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
674
+ extra_step_kwargs = {}
675
+ if accepts_eta:
676
+ extra_step_kwargs["eta"] = eta
677
+
678
+ # check if the scheduler accepts generator
679
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
680
+ if accepts_generator:
681
+ extra_step_kwargs["generator"] = generator
682
+ return extra_step_kwargs
683
+
684
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
685
+ if image is None:
686
+ shape = (
687
+ batch_size,
688
+ self.unet.in_channels,
689
+ height // self.vae_scale_factor,
690
+ width // self.vae_scale_factor,
691
+ )
692
+
693
+ if latents is None:
694
+ if device.type == "mps":
695
+ # randn does not work reproducibly on mps
696
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
697
+ else:
698
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
699
+ else:
700
+ if latents.shape != shape:
701
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
702
+ latents = latents.to(device)
703
+
704
+ # scale the initial noise by the standard deviation required by the scheduler
705
+ latents = latents * self.scheduler.init_noise_sigma
706
+ return latents, None, None
707
+ else:
708
+ init_latent_dist = self.vae.encode(image).latent_dist
709
+ init_latents = init_latent_dist.sample(generator=generator)
710
+ init_latents = 0.18215 * init_latents
711
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
712
+ init_latents_orig = init_latents
713
+ shape = init_latents.shape
714
+
715
+ # add noise to latents using the timesteps
716
+ if device.type == "mps":
717
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
718
+ else:
719
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
720
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
721
+ return latents, init_latents_orig, noise
722
+
723
+ @torch.no_grad()
724
+ def __call__(
725
+ self,
726
+ prompt: Union[str, List[str]],
727
+ negative_prompt: Optional[Union[str, List[str]]] = None,
728
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
729
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
730
+ height: int = 512,
731
+ width: int = 512,
732
+ num_inference_steps: int = 50,
733
+ guidance_scale: float = 7.5,
734
+ strength: float = 0.8,
735
+ num_images_per_prompt: Optional[int] = 1,
736
+ eta: float = 0.0,
737
+ generator: Optional[torch.Generator] = None,
738
+ latents: Optional[torch.FloatTensor] = None,
739
+ max_embeddings_multiples: Optional[int] = 3,
740
+ output_type: Optional[str] = "pil",
741
+ return_dict: bool = True,
742
+ controlnet=None,
743
+ controlnet_image=None,
744
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
745
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
746
+ callback_steps: int = 1,
747
+ ):
748
+ r"""
749
+ Function invoked when calling the pipeline for generation.
750
+
751
+ Args:
752
+ prompt (`str` or `List[str]`):
753
+ The prompt or prompts to guide the image generation.
754
+ negative_prompt (`str` or `List[str]`, *optional*):
755
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
756
+ if `guidance_scale` is less than `1`).
757
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
758
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
759
+ process.
760
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
761
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
762
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
763
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
764
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
765
+ height (`int`, *optional*, defaults to 512):
766
+ The height in pixels of the generated image.
767
+ width (`int`, *optional*, defaults to 512):
768
+ The width in pixels of the generated image.
769
+ num_inference_steps (`int`, *optional*, defaults to 50):
770
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
771
+ expense of slower inference.
772
+ guidance_scale (`float`, *optional*, defaults to 7.5):
773
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
774
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
775
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
776
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
777
+ usually at the expense of lower image quality.
778
+ strength (`float`, *optional*, defaults to 0.8):
779
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
780
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
781
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
782
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
783
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
784
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
785
+ The number of images to generate per prompt.
786
+ eta (`float`, *optional*, defaults to 0.0):
787
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
788
+ [`schedulers.DDIMScheduler`], will be ignored for others.
789
+ generator (`torch.Generator`, *optional*):
790
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
791
+ deterministic.
792
+ latents (`torch.FloatTensor`, *optional*):
793
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
794
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
795
+ tensor will ge generated by sampling using the supplied random `generator`.
796
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
797
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
798
+ output_type (`str`, *optional*, defaults to `"pil"`):
799
+ The output format of the generate image. Choose between
800
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
801
+ return_dict (`bool`, *optional*, defaults to `True`):
802
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
803
+ plain tuple.
804
+ controlnet (`diffusers.ControlNetModel`, *optional*):
805
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
806
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
807
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
808
+ inference.
809
+ callback (`Callable`, *optional*):
810
+ A function that will be called every `callback_steps` steps during inference. The function will be
811
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
812
+ is_cancelled_callback (`Callable`, *optional*):
813
+ A function that will be called every `callback_steps` steps during inference. If the function returns
814
+ `True`, the inference will be cancelled.
815
+ callback_steps (`int`, *optional*, defaults to 1):
816
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
817
+ called at every step.
818
+
819
+ Returns:
820
+ `None` if cancelled by `is_cancelled_callback`,
821
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
822
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
823
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
824
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
825
+ (nsfw) content, according to the `safety_checker`.
826
+ """
827
+ if controlnet is not None and controlnet_image is None:
828
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
829
+
830
+ # 0. Default height and width to unet
831
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
832
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
833
+
834
+ # 1. Check inputs. Raise error if not correct
835
+ self.check_inputs(prompt, height, width, strength, callback_steps)
836
+
837
+ # 2. Define call parameters
838
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
839
+ device = self._execution_device
840
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
841
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
842
+ # corresponds to doing no classifier free guidance.
843
+ do_classifier_free_guidance = guidance_scale > 1.0
844
+
845
+ # 3. Encode input prompt
846
+ text_embeddings = self._encode_prompt(
847
+ prompt,
848
+ device,
849
+ num_images_per_prompt,
850
+ do_classifier_free_guidance,
851
+ negative_prompt,
852
+ max_embeddings_multiples,
853
+ )
854
+ dtype = text_embeddings.dtype
855
+
856
+ # 4. Preprocess image and mask
857
+ if isinstance(image, PIL.Image.Image):
858
+ image = preprocess_image(image)
859
+ if image is not None:
860
+ image = image.to(device=self.device, dtype=dtype)
861
+ if isinstance(mask_image, PIL.Image.Image):
862
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
863
+ if mask_image is not None:
864
+ mask = mask_image.to(device=self.device, dtype=dtype)
865
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
866
+ else:
867
+ mask = None
868
+
869
+ if controlnet_image is not None:
870
+ controlnet_image = prepare_controlnet_image(
871
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
872
+ )
873
+
874
+ # 5. set timesteps
875
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
876
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
877
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
878
+
879
+ # 6. Prepare latent variables
880
+ latents, init_latents_orig, noise = self.prepare_latents(
881
+ image,
882
+ latent_timestep,
883
+ batch_size * num_images_per_prompt,
884
+ height,
885
+ width,
886
+ dtype,
887
+ device,
888
+ generator,
889
+ latents,
890
+ )
891
+
892
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
893
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
894
+
895
+ # 8. Denoising loop
896
+ for i, t in enumerate(self.progress_bar(timesteps)):
897
+ # expand the latents if we are doing classifier free guidance
898
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
899
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
900
+
901
+ unet_additional_args = {}
902
+ if controlnet is not None:
903
+ down_block_res_samples, mid_block_res_sample = controlnet(
904
+ latent_model_input,
905
+ t,
906
+ encoder_hidden_states=text_embeddings,
907
+ controlnet_cond=controlnet_image,
908
+ conditioning_scale=1.0,
909
+ guess_mode=False,
910
+ return_dict=False,
911
+ )
912
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
913
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
914
+
915
+ # predict the noise residual
916
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
917
+
918
+ # perform guidance
919
+ if do_classifier_free_guidance:
920
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
921
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
922
+
923
+ # compute the previous noisy sample x_t -> x_t-1
924
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
925
+
926
+ if mask is not None:
927
+ # masking
928
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
929
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
930
+
931
+ # call the callback, if provided
932
+ if i % callback_steps == 0:
933
+ if callback is not None:
934
+ callback(i, t, latents)
935
+ if is_cancelled_callback is not None and is_cancelled_callback():
936
+ return None
937
+
938
+ return latents
939
+
940
+ def latents_to_image(self, latents):
941
+ # 9. Post-processing
942
+ image = self.decode_latents(latents.to(self.vae.dtype))
943
+ image = self.numpy_to_pil(image)
944
+ return image
945
+
946
+ def text2img(
947
+ self,
948
+ prompt: Union[str, List[str]],
949
+ negative_prompt: Optional[Union[str, List[str]]] = None,
950
+ height: int = 512,
951
+ width: int = 512,
952
+ num_inference_steps: int = 50,
953
+ guidance_scale: float = 7.5,
954
+ num_images_per_prompt: Optional[int] = 1,
955
+ eta: float = 0.0,
956
+ generator: Optional[torch.Generator] = None,
957
+ latents: Optional[torch.FloatTensor] = None,
958
+ max_embeddings_multiples: Optional[int] = 3,
959
+ output_type: Optional[str] = "pil",
960
+ return_dict: bool = True,
961
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
962
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
963
+ callback_steps: int = 1,
964
+ ):
965
+ r"""
966
+ Function for text-to-image generation.
967
+ Args:
968
+ prompt (`str` or `List[str]`):
969
+ The prompt or prompts to guide the image generation.
970
+ negative_prompt (`str` or `List[str]`, *optional*):
971
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
972
+ if `guidance_scale` is less than `1`).
973
+ height (`int`, *optional*, defaults to 512):
974
+ The height in pixels of the generated image.
975
+ width (`int`, *optional*, defaults to 512):
976
+ The width in pixels of the generated image.
977
+ num_inference_steps (`int`, *optional*, defaults to 50):
978
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
979
+ expense of slower inference.
980
+ guidance_scale (`float`, *optional*, defaults to 7.5):
981
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
982
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
983
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
984
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
985
+ usually at the expense of lower image quality.
986
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
987
+ The number of images to generate per prompt.
988
+ eta (`float`, *optional*, defaults to 0.0):
989
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
990
+ [`schedulers.DDIMScheduler`], will be ignored for others.
991
+ generator (`torch.Generator`, *optional*):
992
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
993
+ deterministic.
994
+ latents (`torch.FloatTensor`, *optional*):
995
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
996
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
997
+ tensor will ge generated by sampling using the supplied random `generator`.
998
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
999
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1000
+ output_type (`str`, *optional*, defaults to `"pil"`):
1001
+ The output format of the generate image. Choose between
1002
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1003
+ return_dict (`bool`, *optional*, defaults to `True`):
1004
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1005
+ plain tuple.
1006
+ callback (`Callable`, *optional*):
1007
+ A function that will be called every `callback_steps` steps during inference. The function will be
1008
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1009
+ is_cancelled_callback (`Callable`, *optional*):
1010
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1011
+ `True`, the inference will be cancelled.
1012
+ callback_steps (`int`, *optional*, defaults to 1):
1013
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1014
+ called at every step.
1015
+ Returns:
1016
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1017
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1018
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1019
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1020
+ (nsfw) content, according to the `safety_checker`.
1021
+ """
1022
+ return self.__call__(
1023
+ prompt=prompt,
1024
+ negative_prompt=negative_prompt,
1025
+ height=height,
1026
+ width=width,
1027
+ num_inference_steps=num_inference_steps,
1028
+ guidance_scale=guidance_scale,
1029
+ num_images_per_prompt=num_images_per_prompt,
1030
+ eta=eta,
1031
+ generator=generator,
1032
+ latents=latents,
1033
+ max_embeddings_multiples=max_embeddings_multiples,
1034
+ output_type=output_type,
1035
+ return_dict=return_dict,
1036
+ callback=callback,
1037
+ is_cancelled_callback=is_cancelled_callback,
1038
+ callback_steps=callback_steps,
1039
+ )
1040
+
1041
+ def img2img(
1042
+ self,
1043
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1044
+ prompt: Union[str, List[str]],
1045
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1046
+ strength: float = 0.8,
1047
+ num_inference_steps: Optional[int] = 50,
1048
+ guidance_scale: Optional[float] = 7.5,
1049
+ num_images_per_prompt: Optional[int] = 1,
1050
+ eta: Optional[float] = 0.0,
1051
+ generator: Optional[torch.Generator] = None,
1052
+ max_embeddings_multiples: Optional[int] = 3,
1053
+ output_type: Optional[str] = "pil",
1054
+ return_dict: bool = True,
1055
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1056
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1057
+ callback_steps: int = 1,
1058
+ ):
1059
+ r"""
1060
+ Function for image-to-image generation.
1061
+ Args:
1062
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1063
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1064
+ process.
1065
+ prompt (`str` or `List[str]`):
1066
+ The prompt or prompts to guide the image generation.
1067
+ negative_prompt (`str` or `List[str]`, *optional*):
1068
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1069
+ if `guidance_scale` is less than `1`).
1070
+ strength (`float`, *optional*, defaults to 0.8):
1071
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1072
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1073
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1074
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1075
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1076
+ num_inference_steps (`int`, *optional*, defaults to 50):
1077
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1078
+ expense of slower inference. This parameter will be modulated by `strength`.
1079
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1080
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1081
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1082
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1083
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1084
+ usually at the expense of lower image quality.
1085
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1086
+ The number of images to generate per prompt.
1087
+ eta (`float`, *optional*, defaults to 0.0):
1088
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1089
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1090
+ generator (`torch.Generator`, *optional*):
1091
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1092
+ deterministic.
1093
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1094
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1095
+ output_type (`str`, *optional*, defaults to `"pil"`):
1096
+ The output format of the generate image. Choose between
1097
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1098
+ return_dict (`bool`, *optional*, defaults to `True`):
1099
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1100
+ plain tuple.
1101
+ callback (`Callable`, *optional*):
1102
+ A function that will be called every `callback_steps` steps during inference. The function will be
1103
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1104
+ is_cancelled_callback (`Callable`, *optional*):
1105
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1106
+ `True`, the inference will be cancelled.
1107
+ callback_steps (`int`, *optional*, defaults to 1):
1108
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1109
+ called at every step.
1110
+ Returns:
1111
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1112
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1113
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1114
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1115
+ (nsfw) content, according to the `safety_checker`.
1116
+ """
1117
+ return self.__call__(
1118
+ prompt=prompt,
1119
+ negative_prompt=negative_prompt,
1120
+ image=image,
1121
+ num_inference_steps=num_inference_steps,
1122
+ guidance_scale=guidance_scale,
1123
+ strength=strength,
1124
+ num_images_per_prompt=num_images_per_prompt,
1125
+ eta=eta,
1126
+ generator=generator,
1127
+ max_embeddings_multiples=max_embeddings_multiples,
1128
+ output_type=output_type,
1129
+ return_dict=return_dict,
1130
+ callback=callback,
1131
+ is_cancelled_callback=is_cancelled_callback,
1132
+ callback_steps=callback_steps,
1133
+ )
1134
+
1135
+ def inpaint(
1136
+ self,
1137
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1138
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1139
+ prompt: Union[str, List[str]],
1140
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1141
+ strength: float = 0.8,
1142
+ num_inference_steps: Optional[int] = 50,
1143
+ guidance_scale: Optional[float] = 7.5,
1144
+ num_images_per_prompt: Optional[int] = 1,
1145
+ eta: Optional[float] = 0.0,
1146
+ generator: Optional[torch.Generator] = None,
1147
+ max_embeddings_multiples: Optional[int] = 3,
1148
+ output_type: Optional[str] = "pil",
1149
+ return_dict: bool = True,
1150
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1151
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1152
+ callback_steps: int = 1,
1153
+ ):
1154
+ r"""
1155
+ Function for inpaint.
1156
+ Args:
1157
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1158
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1159
+ process. This is the image whose masked region will be inpainted.
1160
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1161
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1162
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1163
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1164
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1165
+ prompt (`str` or `List[str]`):
1166
+ The prompt or prompts to guide the image generation.
1167
+ negative_prompt (`str` or `List[str]`, *optional*):
1168
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1169
+ if `guidance_scale` is less than `1`).
1170
+ strength (`float`, *optional*, defaults to 0.8):
1171
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1172
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1173
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1174
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1175
+ num_inference_steps (`int`, *optional*, defaults to 50):
1176
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1177
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1178
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1179
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1180
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1181
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1182
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1183
+ usually at the expense of lower image quality.
1184
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1185
+ The number of images to generate per prompt.
1186
+ eta (`float`, *optional*, defaults to 0.0):
1187
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1188
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1189
+ generator (`torch.Generator`, *optional*):
1190
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1191
+ deterministic.
1192
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1193
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1194
+ output_type (`str`, *optional*, defaults to `"pil"`):
1195
+ The output format of the generate image. Choose between
1196
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1197
+ return_dict (`bool`, *optional*, defaults to `True`):
1198
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1199
+ plain tuple.
1200
+ callback (`Callable`, *optional*):
1201
+ A function that will be called every `callback_steps` steps during inference. The function will be
1202
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1203
+ is_cancelled_callback (`Callable`, *optional*):
1204
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1205
+ `True`, the inference will be cancelled.
1206
+ callback_steps (`int`, *optional*, defaults to 1):
1207
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1208
+ called at every step.
1209
+ Returns:
1210
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1211
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1212
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1213
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1214
+ (nsfw) content, according to the `safety_checker`.
1215
+ """
1216
+ return self.__call__(
1217
+ prompt=prompt,
1218
+ negative_prompt=negative_prompt,
1219
+ image=image,
1220
+ mask_image=mask_image,
1221
+ num_inference_steps=num_inference_steps,
1222
+ guidance_scale=guidance_scale,
1223
+ strength=strength,
1224
+ num_images_per_prompt=num_images_per_prompt,
1225
+ eta=eta,
1226
+ generator=generator,
1227
+ max_embeddings_multiples=max_embeddings_multiples,
1228
+ output_type=output_type,
1229
+ return_dict=return_dict,
1230
+ callback=callback,
1231
+ is_cancelled_callback=is_cancelled_callback,
1232
+ callback_steps=callback_steps,
1233
+ )
library/model_util.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+
7
+ import torch
8
+ from library.device_utils import init_ipex
9
+ init_ipex()
10
+
11
+ import diffusers
12
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
13
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
14
+ from safetensors.torch import load_file, save_file
15
+ from library.original_unet import UNet2DConditionModel
16
+ from library.utils import setup_logging
17
+ setup_logging()
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # DiffUsers版StableDiffusionのモデルパラメータ
22
+ NUM_TRAIN_TIMESTEPS = 1000
23
+ BETA_START = 0.00085
24
+ BETA_END = 0.0120
25
+
26
+ UNET_PARAMS_MODEL_CHANNELS = 320
27
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
28
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
29
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
30
+ UNET_PARAMS_IN_CHANNELS = 4
31
+ UNET_PARAMS_OUT_CHANNELS = 4
32
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
33
+ UNET_PARAMS_CONTEXT_DIM = 768
34
+ UNET_PARAMS_NUM_HEADS = 8
35
+ # UNET_PARAMS_USE_LINEAR_PROJECTION = False
36
+
37
+ VAE_PARAMS_Z_CHANNELS = 4
38
+ VAE_PARAMS_RESOLUTION = 256
39
+ VAE_PARAMS_IN_CHANNELS = 3
40
+ VAE_PARAMS_OUT_CH = 3
41
+ VAE_PARAMS_CH = 128
42
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
43
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
44
+
45
+ # V2
46
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
47
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
48
+ # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
49
+
50
+ # Diffusersの設定を読み込むための参照モデル
51
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
52
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
53
+
54
+
55
+ # region StableDiffusion->Diffusersの変換コード
56
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
57
+
58
+
59
+ def shave_segments(path, n_shave_prefix_segments=1):
60
+ """
61
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
62
+ """
63
+ if n_shave_prefix_segments >= 0:
64
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
65
+ else:
66
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
67
+
68
+
69
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
70
+ """
71
+ Updates paths inside resnets to the new naming scheme (local renaming)
72
+ """
73
+ mapping = []
74
+ for old_item in old_list:
75
+ new_item = old_item.replace("in_layers.0", "norm1")
76
+ new_item = new_item.replace("in_layers.2", "conv1")
77
+
78
+ new_item = new_item.replace("out_layers.0", "norm2")
79
+ new_item = new_item.replace("out_layers.3", "conv2")
80
+
81
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
82
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
83
+
84
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
85
+
86
+ mapping.append({"old": old_item, "new": new_item})
87
+
88
+ return mapping
89
+
90
+
91
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
92
+ """
93
+ Updates paths inside resnets to the new naming scheme (local renaming)
94
+ """
95
+ mapping = []
96
+ for old_item in old_list:
97
+ new_item = old_item
98
+
99
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
100
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
101
+
102
+ mapping.append({"old": old_item, "new": new_item})
103
+
104
+ return mapping
105
+
106
+
107
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
108
+ """
109
+ Updates paths inside attentions to the new naming scheme (local renaming)
110
+ """
111
+ mapping = []
112
+ for old_item in old_list:
113
+ new_item = old_item
114
+
115
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
116
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
117
+
118
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
119
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
120
+
121
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
122
+
123
+ mapping.append({"old": old_item, "new": new_item})
124
+
125
+ return mapping
126
+
127
+
128
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
129
+ """
130
+ Updates paths inside attentions to the new naming scheme (local renaming)
131
+ """
132
+ mapping = []
133
+ for old_item in old_list:
134
+ new_item = old_item
135
+
136
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
137
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
138
+
139
+ if diffusers.__version__ < "0.17.0":
140
+ new_item = new_item.replace("q.weight", "query.weight")
141
+ new_item = new_item.replace("q.bias", "query.bias")
142
+
143
+ new_item = new_item.replace("k.weight", "key.weight")
144
+ new_item = new_item.replace("k.bias", "key.bias")
145
+
146
+ new_item = new_item.replace("v.weight", "value.weight")
147
+ new_item = new_item.replace("v.bias", "value.bias")
148
+
149
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
150
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
151
+ else:
152
+ new_item = new_item.replace("q.weight", "to_q.weight")
153
+ new_item = new_item.replace("q.bias", "to_q.bias")
154
+
155
+ new_item = new_item.replace("k.weight", "to_k.weight")
156
+ new_item = new_item.replace("k.bias", "to_k.bias")
157
+
158
+ new_item = new_item.replace("v.weight", "to_v.weight")
159
+ new_item = new_item.replace("v.bias", "to_v.bias")
160
+
161
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
162
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
163
+
164
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
165
+
166
+ mapping.append({"old": old_item, "new": new_item})
167
+
168
+ return mapping
169
+
170
+
171
+ def assign_to_checkpoint(
172
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
173
+ ):
174
+ """
175
+ This does the final conversion step: take locally converted weights and apply a global renaming
176
+ to them. It splits attention layers, and takes into account additional replacements
177
+ that may arise.
178
+
179
+ Assigns the weights to the new checkpoint.
180
+ """
181
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
182
+
183
+ # Splits the attention layers into three variables.
184
+ if attention_paths_to_split is not None:
185
+ for path, path_map in attention_paths_to_split.items():
186
+ old_tensor = old_checkpoint[path]
187
+ channels = old_tensor.shape[0] // 3
188
+
189
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
190
+
191
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
192
+
193
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
194
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
195
+
196
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
197
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
198
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
199
+
200
+ for path in paths:
201
+ new_path = path["new"]
202
+
203
+ # These have already been assigned
204
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
205
+ continue
206
+
207
+ # Global renaming happens here
208
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
209
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
210
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
211
+
212
+ if additional_replacements is not None:
213
+ for replacement in additional_replacements:
214
+ new_path = new_path.replace(replacement["old"], replacement["new"])
215
+
216
+ # proj_attn.weight has to be converted from conv 1D to linear
217
+ reshaping = False
218
+ if diffusers.__version__ < "0.17.0":
219
+ if "proj_attn.weight" in new_path:
220
+ reshaping = True
221
+ else:
222
+ if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
223
+ reshaping = True
224
+
225
+ if reshaping:
226
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
227
+ else:
228
+ checkpoint[new_path] = old_checkpoint[path["old"]]
229
+
230
+
231
+ def conv_attn_to_linear(checkpoint):
232
+ keys = list(checkpoint.keys())
233
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
234
+ for key in keys:
235
+ if ".".join(key.split(".")[-2:]) in attn_keys:
236
+ if checkpoint[key].ndim > 2:
237
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
238
+ elif "proj_attn.weight" in key:
239
+ if checkpoint[key].ndim > 2:
240
+ checkpoint[key] = checkpoint[key][:, :, 0]
241
+
242
+
243
+ def linear_transformer_to_conv(checkpoint):
244
+ keys = list(checkpoint.keys())
245
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
246
+ for key in keys:
247
+ if ".".join(key.split(".")[-2:]) in tf_keys:
248
+ if checkpoint[key].ndim == 2:
249
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
250
+
251
+
252
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
253
+ """
254
+ Takes a state dict and a config, and returns a converted checkpoint.
255
+ """
256
+
257
+ # extract state_dict for UNet
258
+ unet_state_dict = {}
259
+ unet_key = "model.diffusion_model."
260
+ keys = list(checkpoint.keys())
261
+ for key in keys:
262
+ if key.startswith(unet_key):
263
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
264
+
265
+ new_checkpoint = {}
266
+
267
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
268
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
269
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
270
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
271
+
272
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
273
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
274
+
275
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
276
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
277
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
278
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
279
+
280
+ # Retrieves the keys for the input blocks only
281
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
282
+ input_blocks = {
283
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
284
+ }
285
+
286
+ # Retrieves the keys for the middle blocks only
287
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
288
+ middle_blocks = {
289
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
290
+ }
291
+
292
+ # Retrieves the keys for the output blocks only
293
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
294
+ output_blocks = {
295
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
296
+ }
297
+
298
+ for i in range(1, num_input_blocks):
299
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
300
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
301
+
302
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
303
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
304
+
305
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
306
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
307
+ f"input_blocks.{i}.0.op.weight"
308
+ )
309
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
310
+
311
+ paths = renew_resnet_paths(resnets)
312
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
313
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
314
+
315
+ if len(attentions):
316
+ paths = renew_attention_paths(attentions)
317
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
318
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
319
+
320
+ resnet_0 = middle_blocks[0]
321
+ attentions = middle_blocks[1]
322
+ resnet_1 = middle_blocks[2]
323
+
324
+ resnet_0_paths = renew_resnet_paths(resnet_0)
325
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
326
+
327
+ resnet_1_paths = renew_resnet_paths(resnet_1)
328
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
329
+
330
+ attentions_paths = renew_attention_paths(attentions)
331
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
332
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
333
+
334
+ for i in range(num_output_blocks):
335
+ block_id = i // (config["layers_per_block"] + 1)
336
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
337
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
338
+ output_block_list = {}
339
+
340
+ for layer in output_block_layers:
341
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
342
+ if layer_id in output_block_list:
343
+ output_block_list[layer_id].append(layer_name)
344
+ else:
345
+ output_block_list[layer_id] = [layer_name]
346
+
347
+ if len(output_block_list) > 1:
348
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
349
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
350
+
351
+ resnet_0_paths = renew_resnet_paths(resnets)
352
+ paths = renew_resnet_paths(resnets)
353
+
354
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
355
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
356
+
357
+ # オリジナル:
358
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
359
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
360
+
361
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
362
+ for l in output_block_list.values():
363
+ l.sort()
364
+
365
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
366
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
367
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
368
+ f"output_blocks.{i}.{index}.conv.bias"
369
+ ]
370
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
371
+ f"output_blocks.{i}.{index}.conv.weight"
372
+ ]
373
+
374
+ # Clear attentions as they have been attributed above.
375
+ if len(attentions) == 2:
376
+ attentions = []
377
+
378
+ if len(attentions):
379
+ paths = renew_attention_paths(attentions)
380
+ meta_path = {
381
+ "old": f"output_blocks.{i}.1",
382
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
383
+ }
384
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
385
+ else:
386
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
387
+ for path in resnet_0_paths:
388
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
389
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
390
+
391
+ new_checkpoint[new_path] = unet_state_dict[old_path]
392
+
393
+ # SDのv2では1*1のconv2dがlinearに変わっている
394
+ # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
395
+ if v2 and not config.get("use_linear_projection", False):
396
+ linear_transformer_to_conv(new_checkpoint)
397
+
398
+ return new_checkpoint
399
+
400
+
401
+ def convert_ldm_vae_checkpoint(checkpoint, config):
402
+ # extract state dict for VAE
403
+ vae_state_dict = {}
404
+ vae_key = "first_stage_model."
405
+ keys = list(checkpoint.keys())
406
+ for key in keys:
407
+ if key.startswith(vae_key):
408
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
409
+ # if len(vae_state_dict) == 0:
410
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
411
+ # vae_state_dict = checkpoint
412
+
413
+ new_checkpoint = {}
414
+
415
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
416
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
417
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
418
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
419
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
420
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
421
+
422
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
423
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
424
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
425
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
426
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
427
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
428
+
429
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
430
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
431
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
432
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
433
+
434
+ # Retrieves the keys for the encoder down blocks only
435
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
436
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
437
+
438
+ # Retrieves the keys for the decoder up blocks only
439
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
440
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
441
+
442
+ for i in range(num_down_blocks):
443
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
444
+
445
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
446
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
447
+ f"encoder.down.{i}.downsample.conv.weight"
448
+ )
449
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
450
+ f"encoder.down.{i}.downsample.conv.bias"
451
+ )
452
+
453
+ paths = renew_vae_resnet_paths(resnets)
454
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
455
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
456
+
457
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
458
+ num_mid_res_blocks = 2
459
+ for i in range(1, num_mid_res_blocks + 1):
460
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
461
+
462
+ paths = renew_vae_resnet_paths(resnets)
463
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
464
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
465
+
466
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
467
+ paths = renew_vae_attention_paths(mid_attentions)
468
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
469
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
470
+ conv_attn_to_linear(new_checkpoint)
471
+
472
+ for i in range(num_up_blocks):
473
+ block_id = num_up_blocks - 1 - i
474
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
475
+
476
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
477
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
478
+ f"decoder.up.{block_id}.upsample.conv.weight"
479
+ ]
480
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
481
+ f"decoder.up.{block_id}.upsample.conv.bias"
482
+ ]
483
+
484
+ paths = renew_vae_resnet_paths(resnets)
485
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
486
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
487
+
488
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
489
+ num_mid_res_blocks = 2
490
+ for i in range(1, num_mid_res_blocks + 1):
491
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
492
+
493
+ paths = renew_vae_resnet_paths(resnets)
494
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
495
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
496
+
497
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
498
+ paths = renew_vae_attention_paths(mid_attentions)
499
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
500
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
501
+ conv_attn_to_linear(new_checkpoint)
502
+ return new_checkpoint
503
+
504
+
505
+ def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
506
+ """
507
+ Creates a config for the diffusers based on the config of the LDM model.
508
+ """
509
+ # unet_params = original_config.model.params.unet_config.params
510
+
511
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
512
+
513
+ down_block_types = []
514
+ resolution = 1
515
+ for i in range(len(block_out_channels)):
516
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
517
+ down_block_types.append(block_type)
518
+ if i != len(block_out_channels) - 1:
519
+ resolution *= 2
520
+
521
+ up_block_types = []
522
+ for i in range(len(block_out_channels)):
523
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
524
+ up_block_types.append(block_type)
525
+ resolution //= 2
526
+
527
+ config = dict(
528
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
529
+ in_channels=UNET_PARAMS_IN_CHANNELS,
530
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
531
+ down_block_types=tuple(down_block_types),
532
+ up_block_types=tuple(up_block_types),
533
+ block_out_channels=tuple(block_out_channels),
534
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
535
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
536
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
537
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
538
+ )
539
+ if v2 and use_linear_projection_in_v2:
540
+ config["use_linear_projection"] = True
541
+
542
+ return config
543
+
544
+
545
+ def create_vae_diffusers_config():
546
+ """
547
+ Creates a config for the diffusers based on the config of the LDM model.
548
+ """
549
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
550
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
551
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
552
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
553
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
554
+
555
+ config = dict(
556
+ sample_size=VAE_PARAMS_RESOLUTION,
557
+ in_channels=VAE_PARAMS_IN_CHANNELS,
558
+ out_channels=VAE_PARAMS_OUT_CH,
559
+ down_block_types=tuple(down_block_types),
560
+ up_block_types=tuple(up_block_types),
561
+ block_out_channels=tuple(block_out_channels),
562
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
563
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
564
+ )
565
+ return config
566
+
567
+
568
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
569
+ keys = list(checkpoint.keys())
570
+ text_model_dict = {}
571
+ for key in keys:
572
+ if key.startswith("cond_stage_model.transformer"):
573
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
574
+
575
+ # remove position_ids for newer transformer, which causes error :(
576
+ if "text_model.embeddings.position_ids" in text_model_dict:
577
+ text_model_dict.pop("text_model.embeddings.position_ids")
578
+
579
+ return text_model_dict
580
+
581
+
582
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
583
+ # 嫌になるくらい違うぞ!
584
+ def convert_key(key):
585
+ if not key.startswith("cond_stage_model"):
586
+ return None
587
+
588
+ # common conversion
589
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
590
+ key = key.replace("cond_stage_model.model.", "text_model.")
591
+
592
+ if "resblocks" in key:
593
+ # resblocks conversion
594
+ key = key.replace(".resblocks.", ".layers.")
595
+ if ".ln_" in key:
596
+ key = key.replace(".ln_", ".layer_norm")
597
+ elif ".mlp." in key:
598
+ key = key.replace(".c_fc.", ".fc1.")
599
+ key = key.replace(".c_proj.", ".fc2.")
600
+ elif ".attn.out_proj" in key:
601
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
602
+ elif ".attn.in_proj" in key:
603
+ key = None # 特殊なので後で処理する
604
+ else:
605
+ raise ValueError(f"unexpected key in SD: {key}")
606
+ elif ".positional_embedding" in key:
607
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
608
+ elif ".text_projection" in key:
609
+ key = None # 使われない???
610
+ elif ".logit_scale" in key:
611
+ key = None # 使われない???
612
+ elif ".token_embedding" in key:
613
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
614
+ elif ".ln_final" in key:
615
+ key = key.replace(".ln_final", ".final_layer_norm")
616
+ return key
617
+
618
+ keys = list(checkpoint.keys())
619
+ new_sd = {}
620
+ for key in keys:
621
+ # remove resblocks 23
622
+ if ".resblocks.23." in key:
623
+ continue
624
+ new_key = convert_key(key)
625
+ if new_key is None:
626
+ continue
627
+ new_sd[new_key] = checkpoint[key]
628
+
629
+ # attnの変換
630
+ for key in keys:
631
+ if ".resblocks.23." in key:
632
+ continue
633
+ if ".resblocks" in key and ".attn.in_proj_" in key:
634
+ # 三つに分割
635
+ values = torch.chunk(checkpoint[key], 3)
636
+
637
+ key_suffix = ".weight" if "weight" in key else ".bias"
638
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
639
+ key_pfx = key_pfx.replace("_weight", "")
640
+ key_pfx = key_pfx.replace("_bias", "")
641
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
642
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
643
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
644
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
645
+
646
+ # rename or add position_ids
647
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
648
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
649
+ # waifu diffusion v1.4
650
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
651
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
652
+ else:
653
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
654
+
655
+ new_sd["text_model.embeddings.position_ids"] = position_ids
656
+ return new_sd
657
+
658
+
659
+ # endregion
660
+
661
+
662
+ # region Diffusers->StableDiffusion の変換コード
663
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
664
+
665
+
666
+ def conv_transformer_to_linear(checkpoint):
667
+ keys = list(checkpoint.keys())
668
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
669
+ for key in keys:
670
+ if ".".join(key.split(".")[-2:]) in tf_keys:
671
+ if checkpoint[key].ndim > 2:
672
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
673
+
674
+
675
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
676
+ unet_conversion_map = [
677
+ # (stable-diffusion, HF Diffusers)
678
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
679
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
680
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
681
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
682
+ ("input_blocks.0.0.weight", "conv_in.weight"),
683
+ ("input_blocks.0.0.bias", "conv_in.bias"),
684
+ ("out.0.weight", "conv_norm_out.weight"),
685
+ ("out.0.bias", "conv_norm_out.bias"),
686
+ ("out.2.weight", "conv_out.weight"),
687
+ ("out.2.bias", "conv_out.bias"),
688
+ ]
689
+
690
+ unet_conversion_map_resnet = [
691
+ # (stable-diffusion, HF Diffusers)
692
+ ("in_layers.0", "norm1"),
693
+ ("in_layers.2", "conv1"),
694
+ ("out_layers.0", "norm2"),
695
+ ("out_layers.3", "conv2"),
696
+ ("emb_layers.1", "time_emb_proj"),
697
+ ("skip_connection", "conv_shortcut"),
698
+ ]
699
+
700
+ unet_conversion_map_layer = []
701
+ for i in range(4):
702
+ # loop over downblocks/upblocks
703
+
704
+ for j in range(2):
705
+ # loop over resnets/attentions for downblocks
706
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
707
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
708
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
709
+
710
+ if i < 3:
711
+ # no attention layers in down_blocks.3
712
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
713
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
714
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
715
+
716
+ for j in range(3):
717
+ # loop over resnets/attentions for upblocks
718
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
719
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
720
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
721
+
722
+ if i > 0:
723
+ # no attention layers in up_blocks.0
724
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
725
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
726
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
727
+
728
+ if i < 3:
729
+ # no downsample in down_blocks.3
730
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
731
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
732
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
733
+
734
+ # no upsample in up_blocks.3
735
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
736
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
737
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
738
+
739
+ hf_mid_atn_prefix = "mid_block.attentions.0."
740
+ sd_mid_atn_prefix = "middle_block.1."
741
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
742
+
743
+ for j in range(2):
744
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
745
+ sd_mid_res_prefix = f"middle_block.{2*j}."
746
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
747
+
748
+ # buyer beware: this is a *brittle* function,
749
+ # and correct output requires that all of these pieces interact in
750
+ # the exact order in which I have arranged them.
751
+ mapping = {k: k for k in unet_state_dict.keys()}
752
+ for sd_name, hf_name in unet_conversion_map:
753
+ mapping[hf_name] = sd_name
754
+ for k, v in mapping.items():
755
+ if "resnets" in k:
756
+ for sd_part, hf_part in unet_conversion_map_resnet:
757
+ v = v.replace(hf_part, sd_part)
758
+ mapping[k] = v
759
+ for k, v in mapping.items():
760
+ for sd_part, hf_part in unet_conversion_map_layer:
761
+ v = v.replace(hf_part, sd_part)
762
+ mapping[k] = v
763
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
764
+
765
+ if v2:
766
+ conv_transformer_to_linear(new_state_dict)
767
+
768
+ return new_state_dict
769
+
770
+
771
+ def controlnet_conversion_map():
772
+ unet_conversion_map = [
773
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
774
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
775
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
776
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
777
+ ("input_blocks.0.0.weight", "conv_in.weight"),
778
+ ("input_blocks.0.0.bias", "conv_in.bias"),
779
+ ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
780
+ ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
781
+ ]
782
+
783
+ unet_conversion_map_resnet = [
784
+ ("in_layers.0", "norm1"),
785
+ ("in_layers.2", "conv1"),
786
+ ("out_layers.0", "norm2"),
787
+ ("out_layers.3", "conv2"),
788
+ ("emb_layers.1", "time_emb_proj"),
789
+ ("skip_connection", "conv_shortcut"),
790
+ ]
791
+
792
+ unet_conversion_map_layer = []
793
+ for i in range(4):
794
+ for j in range(2):
795
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
796
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
797
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
798
+
799
+ if i < 3:
800
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
801
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
802
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
803
+
804
+ if i < 3:
805
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
806
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
807
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
808
+
809
+ hf_mid_atn_prefix = "mid_block.attentions.0."
810
+ sd_mid_atn_prefix = "middle_block.1."
811
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
812
+
813
+ for j in range(2):
814
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
815
+ sd_mid_res_prefix = f"middle_block.{2*j}."
816
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
817
+
818
+ controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
819
+ for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
820
+ hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
821
+ sd_prefix = f"input_hint_block.{i*2}."
822
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
823
+
824
+ for i in range(12):
825
+ hf_prefix = f"controlnet_down_blocks.{i}."
826
+ sd_prefix = f"zero_convs.{i}.0."
827
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
828
+
829
+ return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
830
+
831
+
832
+ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
833
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
834
+
835
+ mapping = {k: k for k in controlnet_state_dict.keys()}
836
+ for sd_name, diffusers_name in unet_conversion_map:
837
+ mapping[diffusers_name] = sd_name
838
+ for k, v in mapping.items():
839
+ if "resnets" in k:
840
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
841
+ v = v.replace(diffusers_part, sd_part)
842
+ mapping[k] = v
843
+ for k, v in mapping.items():
844
+ for sd_part, diffusers_part in unet_conversion_map_layer:
845
+ v = v.replace(diffusers_part, sd_part)
846
+ mapping[k] = v
847
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
848
+ return new_state_dict
849
+
850
+
851
+ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
852
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
853
+
854
+ mapping = {k: k for k in controlnet_state_dict.keys()}
855
+ for sd_name, diffusers_name in unet_conversion_map:
856
+ mapping[sd_name] = diffusers_name
857
+ for k, v in mapping.items():
858
+ for sd_part, diffusers_part in unet_conversion_map_layer:
859
+ v = v.replace(sd_part, diffusers_part)
860
+ mapping[k] = v
861
+ for k, v in mapping.items():
862
+ if "resnets" in v:
863
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
864
+ v = v.replace(sd_part, diffusers_part)
865
+ mapping[k] = v
866
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
867
+ return new_state_dict
868
+
869
+
870
+ # ================#
871
+ # VAE Conversion #
872
+ # ================#
873
+
874
+
875
+ def reshape_weight_for_sd(w):
876
+ # convert HF linear weights to SD conv2d weights
877
+ return w.reshape(*w.shape, 1, 1)
878
+
879
+
880
+ def convert_vae_state_dict(vae_state_dict):
881
+ vae_conversion_map = [
882
+ # (stable-diffusion, HF Diffusers)
883
+ ("nin_shortcut", "conv_shortcut"),
884
+ ("norm_out", "conv_norm_out"),
885
+ ("mid.attn_1.", "mid_block.attentions.0."),
886
+ ]
887
+
888
+ for i in range(4):
889
+ # down_blocks have two resnets
890
+ for j in range(2):
891
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
892
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
893
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
894
+
895
+ if i < 3:
896
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
897
+ sd_downsample_prefix = f"down.{i}.downsample."
898
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
899
+
900
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
901
+ sd_upsample_prefix = f"up.{3-i}.upsample."
902
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
903
+
904
+ # up_blocks have three resnets
905
+ # also, up blocks in hf are numbered in reverse from sd
906
+ for j in range(3):
907
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
908
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
909
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
910
+
911
+ # this part accounts for mid blocks in both the encoder and the decoder
912
+ for i in range(2):
913
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
914
+ sd_mid_res_prefix = f"mid.block_{i+1}."
915
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
916
+
917
+ if diffusers.__version__ < "0.17.0":
918
+ vae_conversion_map_attn = [
919
+ # (stable-diffusion, HF Diffusers)
920
+ ("norm.", "group_norm."),
921
+ ("q.", "query."),
922
+ ("k.", "key."),
923
+ ("v.", "value."),
924
+ ("proj_out.", "proj_attn."),
925
+ ]
926
+ else:
927
+ vae_conversion_map_attn = [
928
+ # (stable-diffusion, HF Diffusers)
929
+ ("norm.", "group_norm."),
930
+ ("q.", "to_q."),
931
+ ("k.", "to_k."),
932
+ ("v.", "to_v."),
933
+ ("proj_out.", "to_out.0."),
934
+ ]
935
+
936
+ mapping = {k: k for k in vae_state_dict.keys()}
937
+ for k, v in mapping.items():
938
+ for sd_part, hf_part in vae_conversion_map:
939
+ v = v.replace(hf_part, sd_part)
940
+ mapping[k] = v
941
+ for k, v in mapping.items():
942
+ if "attentions" in k:
943
+ for sd_part, hf_part in vae_conversion_map_attn:
944
+ v = v.replace(hf_part, sd_part)
945
+ mapping[k] = v
946
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
947
+ weights_to_convert = ["q", "k", "v", "proj_out"]
948
+ for k, v in new_state_dict.items():
949
+ for weight_name in weights_to_convert:
950
+ if f"mid.attn_1.{weight_name}.weight" in k:
951
+ # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
952
+ new_state_dict[k] = reshape_weight_for_sd(v)
953
+
954
+ return new_state_dict
955
+
956
+
957
+ # endregion
958
+
959
+ # region 自作のモデル読み書きなど
960
+
961
+
962
+ def is_safetensors(path):
963
+ return os.path.splitext(path)[1].lower() == ".safetensors"
964
+
965
+
966
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
967
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
968
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
969
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
970
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
971
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
972
+ ]
973
+
974
+ if is_safetensors(ckpt_path):
975
+ checkpoint = None
976
+ state_dict = load_file(ckpt_path) # , device) # may causes error
977
+ else:
978
+ checkpoint = torch.load(ckpt_path, map_location=device)
979
+ if "state_dict" in checkpoint:
980
+ state_dict = checkpoint["state_dict"]
981
+ else:
982
+ state_dict = checkpoint
983
+ checkpoint = None
984
+
985
+ key_reps = []
986
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
987
+ for key in state_dict.keys():
988
+ if key.startswith(rep_from):
989
+ new_key = rep_to + key[len(rep_from) :]
990
+ key_reps.append((key, new_key))
991
+
992
+ for key, new_key in key_reps:
993
+ state_dict[new_key] = state_dict[key]
994
+ del state_dict[key]
995
+
996
+ return checkpoint, state_dict
997
+
998
+
999
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
1000
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
1001
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
1002
+
1003
+ # Convert the UNet2DConditionModel model.
1004
+ unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
1005
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
1006
+
1007
+ unet = UNet2DConditionModel(**unet_config).to(device)
1008
+ info = unet.load_state_dict(converted_unet_checkpoint)
1009
+ logger.info(f"loading u-net: {info}")
1010
+
1011
+ # Convert the VAE model.
1012
+ vae_config = create_vae_diffusers_config()
1013
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
1014
+
1015
+ vae = AutoencoderKL(**vae_config).to(device)
1016
+ info = vae.load_state_dict(converted_vae_checkpoint)
1017
+ logger.info(f"loading vae: {info}")
1018
+
1019
+ # convert text_model
1020
+ if v2:
1021
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
1022
+ cfg = CLIPTextConfig(
1023
+ vocab_size=49408,
1024
+ hidden_size=1024,
1025
+ intermediate_size=4096,
1026
+ num_hidden_layers=23,
1027
+ num_attention_heads=16,
1028
+ max_position_embeddings=77,
1029
+ hidden_act="gelu",
1030
+ layer_norm_eps=1e-05,
1031
+ dropout=0.0,
1032
+ attention_dropout=0.0,
1033
+ initializer_range=0.02,
1034
+ initializer_factor=1.0,
1035
+ pad_token_id=1,
1036
+ bos_token_id=0,
1037
+ eos_token_id=2,
1038
+ model_type="clip_text_model",
1039
+ projection_dim=512,
1040
+ torch_dtype="float32",
1041
+ transformers_version="4.25.0.dev0",
1042
+ )
1043
+ text_model = CLIPTextModel._from_config(cfg)
1044
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1045
+ else:
1046
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
1047
+
1048
+ # logging.set_verbosity_error() # don't show annoying warning
1049
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
1050
+ # logging.set_verbosity_warning()
1051
+ # logger.info(f"config: {text_model.config}")
1052
+ cfg = CLIPTextConfig(
1053
+ vocab_size=49408,
1054
+ hidden_size=768,
1055
+ intermediate_size=3072,
1056
+ num_hidden_layers=12,
1057
+ num_attention_heads=12,
1058
+ max_position_embeddings=77,
1059
+ hidden_act="quick_gelu",
1060
+ layer_norm_eps=1e-05,
1061
+ dropout=0.0,
1062
+ attention_dropout=0.0,
1063
+ initializer_range=0.02,
1064
+ initializer_factor=1.0,
1065
+ pad_token_id=1,
1066
+ bos_token_id=0,
1067
+ eos_token_id=2,
1068
+ model_type="clip_text_model",
1069
+ projection_dim=768,
1070
+ torch_dtype="float32",
1071
+ )
1072
+ text_model = CLIPTextModel._from_config(cfg)
1073
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1074
+ logger.info(f"loading text encoder: {info}")
1075
+
1076
+ return text_model, vae, unet
1077
+
1078
+
1079
+ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
1080
+ # only for reference
1081
+ version_str = "sd"
1082
+ if v2:
1083
+ version_str += "_v2"
1084
+ else:
1085
+ version_str += "_v1"
1086
+ if v_parameterization:
1087
+ version_str += "_v"
1088
+ return version_str
1089
+
1090
+
1091
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
1092
+ def convert_key(key):
1093
+ # position_idsの除去
1094
+ if ".position_ids" in key:
1095
+ return None
1096
+
1097
+ # common
1098
+ key = key.replace("text_model.encoder.", "transformer.")
1099
+ key = key.replace("text_model.", "")
1100
+ if "layers" in key:
1101
+ # resblocks conversion
1102
+ key = key.replace(".layers.", ".resblocks.")
1103
+ if ".layer_norm" in key:
1104
+ key = key.replace(".layer_norm", ".ln_")
1105
+ elif ".mlp." in key:
1106
+ key = key.replace(".fc1.", ".c_fc.")
1107
+ key = key.replace(".fc2.", ".c_proj.")
1108
+ elif ".self_attn.out_proj" in key:
1109
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
1110
+ elif ".self_attn." in key:
1111
+ key = None # 特殊なので後で処理する
1112
+ else:
1113
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
1114
+ elif ".position_embedding" in key:
1115
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
1116
+ elif ".token_embedding" in key:
1117
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
1118
+ elif "final_layer_norm" in key:
1119
+ key = key.replace("final_layer_norm", "ln_final")
1120
+ return key
1121
+
1122
+ keys = list(checkpoint.keys())
1123
+ new_sd = {}
1124
+ for key in keys:
1125
+ new_key = convert_key(key)
1126
+ if new_key is None:
1127
+ continue
1128
+ new_sd[new_key] = checkpoint[key]
1129
+
1130
+ # attnの変換
1131
+ for key in keys:
1132
+ if "layers" in key and "q_proj" in key:
1133
+ # 三つを結合
1134
+ key_q = key
1135
+ key_k = key.replace("q_proj", "k_proj")
1136
+ key_v = key.replace("q_proj", "v_proj")
1137
+
1138
+ value_q = checkpoint[key_q]
1139
+ value_k = checkpoint[key_k]
1140
+ value_v = checkpoint[key_v]
1141
+ value = torch.cat([value_q, value_k, value_v])
1142
+
1143
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
1144
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
1145
+ new_sd[new_key] = value
1146
+
1147
+ # 最後の層などを捏造するか
1148
+ if make_dummy_weights:
1149
+ logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
1150
+ keys = list(new_sd.keys())
1151
+ for key in keys:
1152
+ if key.startswith("transformer.resblocks.22."):
1153
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
1154
+
1155
+ # Diffusersに含まれない重みを作っておく
1156
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
1157
+ new_sd["logit_scale"] = torch.tensor(1)
1158
+
1159
+ return new_sd
1160
+
1161
+
1162
+ def save_stable_diffusion_checkpoint(
1163
+ v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
1164
+ ):
1165
+ if ckpt_path is not None:
1166
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1167
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1168
+ if checkpoint is None: # safetensors または state_dictのckpt
1169
+ checkpoint = {}
1170
+ strict = False
1171
+ else:
1172
+ strict = True
1173
+ if "state_dict" in state_dict:
1174
+ del state_dict["state_dict"]
1175
+ else:
1176
+ # 新しく作る
1177
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1178
+ checkpoint = {}
1179
+ state_dict = {}
1180
+ strict = False
1181
+
1182
+ def update_sd(prefix, sd):
1183
+ for k, v in sd.items():
1184
+ key = prefix + k
1185
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1186
+ if save_dtype is not None:
1187
+ v = v.detach().clone().to("cpu").to(save_dtype)
1188
+ state_dict[key] = v
1189
+
1190
+ # Convert the UNet model
1191
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1192
+ update_sd("model.diffusion_model.", unet_state_dict)
1193
+
1194
+ # Convert the text encoder model
1195
+ if v2:
1196
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1197
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1198
+ update_sd("cond_stage_model.model.", text_enc_dict)
1199
+ else:
1200
+ text_enc_dict = text_encoder.state_dict()
1201
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1202
+
1203
+ # Convert the VAE
1204
+ if vae is not None:
1205
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1206
+ update_sd("first_stage_model.", vae_dict)
1207
+
1208
+ # Put together new checkpoint
1209
+ key_count = len(state_dict.keys())
1210
+ new_ckpt = {"state_dict": state_dict}
1211
+
1212
+ # epoch and global_step are sometimes not int
1213
+ try:
1214
+ if "epoch" in checkpoint:
1215
+ epochs += checkpoint["epoch"]
1216
+ if "global_step" in checkpoint:
1217
+ steps += checkpoint["global_step"]
1218
+ except:
1219
+ pass
1220
+
1221
+ new_ckpt["epoch"] = epochs
1222
+ new_ckpt["global_step"] = steps
1223
+
1224
+ if is_safetensors(output_file):
1225
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1226
+ save_file(state_dict, output_file, metadata)
1227
+ else:
1228
+ torch.save(new_ckpt, output_file)
1229
+
1230
+ return key_count
1231
+
1232
+
1233
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1234
+ if pretrained_model_name_or_path is None:
1235
+ # load default settings for v1/v2
1236
+ if v2:
1237
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1238
+ else:
1239
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1240
+
1241
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1242
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1243
+ if vae is None:
1244
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1245
+
1246
+ # original U-Net cannot be saved, so we need to convert it to the Diffusers version
1247
+ # TODO this consumes a lot of memory
1248
+ diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
1249
+ diffusers_unet.load_state_dict(unet.state_dict())
1250
+
1251
+ pipeline = StableDiffusionPipeline(
1252
+ unet=diffusers_unet,
1253
+ text_encoder=text_encoder,
1254
+ vae=vae,
1255
+ scheduler=scheduler,
1256
+ tokenizer=tokenizer,
1257
+ safety_checker=None,
1258
+ feature_extractor=None,
1259
+ requires_safety_checker=None,
1260
+ )
1261
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1262
+
1263
+
1264
+ VAE_PREFIX = "first_stage_model."
1265
+
1266
+
1267
+ def load_vae(vae_id, dtype):
1268
+ logger.info(f"load VAE: {vae_id}")
1269
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1270
+ # Diffusers local/remote
1271
+ try:
1272
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1273
+ except EnvironmentError as e:
1274
+ logger.error(f"exception occurs in loading vae: {e}")
1275
+ logger.error("retry with subfolder='vae'")
1276
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1277
+ return vae
1278
+
1279
+ # local
1280
+ vae_config = create_vae_diffusers_config()
1281
+
1282
+ if vae_id.endswith(".bin"):
1283
+ # SD 1.5 VAE on Huggingface
1284
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1285
+ else:
1286
+ # StableDiffusion
1287
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1288
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1289
+
1290
+ # vae only or full model
1291
+ full_model = False
1292
+ for vae_key in vae_sd:
1293
+ if vae_key.startswith(VAE_PREFIX):
1294
+ full_model = True
1295
+ break
1296
+ if not full_model:
1297
+ sd = {}
1298
+ for key, value in vae_sd.items():
1299
+ sd[VAE_PREFIX + key] = value
1300
+ vae_sd = sd
1301
+ del sd
1302
+
1303
+ # Convert the VAE model.
1304
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1305
+
1306
+ vae = AutoencoderKL(**vae_config)
1307
+ vae.load_state_dict(converted_vae_checkpoint)
1308
+ return vae
1309
+
1310
+
1311
+ # endregion
1312
+
1313
+
1314
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1315
+ max_width, max_height = max_reso
1316
+ max_area = max_width * max_height
1317
+
1318
+ resos = set()
1319
+
1320
+ width = int(math.sqrt(max_area) // divisible) * divisible
1321
+ resos.add((width, width))
1322
+
1323
+ width = min_size
1324
+ while width <= max_size:
1325
+ height = min(max_size, int((max_area // width) // divisible) * divisible)
1326
+ if height >= min_size:
1327
+ resos.add((width, height))
1328
+ resos.add((height, width))
1329
+
1330
+ # # make additional resos
1331
+ # if width >= height and width - divisible >= min_size:
1332
+ # resos.add((width - divisible, height))
1333
+ # resos.add((height, width - divisible))
1334
+ # if height >= width and height - divisible >= min_size:
1335
+ # resos.add((width, height - divisible))
1336
+ # resos.add((height - divisible, width))
1337
+
1338
+ width += divisible
1339
+
1340
+ resos = list(resos)
1341
+ resos.sort()
1342
+ return resos
1343
+
1344
+
1345
+ if __name__ == "__main__":
1346
+ resos = make_bucket_resolutions((512, 768))
1347
+ logger.info(f"{len(resos)}")
1348
+ logger.info(f"{resos}")
1349
+ aspect_ratios = [w / h for w, h in resos]
1350
+ logger.info(f"{aspect_ratios}")
1351
+
1352
+ ars = set()
1353
+ for ar in aspect_ratios:
1354
+ if ar in ars:
1355
+ logger.error(f"error! duplicate ar: {ar}")
1356
+ ars.add(ar)
library/original_unet.py ADDED
@@ -0,0 +1,1919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
2
+ # 条件分岐等で不要な部分は削除している
3
+ # コードの多くはDiffusersからコピーしている
4
+ # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
5
+
6
+ # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
7
+ # Unnecessary parts are deleted by condition branching.
8
+ # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
9
+
10
+ """
11
+ v1.5とv2.1の相違点は
12
+ - attention_head_dimがintかlist[int]か
13
+ - cross_attention_dimが768か1024か
14
+ - use_linear_projection: trueがない(=False, 1.5)かあるか
15
+ - upcast_attentionがFalse(1.5)かTrue(2.1)か
16
+ - (以下は多分無視していい)
17
+ - sample_sizeが64か96か
18
+ - dual_cross_attentionがあるかないか
19
+ - num_class_embedsがあるかないか
20
+ - only_cross_attentionがあるかないか
21
+
22
+ v1.5
23
+ {
24
+ "_class_name": "UNet2DConditionModel",
25
+ "_diffusers_version": "0.6.0",
26
+ "act_fn": "silu",
27
+ "attention_head_dim": 8,
28
+ "block_out_channels": [
29
+ 320,
30
+ 640,
31
+ 1280,
32
+ 1280
33
+ ],
34
+ "center_input_sample": false,
35
+ "cross_attention_dim": 768,
36
+ "down_block_types": [
37
+ "CrossAttnDownBlock2D",
38
+ "CrossAttnDownBlock2D",
39
+ "CrossAttnDownBlock2D",
40
+ "DownBlock2D"
41
+ ],
42
+ "downsample_padding": 1,
43
+ "flip_sin_to_cos": true,
44
+ "freq_shift": 0,
45
+ "in_channels": 4,
46
+ "layers_per_block": 2,
47
+ "mid_block_scale_factor": 1,
48
+ "norm_eps": 1e-05,
49
+ "norm_num_groups": 32,
50
+ "out_channels": 4,
51
+ "sample_size": 64,
52
+ "up_block_types": [
53
+ "UpBlock2D",
54
+ "CrossAttnUpBlock2D",
55
+ "CrossAttnUpBlock2D",
56
+ "CrossAttnUpBlock2D"
57
+ ]
58
+ }
59
+
60
+ v2.1
61
+ {
62
+ "_class_name": "UNet2DConditionModel",
63
+ "_diffusers_version": "0.10.0.dev0",
64
+ "act_fn": "silu",
65
+ "attention_head_dim": [
66
+ 5,
67
+ 10,
68
+ 20,
69
+ 20
70
+ ],
71
+ "block_out_channels": [
72
+ 320,
73
+ 640,
74
+ 1280,
75
+ 1280
76
+ ],
77
+ "center_input_sample": false,
78
+ "cross_attention_dim": 1024,
79
+ "down_block_types": [
80
+ "CrossAttnDownBlock2D",
81
+ "CrossAttnDownBlock2D",
82
+ "CrossAttnDownBlock2D",
83
+ "DownBlock2D"
84
+ ],
85
+ "downsample_padding": 1,
86
+ "dual_cross_attention": false,
87
+ "flip_sin_to_cos": true,
88
+ "freq_shift": 0,
89
+ "in_channels": 4,
90
+ "layers_per_block": 2,
91
+ "mid_block_scale_factor": 1,
92
+ "norm_eps": 1e-05,
93
+ "norm_num_groups": 32,
94
+ "num_class_embeds": null,
95
+ "only_cross_attention": false,
96
+ "out_channels": 4,
97
+ "sample_size": 96,
98
+ "up_block_types": [
99
+ "UpBlock2D",
100
+ "CrossAttnUpBlock2D",
101
+ "CrossAttnUpBlock2D",
102
+ "CrossAttnUpBlock2D"
103
+ ],
104
+ "use_linear_projection": true,
105
+ "upcast_attention": true
106
+ }
107
+ """
108
+
109
+ import math
110
+ from types import SimpleNamespace
111
+ from typing import Dict, Optional, Tuple, Union
112
+ import torch
113
+ from torch import nn
114
+ from torch.nn import functional as F
115
+ from einops import rearrange
116
+ from library.utils import setup_logging
117
+ setup_logging()
118
+ import logging
119
+ logger = logging.getLogger(__name__)
120
+
121
+ BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
122
+ TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
123
+ TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
124
+ IN_CHANNELS: int = 4
125
+ OUT_CHANNELS: int = 4
126
+ LAYERS_PER_BLOCK: int = 2
127
+ LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
128
+ TIME_EMBED_FLIP_SIN_TO_COS: bool = True
129
+ TIME_EMBED_FREQ_SHIFT: int = 0
130
+ NORM_GROUPS: int = 32
131
+ NORM_EPS: float = 1e-5
132
+ TRANSFORMER_NORM_NUM_GROUPS = 32
133
+
134
+ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
135
+ UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
136
+
137
+
138
+ # region memory efficient attention
139
+
140
+ # FlashAttentionを使うCrossAttention
141
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
142
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
143
+
144
+ # constants
145
+
146
+ EPSILON = 1e-6
147
+
148
+ # helper functions
149
+
150
+
151
+ def exists(val):
152
+ return val is not None
153
+
154
+
155
+ def default(val, d):
156
+ return val if exists(val) else d
157
+
158
+
159
+ # flash attention forwards and backwards
160
+
161
+ # https://arxiv.org/abs/2205.14135
162
+
163
+
164
+ class FlashAttentionFunction(torch.autograd.Function):
165
+ @staticmethod
166
+ @torch.no_grad()
167
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
168
+ """Algorithm 2 in the paper"""
169
+
170
+ device = q.device
171
+ dtype = q.dtype
172
+ max_neg_value = -torch.finfo(q.dtype).max
173
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
174
+
175
+ o = torch.zeros_like(q)
176
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
177
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
178
+
179
+ scale = q.shape[-1] ** -0.5
180
+
181
+ if not exists(mask):
182
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
183
+ else:
184
+ mask = rearrange(mask, "b n -> b 1 1 n")
185
+ mask = mask.split(q_bucket_size, dim=-1)
186
+
187
+ row_splits = zip(
188
+ q.split(q_bucket_size, dim=-2),
189
+ o.split(q_bucket_size, dim=-2),
190
+ mask,
191
+ all_row_sums.split(q_bucket_size, dim=-2),
192
+ all_row_maxes.split(q_bucket_size, dim=-2),
193
+ )
194
+
195
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
196
+ q_start_index = ind * q_bucket_size - qk_len_diff
197
+
198
+ col_splits = zip(
199
+ k.split(k_bucket_size, dim=-2),
200
+ v.split(k_bucket_size, dim=-2),
201
+ )
202
+
203
+ for k_ind, (kc, vc) in enumerate(col_splits):
204
+ k_start_index = k_ind * k_bucket_size
205
+
206
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
207
+
208
+ if exists(row_mask):
209
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
210
+
211
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
212
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
213
+ q_start_index - k_start_index + 1
214
+ )
215
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
216
+
217
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
218
+ attn_weights -= block_row_maxes
219
+ exp_weights = torch.exp(attn_weights)
220
+
221
+ if exists(row_mask):
222
+ exp_weights.masked_fill_(~row_mask, 0.0)
223
+
224
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
225
+
226
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
227
+
228
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
229
+
230
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
231
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
232
+
233
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
234
+
235
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
236
+
237
+ row_maxes.copy_(new_row_maxes)
238
+ row_sums.copy_(new_row_sums)
239
+
240
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
241
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
242
+
243
+ return o
244
+
245
+ @staticmethod
246
+ @torch.no_grad()
247
+ def backward(ctx, do):
248
+ """Algorithm 4 in the paper"""
249
+
250
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
251
+ q, k, v, o, l, m = ctx.saved_tensors
252
+
253
+ device = q.device
254
+
255
+ max_neg_value = -torch.finfo(q.dtype).max
256
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
257
+
258
+ dq = torch.zeros_like(q)
259
+ dk = torch.zeros_like(k)
260
+ dv = torch.zeros_like(v)
261
+
262
+ row_splits = zip(
263
+ q.split(q_bucket_size, dim=-2),
264
+ o.split(q_bucket_size, dim=-2),
265
+ do.split(q_bucket_size, dim=-2),
266
+ mask,
267
+ l.split(q_bucket_size, dim=-2),
268
+ m.split(q_bucket_size, dim=-2),
269
+ dq.split(q_bucket_size, dim=-2),
270
+ )
271
+
272
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
273
+ q_start_index = ind * q_bucket_size - qk_len_diff
274
+
275
+ col_splits = zip(
276
+ k.split(k_bucket_size, dim=-2),
277
+ v.split(k_bucket_size, dim=-2),
278
+ dk.split(k_bucket_size, dim=-2),
279
+ dv.split(k_bucket_size, dim=-2),
280
+ )
281
+
282
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
283
+ k_start_index = k_ind * k_bucket_size
284
+
285
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
286
+
287
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
288
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
289
+ q_start_index - k_start_index + 1
290
+ )
291
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
292
+
293
+ exp_attn_weights = torch.exp(attn_weights - mc)
294
+
295
+ if exists(row_mask):
296
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
297
+
298
+ p = exp_attn_weights / lc
299
+
300
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
301
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
302
+
303
+ D = (doc * oc).sum(dim=-1, keepdims=True)
304
+ ds = p * scale * (dp - D)
305
+
306
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
307
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
308
+
309
+ dqc.add_(dq_chunk)
310
+ dkc.add_(dk_chunk)
311
+ dvc.add_(dv_chunk)
312
+
313
+ return dq, dk, dv, None, None, None, None
314
+
315
+
316
+ # endregion
317
+
318
+
319
+ def get_parameter_dtype(parameter: torch.nn.Module):
320
+ return next(parameter.parameters()).dtype
321
+
322
+
323
+ def get_parameter_device(parameter: torch.nn.Module):
324
+ return next(parameter.parameters()).device
325
+
326
+
327
+ def get_timestep_embedding(
328
+ timesteps: torch.Tensor,
329
+ embedding_dim: int,
330
+ flip_sin_to_cos: bool = False,
331
+ downscale_freq_shift: float = 1,
332
+ scale: float = 1,
333
+ max_period: int = 10000,
334
+ ):
335
+ """
336
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
337
+
338
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
339
+ These may be fractional.
340
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
341
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
342
+ """
343
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
344
+
345
+ half_dim = embedding_dim // 2
346
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
347
+ exponent = exponent / (half_dim - downscale_freq_shift)
348
+
349
+ emb = torch.exp(exponent)
350
+ emb = timesteps[:, None].float() * emb[None, :]
351
+
352
+ # scale embeddings
353
+ emb = scale * emb
354
+
355
+ # concat sine and cosine embeddings
356
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
357
+
358
+ # flip sine and cosine embeddings
359
+ if flip_sin_to_cos:
360
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
361
+
362
+ # zero pad
363
+ if embedding_dim % 2 == 1:
364
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
365
+ return emb
366
+
367
+
368
+ # Deep Shrink: We do not common this function, because minimize dependencies.
369
+ def resize_like(x, target, mode="bicubic", align_corners=False):
370
+ org_dtype = x.dtype
371
+ if org_dtype == torch.bfloat16:
372
+ x = x.to(torch.float32)
373
+
374
+ if x.shape[-2:] != target.shape[-2:]:
375
+ if mode == "nearest":
376
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
377
+ else:
378
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
379
+
380
+ if org_dtype == torch.bfloat16:
381
+ x = x.to(org_dtype)
382
+ return x
383
+
384
+
385
+ class SampleOutput:
386
+ def __init__(self, sample):
387
+ self.sample = sample
388
+
389
+
390
+ class TimestepEmbedding(nn.Module):
391
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
392
+ super().__init__()
393
+
394
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
395
+ self.act = None
396
+ if act_fn == "silu":
397
+ self.act = nn.SiLU()
398
+ elif act_fn == "mish":
399
+ self.act = nn.Mish()
400
+
401
+ if out_dim is not None:
402
+ time_embed_dim_out = out_dim
403
+ else:
404
+ time_embed_dim_out = time_embed_dim
405
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
406
+
407
+ def forward(self, sample):
408
+ sample = self.linear_1(sample)
409
+
410
+ if self.act is not None:
411
+ sample = self.act(sample)
412
+
413
+ sample = self.linear_2(sample)
414
+ return sample
415
+
416
+
417
+ class Timesteps(nn.Module):
418
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
419
+ super().__init__()
420
+ self.num_channels = num_channels
421
+ self.flip_sin_to_cos = flip_sin_to_cos
422
+ self.downscale_freq_shift = downscale_freq_shift
423
+
424
+ def forward(self, timesteps):
425
+ t_emb = get_timestep_embedding(
426
+ timesteps,
427
+ self.num_channels,
428
+ flip_sin_to_cos=self.flip_sin_to_cos,
429
+ downscale_freq_shift=self.downscale_freq_shift,
430
+ )
431
+ return t_emb
432
+
433
+
434
+ class ResnetBlock2D(nn.Module):
435
+ def __init__(
436
+ self,
437
+ in_channels,
438
+ out_channels,
439
+ ):
440
+ super().__init__()
441
+ self.in_channels = in_channels
442
+ self.out_channels = out_channels
443
+
444
+ self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
445
+
446
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
447
+
448
+ self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
449
+
450
+ self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
451
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
452
+
453
+ # if non_linearity == "swish":
454
+ self.nonlinearity = lambda x: F.silu(x)
455
+
456
+ self.use_in_shortcut = self.in_channels != self.out_channels
457
+
458
+ self.conv_shortcut = None
459
+ if self.use_in_shortcut:
460
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
461
+
462
+ def forward(self, input_tensor, temb):
463
+ hidden_states = input_tensor
464
+
465
+ hidden_states = self.norm1(hidden_states)
466
+ hidden_states = self.nonlinearity(hidden_states)
467
+
468
+ hidden_states = self.conv1(hidden_states)
469
+
470
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
471
+ hidden_states = hidden_states + temb
472
+
473
+ hidden_states = self.norm2(hidden_states)
474
+ hidden_states = self.nonlinearity(hidden_states)
475
+
476
+ hidden_states = self.conv2(hidden_states)
477
+
478
+ if self.conv_shortcut is not None:
479
+ input_tensor = self.conv_shortcut(input_tensor)
480
+
481
+ output_tensor = input_tensor + hidden_states
482
+
483
+ return output_tensor
484
+
485
+
486
+ class DownBlock2D(nn.Module):
487
+ def __init__(
488
+ self,
489
+ in_channels: int,
490
+ out_channels: int,
491
+ add_downsample=True,
492
+ ):
493
+ super().__init__()
494
+
495
+ self.has_cross_attention = False
496
+ resnets = []
497
+
498
+ for i in range(LAYERS_PER_BLOCK):
499
+ in_channels = in_channels if i == 0 else out_channels
500
+ resnets.append(
501
+ ResnetBlock2D(
502
+ in_channels=in_channels,
503
+ out_channels=out_channels,
504
+ )
505
+ )
506
+ self.resnets = nn.ModuleList(resnets)
507
+
508
+ if add_downsample:
509
+ self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
510
+ else:
511
+ self.downsamplers = None
512
+
513
+ self.gradient_checkpointing = False
514
+
515
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
516
+ pass
517
+
518
+ def set_use_sdpa(self, sdpa):
519
+ pass
520
+
521
+ def forward(self, hidden_states, temb=None):
522
+ output_states = ()
523
+
524
+ for resnet in self.resnets:
525
+ if self.training and self.gradient_checkpointing:
526
+
527
+ def create_custom_forward(module):
528
+ def custom_forward(*inputs):
529
+ return module(*inputs)
530
+
531
+ return custom_forward
532
+
533
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
534
+ else:
535
+ hidden_states = resnet(hidden_states, temb)
536
+
537
+ output_states += (hidden_states,)
538
+
539
+ if self.downsamplers is not None:
540
+ for downsampler in self.downsamplers:
541
+ hidden_states = downsampler(hidden_states)
542
+
543
+ output_states += (hidden_states,)
544
+
545
+ return hidden_states, output_states
546
+
547
+
548
+ class Downsample2D(nn.Module):
549
+ def __init__(self, channels, out_channels):
550
+ super().__init__()
551
+
552
+ self.channels = channels
553
+ self.out_channels = out_channels
554
+
555
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
556
+
557
+ def forward(self, hidden_states):
558
+ assert hidden_states.shape[1] == self.channels
559
+ hidden_states = self.conv(hidden_states)
560
+
561
+ return hidden_states
562
+
563
+
564
+ class CrossAttention(nn.Module):
565
+ def __init__(
566
+ self,
567
+ query_dim: int,
568
+ cross_attention_dim: Optional[int] = None,
569
+ heads: int = 8,
570
+ dim_head: int = 64,
571
+ upcast_attention: bool = False,
572
+ ):
573
+ super().__init__()
574
+ inner_dim = dim_head * heads
575
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
576
+ self.upcast_attention = upcast_attention
577
+
578
+ self.scale = dim_head**-0.5
579
+ self.heads = heads
580
+
581
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
582
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
583
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
584
+
585
+ self.to_out = nn.ModuleList([])
586
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
587
+ # no dropout here
588
+
589
+ self.use_memory_efficient_attention_xformers = False
590
+ self.use_memory_efficient_attention_mem_eff = False
591
+ self.use_sdpa = False
592
+
593
+ # Attention processor
594
+ self.processor = None
595
+
596
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
597
+ self.use_memory_efficient_attention_xformers = xformers
598
+ self.use_memory_efficient_attention_mem_eff = mem_eff
599
+
600
+ def set_use_sdpa(self, sdpa):
601
+ self.use_sdpa = sdpa
602
+
603
+ def reshape_heads_to_batch_dim(self, tensor):
604
+ batch_size, seq_len, dim = tensor.shape
605
+ head_size = self.heads
606
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
607
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
608
+ return tensor
609
+
610
+ def reshape_batch_dim_to_heads(self, tensor):
611
+ batch_size, seq_len, dim = tensor.shape
612
+ head_size = self.heads
613
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
614
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
615
+ return tensor
616
+
617
+ def set_processor(self):
618
+ return self.processor
619
+
620
+ def get_processor(self):
621
+ return self.processor
622
+
623
+ def forward(self, hidden_states, context=None, mask=None, **kwargs):
624
+ if self.processor is not None:
625
+ (
626
+ hidden_states,
627
+ encoder_hidden_states,
628
+ attention_mask,
629
+ ) = translate_attention_names_from_diffusers(
630
+ hidden_states=hidden_states, context=context, mask=mask, **kwargs
631
+ )
632
+ return self.processor(
633
+ attn=self,
634
+ hidden_states=hidden_states,
635
+ encoder_hidden_states=context,
636
+ attention_mask=mask,
637
+ **kwargs
638
+ )
639
+ if self.use_memory_efficient_attention_xformers:
640
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
641
+ if self.use_memory_efficient_attention_mem_eff:
642
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
643
+ if self.use_sdpa:
644
+ return self.forward_sdpa(hidden_states, context, mask)
645
+
646
+ query = self.to_q(hidden_states)
647
+ context = context if context is not None else hidden_states
648
+ key = self.to_k(context)
649
+ value = self.to_v(context)
650
+
651
+ query = self.reshape_heads_to_batch_dim(query)
652
+ key = self.reshape_heads_to_batch_dim(key)
653
+ value = self.reshape_heads_to_batch_dim(value)
654
+
655
+ hidden_states = self._attention(query, key, value)
656
+
657
+ # linear proj
658
+ hidden_states = self.to_out[0](hidden_states)
659
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
660
+ return hidden_states
661
+
662
+ def _attention(self, query, key, value):
663
+ if self.upcast_attention:
664
+ query = query.float()
665
+ key = key.float()
666
+
667
+ attention_scores = torch.baddbmm(
668
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
669
+ query,
670
+ key.transpose(-1, -2),
671
+ beta=0,
672
+ alpha=self.scale,
673
+ )
674
+ attention_probs = attention_scores.softmax(dim=-1)
675
+
676
+ # cast back to the original dtype
677
+ attention_probs = attention_probs.to(value.dtype)
678
+
679
+ # compute attention output
680
+ hidden_states = torch.bmm(attention_probs, value)
681
+
682
+ # reshape hidden_states
683
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
684
+ return hidden_states
685
+
686
+ # TODO support Hypernetworks
687
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
688
+ import xformers.ops
689
+
690
+ h = self.heads
691
+ q_in = self.to_q(x)
692
+ context = context if context is not None else x
693
+ context = context.to(x.dtype)
694
+ k_in = self.to_k(context)
695
+ v_in = self.to_v(context)
696
+
697
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
698
+ del q_in, k_in, v_in
699
+
700
+ q = q.contiguous()
701
+ k = k.contiguous()
702
+ v = v.contiguous()
703
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
704
+
705
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
706
+
707
+ out = self.to_out[0](out)
708
+ return out
709
+
710
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
711
+ flash_func = FlashAttentionFunction
712
+
713
+ q_bucket_size = 512
714
+ k_bucket_size = 1024
715
+
716
+ h = self.heads
717
+ q = self.to_q(x)
718
+ context = context if context is not None else x
719
+ context = context.to(x.dtype)
720
+ k = self.to_k(context)
721
+ v = self.to_v(context)
722
+ del context, x
723
+
724
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
725
+
726
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
727
+
728
+ out = rearrange(out, "b h n d -> b n (h d)")
729
+
730
+ out = self.to_out[0](out)
731
+ return out
732
+
733
+ def forward_sdpa(self, x, context=None, mask=None):
734
+ h = self.heads
735
+ q_in = self.to_q(x)
736
+ context = context if context is not None else x
737
+ context = context.to(x.dtype)
738
+ k_in = self.to_k(context)
739
+ v_in = self.to_v(context)
740
+
741
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
742
+ del q_in, k_in, v_in
743
+
744
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
745
+
746
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
747
+
748
+ out = self.to_out[0](out)
749
+ return out
750
+
751
+ def translate_attention_names_from_diffusers(
752
+ hidden_states: torch.FloatTensor,
753
+ context: Optional[torch.FloatTensor] = None,
754
+ mask: Optional[torch.FloatTensor] = None,
755
+ # HF naming
756
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
757
+ attention_mask: Optional[torch.FloatTensor] = None
758
+ ):
759
+ # translate from hugging face diffusers
760
+ context = context if context is not None else encoder_hidden_states
761
+
762
+ # translate from hugging face diffusers
763
+ mask = mask if mask is not None else attention_mask
764
+
765
+ return hidden_states, context, mask
766
+
767
+ # feedforward
768
+ class GEGLU(nn.Module):
769
+ r"""
770
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
771
+
772
+ Parameters:
773
+ dim_in (`int`): The number of channels in the input.
774
+ dim_out (`int`): The number of channels in the output.
775
+ """
776
+
777
+ def __init__(self, dim_in: int, dim_out: int):
778
+ super().__init__()
779
+ self.proj = nn.Linear(dim_in, dim_out * 2)
780
+
781
+ def gelu(self, gate):
782
+ if gate.device.type != "mps":
783
+ return F.gelu(gate)
784
+ # mps: gelu is not implemented for float16
785
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
786
+
787
+ def forward(self, hidden_states):
788
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
789
+ return hidden_states * self.gelu(gate)
790
+
791
+
792
+ class FeedForward(nn.Module):
793
+ def __init__(
794
+ self,
795
+ dim: int,
796
+ ):
797
+ super().__init__()
798
+ inner_dim = int(dim * 4) # mult is always 4
799
+
800
+ self.net = nn.ModuleList([])
801
+ # project in
802
+ self.net.append(GEGLU(dim, inner_dim))
803
+ # project dropout
804
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
805
+ # project out
806
+ self.net.append(nn.Linear(inner_dim, dim))
807
+
808
+ def forward(self, hidden_states):
809
+ for module in self.net:
810
+ hidden_states = module(hidden_states)
811
+ return hidden_states
812
+
813
+
814
+ class BasicTransformerBlock(nn.Module):
815
+ def __init__(
816
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
817
+ ):
818
+ super().__init__()
819
+
820
+ # 1. Self-Attn
821
+ self.attn1 = CrossAttention(
822
+ query_dim=dim,
823
+ cross_attention_dim=None,
824
+ heads=num_attention_heads,
825
+ dim_head=attention_head_dim,
826
+ upcast_attention=upcast_attention,
827
+ )
828
+ self.ff = FeedForward(dim)
829
+
830
+ # 2. Cross-Attn
831
+ self.attn2 = CrossAttention(
832
+ query_dim=dim,
833
+ cross_attention_dim=cross_attention_dim,
834
+ heads=num_attention_heads,
835
+ dim_head=attention_head_dim,
836
+ upcast_attention=upcast_attention,
837
+ )
838
+
839
+ self.norm1 = nn.LayerNorm(dim)
840
+ self.norm2 = nn.LayerNorm(dim)
841
+
842
+ # 3. Feed-forward
843
+ self.norm3 = nn.LayerNorm(dim)
844
+
845
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
846
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
847
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
848
+
849
+ def set_use_sdpa(self, sdpa: bool):
850
+ self.attn1.set_use_sdpa(sdpa)
851
+ self.attn2.set_use_sdpa(sdpa)
852
+
853
+ def forward(self, hidden_states, context=None, timestep=None):
854
+ # 1. Self-Attention
855
+ norm_hidden_states = self.norm1(hidden_states)
856
+
857
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
858
+
859
+ # 2. Cross-Attention
860
+ norm_hidden_states = self.norm2(hidden_states)
861
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
862
+
863
+ # 3. Feed-forward
864
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
865
+
866
+ return hidden_states
867
+
868
+
869
+ class Transformer2DModel(nn.Module):
870
+ def __init__(
871
+ self,
872
+ num_attention_heads: int = 16,
873
+ attention_head_dim: int = 88,
874
+ in_channels: Optional[int] = None,
875
+ cross_attention_dim: Optional[int] = None,
876
+ use_linear_projection: bool = False,
877
+ upcast_attention: bool = False,
878
+ ):
879
+ super().__init__()
880
+ self.in_channels = in_channels
881
+ self.num_attention_heads = num_attention_heads
882
+ self.attention_head_dim = attention_head_dim
883
+ inner_dim = num_attention_heads * attention_head_dim
884
+ self.use_linear_projection = use_linear_projection
885
+
886
+ self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
887
+
888
+ if use_linear_projection:
889
+ self.proj_in = nn.Linear(in_channels, inner_dim)
890
+ else:
891
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
892
+
893
+ self.transformer_blocks = nn.ModuleList(
894
+ [
895
+ BasicTransformerBlock(
896
+ inner_dim,
897
+ num_attention_heads,
898
+ attention_head_dim,
899
+ cross_attention_dim=cross_attention_dim,
900
+ upcast_attention=upcast_attention,
901
+ )
902
+ ]
903
+ )
904
+
905
+ if use_linear_projection:
906
+ self.proj_out = nn.Linear(in_channels, inner_dim)
907
+ else:
908
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
909
+
910
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
911
+ for transformer in self.transformer_blocks:
912
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
913
+
914
+ def set_use_sdpa(self, sdpa):
915
+ for transformer in self.transformer_blocks:
916
+ transformer.set_use_sdpa(sdpa)
917
+
918
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
919
+ # 1. Input
920
+ batch, _, height, weight = hidden_states.shape
921
+ residual = hidden_states
922
+
923
+ hidden_states = self.norm(hidden_states)
924
+ if not self.use_linear_projection:
925
+ hidden_states = self.proj_in(hidden_states)
926
+ inner_dim = hidden_states.shape[1]
927
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
928
+ else:
929
+ inner_dim = hidden_states.shape[1]
930
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
931
+ hidden_states = self.proj_in(hidden_states)
932
+
933
+ # 2. Blocks
934
+ for block in self.transformer_blocks:
935
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
936
+
937
+ # 3. Output
938
+ if not self.use_linear_projection:
939
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
940
+ hidden_states = self.proj_out(hidden_states)
941
+ else:
942
+ hidden_states = self.proj_out(hidden_states)
943
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
944
+
945
+ output = hidden_states + residual
946
+
947
+ if not return_dict:
948
+ return (output,)
949
+
950
+ return SampleOutput(sample=output)
951
+
952
+
953
+ class CrossAttnDownBlock2D(nn.Module):
954
+ def __init__(
955
+ self,
956
+ in_channels: int,
957
+ out_channels: int,
958
+ add_downsample=True,
959
+ cross_attention_dim=1280,
960
+ attn_num_head_channels=1,
961
+ use_linear_projection=False,
962
+ upcast_attention=False,
963
+ ):
964
+ super().__init__()
965
+ self.has_cross_attention = True
966
+ resnets = []
967
+ attentions = []
968
+
969
+ self.attn_num_head_channels = attn_num_head_channels
970
+
971
+ for i in range(LAYERS_PER_BLOCK):
972
+ in_channels = in_channels if i == 0 else out_channels
973
+
974
+ resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
975
+ attentions.append(
976
+ Transformer2DModel(
977
+ attn_num_head_channels,
978
+ out_channels // attn_num_head_channels,
979
+ in_channels=out_channels,
980
+ cross_attention_dim=cross_attention_dim,
981
+ use_linear_projection=use_linear_projection,
982
+ upcast_attention=upcast_attention,
983
+ )
984
+ )
985
+ self.attentions = nn.ModuleList(attentions)
986
+ self.resnets = nn.ModuleList(resnets)
987
+
988
+ if add_downsample:
989
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
990
+ else:
991
+ self.downsamplers = None
992
+
993
+ self.gradient_checkpointing = False
994
+
995
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
996
+ for attn in self.attentions:
997
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
998
+
999
+ def set_use_sdpa(self, sdpa):
1000
+ for attn in self.attentions:
1001
+ attn.set_use_sdpa(sdpa)
1002
+
1003
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1004
+ output_states = ()
1005
+
1006
+ for resnet, attn in zip(self.resnets, self.attentions):
1007
+ if self.training and self.gradient_checkpointing:
1008
+
1009
+ def create_custom_forward(module, return_dict=None):
1010
+ def custom_forward(*inputs):
1011
+ if return_dict is not None:
1012
+ return module(*inputs, return_dict=return_dict)
1013
+ else:
1014
+ return module(*inputs)
1015
+
1016
+ return custom_forward
1017
+
1018
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1019
+ hidden_states = torch.utils.checkpoint.checkpoint(
1020
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1021
+ )[0]
1022
+ else:
1023
+ hidden_states = resnet(hidden_states, temb)
1024
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1025
+
1026
+ output_states += (hidden_states,)
1027
+
1028
+ if self.downsamplers is not None:
1029
+ for downsampler in self.downsamplers:
1030
+ hidden_states = downsampler(hidden_states)
1031
+
1032
+ output_states += (hidden_states,)
1033
+
1034
+ return hidden_states, output_states
1035
+
1036
+
1037
+ class UNetMidBlock2DCrossAttn(nn.Module):
1038
+ def __init__(
1039
+ self,
1040
+ in_channels: int,
1041
+ attn_num_head_channels=1,
1042
+ cross_attention_dim=1280,
1043
+ use_linear_projection=False,
1044
+ ):
1045
+ super().__init__()
1046
+
1047
+ self.has_cross_attention = True
1048
+ self.attn_num_head_channels = attn_num_head_channels
1049
+
1050
+ # Middle block has two resnets and one attention
1051
+ resnets = [
1052
+ ResnetBlock2D(
1053
+ in_channels=in_channels,
1054
+ out_channels=in_channels,
1055
+ ),
1056
+ ResnetBlock2D(
1057
+ in_channels=in_channels,
1058
+ out_channels=in_channels,
1059
+ ),
1060
+ ]
1061
+ attentions = [
1062
+ Transformer2DModel(
1063
+ attn_num_head_channels,
1064
+ in_channels // attn_num_head_channels,
1065
+ in_channels=in_channels,
1066
+ cross_attention_dim=cross_attention_dim,
1067
+ use_linear_projection=use_linear_projection,
1068
+ )
1069
+ ]
1070
+
1071
+ self.attentions = nn.ModuleList(attentions)
1072
+ self.resnets = nn.ModuleList(resnets)
1073
+
1074
+ self.gradient_checkpointing = False
1075
+
1076
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1077
+ for attn in self.attentions:
1078
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1079
+
1080
+ def set_use_sdpa(self, sdpa):
1081
+ for attn in self.attentions:
1082
+ attn.set_use_sdpa(sdpa)
1083
+
1084
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1085
+ for i, resnet in enumerate(self.resnets):
1086
+ attn = None if i == 0 else self.attentions[i - 1]
1087
+
1088
+ if self.training and self.gradient_checkpointing:
1089
+
1090
+ def create_custom_forward(module, return_dict=None):
1091
+ def custom_forward(*inputs):
1092
+ if return_dict is not None:
1093
+ return module(*inputs, return_dict=return_dict)
1094
+ else:
1095
+ return module(*inputs)
1096
+
1097
+ return custom_forward
1098
+
1099
+ if attn is not None:
1100
+ hidden_states = torch.utils.checkpoint.checkpoint(
1101
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1102
+ )[0]
1103
+
1104
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1105
+ else:
1106
+ if attn is not None:
1107
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
1108
+ hidden_states = resnet(hidden_states, temb)
1109
+
1110
+ return hidden_states
1111
+
1112
+
1113
+ class Upsample2D(nn.Module):
1114
+ def __init__(self, channels, out_channels):
1115
+ super().__init__()
1116
+ self.channels = channels
1117
+ self.out_channels = out_channels
1118
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
1119
+
1120
+ def forward(self, hidden_states, output_size):
1121
+ assert hidden_states.shape[1] == self.channels
1122
+
1123
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
1124
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
1125
+ # https://github.com/pytorch/pytorch/issues/86679
1126
+ dtype = hidden_states.dtype
1127
+ if dtype == torch.bfloat16:
1128
+ hidden_states = hidden_states.to(torch.float32)
1129
+
1130
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
1131
+ if hidden_states.shape[0] >= 64:
1132
+ hidden_states = hidden_states.contiguous()
1133
+
1134
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
1135
+ if output_size is None:
1136
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
1137
+ else:
1138
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
1139
+
1140
+ # If the input is bfloat16, we cast back to bfloat16
1141
+ if dtype == torch.bfloat16:
1142
+ hidden_states = hidden_states.to(dtype)
1143
+
1144
+ hidden_states = self.conv(hidden_states)
1145
+
1146
+ return hidden_states
1147
+
1148
+
1149
+ class UpBlock2D(nn.Module):
1150
+ def __init__(
1151
+ self,
1152
+ in_channels: int,
1153
+ prev_output_channel: int,
1154
+ out_channels: int,
1155
+ add_upsample=True,
1156
+ ):
1157
+ super().__init__()
1158
+
1159
+ self.has_cross_attention = False
1160
+ resnets = []
1161
+
1162
+ for i in range(LAYERS_PER_BLOCK_UP):
1163
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1164
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1165
+
1166
+ resnets.append(
1167
+ ResnetBlock2D(
1168
+ in_channels=resnet_in_channels + res_skip_channels,
1169
+ out_channels=out_channels,
1170
+ )
1171
+ )
1172
+
1173
+ self.resnets = nn.ModuleList(resnets)
1174
+
1175
+ if add_upsample:
1176
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1177
+ else:
1178
+ self.upsamplers = None
1179
+
1180
+ self.gradient_checkpointing = False
1181
+
1182
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1183
+ pass
1184
+
1185
+ def set_use_sdpa(self, sdpa):
1186
+ pass
1187
+
1188
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1189
+ for resnet in self.resnets:
1190
+ # pop res hidden states
1191
+ res_hidden_states = res_hidden_states_tuple[-1]
1192
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1193
+
1194
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1195
+
1196
+ if self.training and self.gradient_checkpointing:
1197
+
1198
+ def create_custom_forward(module):
1199
+ def custom_forward(*inputs):
1200
+ return module(*inputs)
1201
+
1202
+ return custom_forward
1203
+
1204
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1205
+ else:
1206
+ hidden_states = resnet(hidden_states, temb)
1207
+
1208
+ if self.upsamplers is not None:
1209
+ for upsampler in self.upsamplers:
1210
+ hidden_states = upsampler(hidden_states, upsample_size)
1211
+
1212
+ return hidden_states
1213
+
1214
+
1215
+ class CrossAttnUpBlock2D(nn.Module):
1216
+ def __init__(
1217
+ self,
1218
+ in_channels: int,
1219
+ out_channels: int,
1220
+ prev_output_channel: int,
1221
+ attn_num_head_channels=1,
1222
+ cross_attention_dim=1280,
1223
+ add_upsample=True,
1224
+ use_linear_projection=False,
1225
+ upcast_attention=False,
1226
+ ):
1227
+ super().__init__()
1228
+ resnets = []
1229
+ attentions = []
1230
+
1231
+ self.has_cross_attention = True
1232
+ self.attn_num_head_channels = attn_num_head_channels
1233
+
1234
+ for i in range(LAYERS_PER_BLOCK_UP):
1235
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1236
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1237
+
1238
+ resnets.append(
1239
+ ResnetBlock2D(
1240
+ in_channels=resnet_in_channels + res_skip_channels,
1241
+ out_channels=out_channels,
1242
+ )
1243
+ )
1244
+ attentions.append(
1245
+ Transformer2DModel(
1246
+ attn_num_head_channels,
1247
+ out_channels // attn_num_head_channels,
1248
+ in_channels=out_channels,
1249
+ cross_attention_dim=cross_attention_dim,
1250
+ use_linear_projection=use_linear_projection,
1251
+ upcast_attention=upcast_attention,
1252
+ )
1253
+ )
1254
+
1255
+ self.attentions = nn.ModuleList(attentions)
1256
+ self.resnets = nn.ModuleList(resnets)
1257
+
1258
+ if add_upsample:
1259
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1260
+ else:
1261
+ self.upsamplers = None
1262
+
1263
+ self.gradient_checkpointing = False
1264
+
1265
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1266
+ for attn in self.attentions:
1267
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1268
+
1269
+ def set_use_sdpa(self, sdpa):
1270
+ for attn in self.attentions:
1271
+ attn.set_use_sdpa(sdpa)
1272
+
1273
+ def forward(
1274
+ self,
1275
+ hidden_states,
1276
+ res_hidden_states_tuple,
1277
+ temb=None,
1278
+ encoder_hidden_states=None,
1279
+ upsample_size=None,
1280
+ ):
1281
+ for resnet, attn in zip(self.resnets, self.attentions):
1282
+ # pop res hidden states
1283
+ res_hidden_states = res_hidden_states_tuple[-1]
1284
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1285
+
1286
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1287
+
1288
+ if self.training and self.gradient_checkpointing:
1289
+
1290
+ def create_custom_forward(module, return_dict=None):
1291
+ def custom_forward(*inputs):
1292
+ if return_dict is not None:
1293
+ return module(*inputs, return_dict=return_dict)
1294
+ else:
1295
+ return module(*inputs)
1296
+
1297
+ return custom_forward
1298
+
1299
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1300
+ hidden_states = torch.utils.checkpoint.checkpoint(
1301
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1302
+ )[0]
1303
+ else:
1304
+ hidden_states = resnet(hidden_states, temb)
1305
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1306
+
1307
+ if self.upsamplers is not None:
1308
+ for upsampler in self.upsamplers:
1309
+ hidden_states = upsampler(hidden_states, upsample_size)
1310
+
1311
+ return hidden_states
1312
+
1313
+
1314
+ def get_down_block(
1315
+ down_block_type,
1316
+ in_channels,
1317
+ out_channels,
1318
+ add_downsample,
1319
+ attn_num_head_channels,
1320
+ cross_attention_dim,
1321
+ use_linear_projection,
1322
+ upcast_attention,
1323
+ ):
1324
+ if down_block_type == "DownBlock2D":
1325
+ return DownBlock2D(
1326
+ in_channels=in_channels,
1327
+ out_channels=out_channels,
1328
+ add_downsample=add_downsample,
1329
+ )
1330
+ elif down_block_type == "CrossAttnDownBlock2D":
1331
+ return CrossAttnDownBlock2D(
1332
+ in_channels=in_channels,
1333
+ out_channels=out_channels,
1334
+ add_downsample=add_downsample,
1335
+ cross_attention_dim=cross_attention_dim,
1336
+ attn_num_head_channels=attn_num_head_channels,
1337
+ use_linear_projection=use_linear_projection,
1338
+ upcast_attention=upcast_attention,
1339
+ )
1340
+
1341
+
1342
+ def get_up_block(
1343
+ up_block_type,
1344
+ in_channels,
1345
+ out_channels,
1346
+ prev_output_channel,
1347
+ add_upsample,
1348
+ attn_num_head_channels,
1349
+ cross_attention_dim=None,
1350
+ use_linear_projection=False,
1351
+ upcast_attention=False,
1352
+ ):
1353
+ if up_block_type == "UpBlock2D":
1354
+ return UpBlock2D(
1355
+ in_channels=in_channels,
1356
+ prev_output_channel=prev_output_channel,
1357
+ out_channels=out_channels,
1358
+ add_upsample=add_upsample,
1359
+ )
1360
+ elif up_block_type == "CrossAttnUpBlock2D":
1361
+ return CrossAttnUpBlock2D(
1362
+ in_channels=in_channels,
1363
+ out_channels=out_channels,
1364
+ prev_output_channel=prev_output_channel,
1365
+ attn_num_head_channels=attn_num_head_channels,
1366
+ cross_attention_dim=cross_attention_dim,
1367
+ add_upsample=add_upsample,
1368
+ use_linear_projection=use_linear_projection,
1369
+ upcast_attention=upcast_attention,
1370
+ )
1371
+
1372
+
1373
+ class UNet2DConditionModel(nn.Module):
1374
+ _supports_gradient_checkpointing = True
1375
+
1376
+ def __init__(
1377
+ self,
1378
+ sample_size: Optional[int] = None,
1379
+ attention_head_dim: Union[int, Tuple[int]] = 8,
1380
+ cross_attention_dim: int = 1280,
1381
+ use_linear_projection: bool = False,
1382
+ upcast_attention: bool = False,
1383
+ **kwargs,
1384
+ ):
1385
+ super().__init__()
1386
+ assert sample_size is not None, "sample_size must be specified"
1387
+ logger.info(
1388
+ f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
1389
+ )
1390
+
1391
+ # 外部からの参照用に定義しておく
1392
+ self.in_channels = IN_CHANNELS
1393
+ self.out_channels = OUT_CHANNELS
1394
+
1395
+ self.sample_size = sample_size
1396
+ self.prepare_config(sample_size=sample_size)
1397
+
1398
+ # state_dictの書式が変わるのでmoduleの持ち方は変えられない
1399
+
1400
+ # input
1401
+ self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
1402
+
1403
+ # time
1404
+ self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
1405
+
1406
+ self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
1407
+
1408
+ self.down_blocks = nn.ModuleList([])
1409
+ self.mid_block = None
1410
+ self.up_blocks = nn.ModuleList([])
1411
+
1412
+ if isinstance(attention_head_dim, int):
1413
+ attention_head_dim = (attention_head_dim,) * 4
1414
+
1415
+ # down
1416
+ output_channel = BLOCK_OUT_CHANNELS[0]
1417
+ for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
1418
+ input_channel = output_channel
1419
+ output_channel = BLOCK_OUT_CHANNELS[i]
1420
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1421
+
1422
+ down_block = get_down_block(
1423
+ down_block_type,
1424
+ in_channels=input_channel,
1425
+ out_channels=output_channel,
1426
+ add_downsample=not is_final_block,
1427
+ attn_num_head_channels=attention_head_dim[i],
1428
+ cross_attention_dim=cross_attention_dim,
1429
+ use_linear_projection=use_linear_projection,
1430
+ upcast_attention=upcast_attention,
1431
+ )
1432
+ self.down_blocks.append(down_block)
1433
+
1434
+ # mid
1435
+ self.mid_block = UNetMidBlock2DCrossAttn(
1436
+ in_channels=BLOCK_OUT_CHANNELS[-1],
1437
+ attn_num_head_channels=attention_head_dim[-1],
1438
+ cross_attention_dim=cross_attention_dim,
1439
+ use_linear_projection=use_linear_projection,
1440
+ )
1441
+
1442
+ # count how many layers upsample the images
1443
+ self.num_upsamplers = 0
1444
+
1445
+ # up
1446
+ reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
1447
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
1448
+ output_channel = reversed_block_out_channels[0]
1449
+ for i, up_block_type in enumerate(UP_BLOCK_TYPES):
1450
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1451
+
1452
+ prev_output_channel = output_channel
1453
+ output_channel = reversed_block_out_channels[i]
1454
+ input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
1455
+
1456
+ # add upsample block for all BUT final layer
1457
+ if not is_final_block:
1458
+ add_upsample = True
1459
+ self.num_upsamplers += 1
1460
+ else:
1461
+ add_upsample = False
1462
+
1463
+ up_block = get_up_block(
1464
+ up_block_type,
1465
+ in_channels=input_channel,
1466
+ out_channels=output_channel,
1467
+ prev_output_channel=prev_output_channel,
1468
+ add_upsample=add_upsample,
1469
+ attn_num_head_channels=reversed_attention_head_dim[i],
1470
+ cross_attention_dim=cross_attention_dim,
1471
+ use_linear_projection=use_linear_projection,
1472
+ upcast_attention=upcast_attention,
1473
+ )
1474
+ self.up_blocks.append(up_block)
1475
+ prev_output_channel = output_channel
1476
+
1477
+ # out
1478
+ self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
1479
+ self.conv_act = nn.SiLU()
1480
+ self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
1481
+
1482
+ # region diffusers compatibility
1483
+ def prepare_config(self, *args, **kwargs):
1484
+ self.config = SimpleNamespace(**kwargs)
1485
+
1486
+ @property
1487
+ def dtype(self) -> torch.dtype:
1488
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1489
+ return get_parameter_dtype(self)
1490
+
1491
+ @property
1492
+ def device(self) -> torch.device:
1493
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1494
+ return get_parameter_device(self)
1495
+
1496
+ def set_attention_slice(self, slice_size):
1497
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1498
+
1499
+ def is_gradient_checkpointing(self) -> bool:
1500
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1501
+
1502
+ def enable_gradient_checkpointing(self):
1503
+ self.set_gradient_checkpointing(value=True)
1504
+
1505
+ def disable_gradient_checkpointing(self):
1506
+ self.set_gradient_checkpointing(value=False)
1507
+
1508
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1509
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1510
+ for module in modules:
1511
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1512
+
1513
+ def set_use_sdpa(self, sdpa: bool) -> None:
1514
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1515
+ for module in modules:
1516
+ module.set_use_sdpa(sdpa)
1517
+
1518
+ def set_gradient_checkpointing(self, value=False):
1519
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1520
+ for module in modules:
1521
+ logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
1522
+ module.gradient_checkpointing = value
1523
+
1524
+ # endregion
1525
+
1526
+ def forward(
1527
+ self,
1528
+ sample: torch.FloatTensor,
1529
+ timestep: Union[torch.Tensor, float, int],
1530
+ encoder_hidden_states: torch.Tensor,
1531
+ class_labels: Optional[torch.Tensor] = None,
1532
+ return_dict: bool = True,
1533
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1534
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1535
+ ) -> Union[Dict, Tuple]:
1536
+ r"""
1537
+ Args:
1538
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1539
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1540
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1541
+ return_dict (`bool`, *optional*, defaults to `True`):
1542
+ Whether or not to return a dict instead of a plain tuple.
1543
+
1544
+ Returns:
1545
+ `SampleOutput` or `tuple`:
1546
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1547
+ """
1548
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1549
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1550
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1551
+ # on the fly if necessary.
1552
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1553
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1554
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1555
+ default_overall_up_factor = 2**self.num_upsamplers
1556
+
1557
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1558
+ # 64で割り切れないときはupsamplerにサイズを伝える
1559
+ forward_upsample_size = False
1560
+ upsample_size = None
1561
+
1562
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1563
+ # logger.info("Forward upsample size to force interpolation output size.")
1564
+ forward_upsample_size = True
1565
+
1566
+ # 1. time
1567
+ timesteps = timestep
1568
+ timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1569
+
1570
+ t_emb = self.time_proj(timesteps)
1571
+
1572
+ # timesteps does not contain any weights and will always return f32 tensors
1573
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1574
+ # there might be better ways to encapsulate this.
1575
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1576
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1577
+ # time_projでキャストしておけばいいんじゃね?
1578
+ t_emb = t_emb.to(dtype=self.dtype)
1579
+ emb = self.time_embedding(t_emb)
1580
+
1581
+ # 2. pre-process
1582
+ sample = self.conv_in(sample)
1583
+
1584
+ down_block_res_samples = (sample,)
1585
+ for downsample_block in self.down_blocks:
1586
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1587
+ # まあこちらのほうがわかりやすいかもしれない
1588
+ if downsample_block.has_cross_attention:
1589
+ sample, res_samples = downsample_block(
1590
+ hidden_states=sample,
1591
+ temb=emb,
1592
+ encoder_hidden_states=encoder_hidden_states,
1593
+ )
1594
+ else:
1595
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1596
+
1597
+ down_block_res_samples += res_samples
1598
+
1599
+ # skip connectionにControlNetの出力を追加する
1600
+ if down_block_additional_residuals is not None:
1601
+ down_block_res_samples = list(down_block_res_samples)
1602
+ for i in range(len(down_block_res_samples)):
1603
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1604
+ down_block_res_samples = tuple(down_block_res_samples)
1605
+
1606
+ # 4. mid
1607
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1608
+
1609
+ # ControlNetの出力を追加する
1610
+ if mid_block_additional_residual is not None:
1611
+ sample += mid_block_additional_residual
1612
+
1613
+ # 5. up
1614
+ for i, upsample_block in enumerate(self.up_blocks):
1615
+ is_final_block = i == len(self.up_blocks) - 1
1616
+
1617
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1618
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1619
+
1620
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1621
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1622
+ if not is_final_block and forward_upsample_size:
1623
+ upsample_size = down_block_res_samples[-1].shape[2:]
1624
+
1625
+ if upsample_block.has_cross_attention:
1626
+ sample = upsample_block(
1627
+ hidden_states=sample,
1628
+ temb=emb,
1629
+ res_hidden_states_tuple=res_samples,
1630
+ encoder_hidden_states=encoder_hidden_states,
1631
+ upsample_size=upsample_size,
1632
+ )
1633
+ else:
1634
+ sample = upsample_block(
1635
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1636
+ )
1637
+
1638
+ # 6. post-process
1639
+ sample = self.conv_norm_out(sample)
1640
+ sample = self.conv_act(sample)
1641
+ sample = self.conv_out(sample)
1642
+
1643
+ if not return_dict:
1644
+ return (sample,)
1645
+
1646
+ return SampleOutput(sample=sample)
1647
+
1648
+ def handle_unusual_timesteps(self, sample, timesteps):
1649
+ r"""
1650
+ timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
1651
+ """
1652
+ if not torch.is_tensor(timesteps):
1653
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1654
+ # This would be a good case for the `match` statement (Python 3.10+)
1655
+ is_mps = sample.device.type == "mps"
1656
+ if isinstance(timesteps, float):
1657
+ dtype = torch.float32 if is_mps else torch.float64
1658
+ else:
1659
+ dtype = torch.int32 if is_mps else torch.int64
1660
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1661
+ elif len(timesteps.shape) == 0:
1662
+ timesteps = timesteps[None].to(sample.device)
1663
+
1664
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1665
+ timesteps = timesteps.expand(sample.shape[0])
1666
+
1667
+ return timesteps
1668
+
1669
+
1670
+ class InferUNet2DConditionModel:
1671
+ def __init__(self, original_unet: UNet2DConditionModel):
1672
+ self.delegate = original_unet
1673
+
1674
+ # override original model's forward method: because forward is not called by `__call__`
1675
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1676
+ self.delegate.forward = self.forward
1677
+
1678
+ # override original model's up blocks' forward method
1679
+ for up_block in self.delegate.up_blocks:
1680
+ if up_block.__class__.__name__ == "UpBlock2D":
1681
+
1682
+ def resnet_wrapper(func, block):
1683
+ def forward(*args, **kwargs):
1684
+ return func(block, *args, **kwargs)
1685
+
1686
+ return forward
1687
+
1688
+ up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
1689
+
1690
+ elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
1691
+
1692
+ def cross_attn_up_wrapper(func, block):
1693
+ def forward(*args, **kwargs):
1694
+ return func(block, *args, **kwargs)
1695
+
1696
+ return forward
1697
+
1698
+ up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
1699
+
1700
+ # Deep Shrink
1701
+ self.ds_depth_1 = None
1702
+ self.ds_depth_2 = None
1703
+ self.ds_timesteps_1 = None
1704
+ self.ds_timesteps_2 = None
1705
+ self.ds_ratio = None
1706
+
1707
+ # call original model's methods
1708
+ def __getattr__(self, name):
1709
+ return getattr(self.delegate, name)
1710
+
1711
+ def __call__(self, *args, **kwargs):
1712
+ return self.delegate(*args, **kwargs)
1713
+
1714
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1715
+ if ds_depth_1 is None:
1716
+ logger.info("Deep Shrink is disabled.")
1717
+ self.ds_depth_1 = None
1718
+ self.ds_timesteps_1 = None
1719
+ self.ds_depth_2 = None
1720
+ self.ds_timesteps_2 = None
1721
+ self.ds_ratio = None
1722
+ else:
1723
+ logger.info(
1724
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1725
+ )
1726
+ self.ds_depth_1 = ds_depth_1
1727
+ self.ds_timesteps_1 = ds_timesteps_1
1728
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1729
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1730
+ self.ds_ratio = ds_ratio
1731
+
1732
+ def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1733
+ for resnet in _self.resnets:
1734
+ # pop res hidden states
1735
+ res_hidden_states = res_hidden_states_tuple[-1]
1736
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1737
+
1738
+ # Deep Shrink
1739
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1740
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1741
+
1742
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1743
+ hidden_states = resnet(hidden_states, temb)
1744
+
1745
+ if _self.upsamplers is not None:
1746
+ for upsampler in _self.upsamplers:
1747
+ hidden_states = upsampler(hidden_states, upsample_size)
1748
+
1749
+ return hidden_states
1750
+
1751
+ def cross_attn_up_block_forward(
1752
+ self,
1753
+ _self,
1754
+ hidden_states,
1755
+ res_hidden_states_tuple,
1756
+ temb=None,
1757
+ encoder_hidden_states=None,
1758
+ upsample_size=None,
1759
+ ):
1760
+ for resnet, attn in zip(_self.resnets, _self.attentions):
1761
+ # pop res hidden states
1762
+ res_hidden_states = res_hidden_states_tuple[-1]
1763
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1764
+
1765
+ # Deep Shrink
1766
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1767
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1768
+
1769
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1770
+ hidden_states = resnet(hidden_states, temb)
1771
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1772
+
1773
+ if _self.upsamplers is not None:
1774
+ for upsampler in _self.upsamplers:
1775
+ hidden_states = upsampler(hidden_states, upsample_size)
1776
+
1777
+ return hidden_states
1778
+
1779
+ def forward(
1780
+ self,
1781
+ sample: torch.FloatTensor,
1782
+ timestep: Union[torch.Tensor, float, int],
1783
+ encoder_hidden_states: torch.Tensor,
1784
+ class_labels: Optional[torch.Tensor] = None,
1785
+ return_dict: bool = True,
1786
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1787
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1788
+ ) -> Union[Dict, Tuple]:
1789
+ r"""
1790
+ current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
1791
+ """
1792
+
1793
+ r"""
1794
+ Args:
1795
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1796
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1797
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1798
+ return_dict (`bool`, *optional*, defaults to `True`):
1799
+ Whether or not to return a dict instead of a plain tuple.
1800
+
1801
+ Returns:
1802
+ `SampleOutput` or `tuple`:
1803
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1804
+ """
1805
+
1806
+ _self = self.delegate
1807
+
1808
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1809
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1810
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1811
+ # on the fly if necessary.
1812
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1813
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1814
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1815
+ default_overall_up_factor = 2**_self.num_upsamplers
1816
+
1817
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1818
+ # 64で割り切れないときはupsamplerにサイズを伝える
1819
+ forward_upsample_size = False
1820
+ upsample_size = None
1821
+
1822
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1823
+ # logger.info("Forward upsample size to force interpolation output size.")
1824
+ forward_upsample_size = True
1825
+
1826
+ # 1. time
1827
+ timesteps = timestep
1828
+ timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1829
+
1830
+ t_emb = _self.time_proj(timesteps)
1831
+
1832
+ # timesteps does not contain any weights and will always return f32 tensors
1833
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1834
+ # there might be better ways to encapsulate this.
1835
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1836
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1837
+ # time_projでキャストしておけばいいんじゃね?
1838
+ t_emb = t_emb.to(dtype=_self.dtype)
1839
+ emb = _self.time_embedding(t_emb)
1840
+
1841
+ # 2. pre-process
1842
+ sample = _self.conv_in(sample)
1843
+
1844
+ down_block_res_samples = (sample,)
1845
+ for depth, downsample_block in enumerate(_self.down_blocks):
1846
+ # Deep Shrink
1847
+ if self.ds_depth_1 is not None:
1848
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1849
+ self.ds_depth_2 is not None
1850
+ and depth == self.ds_depth_2
1851
+ and timesteps[0] < self.ds_timesteps_1
1852
+ and timesteps[0] >= self.ds_timesteps_2
1853
+ ):
1854
+ org_dtype = sample.dtype
1855
+ if org_dtype == torch.bfloat16:
1856
+ sample = sample.to(torch.float32)
1857
+ sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1858
+
1859
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1860
+ # まあこちらのほうがわかりやすいかもしれない
1861
+ if downsample_block.has_cross_attention:
1862
+ sample, res_samples = downsample_block(
1863
+ hidden_states=sample,
1864
+ temb=emb,
1865
+ encoder_hidden_states=encoder_hidden_states,
1866
+ )
1867
+ else:
1868
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1869
+
1870
+ down_block_res_samples += res_samples
1871
+
1872
+ # skip connectionにControlNetの出力を追加する
1873
+ if down_block_additional_residuals is not None:
1874
+ down_block_res_samples = list(down_block_res_samples)
1875
+ for i in range(len(down_block_res_samples)):
1876
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1877
+ down_block_res_samples = tuple(down_block_res_samples)
1878
+
1879
+ # 4. mid
1880
+ sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1881
+
1882
+ # ControlNetの出力を追加する
1883
+ if mid_block_additional_residual is not None:
1884
+ sample += mid_block_additional_residual
1885
+
1886
+ # 5. up
1887
+ for i, upsample_block in enumerate(_self.up_blocks):
1888
+ is_final_block = i == len(_self.up_blocks) - 1
1889
+
1890
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1891
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1892
+
1893
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1894
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1895
+ if not is_final_block and forward_upsample_size:
1896
+ upsample_size = down_block_res_samples[-1].shape[2:]
1897
+
1898
+ if upsample_block.has_cross_attention:
1899
+ sample = upsample_block(
1900
+ hidden_states=sample,
1901
+ temb=emb,
1902
+ res_hidden_states_tuple=res_samples,
1903
+ encoder_hidden_states=encoder_hidden_states,
1904
+ upsample_size=upsample_size,
1905
+ )
1906
+ else:
1907
+ sample = upsample_block(
1908
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1909
+ )
1910
+
1911
+ # 6. post-process
1912
+ sample = _self.conv_norm_out(sample)
1913
+ sample = _self.conv_act(sample)
1914
+ sample = _self.conv_out(sample)
1915
+
1916
+ if not return_dict:
1917
+ return (sample,)
1918
+
1919
+ return SampleOutput(sample=sample)
library/sai_model_spec.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/Stability-AI/ModelSpec
2
+ import datetime
3
+ import hashlib
4
+ from io import BytesIO
5
+ import os
6
+ from typing import List, Optional, Tuple, Union
7
+ import safetensors
8
+ from library.utils import setup_logging
9
+
10
+ setup_logging()
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ r"""
16
+ # Metadata Example
17
+ metadata = {
18
+ # === Must ===
19
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
20
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
21
+ "modelspec.implementation": "sgm",
22
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
23
+ # === Should ===
24
+ "modelspec.author": "Example Corp", # Your name or company name
25
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
26
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
27
+ # === Can ===
28
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
29
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
30
+ }
31
+ """
32
+
33
+ BASE_METADATA = {
34
+ # === Must ===
35
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
36
+ "modelspec.architecture": None,
37
+ "modelspec.implementation": None,
38
+ "modelspec.title": None,
39
+ "modelspec.resolution": None,
40
+ # === Should ===
41
+ "modelspec.description": None,
42
+ "modelspec.author": None,
43
+ "modelspec.date": None,
44
+ # === Can ===
45
+ "modelspec.license": None,
46
+ "modelspec.tags": None,
47
+ "modelspec.merged_from": None,
48
+ "modelspec.prediction_type": None,
49
+ "modelspec.timestep_range": None,
50
+ "modelspec.encoder_layer": None,
51
+ }
52
+
53
+ # 別に使うやつだけ定義
54
+ MODELSPEC_TITLE = "modelspec.title"
55
+
56
+ ARCH_SD_V1 = "stable-diffusion-v1"
57
+ ARCH_SD_V2_512 = "stable-diffusion-v2-512"
58
+ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
59
+ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
60
+ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
61
+ # ARCH_SD3_UNKNOWN = "stable-diffusion-3"
62
+ ARCH_FLUX_1_DEV = "flux-1-dev"
63
+ ARCH_FLUX_1_UNKNOWN = "flux-1"
64
+
65
+ ADAPTER_LORA = "lora"
66
+ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
67
+
68
+ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
69
+ IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
70
+ IMPL_DIFFUSERS = "diffusers"
71
+ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
72
+
73
+ PRED_TYPE_EPSILON = "epsilon"
74
+ PRED_TYPE_V = "v"
75
+
76
+
77
+ def load_bytes_in_safetensors(tensors):
78
+ bytes = safetensors.torch.save(tensors)
79
+ b = BytesIO(bytes)
80
+
81
+ b.seek(0)
82
+ header = b.read(8)
83
+ n = int.from_bytes(header, "little")
84
+
85
+ offset = n + 8
86
+ b.seek(offset)
87
+
88
+ return b.read()
89
+
90
+
91
+ def precalculate_safetensors_hashes(state_dict):
92
+ # calculate each tensor one by one to reduce memory usage
93
+ hash_sha256 = hashlib.sha256()
94
+ for tensor in state_dict.values():
95
+ single_tensor_sd = {"tensor": tensor}
96
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
97
+ hash_sha256.update(bytes_for_tensor)
98
+
99
+ return f"0x{hash_sha256.hexdigest()}"
100
+
101
+
102
+ def update_hash_sha256(metadata: dict, state_dict: dict):
103
+ raise NotImplementedError
104
+
105
+
106
+ def build_metadata(
107
+ state_dict: Optional[dict],
108
+ v2: bool,
109
+ v_parameterization: bool,
110
+ sdxl: bool,
111
+ lora: bool,
112
+ textual_inversion: bool,
113
+ timestamp: float,
114
+ title: Optional[str] = None,
115
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
116
+ is_stable_diffusion_ckpt: Optional[bool] = None,
117
+ author: Optional[str] = None,
118
+ description: Optional[str] = None,
119
+ license: Optional[str] = None,
120
+ tags: Optional[str] = None,
121
+ merged_from: Optional[str] = None,
122
+ timesteps: Optional[Tuple[int, int]] = None,
123
+ clip_skip: Optional[int] = None,
124
+ sd3: Optional[str] = None,
125
+ flux: Optional[str] = None,
126
+ ):
127
+ """
128
+ sd3: only supports "m", flux: only supports "dev"
129
+ """
130
+ # if state_dict is None, hash is not calculated
131
+
132
+ metadata = {}
133
+ metadata.update(BASE_METADATA)
134
+
135
+ # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
136
+ # if state_dict is not None:
137
+ # hash = precalculate_safetensors_hashes(state_dict)
138
+ # metadata["modelspec.hash_sha256"] = hash
139
+
140
+ if sdxl:
141
+ arch = ARCH_SD_XL_V1_BASE
142
+ elif sd3 is not None:
143
+ arch = ARCH_SD3_M + "-" + sd3
144
+ elif flux is not None:
145
+ if flux == "dev":
146
+ arch = ARCH_FLUX_1_DEV
147
+ else:
148
+ arch = ARCH_FLUX_1_UNKNOWN
149
+ elif v2:
150
+ if v_parameterization:
151
+ arch = ARCH_SD_V2_768_V
152
+ else:
153
+ arch = ARCH_SD_V2_512
154
+ else:
155
+ arch = ARCH_SD_V1
156
+
157
+ if lora:
158
+ arch += f"/{ADAPTER_LORA}"
159
+ elif textual_inversion:
160
+ arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
161
+
162
+ metadata["modelspec.architecture"] = arch
163
+
164
+ if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
165
+ is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
166
+
167
+ if flux is not None:
168
+ # Flux
169
+ impl = IMPL_FLUX
170
+ elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
171
+ # Stable Diffusion ckpt, TI, SDXL LoRA
172
+ impl = IMPL_STABILITY_AI
173
+ else:
174
+ # v1/v2 LoRA or Diffusers
175
+ impl = IMPL_DIFFUSERS
176
+ metadata["modelspec.implementation"] = impl
177
+
178
+ if title is None:
179
+ if lora:
180
+ title = "LoRA"
181
+ elif textual_inversion:
182
+ title = "TextualInversion"
183
+ else:
184
+ title = "Checkpoint"
185
+ title += f"@{timestamp}"
186
+ metadata[MODELSPEC_TITLE] = title
187
+
188
+ if author is not None:
189
+ metadata["modelspec.author"] = author
190
+ else:
191
+ del metadata["modelspec.author"]
192
+
193
+ if description is not None:
194
+ metadata["modelspec.description"] = description
195
+ else:
196
+ del metadata["modelspec.description"]
197
+
198
+ if merged_from is not None:
199
+ metadata["modelspec.merged_from"] = merged_from
200
+ else:
201
+ del metadata["modelspec.merged_from"]
202
+
203
+ if license is not None:
204
+ metadata["modelspec.license"] = license
205
+ else:
206
+ del metadata["modelspec.license"]
207
+
208
+ if tags is not None:
209
+ metadata["modelspec.tags"] = tags
210
+ else:
211
+ del metadata["modelspec.tags"]
212
+
213
+ # remove microsecond from time
214
+ int_ts = int(timestamp)
215
+
216
+ # time to iso-8601 compliant date
217
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
218
+ metadata["modelspec.date"] = date
219
+
220
+ if reso is not None:
221
+ # comma separated to tuple
222
+ if isinstance(reso, str):
223
+ reso = tuple(map(int, reso.split(",")))
224
+ if len(reso) == 1:
225
+ reso = (reso[0], reso[0])
226
+ else:
227
+ # resolution is defined in dataset, so use default
228
+ if sdxl or sd3 is not None or flux is not None:
229
+ reso = 1024
230
+ elif v2 and v_parameterization:
231
+ reso = 768
232
+ else:
233
+ reso = 512
234
+ if isinstance(reso, int):
235
+ reso = (reso, reso)
236
+
237
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
238
+
239
+ if flux is not None:
240
+ del metadata["modelspec.prediction_type"]
241
+ elif v_parameterization:
242
+ metadata["modelspec.prediction_type"] = PRED_TYPE_V
243
+ else:
244
+ metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
245
+
246
+ if timesteps is not None:
247
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
248
+ timesteps = (timesteps, timesteps)
249
+ if len(timesteps) == 1:
250
+ timesteps = (timesteps[0], timesteps[0])
251
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
252
+ else:
253
+ del metadata["modelspec.timestep_range"]
254
+
255
+ if clip_skip is not None:
256
+ metadata["modelspec.encoder_layer"] = f"{clip_skip}"
257
+ else:
258
+ del metadata["modelspec.encoder_layer"]
259
+
260
+ # # assert all values are filled
261
+ # assert all([v is not None for v in metadata.values()]), metadata
262
+ if not all([v is not None for v in metadata.values()]):
263
+ logger.error(f"Internal error: some metadata values are None: {metadata}")
264
+
265
+ return metadata
266
+
267
+
268
+ # region utils
269
+
270
+
271
+ def get_title(metadata: dict) -> Optional[str]:
272
+ return metadata.get(MODELSPEC_TITLE, None)
273
+
274
+
275
+ def load_metadata_from_safetensors(model: str) -> dict:
276
+ if not model.endswith(".safetensors"):
277
+ return {}
278
+
279
+ with safetensors.safe_open(model, framework="pt") as f:
280
+ metadata = f.metadata()
281
+ if metadata is None:
282
+ metadata = {}
283
+ return metadata
284
+
285
+
286
+ def build_merged_from(models: List[str]) -> str:
287
+ def get_title(model: str):
288
+ metadata = load_metadata_from_safetensors(model)
289
+ title = metadata.get(MODELSPEC_TITLE, None)
290
+ if title is None:
291
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
292
+ return title
293
+
294
+ titles = [get_title(model) for model in models]
295
+ return ", ".join(titles)
296
+
297
+
298
+ # endregion
299
+
300
+
301
+ r"""
302
+ if __name__ == "__main__":
303
+ import argparse
304
+ import torch
305
+ from safetensors.torch import load_file
306
+ from library import train_util
307
+
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument("--ckpt", type=str, required=True)
310
+ args = parser.parse_args()
311
+
312
+ print(f"Loading {args.ckpt}")
313
+ state_dict = load_file(args.ckpt)
314
+
315
+ print(f"Calculating metadata")
316
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
317
+ print(metadata)
318
+ del state_dict
319
+
320
+ # by reference implementation
321
+ with open(args.ckpt, mode="rb") as file_data:
322
+ file_hash = hashlib.sha256()
323
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
324
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
325
+ content = (
326
+ file_data.read()
327
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
328
+ file_hash.update(content)
329
+ # ===== Update the hash for modelspec =====
330
+ by_ref = f"0x{file_hash.hexdigest()}"
331
+ print(by_ref)
332
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
333
+
334
+ """
library/sd3_models.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
2
+ # the original code is licensed under the MIT License
3
+
4
+ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
5
+
6
+ from ast import Tuple
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+ import math
11
+ from types import SimpleNamespace
12
+ from typing import Dict, List, Optional, Union
13
+ import einops
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.checkpoint import checkpoint
19
+ from transformers import CLIPTokenizer, T5TokenizerFast
20
+
21
+ from library import custom_offloading_utils
22
+ from library.device_utils import clean_memory_on_device
23
+
24
+ from .utils import setup_logging
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ memory_efficient_attention = None
33
+ try:
34
+ import xformers
35
+ except:
36
+ pass
37
+
38
+ try:
39
+ from xformers.ops import memory_efficient_attention
40
+ except:
41
+ memory_efficient_attention = None
42
+
43
+
44
+ # region mmdit
45
+
46
+
47
+ @dataclass
48
+ class SD3Params:
49
+ patch_size: int
50
+ depth: int
51
+ num_patches: int
52
+ pos_embed_max_size: int
53
+ adm_in_channels: int
54
+ qk_norm: Optional[str]
55
+ x_block_self_attn_layers: list[int]
56
+ context_embedder_in_features: int
57
+ context_embedder_out_features: int
58
+ model_type: str
59
+
60
+
61
+ def get_2d_sincos_pos_embed(
62
+ embed_dim,
63
+ grid_size,
64
+ scaling_factor=None,
65
+ offset=None,
66
+ ):
67
+ grid_h = np.arange(grid_size, dtype=np.float32)
68
+ grid_w = np.arange(grid_size, dtype=np.float32)
69
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
70
+ grid = np.stack(grid, axis=0)
71
+ if scaling_factor is not None:
72
+ grid = grid / scaling_factor
73
+ if offset is not None:
74
+ grid = grid - offset
75
+
76
+ grid = grid.reshape([2, 1, grid_size, grid_size])
77
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78
+ return pos_embed
79
+
80
+
81
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
82
+ assert embed_dim % 2 == 0
83
+
84
+ # use half of dimensions to encode grid_h
85
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
86
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
87
+
88
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
89
+ return emb
90
+
91
+
92
+ def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
93
+ """
94
+ This function is contributed by KohakuBlueleaf. Thanks for the contribution!
95
+
96
+ Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
97
+ when the resolution differs from the training resolution.
98
+
99
+ Args:
100
+ embed_dim (int): Dimension of the positional embedding.
101
+ grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
102
+ cls_token (bool): Whether to include class token. Defaults to False.
103
+ extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
104
+ sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
105
+ base_size (int): Base grid size used during training. Defaults to 16.
106
+
107
+ Returns:
108
+ numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
109
+ (H*W + extra_tokens, embed_dim) if cls_token is True.
110
+ """
111
+ # Convert grid_size to tuple if it's an integer
112
+ if isinstance(grid_size, int):
113
+ grid_size = (grid_size, grid_size)
114
+
115
+ # Create normalized grid coordinates (0 to 1)
116
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
117
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
118
+
119
+ # Calculate scaling factors for height and width
120
+ # This ensures that the central region matches the original resolution's embeddings
121
+ scale_h = base_size * grid_size[0] / (sample_size)
122
+ scale_w = base_size * grid_size[1] / (sample_size)
123
+
124
+ # Calculate shift values to center the original resolution's embedding region
125
+ # This ensures that the central sample_size x sample_size region has similar
126
+ # positional embeddings to the original resolution
127
+ shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
128
+ shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
129
+
130
+ # Apply scaling and shifting to create the final grid coordinates
131
+ grid_h = grid_h * scale_h - shift_h
132
+ grid_w = grid_w * scale_w - shift_w
133
+
134
+ # Create 2D grid using meshgrid (note: w goes first)
135
+ grid = np.meshgrid(grid_w, grid_h)
136
+ grid = np.stack(grid, axis=0)
137
+
138
+ # # Calculate the starting indices for the central region
139
+ # # This is used for debugging/visualization of the central region
140
+ # st_h = (grid_size[0] - sample_size) // 2
141
+ # st_w = (grid_size[1] - sample_size) // 2
142
+ # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
143
+
144
+ # Reshape grid for positional embedding calculation
145
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
146
+
147
+ # Generate the sinusoidal positional embeddings
148
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
149
+
150
+ # Add zeros for extra tokens (e.g., [CLS] token) if required
151
+ if cls_token and extra_tokens > 0:
152
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
153
+
154
+ return pos_embed
155
+
156
+
157
+ # if __name__ == "__main__":
158
+ # # This is what you get when you load SD3.5 state dict
159
+ # pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
160
+ # 1536, [384, 384], sample_size=64, base_size=16
161
+ # )).float().unsqueeze(0)
162
+
163
+
164
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
165
+ """
166
+ embed_dim: output dimension for each position
167
+ pos: a list of positions to be encoded: size (M,)
168
+ out: (M, D)
169
+ """
170
+ assert embed_dim % 2 == 0
171
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
172
+ omega /= embed_dim / 2.0
173
+ omega = 1.0 / 10000**omega # (D/2,)
174
+
175
+ pos = pos.reshape(-1) # (M,)
176
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
177
+
178
+ emb_sin = np.sin(out) # (M, D/2)
179
+ emb_cos = np.cos(out) # (M, D/2)
180
+
181
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
182
+ return emb
183
+
184
+
185
+ def get_1d_sincos_pos_embed_from_grid_torch(
186
+ embed_dim,
187
+ pos,
188
+ device=None,
189
+ dtype=torch.float32,
190
+ ):
191
+ omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
192
+ omega *= 2.0 / embed_dim
193
+ omega = 1.0 / 10000**omega
194
+ out = torch.outer(pos.reshape(-1), omega)
195
+ emb = torch.cat([out.sin(), out.cos()], dim=1)
196
+ return emb
197
+
198
+
199
+ def get_2d_sincos_pos_embed_torch(
200
+ embed_dim,
201
+ w,
202
+ h,
203
+ val_center=7.5,
204
+ val_magnitude=7.5,
205
+ device=None,
206
+ dtype=torch.float32,
207
+ ):
208
+ small = min(h, w)
209
+ val_h = (h / small) * val_magnitude
210
+ val_w = (w / small) * val_magnitude
211
+ grid_h, grid_w = torch.meshgrid(
212
+ torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype),
213
+ torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype),
214
+ indexing="ij",
215
+ )
216
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
217
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
218
+ emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
219
+ return emb
220
+
221
+
222
+ def modulate(x, shift, scale):
223
+ if shift is None:
224
+ shift = torch.zeros_like(scale)
225
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
226
+
227
+
228
+ def default(x, default_value):
229
+ if x is None:
230
+ return default_value
231
+ return x
232
+
233
+
234
+ def timestep_embedding(t, dim, max_period=10000):
235
+ half = dim // 2
236
+ # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
237
+ # device=t.device, dtype=t.dtype
238
+ # )
239
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
240
+ args = t[:, None].float() * freqs[None]
241
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
242
+ if dim % 2:
243
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
244
+ if torch.is_floating_point(t):
245
+ embedding = embedding.to(dtype=t.dtype)
246
+ return embedding
247
+
248
+
249
+ class PatchEmbed(nn.Module):
250
+ def __init__(
251
+ self,
252
+ img_size=256,
253
+ patch_size=4,
254
+ in_channels=3,
255
+ embed_dim=512,
256
+ norm_layer=None,
257
+ flatten=True,
258
+ bias=True,
259
+ strict_img_size=True,
260
+ dynamic_img_pad=False,
261
+ ):
262
+ # dynamic_img_pad and norm is omitted in SD3.5
263
+ super().__init__()
264
+ self.patch_size = patch_size
265
+ self.flatten = flatten
266
+ self.strict_img_size = strict_img_size
267
+ self.dynamic_img_pad = dynamic_img_pad
268
+ if img_size is not None:
269
+ self.img_size = img_size
270
+ self.grid_size = img_size // patch_size
271
+ self.num_patches = self.grid_size**2
272
+ else:
273
+ self.img_size = None
274
+ self.grid_size = None
275
+ self.num_patches = None
276
+
277
+ self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
278
+ self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim)
279
+
280
+ def forward(self, x):
281
+ B, C, H, W = x.shape
282
+
283
+ if self.dynamic_img_pad:
284
+ # Pad input so we won't have partial patch
285
+ pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
286
+ pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
287
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
288
+ x = self.proj(x)
289
+ if self.flatten:
290
+ x = x.flatten(2).transpose(1, 2)
291
+ x = self.norm(x)
292
+ return x
293
+
294
+
295
+ # FinalLayer in mmdit.py
296
+ class UnPatch(nn.Module):
297
+ def __init__(self, hidden_size=512, patch_size=4, out_channels=3):
298
+ super().__init__()
299
+ self.patch_size = patch_size
300
+ self.c = out_channels
301
+
302
+ # eps is default in mmdit.py
303
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
304
+ self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels)
305
+ self.adaLN_modulation = nn.Sequential(
306
+ nn.SiLU(),
307
+ nn.Linear(hidden_size, 2 * hidden_size),
308
+ )
309
+
310
+ def forward(self, x: torch.Tensor, cmod, H=None, W=None):
311
+ b, n, _ = x.shape
312
+ p = self.patch_size
313
+ c = self.c
314
+ if H is None and W is None:
315
+ w = h = int(n**0.5)
316
+ assert h * w == n
317
+ else:
318
+ h = H // p if H else n // (W // p)
319
+ w = W // p if W else n // h
320
+ assert h * w == n
321
+
322
+ shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1)
323
+ x = modulate(self.norm_final(x), shift, scale)
324
+ x = self.linear(x)
325
+
326
+ x = x.view(b, h, w, p, p, c)
327
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
328
+ x = x.view(b, c, h * p, w * p)
329
+ return x
330
+
331
+
332
+ class MLP(nn.Module):
333
+ def __init__(
334
+ self,
335
+ in_features,
336
+ hidden_features=None,
337
+ out_features=None,
338
+ act_layer=lambda: nn.GELU(),
339
+ norm_layer=None,
340
+ bias=True,
341
+ use_conv=False,
342
+ ):
343
+ super().__init__()
344
+ out_features = out_features or in_features
345
+ hidden_features = hidden_features or in_features
346
+ self.use_conv = use_conv
347
+
348
+ layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear
349
+
350
+ self.fc1 = layer(in_features, hidden_features, bias=bias)
351
+ self.fc2 = layer(hidden_features, out_features, bias=bias)
352
+ self.act = act_layer()
353
+ self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
354
+
355
+ def forward(self, x):
356
+ x = self.fc1(x)
357
+ x = self.act(x)
358
+ x = self.norm(x)
359
+ x = self.fc2(x)
360
+ return x
361
+
362
+
363
+ class TimestepEmbedding(nn.Module):
364
+ def __init__(self, hidden_size, freq_embed_size=256):
365
+ super().__init__()
366
+ self.mlp = nn.Sequential(
367
+ nn.Linear(freq_embed_size, hidden_size),
368
+ nn.SiLU(),
369
+ nn.Linear(hidden_size, hidden_size),
370
+ )
371
+ self.freq_embed_size = freq_embed_size
372
+
373
+ def forward(self, t, dtype=None, **kwargs):
374
+ t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype)
375
+ t_emb = self.mlp(t_freq)
376
+ return t_emb
377
+
378
+
379
+ class Embedder(nn.Module):
380
+ def __init__(self, input_dim, hidden_size):
381
+ super().__init__()
382
+ self.mlp = nn.Sequential(
383
+ nn.Linear(input_dim, hidden_size),
384
+ nn.SiLU(),
385
+ nn.Linear(hidden_size, hidden_size),
386
+ )
387
+
388
+ def forward(self, x):
389
+ return self.mlp(x)
390
+
391
+
392
+ def rmsnorm(x, eps=1e-6):
393
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
394
+
395
+
396
+ class RMSNorm(torch.nn.Module):
397
+ def __init__(
398
+ self,
399
+ dim: int,
400
+ elementwise_affine: bool = False,
401
+ eps: float = 1e-6,
402
+ device=None,
403
+ dtype=None,
404
+ ):
405
+ """
406
+ Initialize the RMSNorm normalization layer.
407
+ Args:
408
+ dim (int): The dimension of the input tensor.
409
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
410
+ Attributes:
411
+ eps (float): A small value added to the denominator for numerical stability.
412
+ weight (nn.Parameter): Learnable scaling parameter.
413
+ """
414
+ super().__init__()
415
+ self.eps = eps
416
+ self.learnable_scale = elementwise_affine
417
+ if self.learnable_scale:
418
+ self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
419
+ else:
420
+ self.register_parameter("weight", None)
421
+
422
+ def forward(self, x):
423
+ """
424
+ Forward pass through the RMSNorm layer.
425
+ Args:
426
+ x (torch.Tensor): The input tensor.
427
+ Returns:
428
+ torch.Tensor: The output tensor after applying RMSNorm.
429
+ """
430
+ x = rmsnorm(x, eps=self.eps)
431
+ if self.learnable_scale:
432
+ return x * self.weight.to(device=x.device, dtype=x.dtype)
433
+ else:
434
+ return x
435
+
436
+
437
+ class SwiGLUFeedForward(nn.Module):
438
+ def __init__(
439
+ self,
440
+ dim: int,
441
+ hidden_dim: int,
442
+ multiple_of: int,
443
+ ffn_dim_multiplier: float = None,
444
+ ):
445
+ super().__init__()
446
+ hidden_dim = int(2 * hidden_dim / 3)
447
+ # custom dim factor multiplier
448
+ if ffn_dim_multiplier is not None:
449
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
450
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
451
+
452
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
453
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
454
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
455
+
456
+ def forward(self, x):
457
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
458
+
459
+
460
+ # Linears for SelfAttention in mmdit.py
461
+ class AttentionLinears(nn.Module):
462
+ def __init__(
463
+ self,
464
+ dim: int,
465
+ num_heads: int = 8,
466
+ qkv_bias: bool = False,
467
+ pre_only: bool = False,
468
+ qk_norm: Optional[str] = None,
469
+ ):
470
+ super().__init__()
471
+ self.num_heads = num_heads
472
+ self.head_dim = dim // num_heads
473
+
474
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
475
+ if not pre_only:
476
+ self.proj = nn.Linear(dim, dim)
477
+ self.pre_only = pre_only
478
+
479
+ if qk_norm == "rms":
480
+ self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
481
+ self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
482
+ elif qk_norm == "ln":
483
+ self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
484
+ self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
485
+ elif qk_norm is None:
486
+ self.ln_q = nn.Identity()
487
+ self.ln_k = nn.Identity()
488
+ else:
489
+ raise ValueError(qk_norm)
490
+
491
+ def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
492
+ """
493
+ output:
494
+ q, k, v: [B, L, D]
495
+ """
496
+ B, L, C = x.shape
497
+ qkv: torch.Tensor = self.qkv(x)
498
+ q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2)
499
+ q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
500
+ k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
501
+ return (q, k, v)
502
+
503
+ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
504
+ assert not self.pre_only
505
+ x = self.proj(x)
506
+ return x
507
+
508
+
509
+ MEMORY_LAYOUTS = {
510
+ "torch": (
511
+ lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
512
+ lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
513
+ lambda x: (1, x, 1, 1),
514
+ ),
515
+ "xformers": (
516
+ lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim),
517
+ lambda x: x.reshape(x.shape[0], x.shape[1], -1),
518
+ lambda x: (1, 1, x, 1),
519
+ ),
520
+ "math": (
521
+ lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
522
+ lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
523
+ lambda x: (1, x, 1, 1),
524
+ ),
525
+ }
526
+ # ATTN_FUNCTION = {
527
+ # "torch": F.scaled_dot_product_attention,
528
+ # "xformers": memory_efficient_attention,
529
+ # }
530
+
531
+
532
+ def vanilla_attention(q, k, v, mask, scale=None):
533
+ if scale is None:
534
+ scale = math.sqrt(q.size(-1))
535
+ scores = torch.bmm(q, k.transpose(-1, -2)) / scale
536
+ if mask is not None:
537
+ mask = einops.rearrange(mask, "b ... -> b (...)")
538
+ max_neg_value = -torch.finfo(scores.dtype).max
539
+ mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
540
+ scores = scores.masked_fill(~mask, max_neg_value)
541
+ p_attn = F.softmax(scores, dim=-1)
542
+ return torch.bmm(p_attn, v)
543
+
544
+
545
+ def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
546
+ """
547
+ q, k, v: [B, L, D]
548
+ """
549
+ pre_attn_layout = MEMORY_LAYOUTS[mode][0]
550
+ post_attn_layout = MEMORY_LAYOUTS[mode][1]
551
+ q = pre_attn_layout(q, head_dim)
552
+ k = pre_attn_layout(k, head_dim)
553
+ v = pre_attn_layout(v, head_dim)
554
+
555
+ # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
556
+ if mode == "torch":
557
+ assert scale is None
558
+ scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale)
559
+ elif mode == "xformers":
560
+ scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
561
+ else:
562
+ scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)
563
+
564
+ scores = post_attn_layout(scores)
565
+ return scores
566
+
567
+
568
+ # DismantledBlock in mmdit.py
569
+ class SingleDiTBlock(nn.Module):
570
+ """
571
+ A DiT block with gated adaptive layer norm (adaLN) conditioning.
572
+ """
573
+
574
+ def __init__(
575
+ self,
576
+ hidden_size: int,
577
+ num_heads: int,
578
+ mlp_ratio: float = 4.0,
579
+ attn_mode: str = "xformers",
580
+ qkv_bias: bool = False,
581
+ pre_only: bool = False,
582
+ rmsnorm: bool = False,
583
+ scale_mod_only: bool = False,
584
+ swiglu: bool = False,
585
+ qk_norm: Optional[str] = None,
586
+ x_block_self_attn: bool = False,
587
+ **block_kwargs,
588
+ ):
589
+ super().__init__()
590
+ assert attn_mode in MEMORY_LAYOUTS
591
+ self.attn_mode = attn_mode
592
+ if not rmsnorm:
593
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
594
+ else:
595
+ self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
596
+ self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)
597
+
598
+ self.x_block_self_attn = x_block_self_attn
599
+ if self.x_block_self_attn:
600
+ assert not pre_only
601
+ assert not scale_mod_only
602
+ self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)
603
+
604
+ if not pre_only:
605
+ if not rmsnorm:
606
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
607
+ else:
608
+ self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
609
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
610
+ if not pre_only:
611
+ if not swiglu:
612
+ self.mlp = MLP(
613
+ in_features=hidden_size,
614
+ hidden_features=mlp_hidden_dim,
615
+ act_layer=lambda: nn.GELU(approximate="tanh"),
616
+ )
617
+ else:
618
+ self.mlp = SwiGLUFeedForward(
619
+ dim=hidden_size,
620
+ hidden_dim=mlp_hidden_dim,
621
+ multiple_of=256,
622
+ )
623
+ self.scale_mod_only = scale_mod_only
624
+ if self.x_block_self_attn:
625
+ n_mods = 9
626
+ elif not scale_mod_only:
627
+ n_mods = 6 if not pre_only else 2
628
+ else:
629
+ n_mods = 4 if not pre_only else 1
630
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size))
631
+ self.pre_only = pre_only
632
+
633
+ def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
634
+ if not self.pre_only:
635
+ if not self.scale_mod_only:
636
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
637
+ else:
638
+ shift_msa = None
639
+ shift_mlp = None
640
+ (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
641
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
642
+ return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
643
+ else:
644
+ if not self.scale_mod_only:
645
+ (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
646
+ else:
647
+ shift_msa = None
648
+ scale_msa = self.adaLN_modulation(c)
649
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
650
+ return qkv, None
651
+
652
+ def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
653
+ assert self.x_block_self_attn
654
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
655
+ c
656
+ ).chunk(9, dim=1)
657
+ x_norm = self.norm1(x)
658
+ qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
659
+ qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
660
+ return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)
661
+
662
+ def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
663
+ assert not self.pre_only
664
+ x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
665
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
666
+ return x
667
+
668
+ def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
669
+ assert not self.pre_only
670
+ if attn1_dropout > 0.0:
671
+ # Use torch.bernoulli to implement dropout, only dropout the batch dimension
672
+ attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
673
+ attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
674
+ else:
675
+ attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
676
+ x = x + attn_
677
+ attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
678
+ x = x + attn2_
679
+ mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
680
+ x = x + mlp_
681
+ return x
682
+
683
+
684
+ # JointBlock + block_mixing in mmdit.py
685
+ class MMDiTBlock(nn.Module):
686
+ def __init__(self, *args, **kwargs):
687
+ super().__init__()
688
+ pre_only = kwargs.pop("pre_only")
689
+ x_block_self_attn = kwargs.pop("x_block_self_attn")
690
+
691
+ self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
692
+ self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
693
+
694
+ self.head_dim = self.x_block.attn.head_dim
695
+ self.mode = self.x_block.attn_mode
696
+ self.gradient_checkpointing = False
697
+
698
+ def enable_gradient_checkpointing(self):
699
+ self.gradient_checkpointing = True
700
+
701
+ def _forward(self, context, x, c):
702
+ ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
703
+
704
+ if self.x_block.x_block_self_attn:
705
+ x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
706
+ else:
707
+ x_qkv, x_intermediates = self.x_block.pre_attention(x, c)
708
+
709
+ ctx_len = ctx_qkv[0].size(1)
710
+
711
+ q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1)
712
+ k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1)
713
+ v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1)
714
+
715
+ attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode)
716
+ ctx_attn_out = attn[:, :ctx_len]
717
+ x_attn_out = attn[:, ctx_len:]
718
+
719
+ if self.x_block.x_block_self_attn:
720
+ x_q2, x_k2, x_v2 = x_qkv2
721
+ attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode)
722
+ x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
723
+ else:
724
+ x = self.x_block.post_attention(x_attn_out, *x_intermediates)
725
+
726
+ if not self.context_block.pre_only:
727
+ context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
728
+ else:
729
+ context = None
730
+
731
+ return context, x
732
+
733
+ def forward(self, *args, **kwargs):
734
+ if self.training and self.gradient_checkpointing:
735
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
736
+ else:
737
+ return self._forward(*args, **kwargs)
738
+
739
+
740
+ class MMDiT(nn.Module):
741
+ """
742
+ Diffusion model with a Transformer backbone.
743
+ """
744
+
745
+ # prepare pos_embed for latent size * 2
746
+ POS_EMBED_MAX_RATIO = 1.5
747
+
748
+ def __init__(
749
+ self,
750
+ input_size: int = 32,
751
+ patch_size: int = 2,
752
+ in_channels: int = 4,
753
+ depth: int = 28,
754
+ # hidden_size: Optional[int] = None,
755
+ # num_heads: Optional[int] = None,
756
+ mlp_ratio: float = 4.0,
757
+ learn_sigma: bool = False,
758
+ adm_in_channels: Optional[int] = None,
759
+ context_embedder_in_features: Optional[int] = None,
760
+ context_embedder_out_features: Optional[int] = None,
761
+ use_checkpoint: bool = False,
762
+ register_length: int = 0,
763
+ attn_mode: str = "torch",
764
+ rmsnorm: bool = False,
765
+ scale_mod_only: bool = False,
766
+ swiglu: bool = False,
767
+ out_channels: Optional[int] = None,
768
+ pos_embed_scaling_factor: Optional[float] = None,
769
+ pos_embed_offset: Optional[float] = None,
770
+ pos_embed_max_size: Optional[int] = None,
771
+ num_patches=None,
772
+ qk_norm: Optional[str] = None,
773
+ x_block_self_attn_layers: Optional[list[int]] = [],
774
+ qkv_bias: bool = True,
775
+ pos_emb_random_crop_rate: float = 0.0,
776
+ use_scaled_pos_embed: bool = False,
777
+ pos_embed_latent_sizes: Optional[list[int]] = None,
778
+ model_type: str = "sd3m",
779
+ ):
780
+ super().__init__()
781
+ self._model_type = model_type
782
+ self.learn_sigma = learn_sigma
783
+ self.in_channels = in_channels
784
+ default_out_channels = in_channels * 2 if learn_sigma else in_channels
785
+ self.out_channels = default(out_channels, default_out_channels)
786
+ self.patch_size = patch_size
787
+ self.pos_embed_scaling_factor = pos_embed_scaling_factor
788
+ self.pos_embed_offset = pos_embed_offset
789
+ self.pos_embed_max_size = pos_embed_max_size
790
+ self.x_block_self_attn_layers = x_block_self_attn_layers
791
+ self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
792
+ self.gradient_checkpointing = use_checkpoint
793
+
794
+ # hidden_size = default(hidden_size, 64 * depth)
795
+ # num_heads = default(num_heads, hidden_size // 64)
796
+
797
+ # apply magic --> this defines a head_size of 64
798
+ self.hidden_size = 64 * depth
799
+ num_heads = depth
800
+
801
+ self.num_heads = num_heads
802
+
803
+ self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
804
+
805
+ self.x_embedder = PatchEmbed(
806
+ input_size,
807
+ patch_size,
808
+ in_channels,
809
+ self.hidden_size,
810
+ bias=True,
811
+ strict_img_size=self.pos_embed_max_size is None,
812
+ )
813
+ self.t_embedder = TimestepEmbedding(self.hidden_size)
814
+
815
+ self.y_embedder = None
816
+ if adm_in_channels is not None:
817
+ assert isinstance(adm_in_channels, int)
818
+ self.y_embedder = Embedder(adm_in_channels, self.hidden_size)
819
+
820
+ if context_embedder_in_features is not None:
821
+ self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features)
822
+ else:
823
+ self.context_embedder = nn.Identity()
824
+
825
+ self.register_length = register_length
826
+ if self.register_length > 0:
827
+ self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size))
828
+
829
+ # num_patches = self.x_embedder.num_patches
830
+ # Will use fixed sin-cos embedding:
831
+ # just use a buffer already
832
+ if num_patches is not None:
833
+ self.register_buffer(
834
+ "pos_embed",
835
+ torch.empty(1, num_patches, self.hidden_size),
836
+ )
837
+ else:
838
+ self.pos_embed = None
839
+
840
+ self.use_checkpoint = use_checkpoint
841
+ self.joint_blocks = nn.ModuleList(
842
+ [
843
+ MMDiTBlock(
844
+ self.hidden_size,
845
+ num_heads,
846
+ mlp_ratio=mlp_ratio,
847
+ attn_mode=attn_mode,
848
+ qkv_bias=qkv_bias,
849
+ pre_only=i == depth - 1,
850
+ rmsnorm=rmsnorm,
851
+ scale_mod_only=scale_mod_only,
852
+ swiglu=swiglu,
853
+ qk_norm=qk_norm,
854
+ x_block_self_attn=(i in self.x_block_self_attn_layers),
855
+ )
856
+ for i in range(depth)
857
+ ]
858
+ )
859
+ for block in self.joint_blocks:
860
+ block.gradient_checkpointing = use_checkpoint
861
+
862
+ self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
863
+ # self.initialize_weights()
864
+
865
+ self.blocks_to_swap = None
866
+ self.offloader = None
867
+ self.num_blocks = len(self.joint_blocks)
868
+
869
+ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
870
+ self.use_scaled_pos_embed = use_scaled_pos_embed
871
+
872
+ if self.use_scaled_pos_embed:
873
+ # remove pos_embed to free up memory up to 0.4 GB
874
+ self.pos_embed = None
875
+
876
+ # remove duplicates and sort latent sizes in ascending order
877
+ latent_sizes = list(set(latent_sizes))
878
+ latent_sizes = sorted(latent_sizes)
879
+
880
+ patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
881
+
882
+ # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
883
+ max_areas = []
884
+ for i in range(1, len(patched_sizes)):
885
+ prev_area = patched_sizes[i - 1] ** 2
886
+ area = patched_sizes[i] ** 2
887
+ max_areas.append((prev_area + area) // 2)
888
+
889
+ # area of the last latent size, if the latent size exceeds this, error will be raised
890
+ max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
891
+ # print("max_areas", max_areas)
892
+
893
+ self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
894
+
895
+ self.resolution_pos_embeds = {}
896
+ for patched_size in patched_sizes:
897
+ grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
898
+ pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
899
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
900
+ self.resolution_pos_embeds[patched_size] = pos_embed
901
+ # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
902
+
903
+ else:
904
+ self.resolution_area_to_latent_size = None
905
+ self.resolution_pos_embeds = None
906
+
907
+ @property
908
+ def model_type(self):
909
+ return self._model_type
910
+
911
+ @property
912
+ def device(self):
913
+ return next(self.parameters()).device
914
+
915
+ @property
916
+ def dtype(self):
917
+ return next(self.parameters()).dtype
918
+
919
+ def enable_gradient_checkpointing(self):
920
+ self.gradient_checkpointing = True
921
+ for block in self.joint_blocks:
922
+ block.enable_gradient_checkpointing()
923
+
924
+ def disable_gradient_checkpointing(self):
925
+ self.gradient_checkpointing = False
926
+ for block in self.joint_blocks:
927
+ block.disable_gradient_checkpointing()
928
+
929
+ def initialize_weights(self):
930
+ # TODO: Init context_embedder?
931
+ # Initialize transformer layers:
932
+ def _basic_init(module):
933
+ if isinstance(module, nn.Linear):
934
+ torch.nn.init.xavier_uniform_(module.weight)
935
+ if module.bias is not None:
936
+ nn.init.constant_(module.bias, 0)
937
+
938
+ self.apply(_basic_init)
939
+
940
+ # Initialize (and freeze) pos_embed by sin-cos embedding
941
+ if self.pos_embed is not None:
942
+ pos_embed = get_2d_sincos_pos_embed(
943
+ self.pos_embed.shape[-1],
944
+ int(self.pos_embed.shape[-2] ** 0.5),
945
+ scaling_factor=self.pos_embed_scaling_factor,
946
+ )
947
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
948
+
949
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
950
+ w = self.x_embedder.proj.weight.data
951
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
952
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
953
+
954
+ if getattr(self, "y_embedder", None) is not None:
955
+ nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
956
+ nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
957
+
958
+ # Initialize timestep embedding MLP:
959
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
960
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
961
+
962
+ # Zero-out adaLN modulation layers in DiT blocks:
963
+ for block in self.joint_blocks:
964
+ nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
965
+ nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
966
+ nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
967
+ nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
968
+
969
+ # Zero-out output layers:
970
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
971
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
972
+ nn.init.constant_(self.final_layer.linear.weight, 0)
973
+ nn.init.constant_(self.final_layer.linear.bias, 0)
974
+
975
+ def set_pos_emb_random_crop_rate(self, rate: float):
976
+ self.pos_emb_random_crop_rate = rate
977
+
978
+ def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
979
+ p = self.x_embedder.patch_size
980
+ # patched size
981
+ h = (h + 1) // p
982
+ w = (w + 1) // p
983
+ if self.pos_embed is None: # should not happen
984
+ return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
985
+ assert self.pos_embed_max_size is not None
986
+ assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
987
+ assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
988
+
989
+ if not random_crop:
990
+ top = (self.pos_embed_max_size - h) // 2
991
+ left = (self.pos_embed_max_size - w) // 2
992
+ else:
993
+ top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
994
+ left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()
995
+
996
+ spatial_pos_embed = self.pos_embed.reshape(
997
+ 1,
998
+ self.pos_embed_max_size,
999
+ self.pos_embed_max_size,
1000
+ self.pos_embed.shape[-1],
1001
+ )
1002
+ spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
1003
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
1004
+ return spatial_pos_embed
1005
+
1006
+ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
1007
+ p = self.x_embedder.patch_size
1008
+ # patched size
1009
+ h = (h + 1) // p
1010
+ w = (w + 1) // p
1011
+
1012
+ # select pos_embed size based on area
1013
+ area = h * w
1014
+ patched_size = None
1015
+ for area_, patched_size_ in self.resolution_area_to_latent_size:
1016
+ if area <= area_:
1017
+ patched_size = patched_size_
1018
+ break
1019
+ if patched_size is None:
1020
+ raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
1021
+
1022
+ pos_embed = self.resolution_pos_embeds[patched_size]
1023
+ pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
1024
+ if h > pos_embed_size or w > pos_embed_size:
1025
+ # # fallback to normal pos_embed
1026
+ # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
1027
+ # extend pos_embed size
1028
+ logger.warning(
1029
+ f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
1030
+ )
1031
+ pos_embed_size = max(h, w)
1032
+ pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
1033
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
1034
+ self.resolution_pos_embeds[patched_size] = pos_embed
1035
+ logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
1036
+
1037
+ if not random_crop:
1038
+ top = (pos_embed_size - h) // 2
1039
+ left = (pos_embed_size - w) // 2
1040
+ else:
1041
+ top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
1042
+ left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
1043
+
1044
+ if pos_embed.device != device:
1045
+ pos_embed = pos_embed.to(device)
1046
+ # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
1047
+ self.resolution_pos_embeds[patched_size] = pos_embed # update device
1048
+ if pos_embed.dtype != dtype:
1049
+ pos_embed = pos_embed.to(dtype)
1050
+ self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
1051
+
1052
+ spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
1053
+ spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
1054
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
1055
+ # print(
1056
+ # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
1057
+ # )
1058
+ return spatial_pos_embed
1059
+
1060
+ def enable_block_swap(self, num_blocks: int, device: torch.device):
1061
+ self.blocks_to_swap = num_blocks
1062
+
1063
+ assert (
1064
+ self.blocks_to_swap <= self.num_blocks - 2
1065
+ ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
1066
+
1067
+ self.offloader = custom_offloading_utils.ModelOffloader(
1068
+ self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
1069
+ )
1070
+ print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
1071
+
1072
+ def move_to_device_except_swap_blocks(self, device: torch.device):
1073
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
1074
+ if self.blocks_to_swap:
1075
+ save_blocks = self.joint_blocks
1076
+ self.joint_blocks = None
1077
+
1078
+ self.to(device)
1079
+
1080
+ if self.blocks_to_swap:
1081
+ self.joint_blocks = save_blocks
1082
+
1083
+ def prepare_block_swap_before_forward(self):
1084
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
1085
+ return
1086
+ self.offloader.prepare_block_devices_before_forward(self.joint_blocks)
1087
+
1088
+ def forward(
1089
+ self,
1090
+ x: torch.Tensor,
1091
+ t: torch.Tensor,
1092
+ y: Optional[torch.Tensor] = None,
1093
+ context: Optional[torch.Tensor] = None,
1094
+ ) -> torch.Tensor:
1095
+ """
1096
+ Forward pass of DiT.
1097
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
1098
+ t: (N,) tensor of diffusion timesteps
1099
+ y: (N, D) tensor of class labels
1100
+ """
1101
+ pos_emb_random_crop = (
1102
+ False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
1103
+ )
1104
+
1105
+ B, C, H, W = x.shape
1106
+
1107
+ # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
1108
+ if not self.use_scaled_pos_embed:
1109
+ pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
1110
+ else:
1111
+ # print(f"Using scaled pos_embed for size {H}x{W}")
1112
+ pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
1113
+ x = self.x_embedder(x) + pos_embed
1114
+ del pos_embed
1115
+
1116
+ c = self.t_embedder(t, dtype=x.dtype) # (N, D)
1117
+ if y is not None and self.y_embedder is not None:
1118
+ y = self.y_embedder(y) # (N, D)
1119
+ c = c + y # (N, D)
1120
+
1121
+ if context is not None:
1122
+ context = self.context_embedder(context)
1123
+
1124
+ if self.register_length > 0:
1125
+ context = torch.cat(
1126
+ (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1
1127
+ )
1128
+
1129
+ if not self.blocks_to_swap:
1130
+ for block in self.joint_blocks:
1131
+ context, x = block(context, x, c)
1132
+ else:
1133
+ for block_idx, block in enumerate(self.joint_blocks):
1134
+ self.offloader.wait_for_block(block_idx)
1135
+
1136
+ context, x = block(context, x, c)
1137
+
1138
+ self.offloader.submit_move_blocks(self.joint_blocks, block_idx)
1139
+
1140
+ x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
1141
+ return x[:, :, :H, :W]
1142
+
1143
+
1144
+ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
1145
+ mmdit = MMDiT(
1146
+ input_size=None,
1147
+ pos_embed_max_size=params.pos_embed_max_size,
1148
+ patch_size=params.patch_size,
1149
+ in_channels=16,
1150
+ adm_in_channels=params.adm_in_channels,
1151
+ context_embedder_in_features=params.context_embedder_in_features,
1152
+ context_embedder_out_features=params.context_embedder_out_features,
1153
+ depth=params.depth,
1154
+ mlp_ratio=4,
1155
+ qk_norm=params.qk_norm,
1156
+ x_block_self_attn_layers=params.x_block_self_attn_layers,
1157
+ num_patches=params.num_patches,
1158
+ attn_mode=attn_mode,
1159
+ model_type=params.model_type,
1160
+ )
1161
+ return mmdit
1162
+
1163
+
1164
+ # endregion
1165
+
1166
+ # region VAE
1167
+
1168
+ VAE_SCALE_FACTOR = 1.5305
1169
+ VAE_SHIFT_FACTOR = 0.0609
1170
+
1171
+
1172
+ def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
1173
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
1174
+
1175
+
1176
+ class ResnetBlock(torch.nn.Module):
1177
+ def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
1178
+ super().__init__()
1179
+ self.in_channels = in_channels
1180
+ out_channels = in_channels if out_channels is None else out_channels
1181
+ self.out_channels = out_channels
1182
+
1183
+ self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
1184
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1185
+ self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
1186
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1187
+ if self.in_channels != self.out_channels:
1188
+ self.nin_shortcut = torch.nn.Conv2d(
1189
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device
1190
+ )
1191
+ else:
1192
+ self.nin_shortcut = None
1193
+ self.swish = torch.nn.SiLU(inplace=True)
1194
+
1195
+ def forward(self, x):
1196
+ hidden = x
1197
+ hidden = self.norm1(hidden)
1198
+ hidden = self.swish(hidden)
1199
+ hidden = self.conv1(hidden)
1200
+ hidden = self.norm2(hidden)
1201
+ hidden = self.swish(hidden)
1202
+ hidden = self.conv2(hidden)
1203
+ if self.in_channels != self.out_channels:
1204
+ x = self.nin_shortcut(x)
1205
+ return x + hidden
1206
+
1207
+
1208
+ class AttnBlock(torch.nn.Module):
1209
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
1210
+ super().__init__()
1211
+ self.norm = Normalize(in_channels, dtype=dtype, device=device)
1212
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1213
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1214
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1215
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1216
+
1217
+ def forward(self, x):
1218
+ hidden = self.norm(x)
1219
+ q = self.q(hidden)
1220
+ k = self.k(hidden)
1221
+ v = self.v(hidden)
1222
+ b, c, h, w = q.shape
1223
+ q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
1224
+ hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
1225
+ hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
1226
+ hidden = self.proj_out(hidden)
1227
+ return x + hidden
1228
+
1229
+
1230
+ class Downsample(torch.nn.Module):
1231
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
1232
+ super().__init__()
1233
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
1234
+
1235
+ def forward(self, x):
1236
+ pad = (0, 1, 0, 1)
1237
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
1238
+ x = self.conv(x)
1239
+ return x
1240
+
1241
+
1242
+ class Upsample(torch.nn.Module):
1243
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
1244
+ super().__init__()
1245
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1246
+
1247
+ def forward(self, x):
1248
+ org_dtype = x.dtype
1249
+ if x.dtype == torch.bfloat16:
1250
+ x = x.to(torch.float32)
1251
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
1252
+ if x.dtype != org_dtype:
1253
+ x = x.to(org_dtype)
1254
+ x = self.conv(x)
1255
+ return x
1256
+
1257
+
1258
+ class VAEEncoder(torch.nn.Module):
1259
+ def __init__(
1260
+ self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None
1261
+ ):
1262
+ super().__init__()
1263
+ self.num_resolutions = len(ch_mult)
1264
+ self.num_res_blocks = num_res_blocks
1265
+ # downsampling
1266
+ self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1267
+ in_ch_mult = (1,) + tuple(ch_mult)
1268
+ self.in_ch_mult = in_ch_mult
1269
+ self.down = torch.nn.ModuleList()
1270
+ for i_level in range(self.num_resolutions):
1271
+ block = torch.nn.ModuleList()
1272
+ attn = torch.nn.ModuleList()
1273
+ block_in = ch * in_ch_mult[i_level]
1274
+ block_out = ch * ch_mult[i_level]
1275
+ for i_block in range(num_res_blocks):
1276
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
1277
+ block_in = block_out
1278
+ down = torch.nn.Module()
1279
+ down.block = block
1280
+ down.attn = attn
1281
+ if i_level != self.num_resolutions - 1:
1282
+ down.downsample = Downsample(block_in, dtype=dtype, device=device)
1283
+ self.down.append(down)
1284
+ # middle
1285
+ self.mid = torch.nn.Module()
1286
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1287
+ self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
1288
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1289
+ # end
1290
+ self.norm_out = Normalize(block_in, dtype=dtype, device=device)
1291
+ self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1292
+ self.swish = torch.nn.SiLU(inplace=True)
1293
+
1294
+ def forward(self, x):
1295
+ # downsampling
1296
+ hs = [self.conv_in(x)]
1297
+ for i_level in range(self.num_resolutions):
1298
+ for i_block in range(self.num_res_blocks):
1299
+ h = self.down[i_level].block[i_block](hs[-1])
1300
+ hs.append(h)
1301
+ if i_level != self.num_resolutions - 1:
1302
+ hs.append(self.down[i_level].downsample(hs[-1]))
1303
+ # middle
1304
+ h = hs[-1]
1305
+ h = self.mid.block_1(h)
1306
+ h = self.mid.attn_1(h)
1307
+ h = self.mid.block_2(h)
1308
+ # end
1309
+ h = self.norm_out(h)
1310
+ h = self.swish(h)
1311
+ h = self.conv_out(h)
1312
+ return h
1313
+
1314
+
1315
+ class VAEDecoder(torch.nn.Module):
1316
+ def __init__(
1317
+ self,
1318
+ ch=128,
1319
+ out_ch=3,
1320
+ ch_mult=(1, 2, 4, 4),
1321
+ num_res_blocks=2,
1322
+ resolution=256,
1323
+ z_channels=16,
1324
+ dtype=torch.float32,
1325
+ device=None,
1326
+ ):
1327
+ super().__init__()
1328
+ self.num_resolutions = len(ch_mult)
1329
+ self.num_res_blocks = num_res_blocks
1330
+ block_in = ch * ch_mult[self.num_resolutions - 1]
1331
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
1332
+ # z to block_in
1333
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1334
+ # middle
1335
+ self.mid = torch.nn.Module()
1336
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1337
+ self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
1338
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1339
+ # upsampling
1340
+ self.up = torch.nn.ModuleList()
1341
+ for i_level in reversed(range(self.num_resolutions)):
1342
+ block = torch.nn.ModuleList()
1343
+ block_out = ch * ch_mult[i_level]
1344
+ for i_block in range(self.num_res_blocks + 1):
1345
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
1346
+ block_in = block_out
1347
+ up = torch.nn.Module()
1348
+ up.block = block
1349
+ if i_level != 0:
1350
+ up.upsample = Upsample(block_in, dtype=dtype, device=device)
1351
+ curr_res = curr_res * 2
1352
+ self.up.insert(0, up) # prepend to get consistent order
1353
+ # end
1354
+ self.norm_out = Normalize(block_in, dtype=dtype, device=device)
1355
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1356
+ self.swish = torch.nn.SiLU(inplace=True)
1357
+
1358
+ def forward(self, z):
1359
+ # z to block_in
1360
+ hidden = self.conv_in(z)
1361
+ # middle
1362
+ hidden = self.mid.block_1(hidden)
1363
+ hidden = self.mid.attn_1(hidden)
1364
+ hidden = self.mid.block_2(hidden)
1365
+ # upsampling
1366
+ for i_level in reversed(range(self.num_resolutions)):
1367
+ for i_block in range(self.num_res_blocks + 1):
1368
+ hidden = self.up[i_level].block[i_block](hidden)
1369
+ if i_level != 0:
1370
+ hidden = self.up[i_level].upsample(hidden)
1371
+ # end
1372
+ hidden = self.norm_out(hidden)
1373
+ hidden = self.swish(hidden)
1374
+ hidden = self.conv_out(hidden)
1375
+ return hidden
1376
+
1377
+
1378
+ class SDVAE(torch.nn.Module):
1379
+ def __init__(self, dtype=torch.float32, device=None):
1380
+ super().__init__()
1381
+ self.encoder = VAEEncoder(dtype=dtype, device=device)
1382
+ self.decoder = VAEDecoder(dtype=dtype, device=device)
1383
+
1384
+ @property
1385
+ def device(self):
1386
+ return next(self.parameters()).device
1387
+
1388
+ @property
1389
+ def dtype(self):
1390
+ return next(self.parameters()).dtype
1391
+
1392
+ # @torch.autocast("cuda", dtype=torch.float16)
1393
+ def decode(self, latent):
1394
+ return self.decoder(latent)
1395
+
1396
+ # @torch.autocast("cuda", dtype=torch.float16)
1397
+ def encode(self, image):
1398
+ hidden = self.encoder(image)
1399
+ mean, logvar = torch.chunk(hidden, 2, dim=1)
1400
+ logvar = torch.clamp(logvar, -30.0, 20.0)
1401
+ std = torch.exp(0.5 * logvar)
1402
+ return mean + std * torch.randn_like(mean)
1403
+
1404
+ @staticmethod
1405
+ def process_in(latent):
1406
+ return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR
1407
+
1408
+ @staticmethod
1409
+ def process_out(latent):
1410
+ return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR
1411
+
1412
+
1413
+ # endregion