|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import argparse |
|
import numpy as np |
|
import math |
|
from einops import rearrange |
|
import time |
|
import random |
|
import string |
|
import h5py |
|
from tqdm import tqdm |
|
|
|
import webdataset as wds |
|
import gc |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
|
|
from accelerate import Accelerator, DeepSpeedPlugin |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
import utils |
|
|
|
global_batch_size = 16 |
|
|
|
import os |
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
local_rank = os.getenv('RANK') |
|
if local_rank is None: |
|
local_rank = 0 |
|
else: |
|
local_rank = int(local_rank) |
|
print("LOCAL RANK ", local_rank) |
|
|
|
num_devices = torch.cuda.device_count() |
|
if num_devices==0: num_devices = 1 |
|
|
|
|
|
|
|
|
|
|
|
if num_devices <= 1 and utils.is_interactive(): |
|
|
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000) |
|
os.environ["RANK"] = "0" |
|
os.environ["LOCAL_RANK"] = "0" |
|
os.environ["WORLD_SIZE"] = "1" |
|
os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) |
|
global_batch_size = os.environ["GLOBAL_BATCH_SIZE"] |
|
|
|
|
|
if local_rank == 0: |
|
with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2_cpuoffload.json', 'r') as file: |
|
config = json.load(file) |
|
config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"]) |
|
config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices |
|
with open('deepspeed_config_stage2.json', 'w') as file: |
|
json.dump(config, file) |
|
else: |
|
|
|
time.sleep(10) |
|
deepspeed_plugin = DeepSpeedPlugin("/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2_cpuoffload.json") |
|
accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin) |
|
|
|
|
|
|
|
|
|
|
|
print("PID of this process =",os.getpid()) |
|
device = accelerator.device |
|
print("device:",device) |
|
num_workers = num_devices |
|
print(accelerator.state) |
|
world_size = accelerator.state.num_processes |
|
distributed = not accelerator.state.distributed_type == 'NO' |
|
|
|
|
|
if accelerator.mixed_precision == "bf16": |
|
data_type = torch.bfloat16 |
|
elif accelerator.mixed_precision == "fp16": |
|
data_type = torch.float16 |
|
else: |
|
data_type = torch.float32 |
|
|
|
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type) |
|
print = accelerator.print |
|
|
|
|
|
|
|
|
|
|
|
accelerator.state.distributed_type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if utils.is_interactive(): |
|
|
|
model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10)) |
|
model_name = model_name + "_interactive" |
|
print("model_name:", model_name) |
|
|
|
|
|
|
|
jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \ |
|
--model_name={model_name} \ |
|
--subj=1 --batch_size={global_batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=1024 \ |
|
--clip_scale=1. --blur_scale=100. --depth_scale=100. \ |
|
--max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving" |
|
|
|
jupyter_args = jupyter_args.split() |
|
print(jupyter_args) |
|
|
|
from IPython.display import clear_output |
|
get_ipython().run_line_magic('load_ext', 'autoreload') |
|
|
|
get_ipython().run_line_magic('autoreload', '2') |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Model Training Configuration") |
|
parser.add_argument( |
|
"--model_name", type=str, default="testing", |
|
help="name of model, used for ckpt saving and wandb logging (if enabled)", |
|
) |
|
parser.add_argument( |
|
"--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset", |
|
help="Path to where NSD data is stored / where to download it to", |
|
) |
|
parser.add_argument( |
|
"--subj",type=int, default=1, choices=[1,2,5,7], |
|
) |
|
parser.add_argument( |
|
"--batch_size", type=int, default=32, |
|
help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser", |
|
) |
|
parser.add_argument( |
|
"--wandb_log",action=argparse.BooleanOptionalAction,default=True, |
|
help="whether to log to wandb", |
|
) |
|
parser.add_argument( |
|
"--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False, |
|
help="if not using wandb and want to resume from a ckpt", |
|
) |
|
parser.add_argument( |
|
"--wandb_project",type=str,default="stability", |
|
help="wandb project name", |
|
) |
|
parser.add_argument( |
|
"--mixup_pct",type=float,default=.33, |
|
help="proportion of way through training when to switch from BiMixCo to SoftCLIP", |
|
) |
|
parser.add_argument( |
|
"--blurry_recon",action=argparse.BooleanOptionalAction,default=True, |
|
help="whether to output blurry reconstructions", |
|
) |
|
parser.add_argument( |
|
"--depth_recon",action=argparse.BooleanOptionalAction,default=True, |
|
help="whether to output depth reconstructions", |
|
) |
|
parser.add_argument( |
|
"--blur_scale",type=float,default=100., |
|
help="multiply loss from blurry recons by this number", |
|
) |
|
parser.add_argument( |
|
"--depth_scale",type=float,default=100., |
|
help="multiply loss from depth recons by this number", |
|
) |
|
parser.add_argument( |
|
"--clip_scale",type=float,default=1., |
|
help="multiply contrastive loss by this number", |
|
) |
|
parser.add_argument( |
|
"--use_image_aug",action=argparse.BooleanOptionalAction,default=True, |
|
help="whether to use image augmentation", |
|
) |
|
parser.add_argument( |
|
"--num_epochs",type=int,default=120, |
|
help="number of epochs of training", |
|
) |
|
parser.add_argument( |
|
"--hidden_dim",type=int,default=4096, |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], |
|
) |
|
parser.add_argument( |
|
"--ckpt_saving",action=argparse.BooleanOptionalAction,default=True, |
|
) |
|
parser.add_argument( |
|
"--ckpt_interval",type=int,default=5, |
|
help="save backup ckpt and reconstruct every x epochs", |
|
) |
|
parser.add_argument( |
|
"--seed",type=int,default=42, |
|
) |
|
parser.add_argument( |
|
"--max_lr",type=float,default=3e-4, |
|
) |
|
parser.add_argument( |
|
"--seq_len",type=int,default=2, |
|
) |
|
|
|
if utils.is_interactive(): |
|
args = parser.parse_args(jupyter_args) |
|
else: |
|
args = parser.parse_args() |
|
|
|
|
|
for attribute_name in vars(args).keys(): |
|
globals()[attribute_name] = getattr(args, attribute_name) |
|
|
|
|
|
|
|
|
|
|
|
outdir = os.path.abspath(f'../train_logs/{model_name}') |
|
if not os.path.exists(outdir) and ckpt_saving: |
|
os.makedirs(outdir,exist_ok=True) |
|
if use_image_aug: |
|
import kornia |
|
from kornia.augmentation.container import AugmentationSequential |
|
img_augment = AugmentationSequential( |
|
kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), |
|
kornia.augmentation.Resize((224, 224)), |
|
kornia.augmentation.RandomHorizontalFlip(p=0.3), |
|
kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), |
|
kornia.augmentation.RandomGrayscale(p=0.3), |
|
same_on_batch=False, |
|
data_keys=["input"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if subj==1: |
|
num_train = 24958 |
|
num_test = 2770 |
|
test_batch_size = num_test |
|
|
|
def my_split_by_node(urls): return urls |
|
|
|
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" |
|
|
|
print(train_url) |
|
|
|
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ |
|
.shuffle(750, initial=1500, rng=random.Random(42))\ |
|
.decode("torch")\ |
|
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
|
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
|
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True) |
|
|
|
test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" |
|
print(test_url) |
|
|
|
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ |
|
.shuffle(750, initial=1500, rng=random.Random(42))\ |
|
.decode("torch")\ |
|
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
|
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
|
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_vox_indices = [] |
|
test_73k_images = [] |
|
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
|
test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy()) |
|
test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy()) |
|
test_vox_indices = test_vox_indices.astype(np.int16) |
|
print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices)) |
|
print("---\n") |
|
|
|
train_vox_indices = [] |
|
train_73k_images = [] |
|
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
|
train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy()) |
|
train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy()) |
|
train_vox_indices = train_vox_indices.astype(np.int16) |
|
print(train_i, (train_i+1) * batch_size, len(train_vox_indices)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r') |
|
|
|
|
|
voxels = f['betas'][:] |
|
print(f"subj0{subj} betas loaded into memory") |
|
voxels = torch.Tensor(voxels).to("cpu").to(data_type) |
|
print("voxels", voxels.shape) |
|
num_voxels = voxels.shape[-1] |
|
|
|
|
|
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') |
|
images = f['images'][:] |
|
images = torch.Tensor(images).to("cpu").to(data_type) |
|
print("images", images.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from models import Clipper |
|
clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) |
|
clip_seq_dim = 257 |
|
clip_emb_dim = 768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip_model2 = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=False, norm_embs=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if blurry_recon: |
|
from diffusers import VQModel |
|
autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type) |
|
autoenc.eval() |
|
autoenc.requires_grad_(False) |
|
autoenc.to(device) |
|
utils.count_params(autoenc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if blurry_recon: |
|
if utils.is_interactive(): display(utils.torch_to_Image(images[[30]])) |
|
|
|
input_batch = images[[30]].to(device) |
|
print(input_batch.shape) |
|
|
|
downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False) |
|
re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest') |
|
re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
|
print(re_upsampled_enc.shape) |
|
|
|
if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if depth_recon: |
|
from controlnet_aux.midas import MidasDetector |
|
|
|
midas_depth = MidasDetector.from_pretrained( |
|
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device) |
|
midas_depth.model.eval() |
|
midas_depth.model.requires_grad_(False) |
|
midas_depth.model.to(device) |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if depth_recon: |
|
if utils.is_interactive(): display(utils.torch_to_Image(images[[30]])) |
|
|
|
input_batch = images[[30,31]].float().to(device) |
|
print(input_batch.shape) |
|
|
|
midas_emb = midas_depth.model(input_batch).unsqueeze(1) |
|
print(midas_emb.shape) |
|
|
|
prediction = utils.resize(midas_emb, 32) |
|
print(prediction.shape) |
|
|
|
prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half() |
|
midas_emb_size = prediction.flatten(1).shape[1] |
|
print("midas_emb", prediction.shape, prediction.min(), prediction.max()) |
|
print("midas_emb_size", midas_emb_size) |
|
|
|
if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) |
|
|
|
if blurry_recon: |
|
prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1) |
|
prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half() |
|
prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215 |
|
print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max()) |
|
|
|
if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MindEyeModule(nn.Module): |
|
def __init__(self): |
|
super(MindEyeModule, self).__init__() |
|
def forward(self, x): |
|
return x |
|
|
|
model = MindEyeModule() |
|
model |
|
|
|
|
|
|
|
|
|
|
|
time_embedding_dim = 512 |
|
|
|
class RidgeRegression(torch.nn.Module): |
|
|
|
def __init__(self, input_size, out_features): |
|
super(RidgeRegression, self).__init__() |
|
self.out_features = out_features |
|
self.linear = torch.nn.Linear(input_size, out_features) |
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim) |
|
utils.count_params(model.ridge) |
|
utils.count_params(model) |
|
|
|
b = torch.randn((2,1,voxels.shape[1])) |
|
time_emb_test = torch.randn((2,1,time_embedding_dim)) |
|
print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape) |
|
|
|
|
|
|
|
|
|
|
|
num_past_voxels = 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from diffusers.models.vae import Decoder |
|
class BrainNetwork(nn.Module): |
|
def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768): |
|
super().__init__() |
|
self.seq_len = seq_len |
|
self.h = h |
|
self.clip_size = clip_size |
|
|
|
|
|
|
|
|
|
|
|
self.mixer_blocks1 = nn.ModuleList([ |
|
self.mixer_block1(h, drop) for _ in range(n_blocks) |
|
]) |
|
self.mixer_blocks2 = nn.ModuleList([ |
|
self.mixer_block2(seq_len, drop) for _ in range(n_blocks) |
|
]) |
|
|
|
|
|
self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.clip_proj = nn.Sequential( |
|
nn.LayerNorm(clip_size), |
|
nn.GELU(), |
|
nn.Linear(clip_size, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, clip_size) |
|
) |
|
|
|
if blurry_recon: |
|
|
|
|
|
|
|
|
|
|
|
self.blin1 = nn.Linear(h*seq_len, 4096) |
|
self.bgroupnorm = nn.GroupNorm(1, 256) |
|
self.bupsampler = Decoder( |
|
in_channels=256, |
|
out_channels=128, |
|
up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
|
block_out_channels=[32, 64, 128], |
|
layers_per_block=1, |
|
) |
|
|
|
if depth_recon: |
|
|
|
|
|
|
|
|
|
self.dlin1 = nn.Linear(h*seq_len, 4096) |
|
self.dgroupnorm = nn.GroupNorm(1, 256) |
|
self.dupsampler = Decoder( |
|
in_channels=256, |
|
out_channels=1, |
|
up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], |
|
block_out_channels=[32, 64, 128, 256], |
|
layers_per_block=1, |
|
) |
|
|
|
def mixer_block1(self, h, drop): |
|
return nn.Sequential( |
|
nn.LayerNorm(h), |
|
self.mlp(h, h, drop), |
|
) |
|
|
|
def mixer_block2(self, seq_len, drop): |
|
return nn.Sequential( |
|
nn.LayerNorm(seq_len), |
|
self.mlp(seq_len, seq_len, drop) |
|
) |
|
|
|
def mlp(self, in_dim, out_dim, drop): |
|
return nn.Sequential( |
|
nn.Linear(in_dim, out_dim), |
|
nn.GELU(), |
|
nn.Dropout(drop), |
|
nn.Linear(out_dim, out_dim), |
|
) |
|
|
|
def forward(self, x, idx = None): |
|
print(idx) |
|
|
|
b,d = torch.Tensor([0.]), torch.Tensor([0.]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
residual1 = x |
|
residual2 = x.permute(0,2,1) |
|
|
|
for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2): |
|
x = block1(x) + residual1 |
|
|
|
residual1 = x |
|
x = x.permute(0,2,1) |
|
|
|
x = block2(x) + residual2 |
|
|
|
residual2 = x |
|
x = x.permute(0,2,1) |
|
|
|
|
|
x = x.reshape(x.size(0), -1) |
|
|
|
c = self.clin1(x) |
|
|
|
|
|
|
|
|
|
c = self.clip_proj(c.reshape(len(c), -1, self.clip_size)) |
|
|
|
if blurry_recon: |
|
b = self.blin1(x) |
|
b = b.reshape(len(b), 256, 4, 4) |
|
b = self.bgroupnorm(b) |
|
b = self.bupsampler(b) |
|
|
|
if depth_recon: |
|
d = self.dlin1(x) |
|
d = d.reshape(len(d), 256, 4, 4) |
|
d = self.dgroupnorm(d) |
|
d = self.dupsampler(d) |
|
|
|
return c, b, d |
|
|
|
|
|
class TimeEmbedding(nn.Module): |
|
def __init__(self, embedding_time_dim=512, num_past_voxels=15): |
|
super().__init__() |
|
self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim) |
|
self.num_past_voxels = num_past_voxels |
|
self.embedding_time_dim = embedding_time_dim |
|
|
|
def forward(self, time): |
|
|
|
time = time.long() |
|
time = self.embedding_time(time) |
|
return time |
|
|
|
|
|
|
|
model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15) |
|
|
|
model.backbone = BrainNetwork(h=hidden_dim + clip_emb_dim, in_dim=hidden_dim + clip_emb_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) |
|
utils.count_params(model.backbone) |
|
utils.count_params(model) |
|
|
|
|
|
b = torch.randn((1,seq_len,hidden_dim + clip_emb_dim)) |
|
print("b.shape",b.shape) |
|
with torch.no_grad(): |
|
clip_, blur_, depth_ = model.backbone(b) |
|
print(clip_.shape, blur_.shape, depth_.shape) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
voxel_ridge = torch.randn(512,4096) |
|
voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim) |
|
print("b.shape",voxel_ridge.shape) |
|
with torch.no_grad(): |
|
clip_, blur_, depth_ = model.backbone(voxel_ridge) |
|
print(clip_.shape, blur_.shape, depth_.shape)""" |
|
|
|
|
|
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
opt_grouped_parameters = [ |
|
{'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, |
|
] |
|
|
|
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr) |
|
|
|
if lr_scheduler_type == 'linear': |
|
lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
|
optimizer, |
|
total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))), |
|
last_epoch=-1 |
|
) |
|
elif lr_scheduler_type == 'cycle': |
|
total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size))) |
|
print("total_steps", total_steps) |
|
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=max_lr, |
|
total_steps=total_steps, |
|
final_div_factor=1000, |
|
last_epoch=-1, pct_start=2/num_epochs |
|
) |
|
|
|
def save_ckpt(tag): |
|
ckpt_path = outdir+f'/{tag}.pth' |
|
print(f'saving {ckpt_path}',flush=True) |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
try: |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': unwrapped_model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'train_losses': losses, |
|
'test_losses': test_losses, |
|
'lrs': lrs, |
|
}, ckpt_path) |
|
except: |
|
print("Couldn't save... moving on to prevent crashing.") |
|
del unwrapped_model |
|
|
|
print("\nDone with model preparations!") |
|
utils.count_params(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""pp = None |
|
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
|
#with torch.cuda.amp.autocast(dtype=data_type): |
|
#optimizer.zero_grad() |
|
|
|
voxel = voxels[behav[:,0,5].cpu().long()]#.to(device) |
|
image = images[behav[:,0,0].cpu().long()].float()#.to(device).float() |
|
|
|
past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()]#.to(device) # batch_size, 15, 15279 |
|
past_15_times = torch.Tensor([i for i in range(seq_len)])#.to(device) # 15 |
|
print(past_behav[:,:seq_len-1,0].cpu().long()) |
|
past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()] |
|
|
|
break |
|
|
|
print(past_15_times) |
|
#for past in range(1): |
|
# past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device) |
|
|
|
#if blurry_recon: |
|
# blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215 |
|
blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
|
|
|
if depth_recon: |
|
# depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128) |
|
depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
|
depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
|
depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215 |
|
|
|
if use_image_aug: |
|
image = img_augment(image) |
|
|
|
clip_target = clip_model.embed_image(image) |
|
assert not torch.any(torch.isnan(clip_target)) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
voxel, perm, betas, select = utils.mixco(voxel) |
|
past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select) |
|
|
|
for p in range(seq_len-1): |
|
print(past_behav.shape) #128, 15, 17 |
|
print(past_behav[:,p,-1]) |
|
print(past_15_voxels.shape) # 128, 1, 15724 |
|
mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1]) |
|
print(mask) # 128 |
|
past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :]) |
|
print(past_15_voxels) |
|
pp = past_15_voxels |
|
|
|
break""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if local_rank==0 and wandb_log: |
|
import wandb |
|
wandb_project = 'mindeyev2' |
|
wandb_run = model_name |
|
wandb_notes = '' |
|
|
|
print(f"wandb {wandb_project} run {wandb_run}") |
|
wandb.login(host='https://stability.wandb.io') |
|
wandb_config = { |
|
"model_name": model_name, |
|
"global_batch_size": global_batch_size, |
|
"batch_size": batch_size, |
|
"num_epochs": num_epochs, |
|
"clip_scale": clip_scale, |
|
"blur_scale": blur_scale, |
|
"use_image_aug": use_image_aug, |
|
"max_lr": max_lr, |
|
"mixup_pct": mixup_pct, |
|
"num_train": num_train, |
|
"num_test": num_test, |
|
"ckpt_interval": ckpt_interval, |
|
"ckpt_saving": ckpt_saving, |
|
"seed": seed, |
|
"distributed": distributed, |
|
"num_devices": num_devices, |
|
"world_size": world_size, |
|
"train_url": train_url, |
|
"test_url": test_url, |
|
} |
|
print("wandb_config:\n",wandb_config) |
|
if False: |
|
print("wandb_id:",model_name) |
|
wandb.init( |
|
id = model_name, |
|
project=wandb_project, |
|
name=wandb_run, |
|
config=wandb_config, |
|
notes=wandb_notes, |
|
resume="allow", |
|
) |
|
else: |
|
wandb.init( |
|
project=wandb_project, |
|
name=wandb_run, |
|
config=wandb_config, |
|
notes=wandb_notes, |
|
) |
|
else: |
|
wandb_log = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epoch = 0 |
|
losses, test_losses, lrs = [], [], [] |
|
best_test_loss = 1e9 |
|
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) |
|
|
|
|
|
if resume_from_ckpt: |
|
print("\n---resuming from last.pth ckpt---\n") |
|
try: |
|
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
|
except: |
|
print('last.pth failed... trying last_backup.pth') |
|
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
|
epoch = checkpoint['epoch'] |
|
print("Epoch",epoch) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
del checkpoint |
|
elif wandb_log: |
|
if wandb.run.resumed: |
|
print("\n---resuming from last.pth ckpt---\n") |
|
try: |
|
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
|
except: |
|
print('last.pth failed... trying last_backup.pth') |
|
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
|
epoch = checkpoint['epoch'] |
|
print("Epoch",epoch) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
model, optimizer, train_dl, lr_scheduler = accelerator.prepare( |
|
model, optimizer, train_dl, lr_scheduler |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_saturation(image, alpha=2): |
|
gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :] |
|
gray_image = gray_image.unsqueeze(1).expand_as(image) |
|
saturated_image = alpha * image + (1 - alpha) * gray_image |
|
return torch.clamp(saturated_image, 0, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"{model_name} starting with epoch {epoch} / {num_epochs}") |
|
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0)) |
|
test_image, test_voxel = None, None |
|
mse = nn.MSELoss() |
|
l1 = nn.L1Loss() |
|
|
|
for epoch in progress_bar: |
|
model.train() |
|
|
|
fwd_percent_correct = 0. |
|
bwd_percent_correct = 0. |
|
test_fwd_percent_correct = 0. |
|
test_bwd_percent_correct = 0. |
|
|
|
loss_clip_total = 0. |
|
loss_blurry_total = 0. |
|
loss_depth_total = 0. |
|
test_loss_clip_total = 0. |
|
test_loss_blurry_total = 0. |
|
test_loss_depth_total = 0. |
|
|
|
blurry_pixcorr = 0. |
|
test_blurry_pixcorr = 0. |
|
|
|
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
|
with torch.cuda.amp.autocast(): |
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
|
image = images[behav[:,0,0].cpu().long()].to(device).float() |
|
|
|
past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) |
|
|
|
past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()].to(device).float() |
|
past_array = [i for i in range(seq_len-1)] |
|
past_15_times = torch.Tensor(past_array) |
|
|
|
|
|
past_15_times = past_15_times.to(device) |
|
|
|
|
|
|
|
if blurry_recon: |
|
|
|
blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
|
|
|
if depth_recon: |
|
|
|
depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
|
depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
|
depth_image_enc = depth_images |
|
|
|
if use_image_aug: |
|
image = img_augment(image) |
|
|
|
clip_target = clip_model.embed_image(image) |
|
assert not torch.any(torch.isnan(clip_target)) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
voxel, perm, betas, select = utils.mixco(voxel) |
|
past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select) |
|
|
|
|
|
|
|
for p in range(seq_len-1): |
|
|
|
|
|
|
|
mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1]) |
|
|
|
past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :]) |
|
past_15_images[mask, p, :] = torch.zeros_like(past_15_images[0, p, :]) |
|
|
|
|
|
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
|
past_15_images = past_15_images.reshape(-1, past_15_images.shape[-3], past_15_images.shape[-2], past_15_images.shape[-1]) |
|
|
|
past_15_embeddings = clip_model2.embed_image(past_15_images) |
|
|
|
past_15_embeddings = torch.cat([torch.zeros(batch_size, past_15_embeddings.shape[-1]).to(past_15_embeddings.device), past_15_embeddings], dim = 0) |
|
|
|
|
|
|
|
past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
|
past_15_times = past_15_times.reshape(-1) |
|
time_embeddings = model.time_embedding(past_15_times) |
|
|
|
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
|
|
|
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
|
voxel = torch.cat((voxel, positional_current_voxel), dim=-1) |
|
voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)) |
|
voxel_ridge = voxel_ridge.view(seq_len,int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2) |
|
|
|
|
|
past_15_embeddings = past_15_embeddings.reshape(seq_len, int(past_15_embeddings.shape[0]/seq_len), clip_emb_dim).permute(1,0,2) |
|
|
|
|
|
|
|
|
|
voxel_ridge = torch.cat((voxel_ridge, past_15_embeddings), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge) |
|
|
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
loss_clip = utils.mixco_nce( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=.006, |
|
perm=perm, betas=betas, select=select) |
|
else: |
|
epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] |
|
loss_clip = utils.soft_clip_loss( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=epoch_temp) |
|
|
|
loss_clip_total += loss_clip.item() |
|
loss_clip *= clip_scale |
|
loss = loss_clip |
|
|
|
if blurry_recon: |
|
downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False) |
|
re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')) |
|
re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
|
|
|
loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc)) |
|
loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_)) |
|
loss_blurry_total += loss_blurry.item() |
|
loss_blurry *= blur_scale |
|
loss += loss_blurry |
|
|
|
if depth_recon: |
|
loss_depth = l1(depth_image_enc_, depth_image_enc) |
|
|
|
loss_depth_total += loss_depth.item() |
|
loss_depth *= depth_scale |
|
loss += loss_depth |
|
|
|
|
|
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
|
fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item() |
|
bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item() |
|
|
|
if blurry_recon: |
|
with torch.no_grad(): |
|
|
|
random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False) |
|
|
|
blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1) |
|
|
|
pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images) |
|
|
|
|
|
blurry_pixcorr += pixcorr.item() |
|
|
|
|
|
utils.check_loss(loss) |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
|
|
losses.append(loss.item()) |
|
lrs.append(optimizer.param_groups[0]['lr']) |
|
|
|
if lr_scheduler_type is not None: |
|
lr_scheduler.step() |
|
|
|
model.eval() |
|
if local_rank==0: |
|
with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): |
|
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
|
|
|
assert len(behav) == num_test |
|
|
|
|
|
if test_image is None: |
|
voxel = voxels[behav[:,0,5].cpu().long()] |
|
image = behav[:,0,0].cpu().long() |
|
|
|
unique_image, sort_indices = torch.unique(image, return_inverse=True) |
|
for im in unique_image: |
|
locs = torch.where(im == image)[0] |
|
if test_image is None: |
|
test_image = images[im][None] |
|
test_voxel = torch.mean(voxel[locs],axis=0)[None] |
|
else: |
|
test_image = torch.vstack((test_image, images[im][None])) |
|
test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) |
|
|
|
|
|
random_indices = torch.arange(len(test_voxel))[:300] |
|
voxel = test_voxel[random_indices].to(device) |
|
image = test_image[random_indices].to(device) |
|
assert len(image) == 300 |
|
|
|
current_past_behav = past_behav[random_indices] |
|
|
|
past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) |
|
past_15_images = images[current_past_behav[:,:seq_len-1,0].cpu().long()].to(device).float() |
|
past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) |
|
|
|
if blurry_recon: |
|
|
|
blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215 |
|
|
|
if depth_recon: |
|
|
|
depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32) |
|
depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half() |
|
depth_image_enc = depth_images |
|
|
|
clip_target = clip_model.embed_image(image.float()) |
|
|
|
|
|
past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1]) |
|
past_15_images = past_15_images.reshape(-1, past_15_images.shape[-3], past_15_images.shape[-2], past_15_images.shape[-1]) |
|
|
|
past_15_embeddings = clip_model2.embed_image(past_15_images) |
|
|
|
past_15_embeddings = torch.cat([torch.zeros(image.shape[0], past_15_embeddings.shape[-1]).to(past_15_embeddings.device), past_15_embeddings], dim = 0) |
|
|
|
past_15_times = past_15_times.repeat(voxel.shape[0], 1) |
|
past_15_times = past_15_times.reshape(-1) |
|
time_embeddings = model.time_embedding(past_15_times) |
|
past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1) |
|
|
|
positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device) |
|
voxel = torch.cat((voxel, positional_current_voxel), dim=-1) |
|
voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)) |
|
voxel_ridge = voxel_ridge.view(seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2) |
|
past_15_embeddings = past_15_embeddings.view(seq_len, int(past_15_embeddings.shape[0]/seq_len), clip_emb_dim).permute(1,0,2) |
|
|
|
voxel_ridge = torch.cat((voxel_ridge, past_15_embeddings), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge) |
|
|
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
loss_clip = utils.soft_clip_loss( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=.006) |
|
test_loss_clip_total += loss_clip.item() |
|
loss_clip = loss_clip * clip_scale |
|
loss = loss_clip |
|
|
|
if blurry_recon: |
|
downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False) |
|
re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')) |
|
re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215 |
|
|
|
loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc)) |
|
loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_)) |
|
test_loss_blurry_total += loss_blurry.item() |
|
loss_blurry *= blur_scale |
|
loss += loss_blurry |
|
|
|
|
|
blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1) |
|
blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
|
pixcorr = utils.pixcorr(image, blurry_recon_images) |
|
loss += (1 - pixcorr) |
|
test_blurry_pixcorr += pixcorr.item() |
|
|
|
if depth_recon: |
|
loss_depth = l1(depth_image_enc_, depth_image_enc) |
|
|
|
test_loss_depth_total += loss_depth.item() |
|
loss_depth *= depth_scale |
|
loss += loss_depth |
|
|
|
|
|
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
|
test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item() |
|
test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item() |
|
|
|
utils.check_loss(loss) |
|
test_losses.append(loss.item()) |
|
|
|
|
|
print("---") |
|
|
|
assert (test_i+1) == 1 |
|
logs = {"train/loss": np.mean(losses[-(train_i+1):]), |
|
"test/loss": np.mean(test_losses[-(test_i+1):]), |
|
"train/lr": lrs[-1], |
|
"train/num_steps": len(losses), |
|
"test/num_steps": len(test_losses), |
|
"train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), |
|
"train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), |
|
"test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), |
|
"test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), |
|
"train/loss_clip_total": loss_clip_total / (train_i + 1), |
|
"train/loss_blurry_total": loss_blurry_total / (train_i + 1), |
|
"test/loss_clip_total": test_loss_clip_total / (test_i + 1), |
|
"test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), |
|
"train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), |
|
"test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), |
|
"train/loss_depth_total": loss_depth_total / (train_i + 1), |
|
"test/loss_depth_total": test_loss_depth_total / (test_i + 1), |
|
} |
|
|
|
if blurry_recon: |
|
|
|
fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
|
jj=-1 |
|
for j in [0,1,2,3]: |
|
jj+=1 |
|
axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
|
axes[jj].axis('off') |
|
jj+=1 |
|
axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1))) |
|
axes[jj].axis('off') |
|
|
|
if wandb_log: |
|
logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}") |
|
plt.close() |
|
else: |
|
plt.show() |
|
|
|
if depth_recon: |
|
|
|
fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
|
|
|
|
|
jj=-1 |
|
for j in [0,1,2,3]: |
|
jj+=1 |
|
axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224))) |
|
axes[jj].axis('off') |
|
jj+=1 |
|
axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224))) |
|
axes[jj].axis('off') |
|
if wandb_log: |
|
logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}") |
|
plt.close() |
|
else: |
|
plt.show() |
|
|
|
progress_bar.set_postfix(**logs) |
|
|
|
|
|
if epoch % ckpt_interval == 0: |
|
if not utils.is_interactive(): |
|
save_ckpt(f'last') |
|
|
|
if wandb_log: wandb.log(logs) |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
print("\n===Finished!===\n") |
|
if ckpt_saving: |
|
save_ckpt(f'last') |
|
if not utils.is_interactive(): |
|
sys.exit(0) |
|
|
|
|
|
|
|
|
|
|
|
plt.plot(losses) |
|
plt.show() |
|
plt.plot(test_losses) |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy") |
|
|
|
|
|
|
|
|
|
|
|
ii=2 |
|
all_indices = np.unique(train_73k_images) |
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
for batch in tqdm(range(0,len(all_indices),512)): |
|
if batch==0: |
|
clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu() |
|
else: |
|
target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu() |
|
clip_target = torch.vstack((clip_target,target)) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
voxel = test_voxel[[ii]].to(device) |
|
image = test_image[[ii]].to(device) |
|
|
|
print("Original Image (test set)") |
|
display(utils.torch_to_Image(image)) |
|
|
|
clip_target = clip_model.embed_image(image).cpu() |
|
|
|
|
|
voxel_ridge = model.ridge(voxel).unsqueeze(1) |
|
clip_voxels, _, _ = model.backbone(voxel_ridge) |
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
print("clip_voxels_norm", clip_voxels_norm.shape) |
|
print("clip_target_norm", clip_target_norm.shape) |
|
|
|
sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(), |
|
clip_target_norm).flatten()).flip(0) |
|
picks = all_indices[sortt[:5]] |
|
|
|
print("\nNearest neighbors in training set") |
|
for ip,p in enumerate(picks): |
|
display(utils.torch_to_Image(images[[p]])) |
|
|
|
if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0] |
|
|
|
print("\n=====\npredicted_caption:\n", predicted_caption) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from diffusers import StableDiffusionXLPipeline |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True |
|
) |
|
pipe.to("cuda") |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
prompt = predicted_caption |
|
recon = pipe(prompt=prompt).images[0] |
|
|
|
|
|
|
|
|
|
|
|
print("Seen image") |
|
display(utils.torch_to_Image(image)) |
|
|
|
print("Reconstruction") |
|
utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224)) |
|
|
|
|