recoilme commited on
Commit
f288df0
·
1 Parent(s): c226386

train_sdxl_vae_full

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. eval_alchemist.py +67 -15
  2. samples/sample_2.jpg +0 -3
  3. samples/sample_decoded-Copy1.jpg +0 -3
  4. samples/sample_decoded.jpg +0 -3
  5. samples/sample_real.jpg +0 -3
  6. simple_vae/config.json +38 -0
  7. samples/sample_0.jpg → simple_vae/diffusion_pytorch_model.safetensors +2 -2
  8. simple_vae_nightly/config.json +38 -0
  9. samples/sample_1.jpg → simple_vae_nightly/diffusion_pytorch_model.safetensors +2 -2
  10. train_sdxl_vae.py +6 -6
  11. train_sdxl_vae_full.py +590 -0
  12. train_sdxl_vae_simple.py +547 -0
  13. vaetest/001_all.png +0 -3
  14. vaetest/001_decoded_AiArtLab_sdxl_vae.png +0 -3
  15. vaetest/001_decoded_AiArtLab_sdxlvae_nightly.png +0 -3
  16. vaetest/001_decoded_AiArtLab_sdxs.png +0 -3
  17. vaetest/001_decoded_FLUX.1_schnell_vae.png +0 -3
  18. vaetest/001_decoded_KBlueLeaf_EQ_SDXL_VAE.png +0 -3
  19. vaetest/001_decoded_madebyollin_sdxl_vae_fp16.png +0 -3
  20. vaetest/001_decoded_vae.png +0 -3
  21. vaetest/001_decoded_vae_nightly.png +0 -3
  22. vaetest/001_orig.png +0 -3
  23. vaetest/002_all.png +0 -3
  24. vaetest/002_decoded_AiArtLab_sdxl_vae.png +0 -3
  25. vaetest/002_decoded_AiArtLab_sdxlvae_nightly.png +0 -3
  26. vaetest/002_decoded_AiArtLab_sdxs.png +0 -3
  27. vaetest/002_decoded_FLUX.1_schnell_vae.png +0 -3
  28. vaetest/002_decoded_KBlueLeaf_EQ_SDXL_VAE.png +0 -3
  29. vaetest/002_decoded_madebyollin_sdxl_vae_fp16.png +0 -3
  30. vaetest/002_decoded_vae.png +0 -3
  31. vaetest/002_decoded_vae_nightly.png +0 -3
  32. vaetest/002_orig.png +0 -3
  33. vaetest/003_all.png +0 -3
  34. vaetest/003_decoded_AiArtLab_sdxl_vae.png +0 -3
  35. vaetest/003_decoded_AiArtLab_sdxlvae_nightly.png +0 -3
  36. vaetest/003_decoded_AiArtLab_sdxs.png +0 -3
  37. vaetest/003_decoded_FLUX.1_schnell_vae.png +0 -3
  38. vaetest/003_decoded_KBlueLeaf_EQ_SDXL_VAE.png +0 -3
  39. vaetest/003_decoded_madebyollin_sdxl_vae_fp16.png +0 -3
  40. vaetest/003_decoded_vae.png +0 -3
  41. vaetest/003_decoded_vae_nightly.png +0 -3
  42. vaetest/003_orig.png +0 -3
  43. vaetest/004_all.png +0 -3
  44. vaetest/004_decoded_AiArtLab_sdxl_vae.png +0 -3
  45. vaetest/004_decoded_AiArtLab_sdxlvae_nightly.png +0 -3
  46. vaetest/004_decoded_AiArtLab_sdxs.png +0 -3
  47. vaetest/004_decoded_FLUX.1_schnell_vae.png +0 -3
  48. vaetest/004_decoded_KBlueLeaf_EQ_SDXL_VAE.png +0 -3
  49. vaetest/004_decoded_madebyollin_sdxl_vae_fp16.png +0 -3
  50. vaetest/004_decoded_vae.png +0 -3
eval_alchemist.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image, UnidentifiedImageError
6
  from tqdm import tqdm
7
  from torch.utils.data import Dataset, DataLoader
8
  from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage
9
- from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
10
  import random
11
 
12
  # --------------------------- Параметры ---------------------------
@@ -15,24 +15,28 @@ DTYPE = torch.float16
15
  IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
16
  MIN_SIZE = 1280
17
  CROP_SIZE = 512
18
- BATCH_SIZE = 10
19
- MAX_IMAGES = 0
20
  NUM_WORKERS = 4
21
- NUM_SAMPLES_TO_SAVE = 10 # Сколько примеров сохранить (0 - не сохранять)
22
  SAMPLES_FOLDER = "vaetest"
23
 
24
  # Список VAE для тестирования
25
  VAE_LIST = [
26
-
27
  # ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
28
  # ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None),
29
- ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
30
- # ("vae", AutoencoderKL, "/workspace/sdxl_vae/vae", None),
31
- ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
32
- ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
33
- ("AiArtLab/sdxlvae_nightly", AutoencoderKL, "AiArtLab/sdxl_vae", "vae_nightly"),
 
 
34
  ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
35
- ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
 
 
 
36
  ]
37
 
38
  # --------------------------- Sobel Edge Detection ---------------------------
@@ -129,6 +133,44 @@ def deprocess(x):
129
  def _sanitize_name(name: str) -> str:
130
  return name.replace('/', '_').replace('-', '_')
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # --------------------------- Основной код ---------------------------
133
  if __name__ == "__main__":
134
  if NUM_SAMPLES_TO_SAVE > 0:
@@ -182,13 +224,23 @@ if __name__ == "__main__":
182
  for batch in tqdm(dataloader, desc="Обработка батчей"):
183
  batch = batch.to(DEVICE) # [B,3,H,W] в [0,1]
184
  test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера
 
 
 
 
185
 
186
  # 1) считаем реконструкции для всех VAE на весь батч
187
  recon_list = []
188
- for vae in vaes:
189
- latent = vae.encode(test_inp).latent_dist.mode()
190
- dec = vae.decode(latent).sample.float() # [-1,1] (как правило)
191
- recon = deprocess(dec).clamp(0.0, 1.0) # -> [0,1], clamp убирает артефакты
 
 
 
 
 
 
192
  recon_list.append(recon)
193
 
194
  # 2) обновляем метрики (по каждой VAE)
 
6
  from tqdm import tqdm
7
  from torch.utils.data import Dataset, DataLoader
8
  from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage
9
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL, AutoencoderKLWan,AutoencoderKLLTXVideo
10
  import random
11
 
12
  # --------------------------- Параметры ---------------------------
 
15
  IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
16
  MIN_SIZE = 1280
17
  CROP_SIZE = 512
18
+ BATCH_SIZE = 1
19
+ MAX_IMAGES = 100
20
  NUM_WORKERS = 4
21
+ NUM_SAMPLES_TO_SAVE = 2 # Сколько примеров сохранить (0 - не сохранять)
22
  SAMPLES_FOLDER = "vaetest"
23
 
24
  # Список VAE для тестирования
25
  VAE_LIST = [
 
26
  # ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
27
  # ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None),
28
+ # ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
29
+ # ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
30
+ # ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
31
+ # ("AiArtLab/sdxlvae_nightly", AutoencoderKL, "AiArtLab/sdxl_vae", "vae_nightly"),
32
+ # ("Lightricks/LTX-Video", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
33
+ # ("Wan2.2-TI2V-5B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
34
+ # ("Wan2.2-T2V-A14B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
35
  ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
36
+ # ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
37
+ # ("simple_vae", AutoencoderKL, "/workspace/sdxl_vae/simple_vae", None),
38
+ ("simple_vae_nightly", AutoencoderKL, "/workspace/sdxl_vae/simple_vae_nightly", None),
39
+
40
  ]
41
 
42
  # --------------------------- Sobel Edge Detection ---------------------------
 
133
  def _sanitize_name(name: str) -> str:
134
  return name.replace('/', '_').replace('-', '_')
135
 
