sdxl_vae / eval_alchemist.py
recoilme's picture
qwen
7434657
import os
import torch
import torch.nn.functional as F
import lpips
from PIL import Image, UnidentifiedImageError
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop,ToPILImage
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL, AutoencoderKLWan,AutoencoderKLLTXVideo
import random
# --------------------------- Параметры ---------------------------
DEVICE = "cuda"
DTYPE = torch.float16
IMAGE_FOLDER = "/workspace/alchemist" #wget https://huggingface.co/datasets/AiArtLab/alchemist/resolve/main/alchemist.zip
MIN_SIZE = 1280
CROP_SIZE = 512
BATCH_SIZE = 10
MAX_IMAGES = 0
NUM_WORKERS = 4
NUM_SAMPLES_TO_SAVE = 2 # Сколько примеров сохранить (0 - не сохранять)
SAMPLES_FOLDER = "vaetest"
# Список VAE для тестирования
VAE_LIST = [
# ("stable-diffusion-v1-5/stable-diffusion-v1-5", AutoencoderKL, "stable-diffusion-v1-5/stable-diffusion-v1-5", "vae"),
# ("cross-attention/asymmetric-autoencoder-kl-x-1-5", AsymmetricAutoencoderKL, "cross-attention/asymmetric-autoencoder-kl-x-1-5", None),
# ("madebyollin/sdxl-vae-fp16", AutoencoderKL, "madebyollin/sdxl-vae-fp16-fix", None),
# ("KBlueLeaf/EQ-SDXL-VAE", AutoencoderKL, "KBlueLeaf/EQ-SDXL-VAE", None),
# ("AiArtLab/sdxl_vae", AutoencoderKL, "AiArtLab/sdxl_vae", None),
# ("AiArtLab/sdxlvae_nightly", AutoencoderKL, "AiArtLab/sdxl_vae", "vae_nightly"),
# ("Lightricks/LTX-Video", AutoencoderKLLTXVideo, "Lightricks/LTX-Video", "vae"),
# ("Wan2.2-TI2V-5B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-TI2V-5B-Diffusers", "vae"),
# ("Wan2.2-T2V-A14B-Diffusers", AutoencoderKLWan, "Wan-AI/Wan2.2-T2V-A14B-Diffusers", "vae"),
# ("AiArtLab/sdxs", AutoencoderKL, "AiArtLab/sdxs", "vae"),
("FLUX.1-schnell-vae", AutoencoderKL, "black-forest-labs/FLUX.1-schnell", "vae"),
("simple_vae", AutoencoderKL, "AiArtLab/simplevae", "vae"),
("simple_vae2", AutoencoderKL, "AiArtLab/simplevae", None),
("simple_vae_nightly", AutoencoderKL, "/workspace/sdxl_vae/simple_vae_nightly", None),
]
# --------------------------- Sobel Edge Detection ---------------------------
# Определяем фильтры Собеля глобально
_sobel_kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
_sobel_ky = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
def sobel_edges(x: torch.Tensor) -> torch.Tensor:
"""
Вычисляет карту границ с помощью оператора Собеля
x: [B,C,H,W] в диапазоне [-1,1]
Возвращает: [B,C,H,W] - магнитуда градиента
"""
C = x.shape[1]
kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
gx = F.conv2d(x, kx, padding=1, groups=C)
gy = F.conv2d(x, ky, padding=1, groups=C)
return torch.sqrt(gx * gx + gy * gy + 1e-12)
def compute_edge_loss(real: torch.Tensor, fake: torch.Tensor) -> float:
"""
Вычисляет Edge Loss между реальным и сгенерированным изображением
real, fake: [B,C,H,W] в диапазоне [0,1]
Возвращает: скалярное значение loss
"""
# Конвертируем в [-1,1] для sobel_edges
real_norm = real * 2 - 1
fake_norm = fake * 2 - 1
# Получаем карты границ
edges_real = sobel_edges(real_norm)
edges_fake = sobel_edges(fake_norm)
# L1 loss между картами границ
return F.l1_loss(edges_fake, edges_real).item()
# --------------------------- Dataset ---------------------------
class ImageFolderDataset(Dataset):
def __init__(self, root_dir, extensions=('.png',), min_size=1024, crop_size=512, limit=None):
self.root_dir = root_dir
self.min_size = min_size
self.crop_size = crop_size
self.paths = []
print("Сканирование папки...")
for root, _, files in os.walk(root_dir):
for fname in files:
if fname.lower().endswith(extensions):
self.paths.append(os.path.join(root, fname))
if limit:
self.paths = self.paths[:limit]
print("Проверка изображений...")
valid = []
for p in tqdm(self.paths, desc="Проверка"):
try:
with Image.open(p) as im:
im.verify()
valid.append(p)
except:
continue
self.paths = valid
if len(self.paths) == 0:
raise RuntimeError(f"Не найдено валидных изображений в {root_dir}")
random.shuffle(self.paths)
print(f"Найдено {len(self.paths)} изображений")
self.transform = Compose([
Resize(min_size, interpolation=Image.LANCZOS),
CenterCrop(crop_size),
ToTensor(),
])
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
with Image.open(path) as img:
img = img.convert("RGB")
return self.transform(img)
# --------------------------- Функции ---------------------------
def process(x):
return x * 2 - 1
def deprocess(x):
return x * 0.5 + 0.5
def _sanitize_name(name: str) -> str:
return name.replace('/', '_').replace('-', '_')
# --------------------------- Анализ VAE ---------------------------
@torch.no_grad()
def tensor_stats(name, x: torch.Tensor):
finite = torch.isfinite(x)
fin_ratio = finite.float().mean().item()
x_f = x[finite]
minv = x_f.min().item() if x_f.numel() else float('nan')
maxv = x_f.max().item() if x_f.numel() else float('nan')
mean = x_f.mean().item() if x_f.numel() else float('nan')
std = x_f.std().item() if x_f.numel() else float('nan')
big = (x_f.abs() > 20).float().mean().item() if x_f.numel() else float('nan')
print(f"[{name}] shape={tuple(x.shape)} dtype={x.dtype} "
f"finite={fin_ratio:.6f} min={minv:.4g} max={maxv:.4g} mean={mean:.4g} std={std:.4g} |x|>20={big:.6f}")
@torch.no_grad()
def analyze_vae_latents(vae, name, images):
"""
images: [B,3,H,W] в [-1,1]
"""
try:
enc = vae.encode(images)
if hasattr(enc, "latent_dist"):
mu, logvar = enc.latent_dist.mean, enc.latent_dist.logvar
z = enc.latent_dist.sample()
else:
mu, logvar = enc[0], enc[1]
z = mu
tensor_stats(f"{name}.mu", mu)
tensor_stats(f"{name}.logvar", logvar)
tensor_stats(f"{name}.z_raw", z)
sf = getattr(vae.config, "scaling_factor", 1.0)
z_scaled = z * sf
tensor_stats(f"{name}.z_scaled(x{sf})", z_scaled)
except Exception as e:
print(f"⚠️ Ошибка анализа VAE {name}: {e}")
# --------------------------- Основной код ---------------------------
if __name__ == "__main__":
if NUM_SAMPLES_TO_SAVE > 0:
os.makedirs(SAMPLES_FOLDER, exist_ok=True)
dataset = ImageFolderDataset(
IMAGE_FOLDER,
extensions=('.png',),
min_size=MIN_SIZE,
crop_size=CROP_SIZE,
limit=MAX_IMAGES
)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
drop_last=False
)
lpips_net = lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
print("\nЗагрузка VAE моделей...")
vaes = []
names = []
for name, vae_class, model_path, subfolder in VAE_LIST:
try:
print(f" Загружаю {name}...")
# Исправлена загрузка для variant
if "sdxs" in model_path:
vae = vae_class.from_pretrained(model_path, subfolder=subfolder, variant="fp16")
else:
vae = vae_class.from_pretrained(model_path, subfolder=subfolder)
vae = vae.to(DEVICE, DTYPE).eval()
vaes.append(vae)
names.append(name)
except Exception as e:
print(f" ❌ Ошибка загрузки {name}: {e}")
print("\nОценка метрик...")
results = {name: {"mse": 0.0, "psnr": 0.0, "lpips": 0.0, "edge": 0.0, "count": 0} for name in names}
to_pil = ToPILImage()
# >>>>>>>> ОСНОВНЫЕ ИЗМЕНЕНИЯ ЗДЕСЬ (KISS) <<<<<<<<
with torch.no_grad():
images_saved = 0 # считаем именно КОЛ-ВО ИЗОБРАЖЕНИЙ, а не сохранённых файлов
for batch in tqdm(dataloader, desc="Обработка батчей"):
batch = batch.to(DEVICE) # [B,3,H,W] в [0,1]
test_inp = process(batch).to(DTYPE) # [-1,1] для энкодера
# >>> Анализируем латенты каждой VAE на первой итерации
if images_saved == 0: # только для первого батча, чтобы не засорять лог
for vae, name in zip(vaes, names):
analyze_vae_latents(vae, name, test_inp)
# 1) считаем реконструкции для всех VAE на весь батч
recon_list = []
for vae, name in zip(vaes, names):
test_inp_vae = test_inp # локальная копия
#if name == "Wan2.2-T2V-A14B-Diffusers" and test_inp_vae.ndim == 4:
if (isinstance(vae, AutoencoderKLWan) or isinstance(vae, AutoencoderKLLTXVideo)) and test_inp_vae.ndim == 4:
test_inp_vae = test_inp_vae.unsqueeze(2) # только для Wan
latent = vae.encode(test_inp_vae).latent_dist.mode()
dec = vae.decode(latent).sample.float()
if dec.ndim == 5:
dec = dec.squeeze(2)
recon = deprocess(dec).clamp(0.0, 1.0)
recon_list.append(recon)
# 2) обновляем метрики (по каждой VAE)
for recon, name in zip(recon_list, names):
for i in range(batch.shape[0]):
img_orig = batch[i:i+1]
img_recon = recon[i:i+1]
mse = F.mse_loss(img_orig, img_recon).item()
psnr = 10 * torch.log10(1 / torch.tensor(mse)).item()
lpips_val = lpips_net(img_orig, img_recon, normalize=True).mean().item()
edge_loss = compute_edge_loss(img_orig, img_recon)
results[name]["mse"] += mse
results[name]["psnr"] += psnr
results[name]["lpips"] += lpips_val
results[name]["edge"] += edge_loss
results[name]["count"] += 1
# 3) сохраняем ровно NUM_SAMPLES_TO_SAVE изображений (orig + все VAE + общий коллаж)
if NUM_SAMPLES_TO_SAVE > 0:
for i in range(batch.shape[0]):
if images_saved >= NUM_SAMPLES_TO_SAVE:
break
idx_str = f"{images_saved + 1:03d}"
# original
orig_pil = to_pil(batch[i].detach().float().cpu())
orig_pil.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_orig.png"))
# per-VAE decodes
tiles = [orig_pil]
for recon, name in zip(recon_list, names):
recon_pil = to_pil(recon[i].detach().cpu())
recon_pil.save(os.path.join(
SAMPLES_FOLDER, f"{idx_str}_decoded_{_sanitize_name(name)}.png"
))
tiles.append(recon_pil)
# общий коллаж: [orig | vae1 | vae2 | ...]
collage_w = CROP_SIZE * len(tiles)
collage_h = CROP_SIZE
collage = Image.new("RGB", (collage_w, collage_h))
x = 0
for tile in tiles:
collage.paste(tile, (x, 0))
x += CROP_SIZE
collage.save(os.path.join(SAMPLES_FOLDER, f"{idx_str}_all.png"))
images_saved += 1
# Усреднение результатов
for name in names:
count = results[name]["count"]
results[name]["mse"] /= count
results[name]["psnr"] /= count
results[name]["lpips"] /= count
results[name]["edge"] /= count
# Вывод абсолютных значений
print("\n=== Абсолютные значения ===")
for name in names:
print(f"{name:30s}: MSE: {results[name]['mse']:.3e}, PSNR: {results[name]['psnr']:.4f}, "
f"LPIPS: {results[name]['lpips']:.4f}, Edge: {results[name]['edge']:.4f}")
# Вывод таблицы с процентами
print("\n=== Сравнение с первой моделью (%) ===")
print(f"| {'Модель':30s} | {'MSE':>10s} | {'PSNR':>10s} | {'LPIPS':>10s} | {'Edge':>10s} |")
print(f"|{'-'*32}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|")
baseline = names[0]
for name in names:
# Для MSE, LPIPS и Edge: меньше = лучше, поэтому инвертируем
mse_pct = (results[baseline]["mse"] / results[name]["mse"]) * 100
# Для PSNR: больше = лучше
psnr_pct = (results[name]["psnr"] / results[baseline]["psnr"]) * 100
# Для LPIPS и Edge: меньше = лучше
lpips_pct = (results[baseline]["lpips"] / results[name]["lpips"]) * 100
edge_pct = (results[baseline]["edge"] / results[name]["edge"]) * 100
if name == baseline:
print(f"| {name:30s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} | {'100%':>10s} |")
else:
print(f"| {name:30s} | {f'{mse_pct:.1f}%':>10s} | {f'{psnr_pct:.1f}%':>10s} | "
f"{f'{lpips_pct:.1f}%':>10s} | {f'{edge_pct:.1f}%':>10s} |")
print("\n✅ Готово!")