#!/usr/bin/env python # coding: utf-8 # In[1]: # # Code to convert this notebook to .py if you want to run it via command line or with Slurm # from subprocess import call # command = "jupyter nbconvert Train.ipynb --to python" # call(command,shell=True) # # Import packages & functions # In[2]: import os import sys import json import argparse import numpy as np import math from einops import rearrange import time import random import h5py from tqdm import tqdm import webdataset as wds import gc import matplotlib.pyplot as plt import torch import torch.nn as nn from torchvision import transforms from torchvision.transforms import ToPILImage #CHANGED (added) from accelerate import Accelerator, DeepSpeedPlugin # tf32 data type is faster than standard float32 torch.backends.cuda.matmul.allow_tf32 = True # custom functions # import utils global_batch_size = 128 #128 # In[3]: ### Multi-GPU config ### local_rank = os.getenv('RANK') if local_rank is None: local_rank = 0 else: local_rank = int(local_rank) print("LOCAL RANK ", local_rank) num_devices = torch.cuda.device_count() if num_devices==0: num_devices = 1 accelerator = Accelerator(split_batches=False) ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ### # if num_devices <= 1 and utils.is_interactive(): # # can emulate a distributed environment for deepspeed to work in jupyter notebook # os.environ["MASTER_ADDR"] = "localhost" # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000) # os.environ["RANK"] = "0" # os.environ["LOCAL_RANK"] = "0" # os.environ["WORLD_SIZE"] = "1" # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size! # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"] # # alter the deepspeed config according to your global and local batch size # if local_rank == 0: # with open('deepspeed_config_stage2.json', 'r') as file: # config = json.load(file) # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"]) # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices # with open('deepspeed_config_stage2.json', 'w') as file: # json.dump(config, file) # else: # # give some time for the local_rank=0 gpu to prep new deepspeed config file # time.sleep(10) # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json") # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin) # In[4]: print("PID of this process =",os.getpid()) device = accelerator.device print("device:",device) num_workers = num_devices print(accelerator.state) world_size = accelerator.state.num_processes distributed = not accelerator.state.distributed_type == 'NO' print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size) print = accelerator.print # only print if local_rank=0 # # Configurations # In[5]: # if running this interactively, can specify jupyter_args here for argparser to use if utils.is_interactive(): # Example use jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \ --model_name=captions \ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \ --max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug" #max_lr=3e-5 originally jupyter_args = jupyter_args.split() print(jupyter_args) from IPython.display import clear_output # function to clear print outputs in cell get_ipython().run_line_magic('load_ext', 'autoreload') # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions get_ipython().run_line_magic('autoreload', '2') # In[6]: parser = argparse.ArgumentParser(description="Model Training Configuration") parser.add_argument( "--model_name", type=str, default="testing", help="name of model, used for ckpt saving and wandb logging (if enabled)", ) parser.add_argument( "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset", help="Path to where NSD data is stored / where to download it to", ) parser.add_argument( "--subj",type=int, default=1, choices=[1,2,5,7], ) parser.add_argument( "--batch_size", type=int, default=32, help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser", ) parser.add_argument( "--wandb_log",action=argparse.BooleanOptionalAction,default=False, help="whether to log to wandb", ) parser.add_argument( "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False, help="if not using wandb and want to resume from a ckpt", ) parser.add_argument( "--wandb_project",type=str,default="stability", help="wandb project name", ) parser.add_argument( "--mixup_pct",type=float,default=.33, help="proportion of way through training when to switch from BiMixCo to SoftCLIP", ) parser.add_argument( "--use_image_aug",action=argparse.BooleanOptionalAction,default=True, help="whether to use image augmentation", ) parser.add_argument( "--num_epochs",type=int,default=100, help="number of epochs of training", ) parser.add_argument( "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], ) parser.add_argument( "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True, ) parser.add_argument( "--ckpt_interval",type=int,default=5, help="save backup ckpt and reconstruct every x epochs", ) parser.add_argument( "--seed",type=int,default=42, ) parser.add_argument( "--max_lr",type=float,default=3e-4, ) parser.add_argument( "--n_samples_save",type=int,default=0,choices=[0,1], help="Number of reconstructions for monitoring progress, 0 will speed up training", ) parser.add_argument( "--clip_mse_ratio",type=float,default=0.7, help="Number of reconstructions for monitoring progress, 0 will speed up training", ) if utils.is_interactive(): args = parser.parse_args(jupyter_args) else: args = parser.parse_args() # create global variables without the args prefix for attribute_name in vars(args).keys(): globals()[attribute_name] = getattr(args, attribute_name) print("global batch_size", batch_size) batch_size = int(batch_size / num_devices) print("batch_size", batch_size) # In[7]: outdir = os.path.abspath(f'../train_logs/{model_name}') if not os.path.exists(outdir): os.makedirs(outdir,exist_ok=True) if use_image_aug: import kornia from kornia.augmentation.container import AugmentationSequential img_augment = AugmentationSequential( kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), kornia.augmentation.Resize((224, 224)), kornia.augmentation.RandomHorizontalFlip(p=0.3), kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), kornia.augmentation.RandomGrayscale(p=0.3), same_on_batch=False, data_keys=["input"], ) # In[8]: wandb_log = True # # Prep data, models, and dataloaders # ## Dataloader # In[9]: if subj==1: num_train = 24958 num_test = 2770 test_batch_size = num_test def my_split_by_node(urls): return urls train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" print(train_url) train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ .shuffle(750, initial=1500, rng=random.Random(42))\ .decode("torch")\ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True) test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" print(test_url) test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ .shuffle(750, initial=1500, rng=random.Random(42))\ .decode("torch")\ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True) # ### check dataloaders are working # In[10]: # test_indices = [] # test_images = [] # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): # test_indices = np.append(test_indices, behav[:,0,5].numpy()) # test_images = np.append(test_images, behav[:,0,0].numpy()) # test_indices = test_indices.astype(np.int16) # print(test_i, (test_i+1) * test_batch_size, len(test_indices)) # print("---\n") # train_indices = [] # train_images = [] # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): # train_indices = np.append(train_indices, behav[:,0,5].long().numpy()) # train_images = np.append(train_images, behav[:,0,0].numpy()) # train_indices = train_indices.astype(np.int16) # print(train_i, (train_i+1) * batch_size, len(train_indices)) # # train_images = np.hstack((train_images, test_images)) # # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES") # ## Load data and images # In[11]: # load betas f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r') voxels = f['betas'][:] print(f"subj0{subj} betas loaded into memory") voxels = torch.Tensor(voxels).to("cpu").half() if subj==1: voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5)))) print("voxels", voxels.shape) num_voxels = voxels.shape[-1] # load orig images f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') images = f['images'][:] images = torch.Tensor(images).to("cpu").half() print("images", images.shape) # ## Load models # ### CLIP image embeddings model # In[12]: import transformers from transformers import Blip2Processor, Blip2ForConditionalGeneration from PIL import Image # In[13]: from models import Clipper clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) # In[14]: cache_blip2 = "/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b" b2_processor = Blip2Processor.from_pretrained(cache_blip2) b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map="auto") #Load in blip2 as well """from lavis.models import load_model_and_preprocess from lavis.models import model_zoo blip2_model, vis_processors, _ = load_model_and_preprocess( name="blip2_t5", model_type="pretrain_flant5xl_vitL", is_eval=True, device=device) clip_seq_dim = 257 clip_emb_dim = 1024 hidden_dim = 4096""" # In[15]: def embed_images_b2(images): images = (images * 255).type(torch.uint8) with torch.no_grad(): inputs_processed = b2_processor(images, return_tensors="pt").to("cuda", torch.float16) enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values']) return enc_imgs.last_hidden_state.detach(), inputs_processed def embeds_to_captions_b2(embeds): with torch.no_grad(): input_ids = None #inputs['input_ids'] attention_mask = None batch_size = embeds.shape[0] image_embeds = embeds image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = b2_model.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) query_output = query_outputs.last_hidden_state language_model_inputs = b2_model.language_projection(query_output) language_attention_mask = torch.ones( language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device ) if input_ids is None: input_ids = ( torch.LongTensor([[b2_model.config.text_config.bos_token_id]]) .repeat(batch_size, 1) .to(image_embeds.device) ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) # concatenate query embeddings with prompt embeddings inputs_embeds = b2_model.get_input_embeddings()(input_ids) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) outputs = b2_model.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, ) text = b2_processor.batch_decode(outputs, skip_special_tokens=True) return outputs, text # In[16]: image_test = images[1:20].permute(0,2,3,1) #raw_image = Image.open('/fsx/proj-fmri/shared/controlNetData/target/img_t1.jpg').convert('RGB') # Convert the image to a NumPy array #image_test = np.array(raw_image) # In[17]: """import matplotlib.pyplot as plt # Plotting one of the images (taking the first image as an example) img_to_plot = inputs_rec['pixel_values'][-1] # Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C]) img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu') print(img_to_plot.shape) plt.imshow(img_to_plot) plt.show()""" # In[18]: embeds_test, inputs_rec = embed_images_b2(image_test) # In[19]: #inputs_rec['pixel_values'].shape # In[20]: #out = b2_model.generate(**inputs_rec) #print(b2_processor.decode(out[0], skip_special_tokens=True).strip()) # In[21]: outputs_test, text_test = embeds_to_captions_b2(embeds_test) # In[22]: text_test # In[23]: #inputss['pixel_values'].shape # In[24]: #image_test.shape # In[25]: # In[26]: clip_seq_dim = 257 #blip2 image encoder shapes clip_emb_dim = 1408 #blip2 image encoder shapes hidden_dim = 2048 # ### SD VAE (blurry images) # In[27]: from diffusers import AutoencoderKL autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache") # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"]) autoenc.eval() autoenc.requires_grad_(False) autoenc.to(device) utils.count_params(autoenc) # ### MindEye modules # In[28]: class MindEyeModule(nn.Module): def __init__(self): super(MindEyeModule, self).__init__() def forward(self, x): return x model = MindEyeModule() model # In[29]: class RidgeRegression(torch.nn.Module): # make sure to add weight_decay when initializing optimizer def __init__(self, input_size, out_features): super(RidgeRegression, self).__init__() self.out_features = out_features self.linear = torch.nn.Linear(input_size, out_features) def forward(self, x): return self.linear(x) model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim) utils.count_params(model.ridge) utils.count_params(model) b = torch.randn((2,1,voxels.shape[1])) print(b.shape, model.ridge(b).shape) # In[30]: from functools import partial from diffusers.models.vae import Decoder class BrainNetwork(nn.Module): def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.35, blurry_dim=16): super().__init__() self.blurry_dim = blurry_dim norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) self.lin0 = nn.Linear(in_dim, h) self.mlp = nn.ModuleList([ nn.Sequential( nn.Linear(h, h), *[item() for item in act_and_norm], nn.Dropout(drop) ) for _ in range(n_blocks) ]) self.lin1 = nn.Linear(h, out_dim, bias=True) # self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True) self.n_blocks = n_blocks self.clip_size = clip_size self.clip_proj = nn.Sequential( nn.LayerNorm(clip_size), nn.GELU(), nn.Linear(clip_size, 2048), nn.LayerNorm(2048), nn.GELU(), nn.Linear(2048, 2048), nn.LayerNorm(2048), nn.GELU(), nn.Linear(2048, clip_size) ) # self.upsampler = Decoder( # in_channels=64, # out_channels=4, # up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], # block_out_channels=[64, 128, 256], # layers_per_block=1, # ) def forward(self, x): x = self.lin0(x) residual = x for res_block in range(self.n_blocks): x = self.mlp[res_block](x) x += residual residual = x x = x.reshape(len(x), -1) x = self.lin1(x) # b = self.blin1(x) # b = self.upsampler(b.reshape(len(b), -1, 7, 7)) c = self.clip_proj(x.reshape(len(x), -1, self.clip_size)) # return c, b return c model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) utils.count_params(model.backbone) utils.count_params(model) b = torch.randn((4,hidden_dim)) print(b.shape) clip_ = model.backbone(b) print(clip_.shape) # In[31]: no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] opt_grouped_parameters = [ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, ] optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95)) if lr_scheduler_type == 'linear': lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, total_iters=int(num_epochs*(num_train*num_devices//batch_size)), last_epoch=-1 ) elif lr_scheduler_type == 'cycle': total_steps=int(num_epochs*(num_train*num_devices//batch_size)) lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=max_lr, total_steps=total_steps, final_div_factor=1000, last_epoch=-1, pct_start=2/num_epochs ) def save_ckpt(tag): ckpt_path = outdir+f'/{tag}.pth' print(f'saving {ckpt_path}',flush=True) unwrapped_model = accelerator.unwrap_model(model) try: torch.save({ 'epoch': epoch, 'model_state_dict': unwrapped_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'train_losses': losses, 'test_losses': test_losses, 'lrs': lrs, }, ckpt_path) except: print("Couldn't save... moving on to prevent crashing.") del unwrapped_model print("\nDone with model preparations!") utils.count_params(model) # # Weights and Biases # In[32]: # params for wandb if local_rank==0 and True: # only use main process for wandb logging import wandb wandb_project = 'mindeyev2' wandb_run = model_name wandb_notes = '' print(f"wandb {wandb_project} run {wandb_run}") wandb.login(host='https://stability.wandb.io')#, relogin=True) wandb_config = { "model_name": model_name, "batch_size": batch_size, "num_epochs": num_epochs, "use_image_aug": use_image_aug, "max_lr": max_lr, "lr_scheduler_type": lr_scheduler_type, "mixup_pct": mixup_pct, "num_train": num_train, "num_test": num_test, "seed": seed, "distributed": distributed, "num_devices": num_devices, "world_size": world_size, } print("wandb_config:\n",wandb_config) if False: # wandb_auto_resume print("wandb_id:",model_name) wandb.init( id = model_name, project=wandb_project, name=wandb_run, config=wandb_config, notes=wandb_notes, resume="allow", ) else: wandb.init( project=wandb_project, name=wandb_run, config=wandb_config, notes=wandb_notes, ) else: wandb_log = False # # More custom functions # In[33]: # using the same preprocessing as was used in MindEye + BrainDiffuser pixcorr_preprocess = transforms.Compose([ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), ]) def pixcorr(images,brains): # Flatten images while keeping the batch dimension all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1) all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1) corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean() return corrmean # # Main # In[34]: epoch = 0 losses, test_losses, lrs = [], [], [] best_test_loss = 1e9 soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) # Optionally resume from checkpoint # if resume_from_ckpt: print("\n---resuming from last.pth ckpt---\n") try: checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') except: print('last.pth failed... trying last_backup.pth') checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') epoch = checkpoint['epoch'] print("Epoch",epoch) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) del checkpoint elif wandb_log: if wandb.run.resumed: print("\n---resuming from last.pth ckpt---\n") try: checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') except: print('last.pth failed... trying last_backup.pth') checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') epoch = checkpoint['epoch'] print("Epoch",epoch) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) del checkpoint torch.cuda.empty_cache() # In[35]: model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare( model, optimizer, train_dl, test_dl, lr_scheduler ) # In[36]: """transform = transforms.Compose( [ transforms.Resize( (224, 224), ), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ] ) def tensor_2_embed(image): image_for_blip2 = transform(image) #Generate embeddings with blip2_model.maybe_autocast(): blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2)) return blip2_target def embed_2_caption(image_embeds, model): image_embeds = image_embeds.float() image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device) query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = model.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True) inputs_t5 = model.t5_proj(query_output.last_hidden_state) atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) prompt = model.prompt input_tokens = model.t5_tokenizer( prompt, padding="longest", return_tensors="pt" ).to(image.device) encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) with model.maybe_autocast(dtype=torch.bfloat16): inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids) inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) outputs = model.t5_model.generate( inputs_embeds=inputs_embeds, attention_mask=encoder_atts) output_text = model.t5_tokenizer.batch_decode( outputs, skip_special_tokens=True) return output_text""" # In[37]: wandb_log = True # In[ ]: print(f"{model_name} starting with epoch {epoch} / {num_epochs}") progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0)) test_image, test_voxel = None, None mse = nn.MSELoss() for epoch in progress_bar: model.train() fwd_percent_correct = 0. bwd_percent_correct = 0. test_fwd_percent_correct = 0. test_bwd_percent_correct = 0. loss_clip_total = 0. loss_blurry_total = 0. test_loss_clip_total = 0. test_loss_blurry_total = 0. blurry_pixcorr = 0. test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1 for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): if epoch == 0: lrs.append(0) break with torch.cuda.amp.autocast(): optimizer.zero_grad() voxel = voxels[behav[:,0,5].cpu().long()].to(device) image = images[behav[:,0,0].cpu().long()].to(device).float() # blurry_image_enc = autoenc.encode(image).latent_dist.mode() if use_image_aug: image = img_augment(image) # clip_target = clip_model.embed_image(image) clip_target = embed_images_b2(image)[0].to(device) #####CHANGED assert not torch.any(torch.isnan(clip_target)) if epoch < int(mixup_pct * num_epochs): voxel, perm, betas, select = utils.mixco(voxel) voxel_ridge = model.ridge(voxel) # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) clip_voxels = model.backbone(voxel_ridge) clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) if epoch < int(mixup_pct * num_epochs): loss_clip = utils.mixco_nce( clip_voxels_norm, clip_target_norm, temp=.006, perm=perm, betas=betas, select=select) else: epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] loss_clip = utils.soft_clip_loss( clip_voxels_norm, clip_target_norm, temp=epoch_temp) loss_mse= mse(clip_voxels, clip_target) # loss_blurry = mse(blurry_image_enc_, blurry_image_enc) loss_clip_total += loss_clip.item() # loss_blurry_total += loss_blurry.item() # loss = loss_blurry + loss_clip loss = (clip_mse_ratio * loss_clip) + ((1 - clip_mse_ratio) * loss_mse) if (train_i % 10 == 0): print(train_i, loss) # print(batch_size) utils.check_loss(loss) accelerator.backward(loss) optimizer.step() losses.append(loss.item()) lrs.append(optimizer.param_groups[0]['lr']) # forward and backward top 1 accuracy labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) # with torch.no_grad(): # # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode() # random_samps = np.random.choice(np.arange(len(voxel)), size=8, replace=False) # blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1) # blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images) if lr_scheduler_type is not None: lr_scheduler.step() model.eval() for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): with torch.cuda.amp.autocast(): with torch.no_grad(): # all test samples should be loaded per batch such that test_i should never exceed 0 if len(behav) != num_test: print("!",len(behav),num_test) ## Average same-image repeats ## if test_image is None: voxel = voxels[behav[:,0,5].cpu().long()].to(device) image = behav[:,0,0].cpu().long() unique_image, sort_indices = torch.unique(image, return_inverse=True) for im in unique_image: locs = torch.where(im == image)[0] if test_image is None: test_image = images[im][None] test_voxel = torch.mean(voxel[locs],axis=0)[None] else: test_image = torch.vstack((test_image, images[im][None])) test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) # sample of batch_size random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300] voxel = test_voxel[random_indices].to(device) image = test_image[random_indices].to(device) assert len(image) == batch_size # blurry_image_enc = autoenc.encode(image).latent_dist.mode() # clip_target = clip_model.embed_image(image.float()) clip_target = embed_images_b2(image)[0].to(device) #####CHANGED voxel_ridge = model.ridge(voxel) # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) clip_voxels = model.backbone(voxel_ridge) clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) loss_clip = utils.soft_clip_loss( clip_voxels_norm, clip_target_norm, temp=.006) loss_mse = mse(clip_voxels, clip_target) # loss_blurry = mse(blurry_image_enc_, blurry_image_enc) # loss = loss_blurry + loss_clip loss = (clip_mse_ratio * loss_clip) + ((1 - clip_mse_ratio) * loss_mse) utils.check_loss(loss) test_losses.append(loss.item()) # forward and backward top 1 accuracy labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) # # halving the batch size because the decoder is computationally heavy # blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1) # blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1))) # test_blurry_pixcorr += pixcorr(image, blurry_recon_images) #Find captions and print next to images #caption1 = embed_2_caption(clip_voxels[[0]], blip2_model) #caption2 = embed_2_caption(clip_voxels[[1]], blip2_model) #true_embed1 = tensor_2_embed(image[[0]]) #true_embed2 = tensor_2_embed(image[[1]]) # print(clip_voxels[[0]].shape) # print(true_embed1.shape) #true_caption1 = embed_2_caption(true_embed1, blip2_model) #true_caption2 = embed_2_caption(true_embed2, blip2_model) # transform blurry recon latents to images and plot it #fig, axes = plt.subplots(2, 2, figsize=(8, 4)) #axes[0,0].imshow(utils.torch_to_Image(image[[0]])) #axes[0,1].imshow(utils.torch_to_Image(image[[1]])) #axes[0,0].axis('off'); axes[0,1].axis('off'); axes[1,0].axis('off'); axes[1,1].axis('off') #axes[0,0].set_title(caption1) #axes[0,1].set_title(caption2) #axes[1,0].set_title(true_caption1) #axes[1,1].set_title(true_caption2) #plt.show() # # transform blurry recon latents to images and plot it # fig, axes = plt.subplots(1, 4, figsize=(8, 4)) # axes[0].imshow(utils.torch_to_Image(image[[0]])) # axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1))) # axes[2].imshow(utils.torch_to_Image(image[[1]])) # axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1))) # axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off') # axes[0].set_title(caption1) # axes[3].set_title(caption2) # plt.show() if local_rank==0: # if utils.is_interactive(): clear_output(wait=True) assert (test_i+1) == 1 logs = {"train/loss": np.mean(losses[-(train_i+1):]), "test/loss": np.mean(test_losses[-(test_i+1):]), "train/lr": lrs[-1], "train/num_steps": len(losses), "test/num_steps": len(test_losses), "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), "train/loss_clip_total": loss_clip_total / (train_i + 1), "train/loss_blurry_total": loss_blurry_total / (train_i + 1), "test/loss_clip_total": test_loss_clip_total / (test_i + 1), "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), } progress_bar.set_postfix(**logs) fig, axes = plt.subplots(1, 8, figsize=(10, 4)) jj=-1 for j in [0,1,2,3,4,5,6,7]: jj+=1 axes[jj].imshow(utils.torch_to_Image(image[j])) axes[jj].axis('off') if wandb_log: generated_captions = embeds_to_captions_b2(clip_voxels[0:8]) print(generated_captions[1]) logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}" + "\n".join(generated_captions[1])) plt.close() # Save model checkpoint and reconstruct if epoch % ckpt_interval == 0: if not utils.is_interactive(): save_ckpt(f'last') if wandb_log: wandb.log(logs) # wait for other GPUs to catch up if needed accelerator.wait_for_everyone() torch.cuda.empty_cache() gc.collect() print("\n===Finished!===\n") if ckpt_saving: save_ckpt(f'last') if not utils.is_interactive(): sys.exit(0) # In[ ]: plt.plot(losses) plt.show() plt.plot(test_losses) plt.show() # In[ ]: print('test') # In[ ]: def tensor_2_embed_old(tensor): embed_array = torch.zeros((tensor.shape[0],257, 1024)) to_pil = ToPILImage() for sample in range(tensor.shape[0]): PIL_image = to_pil(tensor[sample]) image_for_blip2 = vis_processors["eval"](PIL_image).unsqueeze(0).to(device) #Generate embeddings with blip2_model.maybe_autocast(): blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2)) embed_array[sample] = blip2_target return embed_array # In[ ]: # In[ ]: # In[ ]: