import os import torch import torch.nn.functional as F import lpips from PIL import Image, UnidentifiedImageError from tqdm import tqdm from torch.utils.data import Dataset, DataLoader from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage from diffusers import AutoencoderKL, AsymmetricAutoencoderKL, AutoencoderKLWan,AutoencoderKLLTXVideo import random # --------------------------- Параметры --------------------------- DEVICE = "cuda" DTYPE = torch.float16 IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip MIN_SIZE = 1280 CROP_SIZE = 512 BATCH_SIZE = 10 MAX_IMAGES = 0 NUM_WORKERS = 4 NUM_SAMPLES_TO_SAVE = 2 # Сколько примеров сохранить (0 - не сохранять) SAMPLES_FOLDER = "vaetest" # Список VAE для тестирования VAE_LIST = [ # ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"), # ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None), # ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None), # ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None), # ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None), # ("AiArtLab/sdxlvae_nightly", AutoencoderKL, "AiArtLab/sdxl_vae", "vae_nightly"), # ("Lightricks/LTX-Video", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"), # ("Wan2.2-TI2V-5B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"), # ("Wan2.2-T2V-A14B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"), # ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"), ("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"), ("simple_vae", AutoencoderKL, "AiArtLab/simplevae", "vae"), ("simple_vae2", AutoencoderKL, "AiArtLab/simplevae", None), ("simple_vae_nightly", AutoencoderKL, "/workspace/sdxl_vae/simple_vae_nightly", None), ] # --------------------------- Sobel Edge Detection --------------------------- # Определяем фильтры Собеля глобально _sobel_kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) _sobel_ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3) def sobel_edges(x: torch.Tensor) -> torch.Tensor: """ Вычисляет карту границ с помощью оператора Собеля x: [B,C,H,W] в диапазоне [-1,1] Возвращает: [B,C,H,W] - магнитуда градиента """ C = x.shape[1] kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1) ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1) gx = F.conv2d(x, kx, padding=1, groups=C) gy = F.conv2d(x, ky, padding=1, groups=C) return torch.sqrt(gx * gx + gy * gy + 1e-12) def compute_edge_loss(real: torch.Tensor, fake: torch.Tensor) -> float: """ Вычисляет Edge Loss между реальным и сгенерированным изображением real, fake: [B,C,H,W] в диапазоне [0,1] Возвращает: скалярное значение loss """ # Конвертируем в [-1,1] для sobel_edges real_norm = real * 2 - 1 fake_norm = fake * 2 - 1 # Получаем карты границ edges_real = sobel_edges(real_norm) edges_fake = sobel_edges(fake_norm) # L1 loss между картами границ return F.l1_loss(edges_fake, edges_real).item() # --------------------------- Dataset --------------------------- class ImageFolderDataset(Dataset): def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None): self.root_dir = root_dir self.min_size = min_size self.crop_size = crop_size self.paths = [] print("Сканирование папки...") for root, _, files in os.walk(root_dir): for fname in files: if fname.lower().endswith(extensions): self.paths.append(os.path.join(root, fname)) if limit: self.paths = self.paths[:limit] print("Проверка изображений...") valid = [] for p in tqdm(self.paths, desc="Проверка"): try: with Image.open(p) as im: im.verify() valid.append(p) except: continue self.paths = valid if len(self.paths) == 0: raise RuntimeError(f"Не найдено валидных изображений в {root_dir}") random.shuffle(self.paths) print(f"Найдено {len(self.paths)} изображений") self.transform = Compose([ Resize(min_size, interpolation=Image.LANCZOS), CenterCrop(crop_size), ToTensor(), ]) def __len__(self): return len(self.paths) def __getitem__(self, idx): path = self.paths[idx] with Image.open(path) as img: img = img.convert("RGB") return self.transform(img) # --------------------------- Функции --------------------------- def process(x): return x * 2 - 1 def deprocess(x): return x * 0.5 + 0.5 def _sanitize_name(name: str) -> str: return name.replace('/', '_').replace('-', '_') # --------------------------- Анализ VAE --------------------------- @torch.no_grad() def tensor_stats(name, x: torch.Tensor): finite = torch.isfinite(x) fin_ratio = finite.float().mean().item() x_f = x[finite] minv = x_f.min().item() if x_f.numel() else float('nan') maxv = x_f.max().item() if x_f.numel() else float('nan') mean = x_f.mean().item() if x_f.numel() else float('nan') std = x_f.std().item() if x_f.numel() else float('nan') big = (x_f.abs() > 20).float().mean().item() if x_f.numel() else float('nan') print(f"[{name}] shape={tuple(x.shape)} dtype={x.dtype} " f"finite={fin_ratio:.6f} min={minv:.4g} max={maxv:.4g} mean={mean:.4g} std={std:.4g} |x|>20={big:.6f}") @torch.no_grad() def analyze_vae_latents(vae, name, images): """ images: [B,3,H,W] в [-1,1] """ try: enc = vae.encode(images) if hasattr(enc, "latent_dist"): mu, logvar = enc.latent_dist.mean, enc.latent_dist.logvar z = enc.latent_dist.sample() else: mu, logvar = enc[0], enc[1] z = mu tensor_stats(f"{name}.mu", mu) tensor_stats(f"{name}.logvar", logvar) tensor_stats(f"{name}.z_raw", z) sf = getattr(vae.config, "scaling_factor", 1.0) z_scaled = z * sf tensor_stats(f"{name}.z_scaled(x{sf})", z_scaled) except Exception as e: print(f"⚠️ Ошибка анализа VAE {name}: {e}") # --------------------------- Основной код --------------------------- if __name__ == "__main__": if NUM_SAMPLES_TO_SAVE > 0: os.makedirs(SAMPLES_FOLDER, exist_ok=True) dataset = ImageFolderDataset( IMAGE_FOLDER, extensions=('.png',), min_size=MIN_SIZE, crop_size=CROP_SIZE, limit=MAX_IMAGES ) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, drop_last=False ) lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False) print("\nЗагрузка VAE моделей...") vaes = [] names = [] for name, vae_class, model_path, subfolder in VAE_LIST: try: print(f" Загружаю {name}...") # Исправлена загрузка для variant if "sdxs" in model_path: vae = vae_class.from_pretrained(model_path, subfolder=subfolder, variant="fp16") else: vae = vae_class.from_pretrained(model_path, subfolder=subfolder) vae = vae.to(DEVICE, DTYPE).eval() vaes.append(vae) names.append(name) except Exception as e: print(f" ❌ Ошибка загрузки {name}: {e}") print("\nОценка метрик...") results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "count": 0} for name in names} to_pil = ToPILImage() # >>>>>>>> ОСНОВНЫЕ ИЗМЕНЕНИЯ ЗДЕСЬ (KISS) <<<<<<<< with torch.no_grad(): images_saved = 0 # считаем именно КОЛ-ВО ИЗОБРАЖЕНИЙ, а не сохранённых файлов for batch in tqdm(dataloader, desc="Обработка батчей"): batch = batch.to(DEVICE) # [B,3,H,W] в [0,1] test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера # >>> Анализируем латенты каждой VAE на первой итерации if images_saved == 0: # только для первого батча, чтобы не засорять лог for vae, name in zip(vaes, names): analyze_vae_latents(vae, name, test_inp) # 1) считаем реконструкции для всех VAE на весь батч recon_list = [] for vae, name in zip(vaes, names): test_inp_vae = test_inp # локальная копия #if name == "Wan2.2-T2V-A14B-Diffusers" and test_inp_vae.ndim == 4: if (isinstance(vae, AutoencoderKLWan) or isinstance(vae, AutoencoderKLLTXVideo)) and test_inp_vae.ndim == 4: test_inp_vae = test_inp_vae.unsqueeze(2) # только для Wan latent = vae.encode(test_inp_vae).latent_dist.mode() dec = vae.decode(latent).sample.float() if dec.ndim == 5: dec = dec.squeeze(2) recon = deprocess(dec).clamp(0.0, 1.0) recon_list.append(recon) # 2) обновляем метрики (по каждой VAE) for recon, name in zip(recon_list, names): for i in range(batch.shape[0]): img_orig = batch[i:i+1] img_recon = recon[i:i+1] mse = F.mse_loss(img_orig, img_recon).item() psnr = 10 * torch.log10(1 / torch.tensor(mse)).item() lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item() edge_loss = compute_edge_loss(img_orig, img_recon) results[name]["mse"] += mse results[name]["psnr"] += psnr results[name]["lpips"] += lpips_val results[name]["edge"] += edge_loss results[name]["count"] += 1 # 3) сохраняем ровно NUM_SAMPLES_TO_SAVE изображений (orig + все VAE + общий коллаж) if NUM_SAMPLES_TO_SAVE > 0: for i in range(batch.shape[0]): if images_saved >= NUM_SAMPLES_TO_SAVE: break idx_str = f"{images_saved + 1:03d}" # original orig_pil = to_pil(batch[i].detach().float().cpu()) orig_pil.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_orig.png")) # per-VAE decodes tiles = [orig_pil] for recon, name in zip(recon_list, names): recon_pil = to_pil(recon[i].detach().cpu()) recon_pil.save(os.path.join( SAMPLES_FOLDER, f"{idx_str}_decoded_{_sanitize_name(name)}.png" )) tiles.append(recon_pil) # общий коллаж: [orig | vae1 | vae2 | ...] collage_w = CROP_SIZE * len(tiles) collage_h = CROP_SIZE collage = Image.new("RGB", (collage_w, collage_h)) x = 0 for tile in tiles: collage.paste(tile, (x, 0)) x += CROP_SIZE collage.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_all.png")) images_saved += 1 # Усреднение результатов for name in names: count = results[name]["count"] results[name]["mse"] /= count results[name]["psnr"] /= count results[name]["lpips"] /= count results[name]["edge"] /= count # Вывод абсолютных значений print("\n=== Абсолютные значения ===") for name in names: print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, " f"LPIPS: {results[name]['lpips']:.4f}, Edge: {results[name]['edge']:.4f}") # Вывод таблицы с процентами print("\n=== Сравнение с первой моделью (%) ===") print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} | {'Edge':>10s} |") print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|") baseline = names[0] for name in names: # Для MSE, LPIPS и Edge: меньше = лучше, поэтому инвертируем mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100 # Для PSNR: больше = лучше psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100 # Для LPIPS и Edge: меньше = лучше lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100 edge_pct = (results[baseline]["edge"] / results[name]["edge"]) * 100 if name == baseline: print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |") else: print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | " f"{f'{lpips_pct:.1f}%':>10s} | {f'{edge_pct:.1f}%':>10s} |") print("\n✅ Готово!")