import os
import sys
import json
import argparse
import numpy as np
import math
import time
import random
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator, DeepSpeedPlugin

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

global_batch_size = 128 if num_devices <= 1 and utils.is_interactive(): deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json") if utils.is_interactive():
    # Example use
    jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
        --model_name=test \
        --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
        --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
    jupyter_args = jupyter_args.split()
    print(jupyter_args)
    from IPython.display import clear_output create global variables without the args prefix for attribute_name in vars(args).keys(): globals()[attribute_name] = getattr(args, attribute_name) print("global batch_size", batch_size) batch_size = int(batch_size / num_devices) print("batch_size", batch_size) # In[7]: outdir = os.path.abspath(f'../train_mem_logs/{model_name}') if not os.path.exists(outdir): os.makedirs(outdir,exist_ok=True) if use_image_aug: import kornia from kornia.augmentation.container import AugmentationSequential img_augment = AugmentationSequential( kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), kornia.augmentation.Resize((224, 224)), kornia.augmentation.RandomHorizontalFlip(p=0.3), kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), kornia.augmentation.RandomGrayscale(p=0.3), same_on_batch=False, data_keys=["input"], ) # # Prep data, models, and dataloaders # ## Dataloader # In[8]: if subj==1: num_train = 24958 num_test = 2770 test_batch_size = num_test def my_split_by_node(urls): return urls train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" print(train_url) train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ .shuffle(750, initial=1500, rng=random.Random(42))\ .decode("torch")\ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True) test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" print(test_url) test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ .shuffle(750, initial=1500, rng=random.Random(42))\ .decode("torch")\ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True) # f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
voxels = f['betas'][:]
print(f"subj0{subj} betas loaded into memory")
voxels = torch.Tensor(voxels).to("cpu").half()
if subj==1:
    voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
