recoilme commited on
Commit
7434657
·
1 Parent(s): f288df0
eval_alchemist.py CHANGED
@@ -15,8 +15,8 @@ DTYPE = torch.float16
15
  IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
16
  MIN_SIZE = 1280
17
  CROP_SIZE = 512
18
- BATCH_SIZE = 1
19
- MAX_IMAGES = 100
20
  NUM_WORKERS = 4
21
  NUM_SAMPLES_TO_SAVE = 2 # Сколько примеров сохранить (0 - не сохранять)
22
  SAMPLES_FOLDER = "vaetest"
@@ -32,9 +32,10 @@ VAE_LIST = [
32
  # ("Lightricks/LTX-Video", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
33
  # ("Wan2.2-TI2V-5B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
34
  # ("Wan2.2-T2V-A14B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
35
- ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
36
- # ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
37
- # ("simple_vae", AutoencoderKL, "/workspace/sdxl_vae/simple_vae", None),
 
38
  ("simple_vae_nightly", AutoencoderKL, "/workspace/sdxl_vae/simple_vae_nightly", None),
39
 
40
  ]
 
15
  IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
16
  MIN_SIZE = 1280
17
  CROP_SIZE = 512
18
+ BATCH_SIZE = 10
19
+ MAX_IMAGES = 0
20
  NUM_WORKERS = 4
21
  NUM_SAMPLES_TO_SAVE = 2 # Сколько примеров сохранить (0 - не сохранять)
22
  SAMPLES_FOLDER = "vaetest"
 
32
  # ("Lightricks/LTX-Video", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
33
  # ("Wan2.2-TI2V-5B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
34
  # ("Wan2.2-T2V-A14B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
35
+ # ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
36
+ ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
37
+ ("simple_vae", AutoencoderKL, "AiArtLab/simplevae", "vae"),
38
+ ("simple_vae2", AutoencoderKL, "AiArtLab/simplevae", None),
39
  ("simple_vae_nightly", AutoencoderKL, "/workspace/sdxl_vae/simple_vae_nightly", None),
40
 
41
  ]
eval_alchemist2.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from typing import Dict, List, Tuple, Optional, Any
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop
14
+ from torchvision.utils import save_image
15
+ import lpips
16
+
17
+ from diffusers import (
18
+ AutoencoderKL,
19
+ AutoencoderKLWan,
20
+ AutoencoderKLLTXVideo,
21
+ AutoencoderKLQwenImage
22
+ )
23
+
24
+ from scipy.stats import skew, kurtosis
25
+
26
+
27
+ # ========================== Конфиг ==========================
28
+ DEVICE = "cuda"
29
+ DTYPE = torch.float16
30
+ IMAGE_FOLDER = "/home/recoilme/dataset/alchemist"
31
+ MIN_SIZE = 1280
32
+ CROP_SIZE = 512
33
+ BATCH_SIZE = 10
34
+ MAX_IMAGES = 500
35
+ NUM_WORKERS = 4
36
+ SAMPLES_DIR = "vaetest"
37
+
38
+ VAE_LIST = [
39
+ # ("SD15 VAE", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
40
+ # ("SDXL VAE fp16 fix", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
41
+ ("Wan2.2-TI2V-5B", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
42
+ ("Wan2.2-T2V-A14B", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
43
+ #("SimpleVAE1", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly"),
44
+ #("SimpleVAE2", AutoencoderKL, "/home/recoilme/simplevae/simplevae", "simple_vae_nightly2"),
45
+ #("SimpleVAE nightly", AutoencoderKL, "AiArtLab/simplevae", "simple_vae_nightly"),
46
+ #("FLUX.1-schnell VAE", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
47
+ # ("LTX-Video VAE", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
48
+ ("QwenImage", AutoencoderKLQwenImage, "Qwen/Qwen-Image", "vae"),
49
+ ]
50
+
51
+
52
+ # ========================== Утилиты ==========================
53
+ def to_neg1_1(x: torch.Tensor) -> torch.Tensor:
54
+ return x * 2 - 1
55
+
56
+
57
+ def to_0_1(x: torch.Tensor) -> torch.Tensor:
58
+ return (x + 1) * 0.5
59
+
60
+
61
+ def safe_psnr(mse: float) -> float:
62
+ if mse <= 1e-12:
63
+ return float("inf")
64
+ return 10.0 * float(np.log10(1.0 / mse))
65
+
66
+
67
+ def is_video_like_vae(vae) -> bool:
68
+ # Wan и LTX-Video ждут [B, C, T, H, W]
69
+ return isinstance(vae, (AutoencoderKLWan, AutoencoderKLLTXVideo,AutoencoderKLQwenImage))
70
+
71
+
72
+ def add_time_dim_if_needed(x: torch.Tensor, vae) -> torch.Tensor:
73
+ if is_video_like_vae(vae) and x.ndim == 4:
74
+ return x.unsqueeze(2) # -> [B, C, 1, H, W]
75
+ return x
76
+
77
+
78
+ def strip_time_dim_if_possible(x: torch.Tensor, vae) -> torch.Tensor:
79
+ if is_video_like_vae(vae) and x.ndim == 5 and x.shape[2] == 1:
80
+ return x.squeeze(2) # -> [B, C, H, W]
81
+ return x
82
+
83
+
84
+ @torch.no_grad()
85
+ def sobel_edge_l1(real_0_1: torch.Tensor, fake_0_1: torch.Tensor) -> float:
86
+ real = to_neg1_1(real_0_1)
87
+ fake = to_neg1_1(fake_0_1)
88
+ kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3)
89
+ ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32, device=real.device).view(1, 1, 3, 3)
90
+ C = real.shape[1]
91
+ kx = kx.to(real.dtype).repeat(C, 1, 1, 1)
92
+ ky = ky.to(real.dtype).repeat(C, 1, 1, 1)
93
+
94
+ def grad_mag(x):
95
+ gx = F.conv2d(x, kx, padding=1, groups=C)
96
+ gy = F.conv2d(x, ky, padding=1, groups=C)
97
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
98
+
99
+ return F.l1_loss(grad_mag(fake), grad_mag(real)).item()
100
+
101
+
102
+ def flatten_channels(x: torch.Tensor) -> torch.Tensor:
103
+ # -> [C, N*H*W] или [C, N*T*H*W]
104
+ if x.ndim == 4:
105
+ return x.permute(1, 0, 2, 3).reshape(x.shape[1], -1)
106
+ elif x.ndim == 5:
107
+ return x.permute(1, 0, 2, 3, 4).reshape(x.shape[1], -1)
108
+ else:
109
+ raise ValueError(f"Unexpected tensor ndim={x.ndim}")
110
+
111
+
112
+ def _to_numpy_1d(x: Any) -> Optional[np.ndarray]:
113
+ if x is None:
114
+ return None
115
+ if isinstance(x, (int, float)):
116
+ return None
117
+ if isinstance(x, torch.Tensor):
118
+ x = x.detach().cpu().float().numpy()
119
+ elif isinstance(x, (list, tuple)):
120
+ x = np.array(x, dtype=np.float32)
121
+ elif isinstance(x, np.ndarray):
122
+ x = x.astype(np.float32, copy=False)
123
+ else:
124
+ return None
125
+ x = x.reshape(-1)
126
+ return x
127
+
128
+
129
+ def _to_float(x: Any) -> Optional[float]:
130
+ if x is None:
131
+ return None
132
+ if isinstance(x, (int, float)):
133
+ return float(x)
134
+ if isinstance(x, np.ndarray) and x.size == 1:
135
+ return float(x.item())
136
+ if isinstance(x, torch.Tensor) and x.numel() == 1:
137
+ return float(x.item())
138
+ return None
139
+
140
+
141
+ def get_norm_tensors_and_summary(vae, latent_like: torch.Tensor):
142
+ """
143
+ Нормализация латентов: глобальная и поканальная.
144
+ Применение: сначала глобальная (scalar), затем поканальная (vector).
145
+ Если в конфиге есть несколько ключей — аккумулируем.
146
+ """
147
+ cfg = getattr(vae, "config", vae)
148
+
149
+ scale_keys = [
150
+ "latents_std"
151
+ ]
152
+ shift_keys = [
153
+ "latents_mean"
154
+ ]
155
+
156
+ C = latent_like.shape[1]
157
+ nd = latent_like.ndim # 4 или 5
158
+ dev = latent_like.device
159
+ dt = latent_like.dtype
160
+
161
+ scale_global = getattr(vae.config, "scaling_factor", 1.0)
162
+ shift_global = getattr(vae.config, "shift_factor", 0.0)
163
+ if scale_global is None:
164
+ scale_global = 1.0
165
+ if shift_global is None:
166
+ shift_global = 0.0
167
+
168
+ scale_channel = np.ones(C, dtype=np.float32)
169
+ shift_channel = np.zeros(C, dtype=np.float32)
170
+
171
+ for k in scale_keys:
172
+ v = getattr(cfg, k, None)
173
+ if v is None:
174
+ continue
175
+ vec = _to_numpy_1d(v)
176
+ if vec is not None and vec.size == C:
177
+ scale_channel *= vec
178
+ else:
179
+ s = _to_float(v)
180
+ if s is not None:
181
+ scale_global *= s
182
+
183
+ for k in shift_keys:
184
+ v = getattr(cfg, k, None)
185
+ if v is None:
186
+ continue
187
+ vec = _to_numpy_1d(v)
188
+ if vec is not None and vec.size == C:
189
+ shift_channel += vec
190
+ else:
191
+ s = _to_float(v)
192
+ if s is not None:
193
+ shift_global += s
194
+
195
+ g_shape = [1] * nd
196
+ c_shape = [1] * nd
197
+ c_shape[1] = C
198
+
199
+ t_scale_g = torch.tensor(scale_global, dtype=dt, device=dev).view(*g_shape)
200
+ t_shift_g = torch.tensor(shift_global, dtype=dt, device=dev).view(*g_shape)
201
+ t_scale_c = torch.from_numpy(scale_channel).to(device=dev, dtype=dt).view(*c_shape)
202
+ t_shift_c = torch.from_numpy(shift_channel).to(device=dev, dtype=dt).view(*c_shape)
203
+
204
+ summary = {
205
+ "scale_global": float(scale_global),
206
+ "shift_global": float(shift_global),
207
+ "scale_channel_min": float(scale_channel.min()),
208
+ "scale_channel_mean": float(scale_channel.mean()),
209
+ "scale_channel_max": float(scale_channel.max()),
210
+ "shift_channel_min": float(shift_channel.min()),
211
+ "shift_channel_mean": float(shift_channel.mean()),
212
+ "shift_channel_max": float(shift_channel.max()),
213
+ }
214
+ return t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary
215
+
216
+
217
+ @torch.no_grad()
218
+ def kl_divergence_per_image(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
219
+ kl_map = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()) # [B, ...]
220
+ return kl_map.float().view(kl_map.shape[0], -1).mean(dim=1) # [B]
221
+
222
+
223
+ def sanitize_filename(name: str) -> str:
224
+ name = name.replace("/", "_").replace("\\", "_").replace(" ", "_")
225
+ return "".join(ch if (ch.isalnum() or ch in "._-") else "_" for ch in name)
226
+
227
+
228
+ # ========================== Датасет ==========================
229
+ class ImageFolderDataset(Dataset):
230
+ def __init__(self, root_dir: str, extensions=(".png", ".jpg", ".jpeg", ".webp"), min_size=1024, crop_size=512, limit=None):
231
+ paths = []
232
+ for root, _, files in os.walk(root_dir):
233
+ for fname in files:
234
+ if fname.lower().endswith(extensions):
235
+ paths.append(os.path.join(root, fname))
236
+ if limit:
237
+ paths = paths[:limit]
238
+
239
+ valid = []
240
+ for p in tqdm(paths, desc="Проверяем файлы"):
241
+ try:
242
+ with Image.open(p) as im:
243
+ im.verify()
244
+ valid.append(p)
245
+ except Exception:
246
+ pass
247
+ if not valid:
248
+ raise RuntimeError(f"Нет валидных изображений в {root_dir}")
249
+ random.shuffle(valid)
250
+ self.paths = valid
251
+ print(f"Найдено {len(self.paths)} изображений")
252
+
253
+ self.transform = Compose([
254
+ Resize(min_size),
255
+ CenterCrop(crop_size),
256
+ ToTensor(), # 0..1, float32
257
+ ])
258
+
259
+ def __len__(self):
260
+ return len(self.paths)
261
+
262
+ def __getitem__(self, idx):
263
+ with Image.open(self.paths[idx]) as img:
264
+ img = img.convert("RGB")
265
+ return self.transform(img)
266
+
267
+
268
+ # ========================== Основное ==========================
269
+ def main():
270
+ torch.set_grad_enabled(False)
271
+ os.makedirs(SAMPLES_DIR, exist_ok=True)
272
+
273
+ dataset = ImageFolderDataset(IMAGE_FOLDER, min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES)
274
+ loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
275
+
276
+ lpips_net = lpips.LPIPS(net="vgg").to(DEVICE).eval()
277
+
278
+ # Загрузка VAE
279
+ vaes: List[Tuple[str, object]] = []
280
+ print("\nЗагрузка VAE...")
281
+ for human_name, vae_class, model_path, subfolder in VAE_LIST:
282
+ try:
283
+ vae = vae_class.from_pretrained(model_path, subfolder=subfolder, torch_dtype=DTYPE)
284
+ vae = vae.to(DEVICE).eval()
285
+ vaes.append((human_name, vae))
286
+ print(f" ✅ {human_name}")
287
+ except Exception as e:
288
+ print(f" ❌ {human_name}: {e}")
289
+
290
+ if not vaes:
291
+ print("Нет успешно загруженных VAE. Выходим.")
292
+ return
293
+
294
+ # Агрегаторы
295
+ per_model_metrics: Dict[str, Dict[str, float]] = {
296
+ name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "kl": 0.0, "count": 0.0}
297
+ for name, _ in vaes
298
+ }
299
+
300
+ buffers_zmodel: Dict[str, List[torch.Tensor]] = {name: [] for name, _ in vaes}
301
+ norm_summaries: Dict[str, Dict[str, float]] = {}
302
+
303
+ # Флаг для сохранения первой картинки
304
+ saved_first_for: Dict[str, bool] = {name: False for name, _ in vaes}
305
+
306
+ for batch_0_1 in tqdm(loader, desc="Батчи"):
307
+ batch_0_1 = batch_0_1.to(DEVICE, torch.float32)
308
+ batch_neg1_1 = to_neg1_1(batch_0_1).to(DTYPE)
309
+
310
+ for model_name, vae in vaes:
311
+ x_in = add_time_dim_if_needed(batch_neg1_1, vae)
312
+
313
+ posterior = vae.encode(x_in).latent_dist
314
+ mu, logvar = posterior.mean, posterior.logvar
315
+
316
+ # Реконструкция (детерминированно)
317
+ z_raw_mode = posterior.mode()
318
+ x_dec = vae.decode(z_raw_mode).sample # [-1, 1]
319
+ x_dec = strip_time_dim_if_possible(x_dec, vae)
320
+ x_rec_0_1 = to_0_1(x_dec.float()).clamp(0, 1)
321
+
322
+ # Латенты для UNet: global -> channelwise
323
+ z_raw_sample = posterior.sample()
324
+ t_shift_g, t_scale_g, t_shift_c, t_scale_c, summary = get_norm_tensors_and_summary(vae, z_raw_sample)
325
+
326
+ if model_name not in norm_summaries:
327
+ norm_summaries[model_name] = summary
328
+
329
+ z_tmp = (z_raw_sample - t_shift_g) * t_scale_g
330
+ z_model = (z_tmp - t_shift_c) * t_scale_c
331
+ z_model = strip_time_dim_if_possible(z_model, vae)
332
+
333
+ buffers_zmodel[model_name].append(z_model.detach().to("cpu", torch.float32))
334
+
335
+ # Сохранить первую картинку (оригинал и реконструкцию) для каждого VAE
336
+ if not saved_first_for[model_name]:
337
+ safe = sanitize_filename(model_name)
338
+ orig_path = os.path.join(SAMPLES_DIR, f"{safe}_original.png")
339
+ dec_path = os.path.join(SAMPLES_DIR, f"{safe}_decoded.png")
340
+ save_image(batch_0_1[0:1].cpu(), orig_path)
341
+ save_image(x_rec_0_1[0:1].cpu(), dec_path)
342
+ saved_first_for[model_name] = True
343
+
344
+ # Метрики по картинкам
345
+ B = batch_0_1.shape[0]
346
+ for i in range(B):
347
+ gt = batch_0_1[i:i+1]
348
+ rec = x_rec_0_1[i:i+1]
349
+
350
+ mse = F.mse_loss(gt, rec).item()
351
+ psnr = safe_psnr(mse)
352
+ lp = float(lpips_net(gt, rec, normalize=True).mean().item())
353
+ edge = sobel_edge_l1(gt, rec)
354
+
355
+ per_model_metrics[model_name]["mse"] += mse
356
+ per_model_metrics[model_name]["psnr"] += psnr
357
+ per_model_metrics[model_name]["lpips"] += lp
358
+ per_model_metrics[model_name]["edge"] += edge
359
+
360
+ # KL per-image
361
+ kl_pi = kl_divergence_per_image(mu, logvar) # [B]
362
+ per_model_metrics[model_name]["kl"] += float(kl_pi.sum().item())
363
+ per_model_metrics[model_name]["count"] += B
364
+
365
+ # Усреднение метрик
366
+ for name in per_model_metrics:
367
+ c = max(1.0, per_model_metrics[name]["count"])
368
+ for k in ["mse", "psnr", "lpips", "edge", "kl"]:
369
+ per_model_metrics[name][k] /= c
370
+
371
+ # Подсчёт статистик латентов и нормальности
372
+ per_model_latent_stats = {}
373
+ for name, _ in vaes:
374
+ if not buffers_zmodel[name]:
375
+ continue
376
+ Z = torch.cat(buffers_zmodel[name], dim=0) # [N, C, H, W]
377
+
378
+ # Глобальные
379
+ z_min = float(Z.min().item())
380
+ z_mean = float(Z.mean().item())
381
+ z_max = float(Z.max().item())
382
+ z_std = float(Z.std(unbiased=True).item())
383
+
384
+ # Пер-канально: skew/kurtosis
385
+ Z_ch = flatten_channels(Z).numpy() # [C, *]
386
+ C = Z_ch.shape[0]
387
+ sk = np.zeros(C, dtype=np.float64)
388
+ ku = np.zeros(C, dtype=np.float64)
389
+ for c in range(C):
390
+ v = Z_ch[c]
391
+ sk[c] = float(skew(v, bias=False))
392
+ ku[c] = float(kurtosis(v, fisher=True, bias=False))
393
+
394
+ skew_min, skew_mean, skew_max = float(sk.min()), float(sk.mean()), float(sk.max())
395
+ kurt_min, kurt_mean, kurt_max = float(ku.min()), float(ku.mean()), float(ku.max())
396
+ mean_abs_skew = float(np.mean(np.abs(sk)))
397
+ mean_abs_kurt = float(np.mean(np.abs(ku)))
398
+
399
+ per_model_latent_stats[name] = {
400
+ "Z_min": z_min, "Z_mean": z_mean, "Z_max": z_max, "Z_std": z_std,
401
+ "skew_min": skew_min, "skew_mean": skew_mean, "skew_max": skew_max,
402
+ "kurt_min": kurt_min, "kurt_mean": kurt_mean, "kurt_max": kurt_max,
403
+ "mean_abs_skew": mean_abs_skew, "mean_abs_kurt": mean_abs_kurt,
404
+ }
405
+
406
+ # Печать параметров нормализации (shift/scale)
407
+ print("\n=== Параметры нормализации ��атентов (как применялись) ===")
408
+ for name, _ in vaes:
409
+ if name not in norm_summaries:
410
+ continue
411
+ s = norm_summaries[name]
412
+ print(
413
+ f"{name:26s} | "
414
+ f"shift_g={s['shift_global']:.6g} scale_g={s['scale_global']:.6g} | "
415
+ f"shift_c[min/mean/max]=[{s['shift_channel_min']:.6g}, {s['shift_channel_mean']:.6g}, {s['shift_channel_max']:.6g}] | "
416
+ f"scale_c[min/mean/max]=[{s['scale_channel_min']:.6g}, {s['scale_channel_mean']:.6g}, {s['scale_channel_max']:.6g}]"
417
+ )
418
+
419
+ # Абсолютные метрики
420
+ print("\n=== Абсолютные метрики реконструкции и латентов ===")
421
+ for name, _ in vaes:
422
+ if name not in per_model_latent_stats:
423
+ continue
424
+ m = per_model_metrics[name]
425
+ s = per_model_latent_stats[name]
426
+ print(
427
+ f"{name:26s} | "
428
+ f"MSE={m['mse']:.3e} PSNR={m['psnr']:.2f} LPIPS={m['lpips']:.3f} Edge={m['edge']:.3f} KL={m['kl']:.3f} | "
429
+ f"Z[min/mean/max/std]=[{s['Z_min']:.3f}, {s['Z_mean']:.3f}, {s['Z_max']:.3f}, {s['Z_std']:.3f}] | "
430
+ f"Skew[min/mean/max]=[{s['skew_min']:.3f}, {s['skew_mean']:.3f}, {s['skew_max']:.3f}] | "
431
+ f"Kurt[min/mean/max]=[{s['kurt_min']:.3f}, {s['kurt_mean']:.3f}, {s['kurt_max']:.3f}]"
432
+ )
433
+
434
+ # Сравнение с первой моделью
435
+ baseline = vaes[0][0]
436
+ print("\n=== Сравнение с первой моделью (проценты) ===")
437
+ print(f"| {'Модель':26s} | {'MSE':>9s} | {'PSNR':>9s} | {'LPIPS':>9s} | {'Edge':>9s} | {'Skew|0':>9s} | {'Kurt|0':>9s} |")
438
+ print(f"|{'-'*28}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|{'-'*11}|")
439
+
440
+ b_m = per_model_metrics[baseline]
441
+ b_s = per_model_latent_stats[baseline]
442
+
443
+ for name, _ in vaes:
444
+ m = per_model_metrics[name]
445
+ s = per_model_latent_stats[name]
446
+
447
+ mse_pct = (b_m["mse"] / max(1e-12, m["mse"])) * 100.0 # меньше лучше
448
+ psnr_pct = (m["psnr"] / max(1e-12, b_m["psnr"])) * 100.0 # больше лучше
449
+ lpips_pct= (b_m["lpips"] / max(1e-12, m["lpips"])) * 100.0 # меньше лучше
450
+ edge_pct = (b_m["edge"] / max(1e-12, m["edge"])) * 100.0 # меньше лучше
451
+
452
+ skew0_pct = (b_s["mean_abs_skew"] / max(1e-12, s["mean_abs_skew"])) * 100.0
453
+ kurt0_pct = (b_s["mean_abs_kurt"] / max(1e-12, s["mean_abs_kurt"])) * 100.0
454
+
455
+ if name == baseline:
456
+ print(f"| {name:26s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} | {'100%':>9s} |")
457
+ else:
458
+ print(f"| {name:26s} | {mse_pct:8.1f}% | {psnr_pct:8.1f}% | {lpips_pct:8.1f}% | {edge_pct:8.1f}% | {skew0_pct:8.1f}% | {kurt0_pct:8.1f}% |")
459
+
460
+ # ========================== Коррекции для последнего VAE + сохранение в JSON ==========================
461
+ last_name = vaes[-1][0]
462
+ if buffers_zmodel[last_name]:
463
+ Z = torch.cat(buffers_zmodel[last_name], dim=0) # [N, C, H, W]
464
+
465
+ # Глобальная коррекция (по всем каналам/пикселям)
466
+ z_mean = float(Z.mean().item())
467
+ z_std = float(Z.std(unbiased=True).item())
468
+ correction_global = {
469
+ "shift": -z_mean,
470
+ "scale": (1.0 / z_std) if z_std > 1e-12 else 1.0
471
+ }
472
+
473
+ # Поканальная коррекция
474
+ Z_ch = flatten_channels(Z) # [C, M]
475
+ ch_means_t = Z_ch.mean(dim=1) # [C]
476
+ ch_stds_t = Z_ch.std(dim=1, unbiased=True) + 1e-12 # [C]
477
+ ch_means = [float(x) for x in ch_means_t.tolist()]
478
+ ch_stds = [float(x) for x in ch_stds_t.tolist()]
479
+
480
+ correction_per_channel = [
481
+ {"shift": float(-m), "scale": float(1.0 / s)}
482
+ for m, s in zip(ch_means, ch_stds)
483
+ ]
484
+
485
+ print(f"\n=== Доп. коррекция для {last_name} (поверх VAE-нормализации) ===")
486
+ print(f"global_correction = {correction_global}")
487
+ print(f"channelwise_means = {ch_means}")
488
+ print(f"channelwise_stds = {ch_stds}")
489
+ print(f"channelwise_correction = {correction_per_channel}")
490
+
491
+ # Сохранение в JSON
492
+ json_path = os.path.join(SAMPLES_DIR, f"{sanitize_filename(last_name)}_correction.json")
493
+ to_save = {
494
+ "model_name": last_name,
495
+ "vae_normalization_summary": norm_summaries.get(last_name, {}),
496
+ "global_correction": correction_global,
497
+ "per_channel_means": ch_means,
498
+ "per_channel_stds": ch_stds,
499
+ "per_channel_correction": correction_per_channel,
500
+ "apply_order": {
501
+ "forward": "z_model -> (z - global_shift)*global_scale -> (per-channel: (z - mean_c)/std_c)",
502
+ "inverse": "z_corr -> (per-channel: z*std_c + mean_c) -> (z/global_scale + global_shift)"
503
+ },
504
+ "note": "Эти коэффициенты рассчитаны по z_model (после встроенных VAE shift/scale), чтобы привести распределение к N(0,1)."
505
+ }
506
+ with open(json_path, "w", encoding="utf-8") as f:
507
+ json.dump(to_save, f, ensure_ascii=False, indent=2)
508
+ print("Corrections JSON saved to:", os.path.abspath(json_path))
509
+
510
+ print("\n✅ Готово. Сэмплы сохранены в:", os.path.abspath(SAMPLES_DIR))
511
+
512
+
513
+ if __name__ == "__main__":
514
+ main()
samples/sample_0.jpg ADDED