136
+ # --------------------------- Анализ VAE ---------------------------
137
+ @torch.no_grad()
138
+ def tensor_stats(name, x: torch.Tensor):
139
+ finite = torch.isfinite(x)
140
+ fin_ratio = finite.float().mean().item()
141
+ x_f = x[finite]
142
+ minv = x_f.min().item() if x_f.numel() else float('nan')
143
+ maxv = x_f.max().item() if x_f.numel() else float('nan')
144
+ mean = x_f.mean().item() if x_f.numel() else float('nan')
145
+ std = x_f.std().item() if x_f.numel() else float('nan')
146
+ big = (x_f.abs() > 20).float().mean().item() if x_f.numel() else float('nan')
147
+ print(f"[{name}] shape={tuple(x.shape)} dtype={x.dtype} "
148
+ f"finite={fin_ratio:.6f} min={minv:.4g} max={maxv:.4g} mean={mean:.4g} std={std:.4g} |x|>20={big:.6f}")
149
+
150
+ @torch.no_grad()
151
+ def analyze_vae_latents(vae, name, images):
152
+ """
153
+ images: [B,3,H,W] в [-1,1]
154
+ """
155
+ try:
156
+ enc = vae.encode(images)
157
+ if hasattr(enc, "latent_dist"):
158
+ mu, logvar = enc.latent_dist.mean, enc.latent_dist.logvar
159
+ z = enc.latent_dist.sample()
160
+ else:
161
+ mu, logvar = enc[0], enc[1]
162
+ z = mu
163
+ tensor_stats(f"{name}.mu", mu)
164
+ tensor_stats(f"{name}.logvar", logvar)
165
+ tensor_stats(f"{name}.z_raw", z)
166
+
167
+ sf = getattr(vae.config, "scaling_factor", 1.0)
168
+ z_scaled = z * sf
169
+ tensor_stats(f"{name}.z_scaled(x{sf})", z_scaled)
170
+ except Exception as e:
171
+ print(f"⚠️ Ошибка анализа VAE {name}: {e}")
172
+
173
+
174
  # --------------------------- Основной код ---------------------------
175
  if __name__ == "__main__":
176
  if NUM_SAMPLES_TO_SAVE > 0:
 
224
  for batch in tqdm(dataloader, desc="Обработка батчей"):
225
  batch = batch.to(DEVICE) # [B,3,H,W] в [0,1]
226
  test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера
227
+ # >>> Анализируем латенты каждой VAE на первой итерации
228
+ if images_saved == 0: # только для первого батча, чтобы не засорять лог
229
+ for vae, name in zip(vaes, names):
230
+ analyze_vae_latents(vae, name, test_inp)
231
 
232
  # 1) считаем реконструкции для всех VAE на весь батч
233
  recon_list = []
234
+ for vae, name in zip(vaes, names):
235
+ test_inp_vae = test_inp # локальная копия
236
+ #if name == "Wan2.2-T2V-A14B-Diffusers" and test_inp_vae.ndim == 4:
237
+ if (isinstance(vae, AutoencoderKLWan) or isinstance(vae, AutoencoderKLLTXVideo)) and test_inp_vae.ndim == 4:
238
+ test_inp_vae = test_inp_vae.unsqueeze(2) # только для Wan
239
+ latent = vae.encode(test_inp_vae).latent_dist.mode()
240
+ dec = vae.decode(latent).sample.float()
241
+ if dec.ndim == 5:
242
+ dec = dec.squeeze(2)
243
+ recon = deprocess(dec).clamp(0.0, 1.0)
244
  recon_list.append(recon)
245
 
246
  # 2) обновляем метрики (по каждой VAE)
samples/sample_2.jpg DELETED

Git LFS Details

  • SHA256: 0162a8e3a2b53bbf1f4e48dc9166c6b7dec6040416c3590383968ac2c89cc133
  • Pointer size: 130 Bytes
  • Size of remote file: 53.9 kB
samples/sample_decoded-Copy1.jpg DELETED

Git LFS Details

  • SHA256: bbc45fc2764868844ce3a13e4e297f42ca75aac376e061de4d6b736d981e5e12
  • Pointer size: 130 Bytes
  • Size of remote file: 79.3 kB
samples/sample_decoded.jpg DELETED

Git LFS Details

  • SHA256: 3dce2033b5ea0f9ce2006d9e3b2c8cb123cdfd1a58a72d3562103b8979ebedd0
  • Pointer size: 130 Bytes
  • Size of remote file: 79.1 kB
samples/sample_real.jpg DELETED

Git LFS Details

  • SHA256: 0811374ab3bc11e881daa2ccee89144532ebd6ccb101989ac15bb2fa1db504d2
  • Pointer size: 130 Bytes
  • Size of remote file: 85.4 kB
simple_vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.35.0.dev0",
4
+ "_name_or_path": "simple_vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 1.0,
29
+ "shift_factor": 0,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
samples/sample_0.jpg → simple_vae/diffusion_pytorch_model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3dce2033b5ea0f9ce2006d9e3b2c8cb123cdfd1a58a72d3562103b8979ebedd0
3
- size 79093
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:010d2cb8824a347425be4e41d662b22492965ffb61393621eb1253be8b7fa0ce
3
+ size 335311892
simple_vae_nightly/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.35.0.dev0",
4
+ "_name_or_path": "simple_vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 16,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 1024,
28
+ "scaling_factor": 1.0,
29
+ "shift_factor": 0,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
samples/sample_1.jpg → simple_vae_nightly/diffusion_pytorch_model.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a530cdc04fe70def63768ec9ca898fae8e100ea07bce8536b2d4a1115685e79e
3
- size 157701
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccd57f2cd9455d6c66ed2fee9396dbb53cbeb675fa0c1fbee87a9b0f94c3de79
3
+ size 335311892
train_sdxl_vae.py CHANGED
@@ -24,11 +24,11 @@ from collections import deque
24
 
25
  # --------------------------- Параметры ---------------------------
26
  ds_path = "/workspace/png"
27
- project = "vae"
28
  batch_size = 3
29
- base_learning_rate = 6e-6
30
- min_learning_rate = 1e-6
31
- num_epochs = 8
32
  sample_interval_share = 10
33
  use_wandb = True
34
  save_model = True
@@ -50,7 +50,7 @@ clip_grad_norm = 1.0
50
  mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
  gradient_accumulation_steps = 5
52
  generated_folder = "samples"
53
- save_as = "vae_nightly"
54
  num_workers = 0
55
  device = None # accelerator задаст устройство
56
 
@@ -81,7 +81,7 @@ torch.manual_seed(seed)
81
  np.random.seed(seed)
82
  random.seed(seed)
83
 
84
- torch.backends.cudnn.benchmark = True
85
 
86
  # --------------------------- WandB ---------------------------
87
  if use_wandb and accelerator.is_main_process:
 
24
 
25
  # --------------------------- Параметры ---------------------------
26
  ds_path = "/workspace/png"
27
+ project = "simple_vae"
28
  batch_size = 3
29
+ base_learning_rate = 5e-5
30
+ min_learning_rate = 9e-7
31
+ num_epochs = 16
32
  sample_interval_share = 10
33
  use_wandb = True
34
  save_model = True
 
50
  mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
  gradient_accumulation_steps = 5
52
  generated_folder = "samples"
53
+ save_as = "simple_vae_nightly"
54
  num_workers = 0
55
  device = None # accelerator задаст устройство
56
 
 
81
  np.random.seed(seed)
82
  random.seed(seed)
83
 
84
+ torch.backends.cudnn.benchmark = False
85
 
86
  # --------------------------- WandB ---------------------------
87
  if use_wandb and accelerator.is_main_process:
train_sdxl_vae_full.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ from accelerate import Accelerator
18
+ from PIL import Image, UnidentifiedImageError
19
+ from tqdm import tqdm
20
+ import bitsandbytes as bnb
21
+ import wandb
22
+ import lpips # pip install lpips
23
+ from collections import deque
24
+
25
+ # --------------------------- Параметры ---------------------------
26
+ ds_path = "/workspace/png"
27
+ project = "simple_vae"
28
+ batch_size = 3
29
+ base_learning_rate = 5e-5
30
+ min_learning_rate = 9e-7
31
+ num_epochs = 16
32
+ sample_interval_share = 10
33
+ use_wandb = True
34
+ save_model = True
35
+ use_decay = True
36
+ asymmetric = False
37
+ optimizer_type = "adam8bit"
38
+ dtype = torch.float32
39
+ # model_resolution — то, что подавается в VAE (низкое разрешение)
40
+ model_resolution = 512 # бывший `resolution`
41
+ # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
42
+ high_resolution = 512
43
+ limit = 0
44
+ save_barrier = 1.03
45
+ warmup_percent = 0.01
46
+ percentile_clipping = 95
47
+ beta2 = 0.97
48
+ eps = 1e-6
49
+ clip_grad_norm = 1.0
50
+ mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
+ gradient_accumulation_steps = 5
52
+ generated_folder = "samples"
53
+ save_as = "simple_vae_nightly"
54
+ num_workers = 0
55
+ device = None # accelerator задаст устройство
56
+
57
+ # --------------------------- Тренировочные режимы ---------------------------
58
+ # CHANGED: добавлен параметр для полного обучения VAE (а не только декодера).
59
+ # Если False — поведение прежнее: учим только decoder.* (up_blocks + mid_block).
60
+ # Если True — размораживаем ВСЮ модель и добавляем KL-loss для энкодера.
61
+ full_training = False
62
+
63
+ # CHANGED: добавлен вес (через долю в нормализаторе) для KL, используется только при full_training=True.
64
+ kl_ratio = 0.05 # простая доля для KL в общей смеси (KISS). Игнорируется, если full_training=False.
65
+
66
+ # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
67
+ # Итоговые доли в total loss (сумма = 1.0 после нормализации).
68
+ loss_ratios = {
69
+ "lpips": 0.85,
70
+ "edge": 0.05,
71
+ "mse": 0.05,
72
+ "mae": 0.05,
73
+ # CHANGED: заранее добавлен ключ "kl" (по умолчанию 0.0). Если включаем full_training — активируем ниже.
74
+ "kl": 0.00,
75
+ }
76
+ median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
77
+
78
+ # --------------------------- параметры препроцессинга ---------------------------
79
+ resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1280
80
+
81
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
82
+
83
+ accelerator = Accelerator(
84
+ mixed_precision=mixed_precision,
85
+ gradient_accumulation_steps=gradient_accumulation_steps
86
+ )
87
+ device = accelerator.device
88
+
89
+ # reproducibility
90
+ seed = int(datetime.now().strftime("%Y%m%d"))
91
+ torch.manual_seed(seed)
92
+ np.random.seed(seed)
93
+ random.seed(seed)
94
+
95
+ torch.backends.cudnn.benchmark = False
96
+
97
+ # --------------------------- WandB ---------------------------
98
+ if use_wandb and accelerator.is_main_process:
99
+ wandb.init(project=project, config={
100
+ "batch_size": batch_size,
101
+ "base_learning_rate": base_learning_rate,
102
+ "num_epochs": num_epochs,
103
+ "optimizer_type": optimizer_type,
104
+ "model_resolution": model_resolution,
105
+ "high_resolution": high_resolution,
106
+ "gradient_accumulation_steps": gradient_accumulation_steps,
107
+ "full_training": full_training, # CHANGED: логируем режим
108
+ "kl_ratio": kl_ratio, # CHANGED: логируем долю KL
109
+ })
110
+
111
+ # --------------------------- VAE ---------------------------
112
+ if model_resolution==high_resolution and not asymmetric:
113
+ vae = AutoencoderKL.from_pretrained(project).to(dtype)
114
+ else:
115
+ vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
116
+
117
+ # torch.compile (если доступно) — просто и без лишней логики
118
+ if hasattr(torch, "compile"):
119
+ try:
120
+ vae = torch.compile(vae)
121
+ except Exception as e:
122
+ print(f"[WARN] torch.compile failed: {e}")
123
+
124
+ # >>> Стратегия заморозки / разморозки
125
+ for p in vae.parameters():
126
+ p.requires_grad = False
127
+
128
+ decoder = getattr(vae, "decoder", None)
129
+ if decoder is None:
130
+ raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
131
+
132
+ unfrozen_param_names = []
133
+
134
+ if not full_training:
135
+ # === Прежнее поведение: обучаем только decoder.up_blocks и decoder.mid_block ===
136
+ if not hasattr(decoder, "up_blocks"):
137
+ raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
138
+
139
+ n_up = len(decoder.up_blocks)
140
+ start_idx = 0
141
+ for idx in range(start_idx, n_up):
142
+ block = decoder.up_blocks[idx]
143
+ for name, p in block.named_parameters():
144
+ p.requires_grad = True
145
+ unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
146
+
147
+ if hasattr(decoder, "mid_block"):
148
+ for name, p in decoder.mid_block.named_parameters():
149
+ p.requires_grad = True
150
+ unfrozen_param_names.append(f"decoder.mid_block.{name}")
151
+ else:
152
+ print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
153
+
154
+ # Обучаем только декодер
155
+ trainable_module = vae.decoder
156
+ else:
157
+ # === CHANGED: Полное обучение — размораживаем ВСЕ слои VAE (и энкодер, и декодер, и пост-проекцию) ===
158
+ for name, p in vae.named_parameters():
159
+ p.requires_grad = True
160
+ unfrozen_param_names.append(name)
161
+ trainable_module = vae # CHANGED: учим всю модель
162
+
163
+ # CHANGED: активируем KL-долю в нормализаторе
164
+ loss_ratios["kl"] = float(kl_ratio)
165
+
166
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
167
+ for nm in unfrozen_param_names[:200]:
168
+ print(" ", nm)
169
+
170
+ # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
171
+ class PngFolderDataset(Dataset):
172
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
173
+ self.root_dir = root_dir
174
+ self.resolution = resolution
175
+ self.paths = []
176
+ # collect png files recursively
177
+ for root, _, files in os.walk(root_dir):
178
+ for fname in files:
179
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
180
+ self.paths.append(os.path.join(root, fname))
181
+ # optional limit
182
+ if limit:
183
+ self.paths = self.paths[:limit]
184
+ # verify images and keep only valid ones
185
+ valid = []
186
+ for p in self.paths:
187
+ try:
188
+ with Image.open(p) as im:
189
+ im.verify() # fast check for truncated/corrupted images
190
+ valid.append(p)
191
+ except (OSError, UnidentifiedImageError):
192
+ # skip corrupted image
193
+ continue
194
+ self.paths = valid
195
+ if len(self.paths) == 0:
196
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
197
+ # final shuffle for randomness
198
+ random.shuffle(self.paths)
199
+
200
+ def __len__(self):
201
+ return len(self.paths)
202
+
203
+ def __getitem__(self, idx):
204
+ p = self.paths[idx % len(self.paths)]
205
+ # open and convert to RGB; ensure file is closed promptly
206
+ with Image.open(p) as img:
207
+ img = img.convert("RGB")
208
+ # пережимаем длинную сторону до resize_long_side (Lanczos)
209
+ if not resize_long_side or resize_long_side <= 0:
210
+ return img
211
+ w, h = img.size
212
+ long = max(w, h)
213
+ if long <= resize_long_side:
214
+ return img
215
+ scale = resize_long_side / float(long)
216
+ new_w = int(round(w * scale))
217
+ new_h = int(round(h * scale))
218
+ return img.resize((new_w, new_h), Image.LANCZOS)
219
+
220
+ # --------------------------- Датасет и трансформы ---------------------------
221
+
222
+ def random_crop(img, sz):
223
+ w, h = img.size
224
+ if w < sz or h < sz:
225
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
226
+ x = random.randint(0, max(1, img.width - sz))
227
+ y = random.randint(0, max(1, img.height - sz))
228
+ return img.crop((x, y, x + sz, y + sz))
229
+
230
+ tfm = transforms.Compose([
231
+ transforms.ToTensor(),
232
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
233
+ ])
234
+
235
+ # build dataset using high_resolution crops
236
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
237
+ if len(dataset) < batch_size:
238
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
239
+
240
+ # collate_fn кропит до high_resolution
241
+ def collate_fn(batch):
242
+ imgs = []
243
+ for img in batch: # img is PIL.Image
244
+ img = random_crop(img, high_resolution) # кропим high-res
245
+ imgs.append(tfm(img))
246
+ return torch.stack(imgs)
247
+
248
+ dataloader = DataLoader(
249
+ dataset,
250
+ batch_size=batch_size,
251
+ shuffle=True,
252
+ collate_fn=collate_fn,
253
+ num_workers=num_workers,
254
+ pin_memory=True,
255
+ drop_last=True
256
+ )
257
+
258
+ # --------------------------- Оптимизатор ---------------------------
259
+
260
+ def get_param_groups(module, weight_decay=0.001):
261
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
262
+ decay_params = []
263
+ no_decay_params = []
264
+ for n, p in module.named_parameters():
265
+ if not p.requires_grad:
266
+ continue
267
+ if any(nd in n for nd in no_decay):
268
+ no_decay_params.append(p)
269
+ else:
270
+ decay_params.append(p)
271
+ return [
272
+ {"params": decay_params, "weight_decay": weight_decay},
273
+ {"params": no_decay_params, "weight_decay": 0.0},
274
+ ]
275
+
276
+ def create_optimizer(name, param_groups):
277
+ if name == "adam8bit":
278
+ return bnb.optim.AdamW8bit(
279
+ param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
280
+ )
281
+ raise ValueError(name)
282
+
283
+ param_groups = get_param_groups(trainable_module, weight_decay=0.001)
284
+ optimizer = create_optimizer(optimizer_type, param_groups)
285
+
286
+ # --------------------------- График LR ---------------------------
287
+
288
+ batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
289
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
290
+ total_steps = steps_per_epoch * num_epochs
291
+
292
+ def lr_lambda(step):
293
+ if not use_decay:
294
+ return 1.0
295
+ x = float(step) / float(max(1, total_steps))
296
+ warmup = float(warmup_percent)
297
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
298
+ if x < warmup:
299
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
300
+ decay_ratio = (x - warmup) / (1.0 - warmup)
301
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
302
+
303
+ scheduler = LambdaLR(optimizer, lr_lambda)
304
+
305
+ # Подготовка
306
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
307
+
308
+ # CHANGED: формируем список trainable_params исходя из выбранного trainable_module
309
+ trainable_params = [p for p in (trainable_module.parameters() if hasattr(trainable_module, "parameters") else []) if p.requires_grad]
310
+
311
+ # --------------------------- LPIPS и вспомогательные функции ---------------------------
312
+ _lpips_net = None
313
+
314
+ def _get_lpips():
315
+ global _lpips_net
316
+ if _lpips_net is None:
317
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
318
+ return _lpips_net
319
+
320
+ # Собель для edge loss
321
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
322
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
323
+
324
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
325
+ # x: [B,C,H,W] в [-1,1]
326
+ C = x.shape[1]
327
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
328
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
329
+ gx = F.conv2d(x, kx, padding=1, groups=C)
330
+ gy = F.conv2d(x, ky, padding=1, groups=C)
331
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
332
+
333
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
334
+ class MedianLossNormalizer:
335
+ def __init__(self, desired_ratios: dict, window_steps: int):
336
+ # нормируем доли на случай, если сумма != 1
337
+ s = sum(desired_ratios.values())
338
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
339
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
340
+ self.window = window_steps
341
+
342
+ def update_and_total(self, abs_losses: dict):
343
+ # Заполняем буферы фактическими АБСОЛЮТНЫМИ значениями лоссов
344
+ for k, v in abs_losses.items():
345
+ if k in self.buffers:
346
+ self.buffers[k].append(float(v.detach().abs().cpu()))
347
+ # Медианы (устойчивые к выбросам)
348
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
349
+ # Вычисляем КОЭФФИЦИЕНТЫ как ratio_k / median_k — т.е. именно коэффициенты, а не значения
350
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
351
+ # Итоговый total — сумма по ключам, присутствующим в abs_losses
352
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
353
+ return total, coeffs, meds
354
+
355
+ # CHANGED: создаём нормализатор ПОСЛЕ возможной активации kl_ratio выше
356
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
357
+
358
+ # --------------------------- Сэмплы ---------------------------
359
+ @torch.no_grad()
360
+ def get_fixed_samples(n=3):
361
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
362
+ pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
363
+ tensors = []
364
+ for img in pil_imgs:
365
+ img = random_crop(img, high_resolution) # high-res fixed samples
366
+ tensors.append(tfm(img))
367
+ return torch.stack(tensors).to(accelerator.device, dtype)
368
+
369
+ fixed_samples = get_fixed_samples()
370
+
371
+ @torch.no_grad()
372
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
373
+ # img_tensor: [C,H,W] in [-1,1]
374
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
375
+ return Image.fromarray(arr)
376
+
377
+ @torch.no_grad()
378
+ def generate_and_save_samples(step=None):
379
+ try:
380
+ temp_vae = accelerator.unwrap_model(vae).eval()
381
+ lpips_net = _get_lpips()
382
+ with torch.no_grad():
383
+ # Готовим low-res вход для кодера ВСЕГДА под model_resolution
384
+ orig_high = fixed_samples # [B,C,H,W] в [-1,1]
385
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
386
+ # dtype как у модели
387
+ model_dtype = next(temp_vae.parameters()).dtype
388
+ orig_low = orig_low.to(dtype=model_dtype)
389
+ # encode/decode
390
+ # CHANGED: при валидации/сэмплах всегда используем mean (стабильно и детерминированно)
391
+ enc = temp_vae.encode(orig_low)
392
+ latents_mean = enc.latent_dist.mean
393
+ rec = temp_vae.decode(latents_mean).sample
394
+
395
+ # Приводим spatial размер рекона к high-res (downsample для асимметричных VAE)
396
+ if rec.shape[-2:] != orig_high.shape[-2:]:
397
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
398
+
399
+ # Сохраняем ПЕРВЫЙ семпл: real и decoded без номера шага в имени
400
+ first_real = _to_pil_uint8(orig_high[0])
401
+ first_dec = _to_pil_uint8(rec[0])
402
+ first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
403
+ first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
404
+
405
+ # Дополнительно сохраняем текущие реконструкции без номера шага (чтобы не плодить файлы — будут перезаписываться)
406
+ for i in range(rec.shape[0]):
407
+ _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
408
+
409
+ # LPIPS на полном изображении (high-res) — для лога
410
+ lpips_scores = []
411
+ for i in range(rec.shape[0]):
412
+ orig_full = orig_high[i:i+1].to(torch.float32)
413
+ rec_full = rec[i:i+1].to(torch.float32)
414
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
415
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
416
+ lpips_val = lpips_net(orig_full, rec_full).item()
417
+ lpips_scores.append(lpips_val)
418
+ avg_lpips = float(np.mean(lpips_scores))
419
+
420
+ if use_wandb and accelerator.is_main_process:
421
+ wandb.log({
422
+ "lpips_mean": avg_lpips,
423
+ }, step=step)
424
+ finally:
425
+ gc.collect()
426
+ torch.cuda.empty_cache()
427
+
428
+ if accelerator.is_main_process and save_model:
429
+ print("Генерация сэмплов до старта обучения...")
430
+ generate_and_save_samples(0)
431
+
432
+ accelerator.wait_for_everyone()
433
+
434
+ # --------------------------- Тренировка ---------------------------
435
+
436
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
437
+ global_step = 0
438
+ min_loss = float("inf")
439
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
440
+
441
+ for epoch in range(num_epochs):
442
+ vae.train()
443
+ batch_losses = []
444
+ batch_grads = []
445
+ # Доп. трекинг по отдельным лоссам
446
+ track_losses = {k: [] for k in loss_ratios.keys()}
447
+ for imgs in dataloader:
448
+ with accelerator.accumulate(vae):
449
+ # imgs: high-res tensor from dataloader ([-1,1]), move to device
450
+ imgs = imgs.to(accelerator.device)
451
+
452
+ # ВСЕГДА даунсемплим вход под model_resolution для кодера
453
+ if high_resolution != model_resolution:
454
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
455
+ else:
456
+ imgs_low = imgs
457
+
458
+ # ensure dtype matches model params to avoid float/half mismatch
459
+ model_dtype = next(vae.parameters()).dtype
460
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
461
+
462
+ # Encode/decode
463
+ enc = vae.encode(imgs_low_model)
464
+
465
+ # CHANGED: если тренируем всю модель — используем reparameterization sample()
466
+ # это важно для стохастичности и согласованности с KL.
467
+ latents = enc.latent_dist.sample() if full_training else enc.latent_dist.mean
468
+
469
+ rec = vae.decode(latents).sample # rec может быть увеличенным (асимметричный VAE)
470
+
471
+ # Приводим размер к high-res
472
+ if rec.shape[-2:] != imgs.shape[-2:]:
473
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
474
+
475
+ # Лоссы считаем на high-res
476
+ rec_f32 = rec.to(torch.float32)
477
+ imgs_f32 = imgs.to(torch.float32)
478
+
479
+ # Отдельные лоссы (абсолютные значения)
480
+ abs_losses = {
481
+ "mae": F.l1_loss(rec_f32, imgs_f32),
482
+ "mse": F.mse_loss(rec_f32, imgs_f32),
483
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
484
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
485
+ }
486
+
487
+ # CHANGED: KL-loss добавляется ТОЛЬКО при полном обучении.
488
+ # KL(q(z|x) || N(0,1)) = -0.5 * sum(1 + logσ^2 - μ^2 - σ^2).
489
+ if full_training:
490
+ mean = enc.latent_dist.mean
491
+ logvar = enc.latent_dist.logvar
492
+ # стабильное усреднение по батчу и пространству
493
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
494
+ abs_losses["kl"] = kl
495
+ else:
496
+ # ключ присутствует в ratios, но при partial-training его доля = 0 и он не влияет
497
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
498
+
499
+ # Total с медианными КОЭФФИЦИЕНТАМИ
500
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
501
+
502
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
503
+ print("NaN/Inf loss – stopping")
504
+ raise RuntimeError("NaN/Inf loss")
505
+
506
+ accelerator.backward(total_loss)
507
+
508
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
509
+ if accelerator.sync_gradients:
510
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
511
+ optimizer.step()
512
+ scheduler.step()
513
+ optimizer.zero_grad(set_to_none=True)
514
+
515
+ global_step += 1
516
+ progress.update(1)
517
+
518
+ # --- Логирование ---
519
+ if accelerator.is_main_process:
520
+ try:
521
+ current_lr = optimizer.param_groups[0]["lr"]
522
+ except Exception:
523
+ current_lr = scheduler.get_last_lr()[0]
524
+
525
+ batch_losses.append(total_loss.detach().item())
526
+ # CHANGED: корректно извлекаем scalar из разн. типов
527
+ if isinstance(grad_norm, torch.Tensor):
528
+ batch_grads.append(float(grad_norm.detach().cpu().item()))
529
+ else:
530
+ batch_grads.append(float(grad_norm))
531
+
532
+ for k, v in abs_losses.items():
533
+ track_losses[k].append(float(v.detach().item()))
534
+
535
+ if use_wandb and accelerator.sync_gradients:
536
+ log_dict = {
537
+ "total_loss": float(total_loss.detach().item()),
538
+ "learning_rate": current_lr,
539
+ "epoch": epoch,
540
+ "grad_norm": batch_grads[-1],
541
+ "mode/full_training": int(full_training), # CHANGED: для наглядности в логах
542
+ }
543
+ # добавляем отдельные лоссы
544
+ for k, v in abs_losses.items():
545
+ log_dict[f"loss_{k}"] = float(v.detach().item())
546
+ # логи коэффициентов и медиан
547
+ for k in coeffs:
548
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
549
+ log_dict[f"median_{k}"] = float(meds[k])
550
+ wandb.log(log_dict, step=global_step)
551
+
552
+ # периодические сэмплы и чекпоинты
553
+ if global_step > 0 and global_step % sample_interval == 0:
554
+ if accelerator.is_main_process:
555
+ generate_and_save_samples(global_step)
556
+ accelerator.wait_for_everyone()
557
+
558
+ # Средние по последним итерациям
559
+ n_micro = sample_interval * gradient_accumulation_steps
560
+ if len(batch_losses) >= n_micro:
561
+ avg_loss = float(np.mean(batch_losses[-n_micro:]))
562
+ else:
563
+ avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
564
+
565
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
566
+
567
+ if accelerator.is_main_process:
568
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
569
+ if save_model and avg_loss < min_loss * save_barrier:
570
+ min_loss = avg_loss
571
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
572
+ if use_wandb:
573
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
574
+
575
+ if accelerator.is_main_process:
576
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
577
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
578
+ if use_wandb:
579
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
580
+
581
+ # --------------------------- Финальное сохранение ---------------------------
582
+ if accelerator.is_main_process:
583
+ print("Training finished – saving final model")
584
+ if save_model:
585
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
586
+
587
+ accelerator.free_memory()
588
+ if torch.distributed.is_initialized():
589
+ torch.distributed.destroy_process_group()
590
+ print("Готово!")
train_sdxl_vae_simple.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ from accelerate import Accelerator
18
+ from PIL import Image, UnidentifiedImageError
19
+ from tqdm import tqdm
20
+ import bitsandbytes as bnb
21
+ import wandb
22
+ import lpips # pip install lpips
23
+ from collections import deque
24
+
25
+ # --------------------------- Параметры ---------------------------
26
+ ds_path = "/workspace/png"
27
+ project = "simple_vae"
28
+ batch_size = 3
29
+ base_learning_rate = 5e-5
30
+ min_learning_rate = 9e-7
31
+ num_epochs = 16
32
+ sample_interval_share = 10
33
+ use_wandb = True
34
+ save_model = True
35
+ use_decay = True
36
+ asymmetric = False
37
+ optimizer_type = "adam8bit"
38
+ dtype = torch.float32
39
+ # model_resolution — то, что подавается в VAE (низкое разрешение)
40
+ model_resolution = 512 # бывший `resolution`
41
+ # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
42
+ high_resolution = 512
43
+ limit = 0
44
+ save_barrier = 1.03
45
+ warmup_percent = 0.01
46
+ percentile_clipping = 95
47
+ beta2 = 0.97
48
+ eps = 1e-6
49
+ clip_grad_norm = 1.0
50
+ mixed_precision = "no" # или "fp16"/"bf16" при поддержке
51
+ gradient_accumulation_steps = 5
52
+ generated_folder = "samples"
53
+ save_as = "simple_vae_nightly"
54
+ num_workers = 0
55
+ device = None # accelerator задаст устройство
56
+
57
+ # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
58
+ # Итоговые доли в total loss (сумма = 1.0)
59
+ loss_ratios = {
60
+ "lpips": 0.85,
61
+ "edge": 0.05,
62
+ "mse": 0.05,
63
+ "mae": 0.05,
64
+ }
65
+ median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
66
+
67
+ # --------------------------- параметры препроцессинга ---------------------------
68
+ resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1280
69
+
70
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
71
+
72
+ accelerator = Accelerator(
73
+ mixed_precision=mixed_precision,
74
+ gradient_accumulation_steps=gradient_accumulation_steps
75
+ )
76
+ device = accelerator.device
77
+
78
+ # reproducibility
79
+ seed = int(datetime.now().strftime("%Y%m%d"))
80
+ torch.manual_seed(seed)
81
+ np.random.seed(seed)
82
+ random.seed(seed)
83
+
84
+ torch.backends.cudnn.benchmark = True
85
+
86
+ # --------------------------- WandB ---------------------------
87
+ if use_wandb and accelerator.is_main_process:
88
+ wandb.init(project=project, config={
89
+ "batch_size": batch_size,
90
+ "base_learning_rate": base_learning_rate,
91
+ "num_epochs": num_epochs,
92
+ "optimizer_type": optimizer_type,
93
+ "model_resolution": model_resolution,
94
+ "high_resolution": high_resolution,
95
+ "gradient_accumulation_steps": gradient_accumulation_steps,
96
+ })
97
+
98
+ # --------------------------- VAE ---------------------------
99
+ if model_resolution==high_resolution and not asymmetric:
100
+ vae = AutoencoderKL.from_pretrained(project).to(dtype)
101
+ else:
102
+ vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
103
+
104
+ # torch.compile (если доступно) — просто и без лишней логики
105
+ if hasattr(torch, "compile"):
106
+ try:
107
+ vae = torch.compile(vae)
108
+ except Exception as e:
109
+ print(f"[WARN] torch.compile failed: {e}")
110
+
111
+ # >>> Заморозка всех параметров, затем выборочная разморозка
112
+ for p in vae.parameters():
113
+ p.requires_grad = False
114
+
115
+ decoder = getattr(vae, "decoder", None)
116
+ if decoder is None:
117
+ raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
118
+
119
+ unfrozen_param_names = []
120
+
121
+ if not hasattr(decoder, "up_blocks"):
122
+ raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
123
+
124
+ # >>> Размораживаем все up_blocks и mid_block (как было в твоём варианте start_idx=0)
125
+ n_up = len(decoder.up_blocks)
126
+ start_idx = 0
127
+ for idx in range(start_idx, n_up):
128
+ block = decoder.up_blocks[idx]
129
+ for name, p in block.named_parameters():
130
+ p.requires_grad = True
131
+ unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
132
+
133
+ if hasattr(decoder, "mid_block"):
134
+ for name, p in decoder.mid_block.named_parameters():
135
+ p.requires_grad = True
136
+ unfrozen_param_names.append(f"decoder.mid_block.{name}")
137
+ else:
138
+ print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
139
+
140
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
141
+ for nm in unfrozen_param_names[:200]:
142
+ print(" ", nm)
143
+
144
+ # сохраняем trainable_module (get_param_groups будет учитывать p.requires_grad)
145
+ trainable_module = vae.decoder
146
+
147
+ # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
148
+ class PngFolderDataset(Dataset):
149
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
150
+ self.root_dir = root_dir
151
+ self.resolution = resolution
152
+ self.paths = []
153
+ # collect png files recursively
154
+ for root, _, files in os.walk(root_dir):
155
+ for fname in files:
156
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
157
+ self.paths.append(os.path.join(root, fname))
158
+ # optional limit
159
+ if limit:
160
+ self.paths = self.paths[:limit]
161
+ # verify images and keep only valid ones
162
+ valid = []
163
+ for p in self.paths:
164
+ try:
165
+ with Image.open(p) as im:
166
+ im.verify() # fast check for truncated/corrupted images
167
+ valid.append(p)
168
+ except (OSError, UnidentifiedImageError):
169
+ # skip corrupted image
170
+ continue
171
+ self.paths = valid
172
+ if len(self.paths) == 0:
173
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
174
+ # final shuffle for randomness
175
+ random.shuffle(self.paths)
176
+
177
+ def __len__(self):
178
+ return len(self.paths)
179
+
180
+ def __getitem__(self, idx):
181
+ p = self.paths[idx % len(self.paths)]
182
+ # open and convert to RGB; ensure file is closed promptly
183
+ with Image.open(p) as img:
184
+ img = img.convert("RGB")
185
+ # пережимаем длинную сторону до resize_long_side (Lanczos)
186
+ if not resize_long_side or resize_long_side <= 0:
187
+ return img
188
+ w, h = img.size
189
+ long = max(w, h)
190
+ if long <= resize_long_side:
191
+ return img
192
+ scale = resize_long_side / float(long)
193
+ new_w = int(round(w * scale))
194
+ new_h = int(round(h * scale))
195
+ return img.resize((new_w, new_h), Image.LANCZOS)
196
+
197
+ # --------------------------- Датасет и трансформы ---------------------------
198
+
199
+ def random_crop(img, sz):
200
+ w, h = img.size
201
+ if w < sz or h < sz:
202
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
203
+ x = random.randint(0, max(1, img.width - sz))
204
+ y = random.randint(0, max(1, img.height - sz))
205
+ return img.crop((x, y, x + sz, y + sz))
206
+
207
+ tfm = transforms.Compose([
208
+ transforms.ToTensor(),
209
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
210
+ ])
211
+
212
+ # build dataset using high_resolution crops
213
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
214
+ if len(dataset) < batch_size:
215
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
216
+
217
+ # collate_fn кропит до high_resolution
218
+
219
+ def collate_fn(batch):
220
+ imgs = []
221
+ for img in batch: # img is PIL.Image
222
+ img = random_crop(img, high_resolution) # кропим high-res
223
+ imgs.append(tfm(img))
224
+ return torch.stack(imgs)
225
+
226
+ dataloader = DataLoader(
227
+ dataset,
228
+ batch_size=batch_size,
229
+ shuffle=True,
230
+ collate_fn=collate_fn,
231
+ num_workers=num_workers,
232
+ pin_memory=True,
233
+ drop_last=True
234
+ )
235
+
236
+ # --------------------------- Оптимизатор ---------------------------
237
+
238
+ def get_param_groups(module, weight_decay=0.001):
239
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
240
+ decay_params = []
241
+ no_decay_params = []
242
+ for n, p in module.named_parameters():
243
+ if not p.requires_grad:
244
+ continue
245
+ if any(nd in n for nd in no_decay):
246
+ no_decay_params.append(p)
247
+ else:
248
+ decay_params.append(p)
249
+ return [
250
+ {"params": decay_params, "weight_decay": weight_decay},
251
+ {"params": no_decay_params, "weight_decay": 0.0},
252
+ ]
253
+
254
+ def create_optimizer(name, param_groups):
255
+ if name == "adam8bit":
256
+ return bnb.optim.AdamW8bit(
257
+ param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
258
+ )
259
+ raise ValueError(name)
260
+
261
+ param_groups = get_param_groups(trainable_module, weight_decay=0.001)
262
+ optimizer = create_optimizer(optimizer_type, param_groups)
263
+
264
+ # --------------------------- Подготовка Accelerate (вместе) ---------------------------
265
+
266
+ batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
267
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
268
+ total_steps = steps_per_epoch * num_epochs
269
+
270
+
271
+ def lr_lambda(step):
272
+ if not use_decay:
273
+ return 1.0
274
+ x = float(step) / float(max(1, total_steps))
275
+ warmup = float(warmup_percent)
276
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
277
+ if x < warmup:
278
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
279
+ decay_ratio = (x - warmup) / (1.0 - warmup)
280
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
281
+
282
+ scheduler = LambdaLR(optimizer, lr_lambda)
283
+
284
+ # Подготовка
285
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
286
+
287
+ trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad]
288
+
289
+ # --------------------------- LPIPS и вспомогательные функции ---------------------------
290
+ _lpips_net = None
291
+
292
+ def _get_lpips():
293
+ global _lpips_net
294
+ if _lpips_net is None:
295
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
296
+ return _lpips_net
297
+
298
+ # Собель для edge loss
299
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
300
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
301
+
302
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
303
+ # x: [B,C,H,W] в [-1,1]
304
+ C = x.shape[1]
305
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
306
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
307
+ gx = F.conv2d(x, kx, padding=1, groups=C)
308
+ gy = F.conv2d(x, ky, padding=1, groups=C)
309
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
310
+
311
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
312
+ class MedianLossNormalizer:
313
+ def __init__(self, desired_ratios: dict, window_steps: int):
314
+ # нормируем доли на случай, если сумма != 1
315
+ s = sum(desired_ratios.values())
316
+ self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
317
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
318
+ self.window = window_steps
319
+
320
+ def update_and_total(self, abs_losses: dict):
321
+ # Заполняем буферы фактическими АБСОЛЮТНЫМИ значениями лоссов
322
+ for k, v in abs_losses.items():
323
+ if k in self.buffers:
324
+ self.buffers[k].append(float(v.detach().cpu()))
325
+ # Медианы (устойчивые к выбросам)
326
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
327
+ # Вычисляем КОЭФФИЦИЕНТЫ как ratio_k / median_k — т.е. именно коэффициенты, а не значения
328
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
329
+ # Важно: при таких коэффициентах сумма (coeff_k * median_k) = сумма(ratio_k) = 1, т.е. масштаб стабилен
330
+ total = sum(coeffs[k] * abs_losses[k] for k in coeffs)
331
+ return total, coeffs, meds
332
+
333
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
334
+
335
+ # --------------------------- Сэмплы ---------------------------
336
+ @torch.no_grad()
337
+ def get_fixed_samples(n=3):
338
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
339
+ pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
340
+ tensors = []
341
+ for img in pil_imgs:
342
+ img = random_crop(img, high_resolution) # high-res fixed samples
343
+ tensors.append(tfm(img))
344
+ return torch.stack(tensors).to(accelerator.device, dtype)
345
+
346
+ fixed_samples = get_fixed_samples()
347
+
348
+ @torch.no_grad()
349
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
350
+ # img_tensor: [C,H,W] in [-1,1]
351
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
352
+ return Image.fromarray(arr)
353
+
354
+ @torch.no_grad()
355
+ def generate_and_save_samples(step=None):
356
+ try:
357
+ temp_vae = accelerator.unwrap_model(vae).eval()
358
+ lpips_net = _get_lpips()
359
+ with torch.no_grad():
360
+ # Готовим low-res вход для кодера ВСЕГДА под model_resolution
361
+ orig_high = fixed_samples # [B,C,H,W] в [-1,1]
362
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
363
+ # dtype как у модели
364
+ model_dtype = next(temp_vae.parameters()).dtype
365
+ orig_low = orig_low.to(dtype=model_dtype)
366
+ # encode/decode
367
+ latents = temp_vae.encode(orig_low).latent_dist.mean
368
+ rec = temp_vae.decode(latents).sample
369
+
370
+ # Приводим spatial размер рекона к high-res (downsample для асимметричных VAE)
371
+ if rec.shape[-2:] != orig_high.shape[-2:]:
372
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
373
+
374
+ # Сохраняем ПЕРВЫЙ семпл: real и decoded без номера шага в имени
375
+ first_real = _to_pil_uint8(orig_high[0])
376
+ first_dec = _to_pil_uint8(rec[0])
377
+ first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
378
+ first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
379
+
380
+ # Дополнительно сохраняем текущие реконструкции без номера шага (чтобы не плодить файлы — будут перезаписываться)
381
+ for i in range(rec.shape[0]):
382
+ _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
383
+
384
+ # LPIPS на полном изображении (high-res) — для лога
385
+ lpips_scores = []
386
+ for i in range(rec.shape[0]):
387
+ orig_full = orig_high[i:i+1].to(torch.float32)
388
+ rec_full = rec[i:i+1].to(torch.float32)
389
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
390
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
391
+ lpips_val = lpips_net(orig_full, rec_full).item()
392
+ lpips_scores.append(lpips_val)
393
+ avg_lpips = float(np.mean(lpips_scores))
394
+
395
+ if use_wandb and accelerator.is_main_process:
396
+ wandb.log({
397
+ "lpips_mean": avg_lpips,
398
+ }, step=step)
399
+ finally:
400
+ gc.collect()
401
+ torch.cuda.empty_cache()
402
+
403
+ if accelerator.is_main_process and save_model:
404
+ print("Генерация сэмплов до старта обучения...")
405
+ generate_and_save_samples(0)
406
+
407
+ accelerator.wait_for_everyone()
408
+
409
+ # --------------------------- Тренировка ---------------------------
410
+
411
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
412
+ global_step = 0
413
+ min_loss = float("inf")
414
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
415
+
416
+ for epoch in range(num_epochs):
417
+ vae.train()
418
+ batch_losses = []
419
+ batch_grads = []
420
+ # Доп. трекинг по отдельным лоссам
421
+ track_losses = {k: [] for k in loss_ratios.keys()}
422
+ for imgs in dataloader:
423
+ with accelerator.accumulate(vae):
424
+ # imgs: high-res tensor from dataloader ([-1,1]), move to device
425
+ imgs = imgs.to(accelerator.device)
426
+
427
+ # ВСЕГДА даунсемплим вход под model_resolution для кодера
428
+ # Тупая железяка норовит все по своему сделать
429
+ if high_resolution != model_resolution:
430
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
431
+ else:
432
+ imgs_low = imgs
433
+
434
+ # ensure dtype matches model params to avoid float/half mismatch
435
+ model_dtype = next(vae.parameters()).dtype
436
+ if imgs_low.dtype != model_dtype:
437
+ imgs_low_model = imgs_low.to(dtype=model_dtype)
438
+ else:
439
+ imgs_low_model = imgs_low
440
+
441
+ # Encode/decode
442
+ latents = vae.encode(imgs_low_model).latent_dist.mean
443
+ rec = vae.decode(latents).sample # rec может быть увеличенным (асимметричный VAE)
444
+
445
+ # Приводим размер к high-res
446
+ if rec.shape[-2:] != imgs.shape[-2:]:
447
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
448
+
449
+ # Лоссы считаем на high-res
450
+ rec_f32 = rec.to(torch.float32)
451
+ imgs_f32 = imgs.to(torch.float32)
452
+
453
+ # Отдельные лоссы
454
+ abs_losses = {
455
+ "mae": F.l1_loss(rec_f32, imgs_f32),
456
+ "mse": F.mse_loss(rec_f32, imgs_f32),
457
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
458
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
459
+ }
460
+
461
+ # Total с медианными КОЭФФИЦИЕНТАМИ
462
+ # Не надо так орать когда у тебя получилось понять мою идею
463
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
464
+
465
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
466
+ print("NaN/Inf loss – stopping")
467
+ raise RuntimeError("NaN/Inf loss")
468
+
469
+ accelerator.backward(total_loss)
470
+
471
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
472
+ if accelerator.sync_gradients:
473
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
474
+ optimizer.step()
475
+ scheduler.step()
476
+ optimizer.zero_grad(set_to_none=True)
477
+
478
+ global_step += 1
479
+ progress.update(1)
480
+
481
+ # --- Логирование ---
482
+ if accelerator.is_main_process:
483
+ try:
484
+ current_lr = optimizer.param_groups[0]["lr"]
485
+ except Exception:
486
+ current_lr = scheduler.get_last_lr()[0]
487
+
488
+ batch_losses.append(total_loss.detach().item())
489
+ batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item()))
490
+ for k, v in abs_losses.items():
491
+ track_losses[k].append(float(v.detach().item()))
492
+
493
+ if use_wandb and accelerator.sync_gradients:
494
+ log_dict = {
495
+ "total_loss": float(total_loss.detach().item()),
496
+ "learning_rate": current_lr,
497
+ "epoch": epoch,
498
+ "grad_norm": batch_grads[-1],
499
+ }
500
+ # добавляем отдельные лоссы
501
+ for k, v in abs_losses.items():
502
+ log_dict[f"loss_{k}"] = float(v.detach().item())
503
+ # логи коэффициентов и медиан
504
+ for k in coeffs:
505
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
506
+ log_dict[f"median_{k}"] = float(meds[k])
507
+ wandb.log(log_dict, step=global_step)
508
+
509
+ # периодические сэмплы и чекпоинты
510
+ if global_step > 0 and global_step % sample_interval == 0:
511
+ if accelerator.is_main_process:
512
+ generate_and_save_samples(global_step)
513
+ accelerator.wait_for_everyone()
514
+
515
+ # Средние по последним итерациям
516
+ n_micro = sample_interval * gradient_accumulation_steps
517
+ if len(batch_losses) >= n_micro:
518
+ avg_loss = float(np.mean(batch_losses[-n_micro:]))
519
+ else:
520
+ avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
521
+
522
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
523
+
524
+ if accelerator.is_main_process:
525
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
526
+ if save_model and avg_loss < min_loss * save_barrier:
527
+ min_loss = avg_loss
528
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
529
+ if use_wandb:
530
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
531
+
532
+ if accelerator.is_main_process:
533
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
534
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
535
+ if use_wandb:
536
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
537
+
538
+ # --------------------------- Финальное сохранение ---------------------------
539
+ if accelerator.is_main_process:
540
+ print("Training finished – saving final model")
541
+ if save_model:
542
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
543
+
544
+ accelerator.free_memory()
545
+ if torch.distributed.is_initialized():
546
+ torch.distributed.destroy_process_group()
547
+ print("Готово!")
vaetest/001_all.png DELETED

Git LFS Details

  • SHA256: aa55dc8c99cbe4fd4d7495772bab57a962dc53736c75dcbb508d204a1a1ad701
  • Pointer size: 132 Bytes
  • Size of remote file: 2.59 MB
vaetest/001_decoded_AiArtLab_sdxl_vae.png DELETED

Git LFS Details

  • SHA256: 1b9f93a7b6195bf3d706d61a63355626e975dcf610583c942f997b603c988f87
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
vaetest/001_decoded_AiArtLab_sdxlvae_nightly.png DELETED

Git LFS Details

  • SHA256: baa7025808724a7715641891e699a343a25968f56c24c36694f81263e9c2c6f9
  • Pointer size: 131 Bytes
  • Size of remote file: 357 kB
vaetest/001_decoded_AiArtLab_sdxs.png DELETED

Git LFS Details

  • SHA256: ed2340eb02c0fe96c6438b905a63da9308c2bd0adf3bddbb053647434ba82fd8
  • Pointer size: 131 Bytes
  • Size of remote file: 381 kB
vaetest/001_decoded_FLUX.1_schnell_vae.png DELETED

Git LFS Details

  • SHA256: 16af93e83aa84ab4e477891f4f595666bafb9c44acbbac341db8b576cf713005
  • Pointer size: 131 Bytes
  • Size of remote file: 389 kB
vaetest/001_decoded_KBlueLeaf_EQ_SDXL_VAE.png DELETED

Git LFS Details

  • SHA256: 2d3504a3365f8500e8a1068525169c442663da2e324796b49c6c9d5d4a724622
  • Pointer size: 131 Bytes
  • Size of remote file: 363 kB
vaetest/001_decoded_madebyollin_sdxl_vae_fp16.png DELETED

Git LFS Details

  • SHA256: 09191be6311a3425a720977808fd5f31ee11666ce9c65f6f67650289396c36c3
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB
vaetest/001_decoded_vae.png DELETED

Git LFS Details

  • SHA256: d8164f7782d33a14e907464f8b230786cd63be4316f0606abe14aad1d24f6f98
  • Pointer size: 131 Bytes
  • Size of remote file: 404 kB
vaetest/001_decoded_vae_nightly.png DELETED

Git LFS Details

  • SHA256: ba63b12c217c962584607045b5149d339049adbf205ee41d350528c2f3169409
  • Pointer size: 131 Bytes
  • Size of remote file: 288 kB
vaetest/001_orig.png DELETED

Git LFS Details

  • SHA256: d0daf118faf2b6f8165f6cf04cd79a0ba1c404856552ea69a59e6aae8dccd884
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
vaetest/002_all.png DELETED

