mindeyev2old2 / src /Train-with-memory-rr.py
ckadirt's picture
Upload folder using huggingface_hub
b8ea2b2 verified
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
# # Code to convert this notebook to .py if you want to run it via command line or with Slurm
#from subprocess import call
#command = "jupyter nbconvert Train-with-memory-rr.ipynb --to python"
#call(command,shell=True)
# In[24]:
#get_ipython().system('nvidia-smi')
# # Import packages & functions
# In[3]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import h5py
from tqdm import tqdm
import webdataset as wds
import gc
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator, DeepSpeedPlugin
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
# custom functions #
import utils
global_batch_size = 512 #128
# In[4]:
### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None:
local_rank = 0
else:
local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)
num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1
accelerator = Accelerator(split_batches=False)
### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
# if num_devices <= 1 and utils.is_interactive():
# # can emulate a distributed environment for deepspeed to work in jupyter notebook
# os.environ["MASTER_ADDR"] = "localhost"
# os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
# os.environ["RANK"] = "0"
# os.environ["LOCAL_RANK"] = "0"
# os.environ["WORLD_SIZE"] = "1"
# os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
# global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
# # alter the deepspeed config according to your global and local batch size
# if local_rank == 0:
# with open('deepspeed_config_stage2.json', 'r') as file:
# config = json.load(file)
# config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
# config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
# with open('deepspeed_config_stage2.json', 'w') as file:
# json.dump(config, file)
# else:
# # give some time for the local_rank=0 gpu to prep new deepspeed config file
# time.sleep(10)
# deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
# In[5]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
num_workers = num_devices
print(accelerator.state)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
print = accelerator.print # only print if local_rank=0
# # Configurations
# In[6]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
# Example use
jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
--model_name=test \
--subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
--max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
jupyter_args = jupyter_args.split()
print(jupyter_args)
from IPython.display import clear_output # function to clear print outputs in cell
get_ipython().run_line_magic('load_ext', 'autoreload')
# this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
get_ipython().run_line_magic('autoreload', '2')
# In[7]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
"--model_name", type=str, default="memory_cat_rr",
help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
"--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
"--subj",type=int, default=1, choices=[1,2,5,7],
)
parser.add_argument(
"--batch_size", type=int, default=32,
help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
)
parser.add_argument(
"--wandb_log",action=argparse.BooleanOptionalAction,default=False,
help="whether to log to wandb",
)
parser.add_argument(
"--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
"--wandb_project",type=str,default="stability",
help="wandb project name",
)
parser.add_argument(
"--mixup_pct",type=float,default=.33,
help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
"--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
help="whether to use image augmentation",
)
parser.add_argument(
"--num_epochs",type=int,default=240,
help="number of epochs of training",
)
parser.add_argument(
"--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
"--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
"--ckpt_interval",type=int,default=5,
help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
"--seed",type=int,default=42,
)
parser.add_argument(
"--max_lr",type=float,default=3e-4,
)
parser.add_argument(
"--n_samples_save",type=int,default=0,choices=[0,1],
help="Number of reconstructions for monitoring progress, 0 will speed up training",
)
if utils.is_interactive():
args = parser.parse_args(jupyter_args)
else:
args = parser.parse_args()
# create global variables without the args prefix
for attribute_name in vars(args).keys():
globals()[attribute_name] = getattr(args, attribute_name)
print("global batch_size", batch_size)
batch_size = int(batch_size / num_devices)
print("batch_size", batch_size)
# In[8]:
outdir = os.path.abspath(f'../train_mem_logs/{model_name}')
if not os.path.exists(outdir):
os.makedirs(outdir,exist_ok=True)
if use_image_aug:
import kornia
from kornia.augmentation.container import AugmentationSequential
img_augment = AugmentationSequential(
kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
kornia.augmentation.Resize((224, 224)),
kornia.augmentation.RandomHorizontalFlip(p=0.3),
kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
kornia.augmentation.RandomGrayscale(p=0.3),
same_on_batch=False,
data_keys=["input"],
)
# # Prep data, models, and dataloaders
# ## Dataloader
# In[9]:
if subj==1:
num_train = 24958
num_test = 2770
test_batch_size = num_test
def my_split_by_node(urls): return urls
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
print(train_url)
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
.shuffle(750, initial=1500, rng=random.Random(42))\
.decode("torch")\
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
print(test_url)
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
.shuffle(750, initial=1500, rng=random.Random(42))\
.decode("torch")\
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
# ### check dataloaders are working
# In[10]:
# test_indices = []
# test_images = []
# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
# test_indices = np.append(test_indices, behav[:,0,5].numpy())
# test_images = np.append(test_images, behav[:,0,0].numpy())
# test_indices = test_indices.astype(np.int16)
# print(test_i, (test_i+1) * test_batch_size, len(test_indices))
# print("---\n")
# train_indices = []
# train_images = []
# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
# train_images = np.append(train_images, behav[:,0,0].numpy())
# train_indices = train_indices.astype(np.int16)
# print(train_i, (train_i+1) * batch_size, len(train_indices))
# # train_images = np.hstack((train_images, test_images))
# # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")
# ## Load data and images
# In[11]:
# load betas
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
voxels = f['betas'][:]
print(f"subj0{subj} betas loaded into memory")
voxels = torch.Tensor(voxels).to("cpu").half()
if subj==1:
voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
print("voxels", voxels.shape)
num_voxels = voxels.shape[-1]
# load orig images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'][:]
images = torch.Tensor(images).to("cpu").half()
print("images", images.shape)
# ## Load models
# ### CLIP image embeddings model
# In[12]:
from models import Clipper
clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
clip_seq_dim = 257
clip_emb_dim = 768
hidden_dim = 4096
# ### SD VAE (blurry images)
# In[13]:
from diffusers import AutoencoderKL
autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
autoenc.eval()
autoenc.requires_grad_(False)
autoenc.to(device)
utils.count_params(autoenc)
# ### MindEye modules
# In[14]:
class MindEyeModule(nn.Module):
def __init__(self):
super(MindEyeModule, self).__init__()
def forward(self, x):
return x
model = MindEyeModule()
model
# In[15]:
time_embedding_dim = 512
class RidgeRegression(torch.nn.Module):
# make sure to add weight_decay when initializing optimizer
def __init__(self, input_size, out_features):
super(RidgeRegression, self).__init__()
self.out_features = out_features
self.linear = torch.nn.Linear(input_size, out_features)
def forward(self, x):
return self.linear(x)
model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
utils.count_params(model.ridge)
utils.count_params(model)
b = torch.randn((2,1,voxels.shape[1]))
time_emb_test = torch.randn((2,1,time_embedding_dim))
print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
# In[16]:
from functools import partial
from diffusers.models.vae import Decoder
class BrainNetwork(nn.Module):
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.25, blurry_dim=16):
super().__init__()
self.blurry_dim = blurry_dim
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
self.lin0 = nn.Linear(in_dim, h)
self.mlp = nn.ModuleList([
nn.Sequential(
nn.Linear(h, h),
*[item() for item in act_and_norm],
nn.Dropout(drop)
) for _ in range(n_blocks)
])
self.lin1 = nn.Linear(h, out_dim, bias=True)
self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
self.n_blocks = n_blocks
self.clip_size = clip_size
self.clip_proj = nn.Sequential(
nn.LayerNorm(clip_size),
nn.GELU(),
nn.Linear(clip_size, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Linear(2048, 2048),
nn.LayerNorm(2048),
nn.GELU(),
nn.Linear(2048, clip_size)
)
self.upsampler = Decoder(
in_channels=64,
out_channels=4,
up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
block_out_channels=[64, 128, 256],
layers_per_block=1,
)
def forward(self, x):
x = self.lin0(x)
residual = x
for res_block in range(self.n_blocks):
x = self.mlp[res_block](x)
x += residual
residual = x
x = x.reshape(len(x), -1)
x = self.lin1(x)
b = self.blin1(x)
b = self.upsampler(b.reshape(len(b), -1, 7, 7))
c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
return c, b
model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
utils.count_params(model.backbone)
utils.count_params(model)
b = torch.randn((2,8192))
print(b.shape)
clip_, blur_ = model.backbone(b)
print(clip_.shape, blur_.shape)
# In[17]:
# memory model
from timm.layers.mlp import Mlp
class MemoryEncoder(nn.Module):
def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.25):
super().__init__()
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
self.out_dim = out_dim
self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
self.final_input_dim = in_dim + embedding_time_dim
self.lin0 = nn.Linear(self.final_input_dim, h)
self.mlp = nn.ModuleList([
nn.Sequential(
nn.Linear(h, h),
*[item() for item in act_and_norm],
nn.Dropout(drop)
) for _ in range(n_blocks)
])
self.lin1 = nn.Linear(h, out_dim, bias=True)
self.n_blocks = n_blocks
self.num_past_voxels = num_past_voxels
self.embedding_time_dim = embedding_time_dim
self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))
def forward(self, x, time):
time = time.long()
time = self.embedding_time(time)
x = torch.cat((x, time), dim=-1)
x = self.lin0(x)
residual = x
for res_block in range(self.n_blocks):
x = self.mlp[res_block](x)
x += residual
residual = x
x = x.reshape(len(x), -1)
x = self.lin1(x)
return x
# # test the memory encoder
# memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)
# device = torch.device("cpu")
# memory_encoder.to(device)
# # count params
# total_parameters = 0
# for parameter in memory_encoder.parameters():
# total_parameters += parameter.numel()
# rand_input = torch.randn((2, 15279)).to(device)
# rand_time = torch.randint(0, 15, (2,)).to(device)
# print(rand_input.shape, rand_time.shape)
# memory_encoder(rand_input, rand_time).shape
class MemoryCompressor(nn.Module):
def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.25):
super().__init__()
self.num_past = num_past
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
self.final_input_dim = in_dim * num_past
self.lin0 = nn.Linear(self.final_input_dim, h)
self.mlp = nn.ModuleList([
nn.Sequential(
nn.Linear(h, h),
*[item() for item in act_and_norm],
nn.Dropout(drop)
) for _ in range(n_blocks)
])
self.lin1 = nn.Linear(h, output_dim, bias=True)
self.n_blocks = n_blocks
self.num_past = num_past
self.output_dim = output_dim
def forward(self, x):
# x is (batch_size, num_past, in_dim)
x = x.reshape(len(x), -1)
x = self.lin0(x)
residual = x
for res_block in range(self.n_blocks):
x = self.mlp[res_block](x)
x += residual
residual = x
x = x.reshape(len(x), -1)
x = self.lin1(x)
return x
# # test the memory compressor
# memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)
# device = torch.device("cpu")
# memory_compressor.to(device)
# # count params
# total_parameters = 0
# for parameter in memory_compressor.parameters():
# total_parameters += parameter.numel()
# rand_input = torch.randn((2, 15, 768)).to(device)
# print(rand_input.shape)
# memory_compressor(rand_input).shape
class TimeEmbedding(nn.Module):
def __init__(self, embedding_time_dim=512, num_past_voxels=15):
super().__init__()
self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
self.num_past_voxels = num_past_voxels
self.embedding_time_dim = embedding_time_dim
def forward(self, time):
# time is (batch_size,)
time = time.long()
time = self.embedding_time(time)
return time # (batch_size, embedding_time_dim)
#model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
model.memory_compressor = MemoryCompressor(in_dim=model.ridge.out_features, num_past=15, output_dim=4096)
#utils.count_params(model.memory_encoder)
utils.count_params(model.memory_compressor)
utils.count_params(model)
# In[18]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
{'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
{'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
{'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
#{'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},
{'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},
{'params': [p for n, p in model.time_embedding.named_parameters()], 'weight_decay': 0.0},
]
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
if lr_scheduler_type == 'linear':
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
last_epoch=-1
)
elif lr_scheduler_type == 'cycle':
total_steps=int(num_epochs*(num_train*num_devices//batch_size))
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=max_lr,
total_steps=total_steps,
final_div_factor=1000,
last_epoch=-1, pct_start=2/num_epochs
)
def save_ckpt(tag):
ckpt_path = outdir+f'/{tag}.pth'
print(f'saving {ckpt_path}',flush=True)
unwrapped_model = accelerator.unwrap_model(model)
try:
torch.save({
'epoch': epoch,
'model_state_dict': unwrapped_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'train_losses': losses,
'test_losses': test_losses,
'lrs': lrs,
}, ckpt_path)
except:
print("Couldn't save... moving on to prevent crashing.")
del unwrapped_model
print("\nDone with model preparations!")
utils.count_params(model)
# In[ ]:
# # Weights and Biases
# In[19]:
# params for wandb
wandb_log = True
if local_rank==0 and wandb_log: # only use main process for wandb logging
import wandb
wandb_project = 'stability'
wandb_run = model_name
wandb_notes = ''
print(f"wandb {wandb_project} run {wandb_run}")
wandb.login(host='https://stability.wandb.io')#, relogin=True)
wandb_config = {
"model_name": model_name,
"batch_size": batch_size,
"num_epochs": num_epochs,
"use_image_aug": use_image_aug,
"max_lr": max_lr,
"lr_scheduler_type": lr_scheduler_type,
"mixup_pct": mixup_pct,
"num_train": num_train,
"num_test": num_test,
"seed": seed,
"distributed": distributed,
"num_devices": num_devices,
"world_size": world_size,
}
print("wandb_config:\n",wandb_config)
if False: # wandb_auto_resume
print("wandb_id:",model_name)
wandb.init(
id = model_name,
project=wandb_project,
name=wandb_run,
config=wandb_config,
notes=wandb_notes,
resume="allow",
)
else:
wandb.init(
project=wandb_project,
name=model_name,
config=wandb_config,
notes=wandb_notes,
)
else:
wandb_log = False
# # More custom functions
# In[20]:
# using the same preprocessing as was used in MindEye + BrainDiffuser
pixcorr_preprocess = transforms.Compose([
transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])
def pixcorr(images,brains):
# Flatten images while keeping the batch dimension
all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
return corrmean
# # Main
# In[21]:
epoch = 0
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
# Optionally resume from checkpoint #
if resume_from_ckpt:
print("\n---resuming from last.pth ckpt---\n")
try:
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
except:
print('last.pth failed... trying last_backup.pth')
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
epoch = checkpoint['epoch']
print("Epoch",epoch)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
del checkpoint
elif wandb_log:
if wandb.run.resumed:
print("\n---resuming from last.pth ckpt---\n")
try:
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
except:
print('last.pth failed... trying last_backup.pth')
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
epoch = checkpoint['epoch']
print("Epoch",epoch)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
del checkpoint
torch.cuda.empty_cache()
# In[22]:
model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
model, optimizer, train_dl, test_dl, lr_scheduler
)
# In[23]:
print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))
test_image, test_voxel = None, None
mse = nn.MSELoss()
for epoch in progress_bar:
model.train()
fwd_percent_correct = 0.
bwd_percent_correct = 0.
test_fwd_percent_correct = 0.
test_bwd_percent_correct = 0.
loss_clip_total = 0.
loss_blurry_total = 0.
test_loss_clip_total = 0.
test_loss_blurry_total = 0.
blurry_pixcorr = 0.
test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
#if epoch == 0 or epoch == 1:
# break
with torch.cuda.amp.autocast():
optimizer.zero_grad()
voxel = voxels[behav[:,0,5].cpu().long()].to(device)
image = images[behav[:,0,0].cpu().long()].to(device).float()
past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
blurry_image_enc = autoenc.encode(image).latent_dist.mode()
if use_image_aug: image = img_augment(image)
clip_target = clip_model.embed_image(image)
assert not torch.any(torch.isnan(clip_target))
if epoch < int(mixup_pct * num_epochs):
voxel, perm, betas, select = utils.mixco(voxel)
# reshape past voxels to be (batch_size * 15, 15279)
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
past_15_times = past_15_times.repeat(voxel.shape[0], 1)
past_15_times = past_15_times.reshape(-1)
#print(past_15_voxels.shape, past_15_times.shape)
time_embeddings = model.time_embedding(past_15_times)
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
embeds_past_voxels = model.ridge(past_info_full)
#print(embeds_past_voxels.shape)
embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
#print(embeds_past_voxels.shape)
information_past_voxels = model.memory_compressor(embeds_past_voxels)
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
#print(torch.cat((voxel, positional_current_voxel), dim=-1).shape, positional_current_voxel.shape, voxel.shape)
voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1)
clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
if epoch < int(mixup_pct * num_epochs):
loss_clip = utils.mixco_nce(
clip_voxels_norm,
clip_target_norm,
temp=.006,
perm=perm, betas=betas, select=select)
else:
epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
loss_clip = utils.soft_clip_loss(
clip_voxels_norm,
clip_target_norm,
temp=epoch_temp)
loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
loss_clip_total += loss_clip.item()
loss_blurry_total += loss_blurry.item()
loss = loss_blurry + loss_clip
utils.check_loss(loss)
accelerator.backward(loss)
optimizer.step()
losses.append(loss.item())
lrs.append(optimizer.param_groups[0]['lr'])
# forward and backward top 1 accuracy
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
with torch.no_grad():
# only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)
blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
if lr_scheduler_type is not None:
lr_scheduler.step()
model.eval()
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
print('test')
with torch.cuda.amp.autocast():
with torch.no_grad():
# all test samples should be loaded per batch such that test_i should never exceed 0
if len(behav) != num_test: print("!",len(behav),num_test)
## Average same-image repeats ##
if test_image is None:
voxel = voxels[behav[:,0,5].cpu().long()].to(device)
image = behav[:,0,0].cpu().long()
unique_image, sort_indices = torch.unique(image, return_inverse=True)
for im in unique_image:
locs = torch.where(im == image)[0]
if test_image is None:
test_image = images[im][None]
test_voxel = torch.mean(voxel[locs],axis=0)[None]
else:
test_image = torch.vstack((test_image, images[im][None]))
test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
# sample of batch_size
random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
voxel = test_voxel[random_indices].to(device)
image = test_image[random_indices].to(device)
current_past_behav = past_behav[random_indices]
past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
assert len(image) == batch_size
blurry_image_enc = autoenc.encode(image).latent_dist.mode()
clip_target = clip_model.embed_image(image.float())
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
past_15_times = past_15_times.repeat(voxel.shape[0], 1)
past_15_times = past_15_times.reshape(-1)
print(past_15_voxels.shape, past_15_times.shape)
#print(past_15_voxels.shape, past_15_times.shape)
time_embeddings = model.time_embedding(past_15_times)
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
embeds_past_voxels = model.ridge(past_info_full)
#print(embeds_past_voxels.shape)
embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
#print(embeds_past_voxels.shape)
information_past_voxels = model.memory_compressor(embeds_past_voxels)
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1)
clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
loss_clip = utils.soft_clip_loss(
clip_voxels_norm,
clip_target_norm,
temp=.006)
loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
loss = loss_blurry + loss_clip
utils.check_loss(loss)
test_losses.append(loss.item())
# forward and backward top 1 accuracy
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
# halving the batch size because the decoder is computationally heavy
blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
# transform blurry recon latents to images and plot it
#fig, axes = plt.subplots(1, 4, figsize=(8, 4))
#axes[0].imshow(utils.torch_to_Image(image[[0]]))
#axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
#axes[2].imshow(utils.torch_to_Image(image[[1]]))
#axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
#axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
#plt.show()
wandb.log({"gt": [wandb.Image(utils.torch_to_Image(image[[0]])), wandb.Image(utils.torch_to_Image(image[[1]])) ]}
wandb.log({"preds": [utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)), utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)) ]}
if local_rank==0:
# if utils.is_interactive(): clear_output(wait=True)
assert (test_i+1) == 1
logs = {"train/loss": np.mean(losses[-(train_i+1):]),
"test/loss": np.mean(test_losses[-(test_i+1):]),
"train/lr": lrs[-1],
"train/num_steps": len(losses),
"test/num_steps": len(test_losses),
"train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
"train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
"test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
"test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
"train/loss_clip_total": loss_clip_total / (train_i + 1),
"train/loss_blurry_total": loss_blurry_total / (train_i + 1),
"test/loss_clip_total": test_loss_clip_total / (test_i + 1),
"test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
"train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
"test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
}
progress_bar.set_postfix(**logs)
# Save model checkpoint and reconstruct
if epoch % ckpt_interval == 0:
if not utils.is_interactive():
save_ckpt(f'last')
if wandb_log: wandb.log(logs)
# wait for other GPUs to catch up if needed
accelerator.wait_for_everyone()
torch.cuda.empty_cache()
gc.collect()
print("\n===Finished!===\n")
if ckpt_saving:
save_ckpt(f'last')
if not utils.is_interactive():
sys.exit(0)
# In[ ]:
plt.plot(losses)
plt.show()
plt.plot(test_losses)
plt.show()