Git LFS Details

  • SHA256: cb43df876fea0ab69a3fa63399c378aad4dda308a1534071796834acc26c71a6
  • Pointer size: 130 Bytes
  • Size of remote file: 84.9 kB
samples/sample_1.jpg ADDED

Git LFS Details

  • SHA256: fc0b8542e55bc97fb988441631c9e80543aef8ce0796c6416280282d73da427f
  • Pointer size: 130 Bytes
  • Size of remote file: 75.7 kB
samples/sample_2.jpg ADDED

Git LFS Details

  • SHA256: 6d7969e2ba962645308392a623d1bc8b8573472aae631a68ac2996c31f2dd8af
  • Pointer size: 130 Bytes
  • Size of remote file: 71.2 kB
samples/sample_decoded.jpg ADDED

Git LFS Details

  • SHA256: cb43df876fea0ab69a3fa63399c378aad4dda308a1534071796834acc26c71a6
  • Pointer size: 130 Bytes
  • Size of remote file: 84.9 kB
samples/sample_real.jpg ADDED

Git LFS Details

  • SHA256: b187738cf82a8633e1409e6ed3db35fb5930681957ed8d69ae8cce6da881371f
  • Pointer size: 130 Bytes
  • Size of remote file: 89.9 kB
simple_vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:010d2cb8824a347425be4e41d662b22492965ffb61393621eb1253be8b7fa0ce
3
  size 335311892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5f0a20e403669e880b510514ee575a2a9cb74a1b36ab0e31fc68ef66c2173d7