Git LFS Details

  • SHA256: 475e9c73a47540fcb28f91b6a6cf6a0b752e073169b68af50ae4bf543ac83021
  • Pointer size: 132 Bytes
  • Size of remote file: 4.09 MB
vaetest/002_decoded_AiArtLab_sdxl_vae.png DELETED

Git LFS Details

  • SHA256: 649be9fbb2eb0ab4db67576e5568301835a08f7ae9b0f4898c657e13f180b051
  • Pointer size: 131 Bytes
  • Size of remote file: 560 kB
vaetest/002_decoded_AiArtLab_sdxlvae_nightly.png DELETED

Git LFS Details

  • SHA256: 11fd9c9d672de894863aec20ac0bb23e9c4261414174ee31c06faeefe2696205
  • Pointer size: 131 Bytes
  • Size of remote file: 542 kB
vaetest/002_decoded_AiArtLab_sdxs.png DELETED

Git LFS Details

  • SHA256: b6cffe13d640816e5b6c49b554f8e61c12f117c6547445e22ce23d996ba14183
  • Pointer size: 131 Bytes
  • Size of remote file: 592 kB
vaetest/002_decoded_FLUX.1_schnell_vae.png DELETED

Git LFS Details

  • SHA256: ec01ab749570420d8351ebfeb5d02f608e48e71c2f955050116b9bac8a3f7de6
  • Pointer size: 131 Bytes
  • Size of remote file: 619 kB