print("voxels", voxels.shape)
num_voxels = voxels.shape[-1] load orig images f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') images = f['images'][:] images = torch.Tensor(images).to("cpu").half() print("images", images.shape) # ## Load models # ### CLIP image embeddings model # In[11]: from models import Clipper clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) clip_seq_dim = 257 clip_emb_dim = 768 hidden_dim = 4096 # ### SD VAE (blurry images) # In[12]: from diffusers import AutoencoderKL autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache") # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"]) autoenc.eval() autoenc.requires_grad_(False) autoenc.to(device) utils.count_params(autoenc) # ### MindEye modules # In[13]: class MindEyeModule(nn.Module): def __init__(self): super(MindEyeModule, self).__init__() def forward(self, x): return x model = MindEyeModule() model # In[14]: class RidgeRegression(torch.nn.Module): # make sure to add weight_decay when initializing optimizer def __init__(self, input_size, out_features): super(RidgeRegression, self).__init__() self.out_features = out_features self.linear = torch.nn.Linear(input_size, out_features) def forward(self, x): return self.linear(x) model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim) utils.count_params(model.ridge) utils.count_params(model) b = torch.randn((2,1,voxels.shape[1])) print(b.shape, model.ridge(b).shape) # In[15]: from functools import partial from diffusers.models.vae import Decoder class BrainNetwork(nn.Module): def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16): super().__init__() self.blurry_dim = blurry_dim norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) self.lin0 = nn.Linear(in_dim, h) self.mlp = nn.ModuleList([ nn.Sequential( nn.Linear(h, h), *[item() for item in act_and_norm], nn.Dropout(drop) ) for _ in range(n_blocks) ]) self.lin1 = nn.Linear(h, out_dim, bias=True) self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True) self.n_blocks = n_blocks self.clip_size = clip_size self.clip_proj = nn.Sequential( nn.LayerNorm(clip_size), nn.GELU(), nn.Linear(clip_size, 2048), nn.LayerNorm(2048), nn.GELU(), nn.Linear(2048, 2048), nn.LayerNorm(2048), nn.GELU(), nn.Linear(2048, clip_size) ) self.upsampler = Decoder( in_channels=64, out_channels=4, up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], block_out_channels=[64, 128, 256], layers_per_block=1, ) def forward(self, x): x = self.lin0(x) residual = x for res_block in range(self.n_blocks): x = self.mlp[res_block](x) x += residual residual = x x = x.reshape(len(x), -1) x = self.lin1(x) b = self.blin1(x) b = self.upsampler(b.reshape(len(b), -1, 7, 7)) c = self.clip_proj(x.reshape(len(x), -1, self.clip_size)) return c, b model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) utils.count_params(model.backbone) utils.count_params(model) b = torch.randn((2,hidden_dim)) print(b.shape) clip_, blur_ = model.backbone(b) print(clip_.shape, blur_.shape) # In[19]: # memory model from timm.layers.mlp import Mlp class MemoryEncoder(nn.Module): def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.15): super().__init__() norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) self.out_dim = out_dim self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim) self.final_input_dim = in_dim + embedding_time_dim self.lin0 = nn.Linear(self.final_input_dim, h) self.mlp = nn.ModuleList([ nn.Sequential( nn.Linear(h, h), *[item() for item in act_and_norm], nn.Dropout(drop) ) for _ in range(n_blocks) ]) self.lin1 = nn.Linear(h, out_dim, bias=True) self.n_blocks = n_blocks self.num_past_voxels = num_past_voxels self.embedding_time_dim = embedding_time_dim self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim))) def forward(self, x, time): time = time.long() time = self.embedding_time(time) x = torch.cat((x, time), dim=-1) x = self.lin0(x) residual = x for res_block in range(self.n_blocks): x = self.mlp[res_block](x) x += residual residual = x x = x.reshape(len(x), -1) x = self.lin1(x) return x # class MemoryCompressor(nn.Module):
    def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15):
        super().__init__()
        self.num_past = num_past
        norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
        act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
        act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
        
        self.final_input_dim = in_dim * num_past
        self.lin0 = nn.Linear(self.final_input_dim, h)
        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(h, h),
                *[item() for item in act_and_norm],
                nn.Dropout(drop)
            )
            for _ in range(n_blocks)
        ])
        self.lin1 = nn.Linear(h, output_dim, bias=True)
        self.n_blocks = n_blocks
        self.num_past = num_past
        self.output_dim = output_dim
    
    def forward(self, x): x is (batch_size, num_past, in_dim) x = x.reshape(len(x), -1) x = self.lin0(x) residual = x for res_block in range(self.n_blocks): x = self.mlp[res_block](x) x += residual residual = x x = x.reshape(len(x), -1) x = self.lin1(x) return x # # test the memory compressor # memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768) # device = torch.device("cpu") # memory_compressor.to(device) # # count params # total_parameters = 0 # for parameter in memory_compressor.parameters(): # total_parameters += parameter.numel() # rand_input = torch.randn((2, 15, 768)).to(device) # print(rand_input.shape) # memory_compressor(rand_input).shape model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512) model.memory_compressor = MemoryCompressor(in_dim=model.memory_encoder.out_dim, num_past=15, output_dim=4096) utils.count_params(model.memory_encoder) utils.count_params(model.memory_compressor) utils.count_params(model) # In[17]: no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] opt_grouped_parameters = [ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, {'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2}, {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2}, ] optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95)) if lr_scheduler_type == 'linear': lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, total_iters=int(num_epochs*(num_train*num_devices//batch_size)), last_epoch=-1 ) elif lr_scheduler_type == 'cycle': total_steps=int(num_epochs*(num_train*num_devices//batch_size)) lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=max_lr, total_steps=total_steps, final_div_factor=1000, last_epoch=-1, pct_start=2/num_epochs ) def save_ckpt(tag): ckpt_path = outdir+f'/{tag}.pth' print(f'saving {ckpt_path}',flush=True) unwrapped_model = accelerator.unwrap_model(model) try: torch.save({ 'epoch': epoch, 'model_state_dict': unwrapped_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'train_losses': losses, 'test_losses': test_losses, 'lrs': lrs, }, ckpt_path) except: print("Couldn't save... moving on to prevent crashing.") del unwrapped_model print("\nDone with model preparations!") utils.count_params(model) # # Weights and Biases # In[ ]: # params for wandb wandb_log = True if local_rank==0 and wandb_log: # only use main process for wandb logging import wandb wandb_project = 'stability' wandb_run = model_name wandb_notes = '' print(f"wandb {wandb_project} run {wandb_run}") wandb.login(host='https://stability.wandb.io')#, relogin=True) wandb_config = { "model_name": model_name, "batch_size": batch_size, "num_epochs": num_epochs, "use_image_aug": use_image_aug, "max_lr": max_lr, "lr_scheduler_type": lr_scheduler_type, "mixup_pct": mixup_pct, "num_train": num_train, "num_test": num_test, "seed": seed, "distributed": distributed, "num_devices": num_devices, "world_size": world_size, } print("wandb_config:\n",wandb_config) if False: # wandb_auto_resume print("wandb_id:",model_name) wandb.init( id = model_name, project=wandb_project, name=wandb_run, config=wandb_config, notes=wandb_notes, resume="allow", ) else: wandb.init( project=wandb_project, name=model_name, config=wandb_config, notes=wandb_notes, ) else: wandb_log = False # # More custom functions # In[ ]: # using the same preprocessing as was used in MindEye + BrainDiffuser pixcorr_preprocess = transforms.Compose([ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), ]) def pixcorr(images,brains): # Flatten images while keeping the batch dimension all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1) all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1) corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean() return corrmean # # Main # In[ ]: epoch = 0 losses, test_losses, lrs = [], [], [] best_test_loss = 1e9 soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) # Optionally resume from checkpoint # if resume_from_ckpt: print("\n---resuming from last.pth ckpt---\n") try: checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') except: print('last.pth failed... trying last_backup.pth') checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') epoch = checkpoint['epoch'] print("Epoch",epoch) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) del checkpoint elif wandb_log: if wandb.run.resumed: print("\n---resuming from last.pth ckpt---\n") try: checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') except: print('last.pth failed... trying last_backup.pth') checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') epoch = checkpoint['epoch'] print("Epoch",epoch) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) del checkpoint torch.cuda.empty_cache() # In[ ]: model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare( model, optimizer, train_dl, test_dl, lr_scheduler ) # In[ ]: print(f"{model_name} starting with epoch {epoch} / {num_epochs}") progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0)) test_image, test_voxel = None, None mse = nn.MSELoss() for epoch in progress_bar: model.train() fwd_percent_correct = 0. bwd_percent_correct = 0. test_fwd_percent_correct = 0. test_bwd_percent_correct = 0. loss_clip_total = 0. loss_blurry_total = 0. test_loss_clip_total = 0. test_loss_blurry_total = 0. blurry_pixcorr = 0. test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1 for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): #if epoch == 0 or epoch == 1: # break with torch.cuda.amp.autocast(): optimizer.zero_grad() voxel = voxels[behav[:,0,5].cpu().long()].to(device) image = images[behav[:,0,0].cpu().long()].to(device).float() past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279 past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15 blurry_image_enc = autoenc.encode(image).latent_dist.mode() if use_image_aug: image = img_augment(image) clip_target = clip_model.embed_image(image) assert not torch.any(torch.isnan(clip_target)) if epoch < int(mixup_pct * num_epochs): voxel, perm, betas, select = utils.mixco(voxel) # reshape past voxels to be (batch_size * 15, 15279) past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) past_15_times = past_15_times.repeat(voxel.shape[0], 1) past_15_times = past_15_times.reshape(-1) #print(past_15_voxels.shape, past_15_times.shape) embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times) #print(embeds_past_voxels.shape) embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1) #print(embeds_past_voxels.shape) information_past_voxels = model.memory_compressor(embeds_past_voxels) voxel_ridge = model.ridge(voxel) + information_past_voxels clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) if epoch < int(mixup_pct * num_epochs): loss_clip = utils.mixco_nce( clip_voxels_norm, clip_target_norm, temp=.006, perm=perm, betas=betas, select=select) else: epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] loss_clip = utils.soft_clip_loss( clip_voxels_norm, clip_target_norm, temp=epoch_temp) loss_blurry = mse(blurry_image_enc_, blurry_image_enc) loss_clip_total += loss_clip.item() loss_blurry_total += loss_blurry.item() loss = loss_blurry + loss_clip utils.check_loss(loss) accelerator.backward(loss) optimizer.step() losses.append(loss.item()) lrs.append(optimizer.param_groups[0]['lr']) # forward and backward top 1 accuracy labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) with torch.no_grad(): # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode() random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False) blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1) blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images) if lr_scheduler_type is not None: lr_scheduler.step() model.eval() for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): print('test') with torch.cuda.amp.autocast(): with torch.no_grad(): # all test samples should be loaded per batch such that test_i should never exceed 0 if len(behav) != num_test: print("!",len(behav),num_test) ## Average same-image repeats ## if test_image is None: voxel = voxels[behav[:,0,5].cpu().long()].to(device) image = behav[:,0,0].cpu().long() unique_image, sort_indices = torch.unique(image, return_inverse=True) for im in unique_image: locs = torch.where(im == image)[0] if test_image is None: test_image = images[im][None] test_voxel = torch.mean(voxel[locs],axis=0)[None] else: test_image = torch.vstack((test_image, images[im][None])) test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) # sample of batch_size random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300] voxel = test_voxel[random_indices].to(device) image = test_image[random_indices].to(device) current_past_behav = past_behav[random_indices] past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279 past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15 assert len(image) == batch_size blurry_image_enc = autoenc.encode(image).latent_dist.mode() clip_target = clip_model.embed_image(image.float()) past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) past_15_times = past_15_times.repeat(voxel.shape[0], 1) past_15_times = past_15_times.reshape(-1) print(past_15_voxels.shape, past_15_times.shape) embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times) embeds_past_voxels = embeds_past_voxels.reshape(batch_size, 15, -1) information_past_voxels = model.memory_compressor(embeds_past_voxels) voxel_ridge = model.ridge(voxel) + information_past_voxels clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) loss_clip = utils.soft_clip_loss( clip_voxels_norm, clip_target_norm, temp=.006) loss_blurry = mse(blurry_image_enc_, blurry_image_enc) loss = loss_blurry + loss_clip utils.check_loss(loss) test_losses.append(loss.item()) # forward and backward top 1 accuracy labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) # halving the batch size because the decoder is computationally heavy blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1) blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1))) test_blurry_pixcorr += pixcorr(image, blurry_recon_images) # transform blurry recon latents to images and plot it fig, axes = plt.subplots(1, 4, figsize=(8, 4)) axes[0].imshow(utils.torch_to_Image(image[[0]])) axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1))) axes[2].imshow(utils.torch_to_Image(image[[1]])) axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1))) axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off') plt.show() if local_rank==0: # if utils.is_interactive(): clear_output(wait=True) assert (test_i+1) == 1 logs = {"train/loss": np.mean(losses[-(train_i+1):]), "test/loss": np.mean(test_losses[-(test_i+1):]), "train/lr": lrs[-1], "train/num_steps": len(losses), "test/num_steps": len(test_losses), "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), "train/loss_clip_total": loss_clip_total / (train_i + 1), "train/loss_blurry_total": loss_blurry_total / (train_i + 1), "test/loss_clip_total": test_loss_clip_total / (test_i + 1), "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), } progress_bar.set_postfix(**logs) # Save model checkpoint and reconstruct if epoch % ckpt_interval == 0: if not utils.is_interactive(): save_ckpt(f'last') if wandb_log: wandb.log(logs) # wait for other GPUs to catch up if needed accelerator.wait_for_everyone() torch.cuda.empty_cache() gc.collect() print("\n===Finished!===\n") if ckpt_saving: save_ckpt(f'last') if not utils.is_interactive(): sys.exit(0) # In[ ]: plt.plot(losses) plt.show() plt.plot(test_losses) plt.show()