import torch from dataloader import get_dataloaders from config import Config from noise_scheduler import FrequencyAwareNoise import matplotlib.pyplot as plt def debug_data(): config = Config() train_loader, _ = get_dataloaders(config) x0, _ = next(iter(train_loader)) # Visualize original plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.imshow(x0[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) plt.title("Original") # Visualize noisy noise_scheduler = FrequencyAwareNoise(config) xt = noise_scheduler.apply_noise(x0, torch.tensor([500] * len(x0))) plt.subplot(1, 2, 2) plt.imshow(xt[0].permute(1, 2, 0).numpy() * 0.5 + 0.5) plt.title("Noisy (t=500)") plt.show() if __name__ == "__main__": debug_data()