vaetest/002_decoded_KBlueLeaf_EQ_SDXL_VAE.png DELETED

Git LFS Details

  • SHA256: a062c7f9febaf5d7549f74170dc6e73a8f5e200cee20ebb84aee8a5e1410c35e
  • Pointer size: 131 Bytes
  • Size of remote file: 544 kB
vaetest/002_decoded_madebyollin_sdxl_vae_fp16.png DELETED

Git LFS Details

  • SHA256: 2531c034439e6aabf0b9bc470f3693ae943e17ebad8472db13294fa632f0edb1
  • Pointer size: 131 Bytes
  • Size of remote file: 586 kB
vaetest/002_decoded_vae.png DELETED

Git LFS Details

  • SHA256: 8e983f62827b886b083b787ffd0fe47a2f5362153590ead7ee6b3cd0bc4fa2bb
  • Pointer size: 131 Bytes
  • Size of remote file: 340 kB
vaetest/002_decoded_vae_nightly.png DELETED

Git LFS Details

  • SHA256: e5b39b25ddd4b3bd89a8dac3840533627d39942c78325b914c080ad32f01d592
  • Pointer size: 131 Bytes
  • Size of remote file: 338 kB
vaetest/002_orig.png DELETED

Git LFS Details

  • SHA256: 6252c9953b41a1f806b23926fe0e50dfc73558455f450a0755fef2a74ef31a8f
  • Pointer size: 131 Bytes
  • Size of remote file: 620 kB
