|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import argparse |
|
import numpy as np |
|
import math |
|
from einops import rearrange |
|
import time |
|
import random |
|
import h5py |
|
from tqdm import tqdm |
|
|
|
import webdataset as wds |
|
import gc |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
|
|
from accelerate import Accelerator, DeepSpeedPlugin |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
import utils |
|
|
|
global_batch_size = 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
local_rank = os.getenv('RANK') |
|
if local_rank is None: |
|
local_rank = 0 |
|
else: |
|
local_rank = int(local_rank) |
|
print("LOCAL RANK ", local_rank) |
|
|
|
num_devices = torch.cuda.device_count() |
|
if num_devices==0: num_devices = 1 |
|
|
|
accelerator = Accelerator(split_batches=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("PID of this process =",os.getpid()) |
|
device = accelerator.device |
|
print("device:",device) |
|
num_workers = num_devices |
|
print(accelerator.state) |
|
world_size = accelerator.state.num_processes |
|
distributed = not accelerator.state.distributed_type == 'NO' |
|
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size) |
|
print = accelerator.print |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if utils.is_interactive(): |
|
|
|
jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \ |
|
--model_name=test \ |
|
--subj=1 --batch_size={global_batch_size} --n_samples_save=0 \ |
|
--max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug" |
|
|
|
jupyter_args = jupyter_args.split() |
|
print(jupyter_args) |
|
|
|
from IPython.display import clear_output |
|
get_ipython().run_line_magic('load_ext', 'autoreload') |
|
|
|
get_ipython().run_line_magic('autoreload', '2') |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Model Training Configuration") |
|
parser.add_argument( |
|
"--model_name", type=str, default="memory_cat_rr", |
|
help="name of model, used for ckpt saving and wandb logging (if enabled)", |
|
) |
|
parser.add_argument( |
|
"--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset", |
|
help="Path to where NSD data is stored / where to download it to", |
|
) |
|
parser.add_argument( |
|
"--subj",type=int, default=1, choices=[1,2,5,7], |
|
) |
|
parser.add_argument( |
|
"--batch_size", type=int, default=32, |
|
help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser", |
|
) |
|
parser.add_argument( |
|
"--wandb_log",action=argparse.BooleanOptionalAction,default=False, |
|
help="whether to log to wandb", |
|
) |
|
parser.add_argument( |
|
"--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False, |
|
help="if not using wandb and want to resume from a ckpt", |
|
) |
|
parser.add_argument( |
|
"--wandb_project",type=str,default="stability", |
|
help="wandb project name", |
|
) |
|
parser.add_argument( |
|
"--mixup_pct",type=float,default=.33, |
|
help="proportion of way through training when to switch from BiMixCo to SoftCLIP", |
|
) |
|
parser.add_argument( |
|
"--use_image_aug",action=argparse.BooleanOptionalAction,default=True, |
|
help="whether to use image augmentation", |
|
) |
|
parser.add_argument( |
|
"--num_epochs",type=int,default=240, |
|
help="number of epochs of training", |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], |
|
) |
|
parser.add_argument( |
|
"--ckpt_saving",action=argparse.BooleanOptionalAction,default=True, |
|
) |
|
parser.add_argument( |
|
"--ckpt_interval",type=int,default=5, |
|
help="save backup ckpt and reconstruct every x epochs", |
|
) |
|
parser.add_argument( |
|
"--seed",type=int,default=42, |
|
) |
|
parser.add_argument( |
|
"--max_lr",type=float,default=3e-4, |
|
) |
|
parser.add_argument( |
|
"--n_samples_save",type=int,default=0,choices=[0,1], |
|
help="Number of reconstructions for monitoring progress, 0 will speed up training", |
|
) |
|
|
|
if utils.is_interactive(): |
|
args = parser.parse_args(jupyter_args) |
|
else: |
|
args = parser.parse_args() |
|
|
|
|
|
for attribute_name in vars(args).keys(): |
|
globals()[attribute_name] = getattr(args, attribute_name) |
|
|
|
print("global batch_size", batch_size) |
|
batch_size = int(batch_size / num_devices) |
|
print("batch_size", batch_size) |
|
|
|
|
|
|
|
|
|
|
|
outdir = os.path.abspath(f'../train_mem_logs/{model_name}') |
|
if not os.path.exists(outdir): |
|
os.makedirs(outdir,exist_ok=True) |
|
if use_image_aug: |
|
import kornia |
|
from kornia.augmentation.container import AugmentationSequential |
|
img_augment = AugmentationSequential( |
|
kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), |
|
kornia.augmentation.Resize((224, 224)), |
|
kornia.augmentation.RandomHorizontalFlip(p=0.3), |
|
kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), |
|
kornia.augmentation.RandomGrayscale(p=0.3), |
|
same_on_batch=False, |
|
data_keys=["input"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if subj==1: |
|
num_train = 24958 |
|
num_test = 2770 |
|
test_batch_size = num_test |
|
|
|
def my_split_by_node(urls): return urls |
|
|
|
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" |
|
print(train_url) |
|
|
|
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ |
|
.shuffle(750, initial=1500, rng=random.Random(42))\ |
|
.decode("torch")\ |
|
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
|
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
|
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True) |
|
|
|
test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" |
|
print(test_url) |
|
|
|
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ |
|
.shuffle(750, initial=1500, rng=random.Random(42))\ |
|
.decode("torch")\ |
|
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
|
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
|
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r') |
|
voxels = f['betas'][:] |
|
print(f"subj0{subj} betas loaded into memory") |
|
voxels = torch.Tensor(voxels).to("cpu").half() |
|
if subj==1: |
|
voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5)))) |
|
print("voxels", voxels.shape) |
|
num_voxels = voxels.shape[-1] |
|
|
|
|
|
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') |
|
images = f['images'][:] |
|
images = torch.Tensor(images).to("cpu").half() |
|
print("images", images.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from models import Clipper |
|
clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) |
|
|
|
clip_seq_dim = 257 |
|
clip_emb_dim = 768 |
|
hidden_dim = 4096 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from diffusers import AutoencoderKL |
|
autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache") |
|
|
|
autoenc.eval() |
|
autoenc.requires_grad_(False) |
|
autoenc.to(device) |
|
utils.count_params(autoenc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MindEyeModule(nn.Module): |
|
def __init__(self): |
|
super(MindEyeModule, self).__init__() |
|
def forward(self, x): |
|
return x |
|
|
|
model = MindEyeModule() |
|
model |
|
|
|
|
|
|
|
|
|
|
|
time_embedding_dim = 512 |
|
|
|
class RidgeRegression(torch.nn.Module): |
|
|
|
def __init__(self, input_size, out_features): |
|
super(RidgeRegression, self).__init__() |
|
self.out_features = out_features |
|
self.linear = torch.nn.Linear(input_size, out_features) |
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim) |
|
utils.count_params(model.ridge) |
|
utils.count_params(model) |
|
|
|
b = torch.randn((2,1,voxels.shape[1])) |
|
time_emb_test = torch.randn((2,1,time_embedding_dim)) |
|
print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape) |
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from diffusers.models.vae import Decoder |
|
class BrainNetwork(nn.Module): |
|
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.25, blurry_dim=16): |
|
super().__init__() |
|
self.blurry_dim = blurry_dim |
|
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) |
|
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU |
|
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) |
|
self.lin0 = nn.Linear(in_dim, h) |
|
self.mlp = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(h, h), |
|
*[item() for item in act_and_norm], |
|
nn.Dropout(drop) |
|
) for _ in range(n_blocks) |
|
]) |
|
self.lin1 = nn.Linear(h, out_dim, bias=True) |
|
self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True) |
|
self.n_blocks = n_blocks |
|
self.clip_size = clip_size |
|
self.clip_proj = nn.Sequential( |
|
nn.LayerNorm(clip_size), |
|
nn.GELU(), |
|
nn.Linear(clip_size, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, clip_size) |
|
) |
|
self.upsampler = Decoder( |
|
in_channels=64, |
|
out_channels=4, |
|
up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
|
block_out_channels=[64, 128, 256], |
|
layers_per_block=1, |
|
) |
|
|
|
def forward(self, x): |
|
x = self.lin0(x) |
|
residual = x |
|
for res_block in range(self.n_blocks): |
|
x = self.mlp[res_block](x) |
|
x += residual |
|
residual = x |
|
x = x.reshape(len(x), -1) |
|
x = self.lin1(x) |
|
b = self.blin1(x) |
|
b = self.upsampler(b.reshape(len(b), -1, 7, 7)) |
|
c = self.clip_proj(x.reshape(len(x), -1, self.clip_size)) |
|
return c, b |
|
|
|
model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) |
|
utils.count_params(model.backbone) |
|
utils.count_params(model) |
|
|
|
b = torch.randn((2,8192)) |
|
print(b.shape) |
|
clip_, blur_ = model.backbone(b) |
|
print(clip_.shape, blur_.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from timm.layers.mlp import Mlp |
|
|
|
class MemoryEncoder(nn.Module): |
|
def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.25): |
|
super().__init__() |
|
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) |
|
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU |
|
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) |
|
self.out_dim = out_dim |
|
self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim) |
|
self.final_input_dim = in_dim + embedding_time_dim |
|
self.lin0 = nn.Linear(self.final_input_dim, h) |
|
self.mlp = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(h, h), |
|
*[item() for item in act_and_norm], |
|
nn.Dropout(drop) |
|
) for _ in range(n_blocks) |
|
]) |
|
self.lin1 = nn.Linear(h, out_dim, bias=True) |
|
self.n_blocks = n_blocks |
|
self.num_past_voxels = num_past_voxels |
|
self.embedding_time_dim = embedding_time_dim |
|
self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim))) |
|
|
|
|
|
def forward(self, x, time): |
|
time = time.long() |
|
time = self.embedding_time(time) |
|
x = torch.cat((x, time), dim=-1) |
|
x = self.lin0(x) |
|
residual = x |
|
for res_block in range(self.n_blocks): |
|
x = self.mlp[res_block](x) |
|
x += residual |
|
residual = x |
|
x = x.reshape(len(x), -1) |
|
x = self.lin1(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryCompressor(nn.Module): |
|
def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.25): |
|
super().__init__() |
|
self.num_past = num_past |
|
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) |
|
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU |
|
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) |
|
self.final_input_dim = in_dim * num_past |
|
self.lin0 = nn.Linear(self.final_input_dim, h) |
|
self.mlp = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(h, h), |
|
*[item() for item in act_and_norm], |
|
nn.Dropout(drop) |
|
) for _ in range(n_blocks) |
|
]) |
|
self.lin1 = nn.Linear(h, output_dim, bias=True) |
|
self.n_blocks = n_blocks |
|
self.num_past = num_past |
|
self.output_dim = output_dim |
|
|
|
def forward(self, x): |
|
|
|
x = x.reshape(len(x), -1) |
|
x = self.lin0(x) |
|
residual = x |
|
for res_block in range(self.n_blocks): |
|
x = self.mlp[res_block](x) |
|
x += residual |
|
residual = x |
|
x = x.reshape(len(x), -1) |
|
x = self.lin1(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
def __init__(self, embedding_time_dim=512, num_past_voxels=15): |
|
super().__init__() |
|
self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim) |
|
self.num_past_voxels = num_past_voxels |
|
self.embedding_time_dim = embedding_time_dim |
|
|
|
def forward(self, time): |
|
|
|
time = time.long() |
|
time = self.embedding_time(time) |
|
return time |
|
|
|
|
|
|
|
model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15) |
|
model.memory_compressor = MemoryCompressor(in_dim=model.ridge.out_features, num_past=15, output_dim=4096) |
|
|
|
|
|
utils.count_params(model.memory_compressor) |
|
utils.count_params(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
opt_grouped_parameters = [ |
|
{'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, |
|
|
|
{'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.time_embedding.named_parameters()], 'weight_decay': 0.0}, |
|
] |
|
|
|
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95)) |
|
|
|
if lr_scheduler_type == 'linear': |
|
lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
|
optimizer, |
|
total_iters=int(num_epochs*(num_train*num_devices//batch_size)), |
|
last_epoch=-1 |
|
) |
|
elif lr_scheduler_type == 'cycle': |
|
total_steps=int(num_epochs*(num_train*num_devices//batch_size)) |
|
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=max_lr, |
|
total_steps=total_steps, |
|
final_div_factor=1000, |
|
last_epoch=-1, pct_start=2/num_epochs |
|
) |
|
|
|
def save_ckpt(tag): |
|
ckpt_path = outdir+f'/{tag}.pth' |
|
print(f'saving {ckpt_path}',flush=True) |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
try: |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': unwrapped_model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'train_losses': losses, |
|
'test_losses': test_losses, |
|
'lrs': lrs, |
|
}, ckpt_path) |
|
except: |
|
print("Couldn't save... moving on to prevent crashing.") |
|
del unwrapped_model |
|
|
|
print("\nDone with model preparations!") |
|
utils.count_params(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb_log = True |
|
if local_rank==0 and wandb_log: |
|
import wandb |
|
|
|
wandb_project = 'stability' |
|
wandb_run = model_name |
|
wandb_notes = '' |
|
|
|
print(f"wandb {wandb_project} run {wandb_run}") |
|
wandb.login(host='https://stability.wandb.io') |
|
wandb_config = { |
|
"model_name": model_name, |
|
"batch_size": batch_size, |
|
"num_epochs": num_epochs, |
|
"use_image_aug": use_image_aug, |
|
"max_lr": max_lr, |
|
"lr_scheduler_type": lr_scheduler_type, |
|
"mixup_pct": mixup_pct, |
|
"num_train": num_train, |
|
"num_test": num_test, |
|
"seed": seed, |
|
"distributed": distributed, |
|
"num_devices": num_devices, |
|
"world_size": world_size, |
|
} |
|
print("wandb_config:\n",wandb_config) |
|
if False: |
|
print("wandb_id:",model_name) |
|
wandb.init( |
|
id = model_name, |
|
project=wandb_project, |
|
name=wandb_run, |
|
config=wandb_config, |
|
notes=wandb_notes, |
|
resume="allow", |
|
) |
|
else: |
|
wandb.init( |
|
project=wandb_project, |
|
name=model_name, |
|
config=wandb_config, |
|
notes=wandb_notes, |
|
) |
|
else: |
|
wandb_log = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixcorr_preprocess = transforms.Compose([ |
|
transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), |
|
]) |
|
def pixcorr(images,brains): |
|
|
|
all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1) |
|
all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1) |
|
corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean() |
|
return corrmean |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epoch = 0 |
|
losses, test_losses, lrs = [], [], [] |
|
best_test_loss = 1e9 |
|
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) |
|
|
|
|
|
if resume_from_ckpt: |
|
print("\n---resuming from last.pth ckpt---\n") |
|
try: |
|
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
|
except: |
|
print('last.pth failed... trying last_backup.pth') |
|
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
|
epoch = checkpoint['epoch'] |
|
print("Epoch",epoch) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) |
|
del checkpoint |
|
elif wandb_log: |
|
if wandb.run.resumed: |
|
print("\n---resuming from last.pth ckpt---\n") |
|
try: |
|
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
|
except: |
|
print('last.pth failed... trying last_backup.pth') |
|
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
|
epoch = checkpoint['epoch'] |
|
print("Epoch",epoch) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) |
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare( |
|
model, optimizer, train_dl, test_dl, lr_scheduler |
|
) |
|
|
|
|
|
|
|
|
|
|
|
print(f"{model_name} starting with epoch {epoch} / {num_epochs}") |
|
progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0)) |
|
test_image, test_voxel = None, None |
|
mse = nn.MSELoss() |
|
for epoch in progress_bar: |
|
model.train() |
|
|
|
fwd_percent_correct = 0. |
|
bwd_percent_correct = 0. |
|
test_fwd_percent_correct = 0. |
|
test_bwd_percent_correct = 0. |
|
|
|
loss_clip_total = 0. |
|
loss_blurry_total = 0. |
|
test_loss_clip_total = 0. |
|
test_loss_blurry_total = 0. |
|
|
|
blurry_pixcorr = 0. |
|
test_blurry_pixcorr = 0. |
|
|
|
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
optimizer.zero_grad() |
|
|
|
voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
|
|
|
image = images[behav[:,0,0].cpu().long()].to(device).float() |
|
|
|
past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) |
|
past_15_times = torch.Tensor([i for i in range(15)]).to(device) |
|
|
|
blurry_image_enc = autoenc.encode(image).latent_dist.mode() |
|
|
|
if use_image_aug: image = img_augment(image) |
|
|
|
clip_target = clip_model.embed_image(image) |
|
assert not torch.any(torch.isnan(clip_target)) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
voxel, perm, betas, select = utils.mixco(voxel) |
|
|
|
|
|
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
|
past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
|
past_15_times = past_15_times.reshape(-1) |
|
|
|
|
|
time_embeddings = model.time_embedding(past_15_times) |
|
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
|
embeds_past_voxels = model.ridge(past_info_full) |
|
|
|
embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1) |
|
|
|
information_past_voxels = model.memory_compressor(embeds_past_voxels) |
|
|
|
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
|
|
|
voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1) |
|
|
|
clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) |
|
|
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
loss_clip = utils.mixco_nce( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=.006, |
|
perm=perm, betas=betas, select=select) |
|
else: |
|
epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] |
|
loss_clip = utils.soft_clip_loss( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=epoch_temp) |
|
|
|
loss_blurry = mse(blurry_image_enc_, blurry_image_enc) |
|
|
|
loss_clip_total += loss_clip.item() |
|
loss_blurry_total += loss_blurry.item() |
|
|
|
loss = loss_blurry + loss_clip |
|
|
|
utils.check_loss(loss) |
|
|
|
accelerator.backward(loss) |
|
optimizer.step() |
|
|
|
losses.append(loss.item()) |
|
lrs.append(optimizer.param_groups[0]['lr']) |
|
|
|
|
|
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
|
fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) |
|
bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) |
|
|
|
with torch.no_grad(): |
|
|
|
random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False) |
|
blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1) |
|
blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images) |
|
|
|
if lr_scheduler_type is not None: |
|
lr_scheduler.step() |
|
|
|
model.eval() |
|
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
|
print('test') |
|
with torch.cuda.amp.autocast(): |
|
with torch.no_grad(): |
|
|
|
if len(behav) != num_test: print("!",len(behav),num_test) |
|
|
|
|
|
|
|
if test_image is None: |
|
voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
|
|
|
image = behav[:,0,0].cpu().long() |
|
|
|
unique_image, sort_indices = torch.unique(image, return_inverse=True) |
|
for im in unique_image: |
|
locs = torch.where(im == image)[0] |
|
if test_image is None: |
|
test_image = images[im][None] |
|
test_voxel = torch.mean(voxel[locs],axis=0)[None] |
|
else: |
|
test_image = torch.vstack((test_image, images[im][None])) |
|
test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) |
|
|
|
|
|
random_indices = torch.arange(len(test_voxel))[:batch_size] |
|
voxel = test_voxel[random_indices].to(device) |
|
image = test_image[random_indices].to(device) |
|
|
|
current_past_behav = past_behav[random_indices] |
|
|
|
past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) |
|
past_15_times = torch.Tensor([i for i in range(15)]).to(device) |
|
|
|
assert len(image) == batch_size |
|
|
|
blurry_image_enc = autoenc.encode(image).latent_dist.mode() |
|
|
|
clip_target = clip_model.embed_image(image.float()) |
|
|
|
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
|
past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
|
past_15_times = past_15_times.reshape(-1) |
|
|
|
print(past_15_voxels.shape, past_15_times.shape) |
|
|
|
|
|
time_embeddings = model.time_embedding(past_15_times) |
|
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
|
embeds_past_voxels = model.ridge(past_info_full) |
|
|
|
embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1) |
|
|
|
information_past_voxels = model.memory_compressor(embeds_past_voxels) |
|
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
|
|
|
voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1) |
|
|
|
clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) |
|
|
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
loss_clip = utils.soft_clip_loss( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=.006) |
|
|
|
loss_blurry = mse(blurry_image_enc_, blurry_image_enc) |
|
|
|
loss = loss_blurry + loss_clip |
|
|
|
utils.check_loss(loss) |
|
|
|
test_losses.append(loss.item()) |
|
|
|
|
|
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
|
test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) |
|
test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) |
|
|
|
|
|
blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1) |
|
blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1))) |
|
test_blurry_pixcorr += pixcorr(image, blurry_recon_images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb.log({"gt": [wandb.Image(utils.torch_to_Image(image[[0]])), wandb.Image(utils.torch_to_Image(image[[1]])) ]} |
|
wandb.log({"preds": [utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)), utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)) ]} |
|
|
|
if local_rank==0: |
|
|
|
assert (test_i+1) == 1 |
|
logs = {"train/loss": np.mean(losses[-(train_i+1):]), |
|
"test/loss": np.mean(test_losses[-(test_i+1):]), |
|
"train/lr": lrs[-1], |
|
"train/num_steps": len(losses), |
|
"test/num_steps": len(test_losses), |
|
"train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), |
|
"train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), |
|
"test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), |
|
"test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), |
|
"train/loss_clip_total": loss_clip_total / (train_i + 1), |
|
"train/loss_blurry_total": loss_blurry_total / (train_i + 1), |
|
"test/loss_clip_total": test_loss_clip_total / (test_i + 1), |
|
"test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), |
|
"train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), |
|
"test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), |
|
} |
|
progress_bar.set_postfix(**logs) |
|
|
|
|
|
if epoch % ckpt_interval == 0: |
|
if not utils.is_interactive(): |
|
save_ckpt(f'last') |
|
|
|
if wandb_log: wandb.log(logs) |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
print("\n===Finished!===\n") |
|
if ckpt_saving: |
|
save_ckpt(f'last') |
|
if not utils.is_interactive(): |
|
sys.exit(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.plot(losses) |
|
plt.show() |
|
plt.plot(test_losses) |
|
plt.show() |
|
|
|
|