3
  size 335311892
simple_vae_nightly/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ccd57f2cd9455d6c66ed2fee9396dbb53cbeb675fa0c1fbee87a9b0f94c3de79
3
  size 335311892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b705da7f401289eefa22570514d7c1b9b2f9fd32a71159e2d3d5888f74e41cd
3
  size 335311892
train_sdxl_vae_full.py CHANGED
@@ -58,7 +58,7 @@ device = None # accelerator задаст устройство
58
  # CHANGED: добавлен параметр для полного обучения VAE (а не только декодера).
59
  # Если False — поведение прежнее: учим только decoder.* (up_blocks + mid_block).
60
  # Если True — размораживаем ВСЮ модель и добавляем KL-loss для энкодера.
61
- full_training = False
62
 
63
  # CHANGED: добавлен вес (через долю в нормализаторе) для KL, используется только при full_training=True.
64
  kl_ratio = 0.05 # простая доля для KL в общей смеси (KISS). Игнорируется, если full_training=False.
@@ -66,12 +66,12 @@ kl_ratio = 0.05 # простая доля для KL в общей с
66
  # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
67
  # Итоговые доли в total loss (сумма = 1.0 после нормализации).
68
  loss_ratios = {
69
- "lpips": 0.85,
70
  "edge": 0.05,
71
  "mse": 0.05,
72
  "mae": 0.05,
73
  # CHANGED: заранее добавлен ключ "kl" (по умолчанию 0.0). Если включаем full_training — активируем ниже.
74
- "kl": 0.00,
75
  }