vaetest/003_all.png DELETED

Git LFS Details

  • SHA256: 548d2b84bfd5de79de3f1a2fa1f8a38623f5331ed419ba2c851750505c1db3b0
  • Pointer size: 132 Bytes
  • Size of remote file: 3.06 MB
vaetest/003_decoded_AiArtLab_sdxl_vae.png DELETED

Git LFS Details

  • SHA256: a70f345e9676b449bc62dc93db74d977259b23afd2bdd2e4333c5c88c2176e57
  • Pointer size: 131 Bytes
  • Size of remote file: 422 kB
vaetest/003_decoded_AiArtLab_sdxlvae_nightly.png DELETED

Git LFS Details

  • SHA256: 1bac54a74e3fecd15adf84ab7b5817234d48ac492fef8ba77164830b8b1bafad
  • Pointer size: 131 Bytes
  • Size of remote file: 414 kB
vaetest/003_decoded_AiArtLab_sdxs.png DELETED

Git LFS Details

  • SHA256: 806f6fd01137f27cee2046a5d62c6e0194296e0ba2868fc6040791ad4d36969b
  • Pointer size: 131 Bytes
  • Size of remote file: 448 kB
vaetest/003_decoded_FLUX.1_schnell_vae.png DELETED

Git LFS Details

  • SHA256: 5dce20c8229031ff33319f27e0f09acb488a40afe6710016e237c51dc325bd1b
  • Pointer size: 131 Bytes
  • Size of remote file: 463 kB
vaetest/003_decoded_KBlueLeaf_EQ_SDXL_VAE.png DELETED

Git LFS Details

  • SHA256: ae89ef0f536aace93f101b20c0885ffc639bbf10c890aad3936170603517b859
  • Pointer size: 131 Bytes
  • Size of remote file: 424 kB
vaetest/003_decoded_madebyollin_sdxl_vae_fp16.png DELETED

Git LFS Details

  • SHA256: 2657105feae4af7a1d524679fce56f2de3add1d858ca2939480ca1e62f268f40
  • Pointer size: 131 Bytes
  • Size of remote file: 446 kB
vaetest/003_decoded_vae.png DELETED

Git LFS Details

  • SHA256: 7dbd61f1d41e41c8482a0b301fca352406b6744adbb2412a4832be0845cf62b8
  • Pointer size: 131 Bytes
  • Size of remote file: 517 kB
vaetest/003_decoded_vae_nightly.png DELETED

Git LFS Details

  • SHA256: 88ed782d6e7b0b27dfa87ad57dc7de2e25404d13c2347569033b6b0f185174a4
  • Pointer size: 131 Bytes
  • Size of remote file: 337 kB
vaetest/003_orig.png DELETED

Git LFS Details

  • SHA256: b28bbe7f5fb14e036e58087158aa36abd09b9ea69613ceda0a9951fb2ff4219d
  • Pointer size: 131 Bytes
  • Size of remote file: 443 kB
vaetest/004_all.png DELETED

Git LFS Details

  • SHA256: 80bba4f7da8537d426fe10805dcbb90e96a5b401d667419ca8c396f8e5dccf7a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
vaetest/004_decoded_AiArtLab_sdxl_vae.png DELETED

Git LFS Details

  • SHA256: 847a6451529c7437200783b5eddeeec43a5e4d43ef657d3e81e867b9fcee59e2
  • Pointer size: 131 Bytes
  • Size of remote file: 218 kB
vaetest/004_decoded_AiArtLab_sdxlvae_nightly.png DELETED

Git LFS Details

  • SHA256: 8f9fc4c75fcb6a8a07f0b08c81e9d1f0697b08de8d05302cbbb3da53efdb5b58
  • Pointer size: 131 Bytes
  • Size of remote file: 217 kB
vaetest/004_decoded_AiArtLab_sdxs.png DELETED

Git LFS Details

  • SHA256: 7a23c25b945939150e2c3bb2fb2478cea1ac5aa862d134fbc43763e7f77c6b27
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
vaetest/004_decoded_FLUX.1_schnell_vae.png DELETED

Git LFS Details

  • SHA256: 124f946ba8c81ab12d2e7c9f93af9101e8dfe94c5ac90c59e4d37e1b4b8299df
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB
vaetest/004_decoded_KBlueLeaf_EQ_SDXL_VAE.png DELETED

Git LFS Details

  • SHA256: 032ff52fd8ce44df4cb1a51133530d7df7b9248cb904dc0bf0bce437df871076
  • Pointer size: 131 Bytes
  • Size of remote file: 240 kB
vaetest/004_decoded_madebyollin_sdxl_vae_fp16.png DELETED

Git LFS Details

  • SHA256: 338c1131ef5c3591bd1e6307d041d347bcf03a8f041f41deaf5c3e311920fde8
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
vaetest/004_decoded_vae.png DELETED

Git LFS Details

  • SHA256: f49815ab78cb882c2f5e04ab6f3c2f3af691d7b70b6c15425f80c22978d4b34a
  • Pointer size: 131 Bytes
  • Size of remote file: 413 kB