|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import argparse |
|
import numpy as np |
|
import math |
|
from einops import rearrange |
|
import time |
|
import random |
|
import h5py |
|
from tqdm import tqdm |
|
|
|
import webdataset as wds |
|
import gc |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from torchvision.transforms import ToPILImage |
|
|
|
from accelerate import Accelerator, DeepSpeedPlugin |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
import utils |
|
|
|
global_batch_size = 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
local_rank = os.getenv('RANK') |
|
if local_rank is None: |
|
local_rank = 0 |
|
else: |
|
local_rank = int(local_rank) |
|
print("LOCAL RANK ", local_rank) |
|
|
|
num_devices = torch.cuda.device_count() |
|
if num_devices==0: num_devices = 1 |
|
|
|
accelerator = Accelerator(split_batches=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("PID of this process =",os.getpid()) |
|
device = accelerator.device |
|
print("device:",device) |
|
num_workers = num_devices |
|
print(accelerator.state) |
|
world_size = accelerator.state.num_processes |
|
distributed = not accelerator.state.distributed_type == 'NO' |
|
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size) |
|
print = accelerator.print |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if utils.is_interactive(): |
|
|
|
jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \ |
|
--model_name=captions \ |
|
--subj=1 --batch_size={global_batch_size} --n_samples_save=0 \ |
|
--max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug" |
|
|
|
jupyter_args = jupyter_args.split() |
|
print(jupyter_args) |
|
|
|
from IPython.display import clear_output |
|
get_ipython().run_line_magic('load_ext', 'autoreload') |
|
|
|
get_ipython().run_line_magic('autoreload', '2') |
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Model Training Configuration") |
|
parser.add_argument( |
|
"--model_name", type=str, default="testing", |
|
help="name of model, used for ckpt saving and wandb logging (if enabled)", |
|
) |
|
parser.add_argument( |
|
"--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset", |
|
help="Path to where NSD data is stored / where to download it to", |
|
) |
|
parser.add_argument( |
|
"--subj",type=int, default=1, choices=[1,2,5,7], |
|
) |
|
parser.add_argument( |
|
"--batch_size", type=int, default=32, |
|
help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser", |
|
) |
|
parser.add_argument( |
|
"--wandb_log",action=argparse.BooleanOptionalAction,default=False, |
|
help="whether to log to wandb", |
|
) |
|
parser.add_argument( |
|
"--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False, |
|
help="if not using wandb and want to resume from a ckpt", |
|
) |
|
parser.add_argument( |
|
"--wandb_project",type=str,default="stability", |
|
help="wandb project name", |
|
) |
|
parser.add_argument( |
|
"--mixup_pct",type=float,default=.33, |
|
help="proportion of way through training when to switch from BiMixCo to SoftCLIP", |
|
) |
|
parser.add_argument( |
|
"--use_image_aug",action=argparse.BooleanOptionalAction,default=True, |
|
help="whether to use image augmentation", |
|
) |
|
parser.add_argument( |
|
"--num_epochs",type=int,default=100, |
|
help="number of epochs of training", |
|
) |
|
parser.add_argument( |
|
"--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], |
|
) |
|
parser.add_argument( |
|
"--ckpt_saving",action=argparse.BooleanOptionalAction,default=True, |
|
) |
|
parser.add_argument( |
|
"--ckpt_interval",type=int,default=5, |
|
help="save backup ckpt and reconstruct every x epochs", |
|
) |
|
parser.add_argument( |
|
"--seed",type=int,default=42, |
|
) |
|
parser.add_argument( |
|
"--max_lr",type=float,default=3e-4, |
|
) |
|
parser.add_argument( |
|
"--n_samples_save",type=int,default=0,choices=[0,1], |
|
help="Number of reconstructions for monitoring progress, 0 will speed up training", |
|
) |
|
parser.add_argument( |
|
"--clip_mse_ratio",type=float,default=0.7, |
|
help="Number of reconstructions for monitoring progress, 0 will speed up training", |
|
) |
|
|
|
if utils.is_interactive(): |
|
args = parser.parse_args(jupyter_args) |
|
else: |
|
args = parser.parse_args() |
|
|
|
|
|
for attribute_name in vars(args).keys(): |
|
globals()[attribute_name] = getattr(args, attribute_name) |
|
|
|
print("global batch_size", batch_size) |
|
batch_size = int(batch_size / num_devices) |
|
print("batch_size", batch_size) |
|
|
|
|
|
|
|
|
|
|
|
outdir = os.path.abspath(f'../train_logs/{model_name}') |
|
if not os.path.exists(outdir): |
|
os.makedirs(outdir,exist_ok=True) |
|
if use_image_aug: |
|
import kornia |
|
from kornia.augmentation.container import AugmentationSequential |
|
img_augment = AugmentationSequential( |
|
kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3), |
|
kornia.augmentation.Resize((224, 224)), |
|
kornia.augmentation.RandomHorizontalFlip(p=0.3), |
|
kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3), |
|
kornia.augmentation.RandomGrayscale(p=0.3), |
|
same_on_batch=False, |
|
data_keys=["input"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
wandb_log = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if subj==1: |
|
num_train = 24958 |
|
num_test = 2770 |
|
test_batch_size = num_test |
|
|
|
def my_split_by_node(urls): return urls |
|
|
|
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar" |
|
print(train_url) |
|
|
|
train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\ |
|
.shuffle(750, initial=1500, rng=random.Random(42))\ |
|
.decode("torch")\ |
|
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
|
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
|
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True) |
|
|
|
test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar" |
|
print(test_url) |
|
|
|
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\ |
|
.shuffle(750, initial=1500, rng=random.Random(42))\ |
|
.decode("torch")\ |
|
.rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\ |
|
.to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"]) |
|
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r') |
|
voxels = f['betas'][:] |
|
print(f"subj0{subj} betas loaded into memory") |
|
voxels = torch.Tensor(voxels).to("cpu").half() |
|
if subj==1: |
|
voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5)))) |
|
print("voxels", voxels.shape) |
|
num_voxels = voxels.shape[-1] |
|
|
|
|
|
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r') |
|
images = f['images'][:] |
|
images = torch.Tensor(images).to("cpu").half() |
|
print("images", images.shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import transformers |
|
from transformers import Blip2Processor, Blip2ForConditionalGeneration |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
from models import Clipper |
|
clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True) |
|
|
|
|
|
|
|
|
|
|
|
cache_blip2 = "/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b" |
|
|
|
b2_processor = Blip2Processor.from_pretrained(cache_blip2) |
|
b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map="auto") |
|
|
|
|
|
"""from lavis.models import load_model_and_preprocess |
|
from lavis.models import model_zoo |
|
blip2_model, vis_processors, _ = load_model_and_preprocess( |
|
name="blip2_t5", model_type="pretrain_flant5xl_vitL", is_eval=True, device=device) |
|
|
|
clip_seq_dim = 257 |
|
clip_emb_dim = 1024 |
|
hidden_dim = 4096""" |
|
|
|
|
|
|
|
|
|
|
|
def embed_images_b2(images): |
|
images = (images * 255).type(torch.uint8) |
|
with torch.no_grad(): |
|
inputs_processed = b2_processor(images, return_tensors="pt").to("cuda", torch.float16) |
|
enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values']) |
|
return enc_imgs.last_hidden_state.detach(), inputs_processed |
|
|
|
def embeds_to_captions_b2(embeds): |
|
with torch.no_grad(): |
|
input_ids = None |
|
attention_mask = None |
|
batch_size = embeds.shape[0] |
|
image_embeds = embeds |
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) |
|
|
|
query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_outputs = b2_model.qformer( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_attention_mask, |
|
return_dict=True, |
|
) |
|
query_output = query_outputs.last_hidden_state |
|
|
|
language_model_inputs = b2_model.language_projection(query_output) |
|
language_attention_mask = torch.ones( |
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device |
|
) |
|
if input_ids is None: |
|
input_ids = ( |
|
torch.LongTensor([[b2_model.config.text_config.bos_token_id]]) |
|
.repeat(batch_size, 1) |
|
.to(image_embeds.device) |
|
) |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) |
|
|
|
|
|
inputs_embeds = b2_model.get_input_embeddings()(input_ids) |
|
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) |
|
|
|
outputs = b2_model.language_model.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
) |
|
text = b2_processor.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
return outputs, text |
|
|
|
|
|
|
|
|
|
|
|
image_test = images[1:20].permute(0,2,3,1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""import matplotlib.pyplot as plt |
|
# Plotting one of the images (taking the first image as an example) |
|
img_to_plot = inputs_rec['pixel_values'][-1] |
|
|
|
# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C]) |
|
img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu') |
|
print(img_to_plot.shape) |
|
|
|
plt.imshow(img_to_plot) |
|
plt.show()""" |
|
|
|
|
|
|
|
|
|
|
|
embeds_test, inputs_rec = embed_images_b2(image_test) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs_test, text_test = embeds_to_captions_b2(embeds_test) |
|
|
|
|
|
|
|
|
|
|
|
text_test |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip_seq_dim = 257 |
|
clip_emb_dim = 1408 |
|
hidden_dim = 2048 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from diffusers import AutoencoderKL |
|
autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache") |
|
|
|
autoenc.eval() |
|
autoenc.requires_grad_(False) |
|
autoenc.to(device) |
|
utils.count_params(autoenc) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MindEyeModule(nn.Module): |
|
def __init__(self): |
|
super(MindEyeModule, self).__init__() |
|
def forward(self, x): |
|
return x |
|
|
|
model = MindEyeModule() |
|
model |
|
|
|
|
|
|
|
|
|
|
|
class RidgeRegression(torch.nn.Module): |
|
|
|
def __init__(self, input_size, out_features): |
|
super(RidgeRegression, self).__init__() |
|
self.out_features = out_features |
|
self.linear = torch.nn.Linear(input_size, out_features) |
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim) |
|
utils.count_params(model.ridge) |
|
utils.count_params(model) |
|
|
|
b = torch.randn((2,1,voxels.shape[1])) |
|
print(b.shape, model.ridge(b).shape) |
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from diffusers.models.vae import Decoder |
|
class BrainNetwork(nn.Module): |
|
def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.35, blurry_dim=16): |
|
super().__init__() |
|
self.blurry_dim = blurry_dim |
|
norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) |
|
act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU |
|
act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) |
|
self.lin0 = nn.Linear(in_dim, h) |
|
self.mlp = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(h, h), |
|
*[item() for item in act_and_norm], |
|
nn.Dropout(drop) |
|
) for _ in range(n_blocks) |
|
]) |
|
self.lin1 = nn.Linear(h, out_dim, bias=True) |
|
|
|
self.n_blocks = n_blocks |
|
self.clip_size = clip_size |
|
self.clip_proj = nn.Sequential( |
|
nn.LayerNorm(clip_size), |
|
nn.GELU(), |
|
nn.Linear(clip_size, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, 2048), |
|
nn.LayerNorm(2048), |
|
nn.GELU(), |
|
nn.Linear(2048, clip_size) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
x = self.lin0(x) |
|
residual = x |
|
for res_block in range(self.n_blocks): |
|
x = self.mlp[res_block](x) |
|
x += residual |
|
residual = x |
|
x = x.reshape(len(x), -1) |
|
x = self.lin1(x) |
|
|
|
|
|
c = self.clip_proj(x.reshape(len(x), -1, self.clip_size)) |
|
|
|
return c |
|
|
|
model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) |
|
utils.count_params(model.backbone) |
|
utils.count_params(model) |
|
|
|
b = torch.randn((4,hidden_dim)) |
|
print(b.shape) |
|
clip_ = model.backbone(b) |
|
print(clip_.shape) |
|
|
|
|
|
|
|
|
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
opt_grouped_parameters = [ |
|
{'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, |
|
{'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, |
|
] |
|
|
|
optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95)) |
|
|
|
if lr_scheduler_type == 'linear': |
|
lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
|
optimizer, |
|
total_iters=int(num_epochs*(num_train*num_devices//batch_size)), |
|
last_epoch=-1 |
|
) |
|
elif lr_scheduler_type == 'cycle': |
|
total_steps=int(num_epochs*(num_train*num_devices//batch_size)) |
|
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( |
|
optimizer, |
|
max_lr=max_lr, |
|
total_steps=total_steps, |
|
final_div_factor=1000, |
|
last_epoch=-1, pct_start=2/num_epochs |
|
) |
|
|
|
def save_ckpt(tag): |
|
ckpt_path = outdir+f'/{tag}.pth' |
|
print(f'saving {ckpt_path}',flush=True) |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
try: |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': unwrapped_model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'lr_scheduler': lr_scheduler.state_dict(), |
|
'train_losses': losses, |
|
'test_losses': test_losses, |
|
'lrs': lrs, |
|
}, ckpt_path) |
|
except: |
|
print("Couldn't save... moving on to prevent crashing.") |
|
del unwrapped_model |
|
|
|
print("\nDone with model preparations!") |
|
utils.count_params(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if local_rank==0 and True: |
|
import wandb |
|
|
|
wandb_project = 'mindeyev2' |
|
wandb_run = model_name |
|
wandb_notes = '' |
|
|
|
print(f"wandb {wandb_project} run {wandb_run}") |
|
wandb.login(host='https://stability.wandb.io') |
|
wandb_config = { |
|
"model_name": model_name, |
|
"batch_size": batch_size, |
|
"num_epochs": num_epochs, |
|
"use_image_aug": use_image_aug, |
|
"max_lr": max_lr, |
|
"lr_scheduler_type": lr_scheduler_type, |
|
"mixup_pct": mixup_pct, |
|
"num_train": num_train, |
|
"num_test": num_test, |
|
"seed": seed, |
|
"distributed": distributed, |
|
"num_devices": num_devices, |
|
"world_size": world_size, |
|
} |
|
print("wandb_config:\n",wandb_config) |
|
if False: |
|
print("wandb_id:",model_name) |
|
wandb.init( |
|
id = model_name, |
|
project=wandb_project, |
|
name=wandb_run, |
|
config=wandb_config, |
|
notes=wandb_notes, |
|
resume="allow", |
|
) |
|
else: |
|
wandb.init( |
|
project=wandb_project, |
|
name=wandb_run, |
|
config=wandb_config, |
|
notes=wandb_notes, |
|
) |
|
else: |
|
wandb_log = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pixcorr_preprocess = transforms.Compose([ |
|
transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), |
|
]) |
|
def pixcorr(images,brains): |
|
|
|
all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1) |
|
all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1) |
|
corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean() |
|
return corrmean |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
epoch = 0 |
|
losses, test_losses, lrs = [], [], [] |
|
best_test_loss = 1e9 |
|
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs)) |
|
|
|
|
|
if resume_from_ckpt: |
|
print("\n---resuming from last.pth ckpt---\n") |
|
try: |
|
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
|
except: |
|
print('last.pth failed... trying last_backup.pth') |
|
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
|
epoch = checkpoint['epoch'] |
|
print("Epoch",epoch) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) |
|
del checkpoint |
|
elif wandb_log: |
|
if wandb.run.resumed: |
|
print("\n---resuming from last.pth ckpt---\n") |
|
try: |
|
checkpoint = torch.load(outdir+'/last.pth', map_location='cpu') |
|
except: |
|
print('last.pth failed... trying last_backup.pth') |
|
checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu') |
|
epoch = checkpoint['epoch'] |
|
print("Epoch",epoch) |
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
diffusion_diffuser.load_state_dict(checkpoint['model_state_dict']) |
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare( |
|
model, optimizer, train_dl, test_dl, lr_scheduler |
|
) |
|
|
|
|
|
|
|
|
|
|
|
"""transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(224, 224), |
|
), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
] |
|
) |
|
|
|
def tensor_2_embed(image): |
|
image_for_blip2 = transform(image) |
|
|
|
#Generate embeddings |
|
with blip2_model.maybe_autocast(): |
|
blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2)) |
|
|
|
return blip2_target |
|
|
|
def embed_2_caption(image_embeds, model): |
|
image_embeds = image_embeds.float() |
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( |
|
image.device) |
|
|
|
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_output = model.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True) |
|
|
|
inputs_t5 = model.t5_proj(query_output.last_hidden_state) |
|
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) |
|
prompt = model.prompt |
|
input_tokens = model.t5_tokenizer( |
|
prompt, padding="longest", return_tensors="pt" |
|
).to(image.device) |
|
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) |
|
|
|
with model.maybe_autocast(dtype=torch.bfloat16): |
|
inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids) |
|
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) |
|
|
|
outputs = model.t5_model.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=encoder_atts) |
|
output_text = model.t5_tokenizer.batch_decode( |
|
outputs, skip_special_tokens=True) |
|
|
|
return output_text""" |
|
|
|
|
|
|
|
|
|
|
|
wandb_log = True |
|
|
|
|
|
|
|
|
|
|
|
print(f"{model_name} starting with epoch {epoch} / {num_epochs}") |
|
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0)) |
|
test_image, test_voxel = None, None |
|
mse = nn.MSELoss() |
|
for epoch in progress_bar: |
|
model.train() |
|
|
|
fwd_percent_correct = 0. |
|
bwd_percent_correct = 0. |
|
test_fwd_percent_correct = 0. |
|
test_bwd_percent_correct = 0. |
|
|
|
loss_clip_total = 0. |
|
loss_blurry_total = 0. |
|
test_loss_clip_total = 0. |
|
test_loss_blurry_total = 0. |
|
|
|
blurry_pixcorr = 0. |
|
test_blurry_pixcorr = 0. |
|
|
|
for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl): |
|
if epoch == 0: |
|
lrs.append(0) |
|
break |
|
with torch.cuda.amp.autocast(): |
|
optimizer.zero_grad() |
|
|
|
voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
|
|
|
image = images[behav[:,0,0].cpu().long()].to(device).float() |
|
|
|
|
|
|
|
if use_image_aug: image = img_augment(image) |
|
|
|
clip_target = embed_images_b2(image)[0].to(device) |
|
assert not torch.any(torch.isnan(clip_target)) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
voxel, perm, betas, select = utils.mixco(voxel) |
|
|
|
voxel_ridge = model.ridge(voxel) |
|
|
|
|
|
clip_voxels = model.backbone(voxel_ridge) |
|
|
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
if epoch < int(mixup_pct * num_epochs): |
|
loss_clip = utils.mixco_nce( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=.006, |
|
perm=perm, betas=betas, select=select) |
|
else: |
|
epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)] |
|
loss_clip = utils.soft_clip_loss( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=epoch_temp) |
|
|
|
loss_mse= mse(clip_voxels, clip_target) |
|
|
|
|
|
|
|
loss_clip_total += loss_clip.item() |
|
|
|
|
|
|
|
loss = (clip_mse_ratio * loss_clip) + ((1 - clip_mse_ratio) * loss_mse) |
|
if (train_i % 10 == 0): |
|
print(train_i, loss) |
|
|
|
utils.check_loss(loss) |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
|
|
losses.append(loss.item()) |
|
lrs.append(optimizer.param_groups[0]['lr']) |
|
|
|
|
|
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
|
fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) |
|
bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if lr_scheduler_type is not None: |
|
lr_scheduler.step() |
|
|
|
model.eval() |
|
for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): |
|
with torch.cuda.amp.autocast(): |
|
with torch.no_grad(): |
|
|
|
if len(behav) != num_test: print("!",len(behav),num_test) |
|
|
|
|
|
if test_image is None: |
|
voxel = voxels[behav[:,0,5].cpu().long()].to(device) |
|
|
|
image = behav[:,0,0].cpu().long() |
|
|
|
unique_image, sort_indices = torch.unique(image, return_inverse=True) |
|
for im in unique_image: |
|
locs = torch.where(im == image)[0] |
|
if test_image is None: |
|
test_image = images[im][None] |
|
test_voxel = torch.mean(voxel[locs],axis=0)[None] |
|
else: |
|
test_image = torch.vstack((test_image, images[im][None])) |
|
test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None])) |
|
|
|
|
|
random_indices = torch.arange(len(test_voxel))[:batch_size] |
|
voxel = test_voxel[random_indices].to(device) |
|
image = test_image[random_indices].to(device) |
|
assert len(image) == batch_size |
|
|
|
|
|
|
|
|
|
clip_target = embed_images_b2(image)[0].to(device) |
|
|
|
voxel_ridge = model.ridge(voxel) |
|
|
|
|
|
clip_voxels = model.backbone(voxel_ridge) |
|
|
|
clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1) |
|
clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1) |
|
|
|
loss_clip = utils.soft_clip_loss( |
|
clip_voxels_norm, |
|
clip_target_norm, |
|
temp=.006) |
|
|
|
loss_mse = mse(clip_voxels, clip_target) |
|
|
|
|
|
|
|
|
|
loss = (clip_mse_ratio * loss_clip) + ((1 - clip_mse_ratio) * loss_mse) |
|
|
|
utils.check_loss(loss) |
|
|
|
test_losses.append(loss.item()) |
|
|
|
|
|
labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) |
|
test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1) |
|
test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if local_rank==0: |
|
|
|
assert (test_i+1) == 1 |
|
logs = {"train/loss": np.mean(losses[-(train_i+1):]), |
|
"test/loss": np.mean(test_losses[-(test_i+1):]), |
|
"train/lr": lrs[-1], |
|
"train/num_steps": len(losses), |
|
"test/num_steps": len(test_losses), |
|
"train/fwd_pct_correct": fwd_percent_correct / (train_i + 1), |
|
"train/bwd_pct_correct": bwd_percent_correct / (train_i + 1), |
|
"test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1), |
|
"test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1), |
|
"train/loss_clip_total": loss_clip_total / (train_i + 1), |
|
"train/loss_blurry_total": loss_blurry_total / (train_i + 1), |
|
"test/loss_clip_total": test_loss_clip_total / (test_i + 1), |
|
"test/loss_blurry_total": test_loss_blurry_total / (test_i + 1), |
|
"train/blurry_pixcorr": blurry_pixcorr / (train_i + 1), |
|
"test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1), |
|
} |
|
progress_bar.set_postfix(**logs) |
|
|
|
fig, axes = plt.subplots(1, 8, figsize=(10, 4)) |
|
jj=-1 |
|
for j in [0,1,2,3,4,5,6,7]: |
|
jj+=1 |
|
axes[jj].imshow(utils.torch_to_Image(image[j])) |
|
axes[jj].axis('off') |
|
|
|
if wandb_log: |
|
generated_captions = embeds_to_captions_b2(clip_voxels[0:8]) |
|
print(generated_captions[1]) |
|
logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}" + "\n".join(generated_captions[1])) |
|
plt.close() |
|
|
|
if epoch % ckpt_interval == 0: |
|
if not utils.is_interactive(): |
|
save_ckpt(f'last') |
|
|
|
if wandb_log: wandb.log(logs) |
|
|
|
|
|
accelerator.wait_for_everyone() |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
print("\n===Finished!===\n") |
|
if ckpt_saving: |
|
save_ckpt(f'last') |
|
if not utils.is_interactive(): |
|
sys.exit(0) |
|
|
|
|
|
|
|
|
|
|
|
plt.plot(losses) |
|
plt.show() |
|
plt.plot(test_losses) |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
print('test') |
|
|
|
|
|
|
|
|
|
|
|
def tensor_2_embed_old(tensor): |
|
embed_array = torch.zeros((tensor.shape[0],257, 1024)) |
|
to_pil = ToPILImage() |
|
for sample in range(tensor.shape[0]): |
|
PIL_image = to_pil(tensor[sample]) |
|
image_for_blip2 = vis_processors["eval"](PIL_image).unsqueeze(0).to(device) |
|
|
|
with blip2_model.maybe_autocast(): |
|
blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2)) |
|
embed_array[sample] = blip2_target |
|
|
|
return embed_array |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|