76
  median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
77
 
 
58
  # CHANGED: добавлен параметр для полного обучения VAE (а не только декодера).
59
  # Если False — поведение прежнее: учим только decoder.* (up_blocks + mid_block).
60
  # Если True — размораживаем ВСЮ модель и добавляем KL-loss для энкодера.
61
+ full_training = True
62
 
63
  # CHANGED: добавлен вес (через долю в нормализаторе) для KL, используется только при full_training=True.
64
  kl_ratio = 0.05 # простая доля для KL в общей смеси (KISS). Игнорируется, если full_training=False.
 
66
  # --- Пропорции лоссов и окно медианного нормирования (КОЭФ., не значения) ---
67
  # Итоговые доли в total loss (сумма = 1.0 после нормализации).
68
  loss_ratios = {
69
+ "lpips": 0.80,
70
  "edge": 0.05,
71
  "mse": 0.05,
72
  "mae": 0.05,
73
  # CHANGED: заранее добавлен ключ "kl" (по умолчанию 0.0). Если включаем full_training — активируем ниже.
74
+ "kl": 0.05,
75
  }
76
  median_coeff_steps = 256 # за сколько шагов считать медианные коэффициенты
77
 
train_sdxl_vae_qwen.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+
20
+ from accelerate import Accelerator
21
+ from PIL import Image, UnidentifiedImageError
22
+ from tqdm import tqdm
23
+ import bitsandbytes as bnb
24
+ import wandb
25
+ import lpips # pip install lpips
26
+ from collections import deque
27
+
28
+ # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/png"
30
+ project = "qwen_vae"
31
+ batch_size = 3
32
+ base_learning_rate = 5e-5
33
+ min_learning_rate = 9e-7
34
+ num_epochs = 16
35
+ sample_interval_share = 10
36
+ use_wandb = True
37
+ save_model = True
38
+ use_decay = True
39
+ optimizer_type = "adam8bit"
40
+ dtype = torch.float32
41
+
42
+ model_resolution = 512
43
+ high_resolution = 512
44
+ limit = 0
45
+ save_barrier = 1.03
46
+ warmup_percent = 0.01
47
+ percentile_clipping = 95
48
+ beta2 = 0.97
49
+ eps = 1e-6
50
+ clip_grad_norm = 1.0
51
+ mixed_precision = "no"
52
+ gradient_accumulation_steps = 5
53
+ generated_folder = "samples"
54
+ save_as = "wen_vae_nightly"
55
+ num_workers = 0
56
+ device = None
57
+
58
+ # --- Режимы обучения ---
59
+ # QWEN: учим только декодер
60
+ train_decoder_only = True
61
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
62
+ kl_ratio = 0.05
63
+
64
+ # Доли лоссов
65
+ loss_ratios = {
66
+ "lpips": 0.80,
67
+ "edge": 0.05,
68
+ "mse": 0.10,
69
+ "mae": 0.05,
70
+ "kl": 0.00, # активируем при full_training=True
71
+ }
72
+ median_coeff_steps = 256
73
+
74
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
75
+
76
+ # QWEN: конфиг загрузки модели
77
+ vae_kind = "qwen" # "qwen" или "kl" (обычный)
78
+ vae_model_id = "Qwen/Qwen-Image"
79
+ vae_subfolder = "vae"
80
+
81
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
82
+
83
+ accelerator = Accelerator(
84
+ mixed_precision=mixed_precision,
85
+ gradient_accumulation_steps=gradient_accumulation_steps
86
+ )
87
+ device = accelerator.device
88
+
89
+ # reproducibility
90
+ seed = int(datetime.now().strftime("%Y%m%d"))
91
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
92
+ torch.backends.cudnn.benchmark = False
93
+
94
+ # --------------------------- WandB ---------------------------
95
+ if use_wandb and accelerator.is_main_process:
96
+ wandb.init(project=project, config={
97
+ "batch_size": batch_size,
98
+ "base_learning_rate": base_learning_rate,
99
+ "num_epochs": num_epochs,
100
+ "optimizer_type": optimizer_type,
101
+ "model_resolution": model_resolution,
102
+ "high_resolution": high_resolution,
103
+ "gradient_accumulation_steps": gradient_accumulation_steps,
104
+ "train_decoder_only": train_decoder_only,
105
+ "full_training": full_training,
106
+ "kl_ratio": kl_ratio,
107
+ "vae_kind": vae_kind,
108
+ "vae_model_id": vae_model_id,
109
+ })
110
+
111
+ # --------------------------- VAE ---------------------------
112
+ def is_qwen_vae(vae) -> bool:
113
+ return isinstance(vae, AutoencoderKLQwenImage) or ("Qwen" in vae.__class__.__name__)
114
+
115
+ # загрузка
116
+ if vae_kind == "qwen":
117
+ vae = AutoencoderKLQwenImage.from_pretrained(vae_model_id, subfolder=vae_subfolder)
118
+ else:
119
+ # старое поведение (пример)
120
+ if model_resolution==high_resolution:
121
+ vae = AutoencoderKL.from_pretrained(project)
122
+ else:
123
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
124
+
125
+ vae = vae.to(dtype)
126
+
127
+ # torch.compile (опционально)
128
+ if hasattr(torch, "compile"):
129
+ try:
130
+ vae = torch.compile(vae)
131
+ except Exception as e:
132
+ print(f"[WARN] torch.compile failed: {e}")
133
+
134
+ # --------------------------- Freeze/Unfreeze ---------------------------
135
+ for p in vae.parameters():
136
+ p.requires_grad = False
137
+
138
+ unfrozen_param_names = []
139
+
140
+ if full_training and not train_decoder_only:
141
+ # учим всю модель
142
+ for name, p in vae.named_parameters():
143
+ p.requires_grad = True
144
+ unfrozen_param_names.append(name)
145
+ loss_ratios["kl"] = float(kl_ratio)
146
+ trainable_module = vae
147
+ else:
148
+ # QWEN: учим только декодер (и post_quant_conv — часть декодерного тракта)
149
+ # универсально: всё, что начинается с "decoder." или "post_quant_conv"
150
+ for name, p in vae.named_parameters():
151
+ if name.startswith("decoder.") or name.startswith("post_quant_conv"):
152
+ p.requires_grad = True
153
+ unfrozen_param_names.append(name)
154
+ trainable_module = vae.decoder if hasattr(vae, "decoder") else vae
155
+
156
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
157
+ for nm in unfrozen_param_names[:200]:
158
+ print(" ", nm)
159
+
160
+ # --------------------------- Датасет ---------------------------
161
+ class PngFolderDataset(Dataset):
162
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
163
+ self.root_dir = root_dir
164
+ self.resolution = resolution
165
+ self.paths = []
166
+ for root, _, files in os.walk(root_dir):
167
+ for fname in files:
168
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
169
+ self.paths.append(os.path.join(root, fname))
170
+ if limit:
171
+ self.paths = self.paths[:limit]
172
+ valid = []
173
+ for p in self.paths:
174
+ try:
175
+ with Image.open(p) as im:
176
+ im.verify()
177
+ valid.append(p)
178
+ except (OSError, UnidentifiedImageError):
179
+ continue
180
+ self.paths = valid
181
+ if len(self.paths) == 0:
182
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
183
+ random.shuffle(self.paths)
184
+
185
+ def __len__(self):
186
+ return len(self.paths)
187
+
188
+ def __getitem__(self, idx):
189
+ p = self.paths[idx % len(self.paths)]
190
+ with Image.open(p) as img:
191
+ img = img.convert("RGB")
192
+ if not resize_long_side or resize_long_side <= 0:
193
+ return img
194
+ w, h = img.size
195
+ long = max(w, h)
196
+ if long <= resize_long_side:
197
+ return img
198
+ scale = resize_long_side / float(long)
199
+ new_w = int(round(w * scale))
200
+ new_h = int(round(h * scale))
201
+ return img.resize((new_w, new_h), Image.LANCZOS)
202
+
203
+ def random_crop(img, sz):
204
+ w, h = img.size
205
+ if w < sz or h < sz:
206
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
207
+ x = random.randint(0, max(1, img.width - sz))
208
+ y = random.randint(0, max(1, img.height - sz))
209
+ return img.crop((x, y, x + sz, y + sz))
210
+
211
+ tfm = transforms.Compose([
212
+ transforms.ToTensor(),
213
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
214
+ ])
215
+
216
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
217
+ if len(dataset) < batch_size:
218
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
219
+
220
+ def collate_fn(batch):
221
+ imgs = []
222
+ for img in batch:
223
+ img = random_crop(img, high_resolution)
224
+ imgs.append(tfm(img))
225
+ return torch.stack(imgs)
226
+
227
+ dataloader = DataLoader(
228
+ dataset,
229
+ batch_size=batch_size,
230
+ shuffle=True,
231
+ collate_fn=collate_fn,
232
+ num_workers=num_workers,
233
+ pin_memory=True,
234
+ drop_last=True
235
+ )
236
+
237
+ # --------------------------- Оптимизатор ---------------------------
238
+ def get_param_groups(module, weight_decay=0.001):
239
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
240
+ decay_params, no_decay_params = [], []
241
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
242
+ if not p.requires_grad:
243
+ continue
244
+ if any(nd in n for nd in no_decay):
245
+ no_decay_params.append(p)
246
+ else:
247
+ decay_params.append(p)
248
+ return [
249
+ {"params": decay_params, "weight_decay": weight_decay},
250
+ {"params": no_decay_params, "weight_decay": 0.0},
251
+ ]
252
+
253
+ def create_optimizer(name, param_groups):
254
+ if name == "adam8bit":
255
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
256
+ raise ValueError(name)
257
+
258
+ param_groups = get_param_groups(trainable_module, weight_decay=0.001)
259
+ optimizer = create_optimizer(optimizer_type, param_groups)
260
+
261
+ # --------------------------- LR schedule ---------------------------
262
+ batches_per_epoch = len(dataloader)
263
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
264
+ total_steps = steps_per_epoch * num_epochs
265
+
266
+ def lr_lambda(step):
267
+ if not use_decay:
268
+ return 1.0
269
+ x = float(step) / float(max(1, total_steps))
270
+ warmup = float(warmup_percent)
271
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
272
+ if x < warmup:
273
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
274
+ decay_ratio = (x - warmup) / (1.0 - warmup)
275
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
276
+
277
+ scheduler = LambdaLR(optimizer, lr_lambda)
278
+
279
+ # Подготовка
280
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
281
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
282
+
283
+ # --------------------------- LPIPS и вспомогательные ---------------------------
284
+ _lpips_net = None
285
+ def _get_lpips():
286
+ global _lpips_net
287
+ if _lpips_net is None:
288
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
289
+ return _lpips_net
290
+
291
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
292
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
293
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
294
+ C = x.shape[1]
295
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
296
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
297
+ gx = F.conv2d(x, kx, padding=1, groups=C)
298
+ gy = F.conv2d(x, ky, padding=1, groups=C)
299
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
300
+
301
+ class MedianLossNormalizer:
302
+ def __init__(self, desired_ratios: dict, window_steps: int):
303
+ s = sum(desired_ratios.values())
304
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
305
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
306
+ self.window = window_steps
307
+
308
+ def update_and_total(self, abs_losses: dict):
309
+ for k, v in abs_losses.items():
310
+ if k in self.buffers:
311
+ self.buffers[k].append(float(v.detach().abs().cpu()))
312
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
313
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
314
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
315
+ return total, coeffs, meds
316
+
317
+ if full_training and not train_decoder_only:
318
+ loss_ratios["kl"] = float(kl_ratio)
319
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
320
+
321
+ # --------------------------- Сэмплы ---------------------------
322
+ @torch.no_grad()
323
+ def get_fixed_samples(n=3):
324
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
325
+ pil_imgs = [dataset[i] for i in idx]
326
+ tensors = []
327
+ for img in pil_imgs:
328
+ img = random_crop(img, high_resolution)
329
+ tensors.append(tfm(img))
330
+ return torch.stack(tensors).to(accelerator.device, dtype)
331
+
332
+ fixed_samples = get_fixed_samples()
333
+
334
+ @torch.no_grad()
335
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
336
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
337
+ return Image.fromarray(arr)
338
+
339
+ @torch.no_grad()
340
+ def generate_and_save_samples(step=None):
341
+ try:
342
+ temp_vae = accelerator.unwrap_model(vae).eval()
343
+ lpips_net = _get_lpips()
344
+ with torch.no_grad():
345
+ orig_high = fixed_samples
346
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
347
+ model_dtype = next(temp_vae.parameters()).dtype
348
+ orig_low = orig_low.to(dtype=model_dtype)
349
+
350
+ # QWEN: добавляем T=1 на encode/decode и снимаем при сравнении
351
+ if is_qwen_vae(temp_vae):
352
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
353
+ enc = temp_vae.encode(x_in)
354
+ latents_mean = enc.latent_dist.mean
355
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
356
+ rec = dec.squeeze(2) # [B,3,H,W]
357
+ else:
358
+ enc = temp_vae.encode(orig_low)
359
+ latents_mean = enc.latent_dist.mean
360
+ rec = temp_vae.decode(latents_mean).sample
361
+
362
+ if rec.shape[-2:] != orig_high.shape[-2:]:
363
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
364
+
365
+ first_real = _to_pil_uint8(orig_high[0])
366
+ first_dec = _to_pil_uint8(rec[0])
367
+ first_real.save(f"{generated_folder}/sample_real.jpg", quality=95)
368
+ first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95)
369
+
370
+ for i in range(rec.shape[0]):
371
+ _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95)
372
+
373
+ lpips_scores = []
374
+ for i in range(rec.shape[0]):
375
+ orig_full = orig_high[i:i+1].to(torch.float32)
376
+ rec_full = rec[i:i+1].to(torch.float32)
377
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
378
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
379
+ lpips_val = lpips_net(orig_full, rec_full).item()
380
+ lpips_scores.append(lpips_val)
381
+ avg_lpips = float(np.mean(lpips_scores))
382
+
383
+ if use_wandb and accelerator.is_main_process:
384
+ wandb.log({"lpips_mean": avg_lpips}, step=step)
385
+ finally:
386
+ gc.collect()
387
+ torch.cuda.empty_cache()
388
+
389
+ if accelerator.is_main_process and save_model:
390
+ print("Генерация сэмплов до старта обучения...")
391
+ generate_and_save_samples(0)
392
+
393
+ accelerator.wait_for_everyone()
394
+
395
+ # --------------------------- Тренировка ---------------------------
396
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
397
+ global_step = 0
398
+ min_loss = float("inf")
399
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
400
+
401
+ for epoch in range(num_epochs):
402
+ vae.train()
403
+ batch_losses, batch_grads = [], []
404
+ track_losses = {k: [] for k in loss_ratios.keys()}
405
+
406
+ for imgs in dataloader:
407
+ with accelerator.accumulate(vae):
408
+ imgs = imgs.to(accelerator.device)
409
+
410
+ if high_resolution != model_resolution:
411
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
412
+ else:
413
+ imgs_low = imgs
414
+
415
+ model_dtype = next(vae.parameters()).dtype
416
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
417
+
418
+ # QWEN: encode/decode с T=1
419
+ if is_qwen_vae(vae):
420
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
421
+ enc = vae.encode(x_in)
422
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
423
+ dec = vae.decode(latents).sample # [B,3,1,H,W]
424
+ rec = dec.squeeze(2) # [B,3,H,W]
425
+ else:
426
+ enc = vae.encode(imgs_low_model)
427
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
428
+ rec = vae.decode(latents).sample
429
+
430
+ if rec.shape[-2:] != imgs.shape[-2:]:
431
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
432
+
433
+ rec_f32 = rec.to(torch.float32)
434
+ imgs_f32 = imgs.to(torch.float32)
435
+
436
+ abs_losses = {
437
+ "mae": F.l1_loss(rec_f32, imgs_f32),
438
+ "mse": F.mse_loss(rec_f32, imgs_f32),
439
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
440
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
441
+ }
442
+
443
+ if full_training and not train_decoder_only:
444
+ mean = enc.latent_dist.mean
445
+ logvar = enc.latent_dist.logvar
446
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
447
+ abs_losses["kl"] = kl
448
+ else:
449
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
450
+
451
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
452
+
453
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
454
+ raise RuntimeError("NaN/Inf loss")
455
+
456
+ accelerator.backward(total_loss)
457
+
458
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
459
+ if accelerator.sync_gradients:
460
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
461
+ optimizer.step()
462
+ scheduler.step()
463
+ optimizer.zero_grad(set_to_none=True)
464
+ global_step += 1
465
+ progress.update(1)
466
+
467
+ if accelerator.is_main_process:
468
+ try:
469
+ current_lr = optimizer.param_groups[0]["lr"]
470
+ except Exception:
471
+ current_lr = scheduler.get_last_lr()[0]
472
+
473
+ batch_losses.append(total_loss.detach().item())
474
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
475
+ for k, v in abs_losses.items():
476
+ track_losses[k].append(float(v.detach().item()))
477
+
478
+ if use_wandb and accelerator.sync_gradients:
479
+ log_dict = {
480
+ "total_loss": float(total_loss.detach().item()),
481
+ "learning_rate": current_lr,
482
+ "epoch": epoch,
483
+ "grad_norm": batch_grads[-1],
484
+ "mode/train_decoder_only": int(train_decoder_only),
485
+ "mode/full_training": int(full_training),
486
+ }
487
+ for k, v in abs_losses.items():
488
+ log_dict[f"loss_{k}"] = float(v.detach().item())
489
+ for k in coeffs:
490
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
491
+ log_dict[f"median_{k}"] = float(meds[k])
492
+ wandb.log(log_dict, step=global_step)
493
+
494
+ if global_step > 0 and global_step % sample_interval == 0:
495
+ if accelerator.is_main_process:
496
+ generate_and_save_samples(global_step)
497
+ accelerator.wait_for_everyone()
498
+
499
+ n_micro = sample_interval * gradient_accumulation_steps
500
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
501
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
502
+
503
+ if accelerator.is_main_process:
504
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
505
+ if save_model and avg_loss < min_loss * save_barrier:
506
+ min_loss = avg_loss
507
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
508
+ if use_wandb:
509
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
510
+
511
+ if accelerator.is_main_process:
512
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
513
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
514
+ if use_wandb:
515
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
516
+
517
+ # --------------------------- Финальное сохранение ---------------------------
518
+ if accelerator.is_main_process:
519
+ print("Training finished – saving final model")
520
+ if save_model:
521
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
522
+
523
+ accelerator.free_memory()
524
+ if torch.distributed.is_initialized():
525
+ torch.distributed.destroy_process_group()
526
+ print("Готово!")
vaetest/001_all.png ADDED

Git LFS Details

  • SHA256: 7b7a8098d61a1525db5ce3eaa5cd50e132a5f846a6053f789edd4801e37b0d18
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
vaetest/001_decoded_FLUX.1_schnell_vae.png ADDED

Git LFS Details

  • SHA256: 21b88d5045d1b9c0a3785d5b96a6dcd225ea92921143fd4a2fe5daabf060ccae
  • Pointer size: 131 Bytes
  • Size of remote file: 494 kB
vaetest/001_decoded_simple_vae.png ADDED

Git LFS Details

  • SHA256: 816836033774e8ad18e4853763cf2f040db44ab08de261ef3b1be95931d7f28d
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
vaetest/001_decoded_simple_vae2.png ADDED

Git LFS Details

  • SHA256: 903009ea3ea4344918cf79b06ffc8ba55402275a1fa41ed7d118086c49ec9dd4
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
vaetest/001_decoded_simple_vae_nightly.png ADDED

Git LFS Details

  • SHA256: 1246db8b7d3e6a36199dbd83b8532a28917df0528f016429cf499179d8d2bcf4
  • Pointer size: 131 Bytes
  • Size of remote file: 483 kB
vaetest/001_orig.png ADDED

Git LFS Details

  • SHA256: 12c632e0aecc1925185142be560a65e204e12e7167dbcc1a49e3017b371638fe
  • Pointer size: 131 Bytes
  • Size of remote file: 464 kB
vaetest/002_all.png ADDED

Git LFS Details

  • SHA256: 058a39dde15443d7547a4944df594d285b465c5e8817225c669ea04adf4d6c01
  • Pointer size: 132 Bytes
  • Size of remote file: 1.74 MB
vaetest/002_decoded_FLUX.1_schnell_vae.png ADDED

Git LFS Details

  • SHA256: f5bed96d137dbaa377cebaaceb9919a2d8d51a1793759b643e2d772e4fe33785
  • Pointer size: 131 Bytes
  • Size of remote file: 380 kB
vaetest/002_decoded_simple_vae.png ADDED

Git LFS Details

  • SHA256: 808a6138bc48cd85f752d2bcabd1e7795c89df6794aab686a32e5c9ac2f7214f
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
vaetest/002_decoded_simple_vae2.png ADDED

Git LFS Details

  • SHA256: 4f67867e41530bcb498a46b963f65db5c603bd70db042b294450d337a9ddb651
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
vaetest/002_decoded_simple_vae_nightly.png ADDED

Git LFS Details

  • SHA256: 7905cccc22fc84bddcf712d9615a069b2dd3999a17b37b747b08d2db9c8719b7
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
vaetest/002_orig.png ADDED

Git LFS Details

  • SHA256: 177599cb0d77d66058bb53146156de2d9654ac98255ae2005f7aadbebaee0fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB