In [1]:
# # Code to convert this notebook to .py if you want to run it via command line or with Slurm
# from subprocess import call
# command = "jupyter nbconvert Train.ipynb --to python"
# call(command,shell=True)

# Import packages & functions

In [2]:
import os
import sys
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import h5py
from tqdm import tqdm

import webdataset as wds
import gc

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.transforms import ToPILImage   #CHANGED (added)

from accelerate import Accelerator, DeepSpeedPlugin

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils

global_batch_size = 128 #128

  from .autonotebook import tqdm as notebook_tqdm


[2023-11-19 16:32:39,711] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1

accelerator = Accelerator(split_batches=False)

### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###

# if num_devices <= 1 and utils.is_interactive():
#     # can emulate a distributed environment for deepspeed to work in jupyter notebook
#     os.environ["MASTER_ADDR"] = "localhost"
#     os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
#     os.environ["RANK"] = "0"
#     os.environ["LOCAL_RANK"] = "0"
#     os.environ["WORLD_SIZE"] = "1"
#     os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
#     global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]

# # alter the deepspeed config according to your global and local batch size
# if local_rank == 0:
#     with open('deepspeed_config_stage2.json', 'r') as file:
#         config = json.load(file)
#     config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
#     config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
#     with open('deepspeed_config_stage2.json', 'w') as file:
#         json.dump(config, file)
# else:
#     # give some time for the local_rank=0 gpu to prep new deepspeed config file
#     time.sleep(10)
# deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)

LOCAL RANK  0


In [4]:
print("PID of this process =",os.getpid())
device = accelerator.device
print("device:",device)
num_workers = num_devices
print(accelerator.state)
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
print = accelerator.print # only print if local_rank=0

PID of this process = 2370606
device: cuda
Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

distributed = False num_devices = 1 local rank = 0 world size = 1


# Configurations

In [5]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    # Example use
    jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
                    --model_name=captions \
                    --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
                    --max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug"
    #max_lr=3e-5 originally
    jupyter_args = jupyter_args.split()
    print(jupyter_args)
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=captions', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=3e-1', '--mixup_pct=.66', '--num_epochs=30', '--ckpt_interval=999', '--no-use_image_aug']


In [6]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,5,7],
)
parser.add_argument(
    "--batch_size", type=int, default=32,
    help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
    help="whether to log to wandb",
)
parser.add_argument(
    "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
    help="if not using wandb and want to resume from a ckpt",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=240,
    help="number of epochs of training",
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=3e-4,
)
parser.add_argument(
    "--n_samples_save",type=int,default=0,choices=[0,1],
    help="Number of reconstructions for monitoring progress, 0 will speed up training",
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)

print("global batch_size", batch_size)
batch_size = int(batch_size / num_devices)
print("batch_size", batch_size)

global batch_size 128
batch_size 128


In [7]:
outdir = os.path.abspath(f'../train_logs/{model_name}')
if not os.path.exists(outdir):
    os.makedirs(outdir,exist_ok=True)
if use_image_aug:
    import kornia
    from kornia.augmentation.container import AugmentationSequential
    img_augment = AugmentationSequential(
        kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
        kornia.augmentation.Resize((224, 224)),
        kornia.augmentation.RandomHorizontalFlip(p=0.3),
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
        kornia.augmentation.RandomGrayscale(p=0.3),
        same_on_batch=False,
        data_keys=["input"],
    )

In [8]:
wandb_log = False

# Prep data, models, and dataloaders

## Dataloader

In [9]:
if subj==1:
    num_train = 24958
    num_test = 2770
test_batch_size = num_test

def my_split_by_node(urls): return urls
    
train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
print(train_url)

train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
                    .shuffle(750, initial=1500, rng=random.Random(42))\
                    .decode("torch")\
                    .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                    .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)

test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
print(test_url)

test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                    .shuffle(750, initial=1500, rng=random.Random(42))\
                    .decode("torch")\
                    .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                    .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)

/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar
/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar


### check dataloaders are working

In [10]:
# test_indices = []
# test_images = []
# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
#     test_indices = np.append(test_indices, behav[:,0,5].numpy())
#     test_images = np.append(test_images, behav[:,0,0].numpy())
# test_indices = test_indices.astype(np.int16)
# print(test_i, (test_i+1) * test_batch_size, len(test_indices))
# print("---\n")

# train_indices = []
# train_images = []
# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
#     train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
#     train_images = np.append(train_images, behav[:,0,0].numpy())
# train_indices = train_indices.astype(np.int16)
# print(train_i, (train_i+1) * batch_size, len(train_indices))

# # train_images = np.hstack((train_images, test_images))
# # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")

## Load data and images

In [11]:
# load betas
f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
voxels = f['betas'][:]
print(f"subj0{subj} betas loaded into memory")
voxels = torch.Tensor(voxels).to("cpu").half()
if subj==1:
    voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
print("voxels", voxels.shape)
num_voxels = voxels.shape[-1]

# load orig images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'][:]
images = torch.Tensor(images).to("cpu").half()
print("images", images.shape)

subj01 betas loaded into memory
voxels torch.Size([27750, 15729])
images torch.Size([73000, 3, 224, 224])


## Load models

### CLIP image embeddings  model

In [12]:
import transformers
from transformers import Blip2Processor, Blip2ForConditionalGeneration

from PIL import Image

In [13]:
from models import Clipper
clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)

ViT-L/14 cuda:0


In [14]:
cache_blip2 = "/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b"

b2_processor = Blip2Processor.from_pretrained(cache_blip2)
b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map="auto")

#Load in blip2 as well
"""from lavis.models import load_model_and_preprocess
from lavis.models import model_zoo
blip2_model, vis_processors, _ = load_model_and_preprocess(
            name="blip2_t5", model_type="pretrain_flant5xl_vitL", is_eval=True, device=device)

clip_seq_dim = 257
clip_emb_dim = 1024
hidden_dim = 4096"""

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:08<00:00, 34.47s/it]


'from lavis.models import load_model_and_preprocess\nfrom lavis.models import model_zoo\nblip2_model, vis_processors, _ = load_model_and_preprocess(\n            name="blip2_t5", model_type="pretrain_flant5xl_vitL", is_eval=True, device=device)\n\nclip_seq_dim = 257\nclip_emb_dim = 1024\nhidden_dim = 4096'

In [74]:
def embed_images_b2(images):
    images = (images * 255).type(torch.uint8)
    with torch.no_grad():
        inputs_processed = b2_processor(images, return_tensors="pt").to("cuda", torch.float16)
        enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values'])
        return enc_imgs.last_hidden_state.detach(), inputs_processed

def embeds_to_captions_b2(embeds, sample = False, temp = 0.9):
    with torch.no_grad():
        input_ids = None #inputs['input_ids']
        attention_mask = None
        batch_size = embeds.shape[0]
        image_embeds = embeds
        image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)

        query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_outputs = b2_model.qformer(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_attention_mask,
            return_dict=True,
        )
        query_output = query_outputs.last_hidden_state

        language_model_inputs = b2_model.language_projection(query_output)
        language_attention_mask = torch.ones(
            language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
        )
        if input_ids is None:
            input_ids = (
                torch.LongTensor([[b2_model.config.text_config.bos_token_id]])
                .repeat(batch_size, 1)
                .to(image_embeds.device)
            )
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)

        # concatenate query embeddings with prompt embeddings
        inputs_embeds = b2_model.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)

        outputs = b2_model.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            temperature=temp,
            do_sample = sample
        )
        text = b2_processor.batch_decode(outputs, skip_special_tokens=True)
        
        return outputs, text


In [73]:
b2_model.language_model.generate(do_sample = True, temperature=1)

tensor([[    2,  6209,    14,    10,   205,   425,    13,    10,  7297,  1280,
             9,   418,   116,  1437,    38, 10728,    33,   117,  1114,    99]],
       device='cuda:0')

In [16]:
image_test = images[1:20].permute(0,2,3,1)
#raw_image = Image.open('/fsx/proj-fmri/shared/controlNetData/target/img_t1.jpg').convert('RGB')
# Convert the image to a NumPy array
#image_test = np.array(raw_image)


In [17]:
"""import matplotlib.pyplot as plt
# Plotting one of the images (taking the first image as an example)
img_to_plot = inputs_rec['pixel_values'][-1]

# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])
img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')
print(img_to_plot.shape)

plt.imshow(img_to_plot)
plt.show()"""

"import matplotlib.pyplot as plt\n# Plotting one of the images (taking the first image as an example)\nimg_to_plot = inputs_rec['pixel_values'][-1]\n\n# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\nimg_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\nprint(img_to_plot.shape)\n\nplt.imshow(img_to_plot)\nplt.show()"

In [18]:
embeds_test, inputs_rec = embed_images_b2(image_test)

In [19]:
#inputs_rec['pixel_values'].shape

In [20]:
#out = b2_model.generate(**inputs_rec)
#print(b2_processor.decode(out[0], skip_special_tokens=True).strip())

In [21]:
outputs_test, text_test = embeds_to_captions_b2(embeds_test)



In [22]:
text_test

['a cat sitting on a toilet seat\n',
 'a person cutting a pizza on a cutting board\n',
 'a sandwich and a drink on a table\n',
 'a man crossing the street in front of a truck\n',
 'a giraffe standing in front of trees\n',
 'three men standing together\n',
 'a bird standing on a rock next to a body of water\n',
 'two men sitting on a street corner in asia\n',
 'a woman and two children playing tennis on a court\n',
 'a tall brick building with a clock on the side\n',
 'a train is on the tracks\n',
 'a man and woman in the water with a surfboard\n',
 'a living room with a desk and a chair\n',
 'a group of men on a basketball court\n',
 'a man holding an umbrella\n',
 'a man in a red shirt\n',
 'a group of people holding cell phones and wine glasses\n',
 'a laptop computer sitting on a table in front of a television\n',
 'a baseball player is swinging a bat on a field\n']

In [23]:
#inputss['pixel_values'].shape

In [24]:
#image_test.shape

In [25]:
max_lr = 1e-4

In [26]:
clip_seq_dim = 257 #blip2 image encoder shapes
clip_emb_dim = 1408 #blip2 image encoder shapes
hidden_dim = 2048

### SD VAE (blurry images)

In [40]:
from diffusers import AutoencoderKL
autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
autoenc.eval()
autoenc.requires_grad_(False)
autoenc.to(device)
utils.count_params(autoenc)

param counts:
83,653,863 total
0 trainable


### MindEye modules

In [41]:
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x
        
model = MindEyeModule()
model

MindEyeModule()

In [42]:
class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer
    def __init__(self, input_size, out_features): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linear = torch.nn.Linear(input_size, out_features)
    def forward(self, x):
        return self.linear(x)
        
model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
utils.count_params(model.ridge)
utils.count_params(model)

b = torch.randn((2,1,voxels.shape[1]))
print(b.shape, model.ridge(b).shape)

param counts:
32,215,040 total
32,215,040 trainable
param counts:
32,215,040 total
32,215,040 trainable
torch.Size([2, 1, 15729]) torch.Size([2, 1, 2048])


In [43]:
from functools import partial
from diffusers.models.vae import Decoder
class BrainNetwork(nn.Module):
    def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):
        super().__init__()
        self.blurry_dim = blurry_dim
        norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
        act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
        act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
        self.lin0 = nn.Linear(in_dim, h)
        self.mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(h, h),
                *[item() for item in act_and_norm],
                nn.Dropout(drop)
            ) for _ in range(n_blocks)
        ])
        self.lin1 = nn.Linear(h, out_dim, bias=True)
        # self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
        self.n_blocks = n_blocks
        self.clip_size = clip_size
        self.clip_proj = nn.Sequential(
            nn.LayerNorm(clip_size),
            nn.GELU(),
            nn.Linear(clip_size, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Linear(2048, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Linear(2048, clip_size)
        )
        # self.upsampler = Decoder(
        #         in_channels=64,
        #         out_channels=4,
        #         up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
        #         block_out_channels=[64, 128, 256],
        #         layers_per_block=1,
        #     )
        
    def forward(self, x):
        x = self.lin0(x)
        residual = x
        for res_block in range(self.n_blocks):
            x = self.mlp[res_block](x)
            x += residual
            residual = x
        x = x.reshape(len(x), -1)
        x = self.lin1(x)
        # b = self.blin1(x)
        # b = self.upsampler(b.reshape(len(b), -1, 7, 7))
        c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
        # return c, b
        return c

model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) 
utils.count_params(model.backbone)
utils.count_params(model)

b = torch.randn((4,hidden_dim))
print(b.shape)
clip_ = model.backbone(b)
print(clip_.shape)

param counts:
772,419,072 total
772,419,072 trainable
param counts:
804,634,112 total
804,634,112 trainable
torch.Size([4, 2048])
torch.Size([4, 257, 1408])


In [44]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
opt_grouped_parameters = [
    {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    total_steps=int(num_epochs*(num_train*num_devices//batch_size))
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )
    
def save_ckpt(tag):    
    ckpt_path = outdir+f'/{tag}.pth'
    print(f'saving {ckpt_path}',flush=True)
    unwrapped_model = accelerator.unwrap_model(model)
    try:
        torch.save({
            'epoch': epoch,
            'model_state_dict': unwrapped_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'train_losses': losses,
            'test_losses': test_losses,
            'lrs': lrs,
            }, ckpt_path)
    except:
        print("Couldn't save... moving on to prevent crashing.")
    del unwrapped_model
        
print("\nDone with model preparations!")
utils.count_params(model)


Done with model preparations!
param counts:
804,634,112 total
804,634,112 trainable


# Weights and Biases

In [32]:
# params for wandb
if local_rank==0 and True: # only use main process for wandb logging
    import wandb
    
    wandb_project = 'mindeyev2'
    wandb_run = model_name
    wandb_notes = ''
    
    print(f"wandb {wandb_project} run {wandb_run}")
    wandb.login(host='https://stability.wandb.io')#, relogin=True)
    wandb_config = {
      "model_name": model_name,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "use_image_aug": use_image_aug,
      "max_lr": max_lr,
      "lr_scheduler_type": lr_scheduler_type,
      "mixup_pct": mixup_pct,
      "num_train": num_train,
      "num_test": num_test,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
    }
    print("wandb_config:\n",wandb_config)
    if False: # wandb_auto_resume
        print("wandb_id:",model_name)
        wandb.init(
            id = model_name,
            project=wandb_project,
            name=wandb_run,
            config=wandb_config,
            notes=wandb_notes,
            resume="allow",
        )
    else:
        wandb.init(
            project=wandb_project,
            name=wandb_run,
            config=wandb_config,
            notes=wandb_notes,
        )
else:
    wandb_log = False

wandb mindeyev2 run captions


[34m[1mwandb[0m: Currently logged in as: [33mckadirt[0m. Use [1m`wandb login --relogin`[0m to force relogin


wandb_config:
 {'model_name': 'captions', 'batch_size': 128, 'num_epochs': 30, 'use_image_aug': False, 'max_lr': 0.0001, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1}


# More custom functions

In [34]:
# using the same preprocessing as was used in MindEye + BrainDiffuser
pixcorr_preprocess = transforms.Compose([
    transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])
def pixcorr(images,brains):
    # Flatten images while keeping the batch dimension
    all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
    all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
    corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
    return corrmean

# Main

In [51]:
epoch = 0
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

# Optionally resume from checkpoint #
if resume_from_ckpt:
    print("\n---resuming from last.pth ckpt---\n")
    try:
        checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
    except:
        print('last.pth failed... trying last_backup.pth')
        checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
    epoch = checkpoint['epoch']
    print("Epoch",epoch)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
    del checkpoint
elif wandb_log:
    if wandb.run.resumed:
        print("\n---resuming from last.pth ckpt---\n")
        try:
            checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
        except:
            print('last.pth failed... trying last_backup.pth')
            checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
        epoch = checkpoint['epoch']
        print("Epoch",epoch)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
        del checkpoint
torch.cuda.empty_cache()

In [36]:
checkpoint = torch.load('/fsx/proj-fmri/ckadirt/MindEyeV2/train_logs/caption_clip_0.5_bz/last.pth', map_location='cpu')

In [45]:
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [46]:
model

MindEyeModule(
  (ridge): RidgeRegression(
    (linear): Linear(in_features=15729, out_features=2048, bias=True)
  )
  (backbone): BrainNetwork(
    (lin0): Linear(in_features=2048, out_features=2048, bias=True)
    (mlp): ModuleList(
      (0-3): 4 x Sequential(
        (0): Linear(in_features=2048, out_features=2048, bias=True)
        (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (2): GELU(approximate='none')
        (3): Dropout(p=0.15, inplace=False)
      )
    )
    (lin1): Linear(in_features=2048, out_features=361856, bias=True)
    (clip_proj): Sequential(
      (0): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=1408, out_features=2048, bias=True)
      (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (4): GELU(approximate='none')
      (5): Linear(in_features=2048, out_features=2048, bias=True)
      (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    

In [47]:
model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
model, optimizer, train_dl, test_dl, lr_scheduler
)

In [None]:
"""transform = transforms.Compose(
            [
                transforms.Resize(
                    (224, 224),
                ),
                transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
            ]
        )

def tensor_2_embed(image):    
    image_for_blip2 = transform(image)
    
    #Generate embeddings
    with blip2_model.maybe_autocast():
        blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))
                
    return blip2_target

def embed_2_caption(image_embeds, model):
    image_embeds = image_embeds.float()
    image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                image.device)

    query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
    query_output = model.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True)

    inputs_t5 = model.t5_proj(query_output.last_hidden_state)
    atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
    prompt = model.prompt
    input_tokens = model.t5_tokenizer(
                prompt, padding="longest", return_tensors="pt"
            ).to(image.device)
    encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
    
    with model.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = model.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts)
            output_text = model.t5_tokenizer.batch_decode(
                outputs, skip_special_tokens=True)
    
    return output_text"""

In [48]:
wandb_log = False

In [49]:
predicted_embeddings = None

In [52]:
print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
test_image, test_voxel = None, None
mse = nn.MSELoss()
for epoch in progress_bar:
    model.train()
    
    fwd_percent_correct = 0.
    bwd_percent_correct = 0.
    test_fwd_percent_correct = 0.
    test_bwd_percent_correct = 0.

    loss_clip_total = 0.
    loss_blurry_total = 0.
    test_loss_clip_total = 0.
    test_loss_blurry_total = 0.

    blurry_pixcorr = 0.
    test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
    
    """for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
        if epoch == 0:
            lrs.append(0)
            break
        with torch.cuda.amp.autocast():
            optimizer.zero_grad()

            voxel = voxels[behav[:,0,5].cpu().long()].to(device)
            
            image = images[behav[:,0,0].cpu().long()].to(device).float()

            # blurry_image_enc = autoenc.encode(image).latent_dist.mode()
            
            if use_image_aug: image = img_augment(image)
            # clip_target = clip_model.embed_image(image)
            clip_target = embed_images_b2(image)[0].to(device) #####CHANGED
            assert not torch.any(torch.isnan(clip_target))
  
            if epoch < int(mixup_pct * num_epochs):
                voxel, perm, betas, select = utils.mixco(voxel)

            voxel_ridge = model.ridge(voxel)
            
            # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
            clip_voxels = model.backbone(voxel_ridge)
            
            clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
            clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)

            if epoch < int(mixup_pct * num_epochs):                
                loss_clip = utils.mixco_nce(
                     clip_voxels_norm,
                     clip_target_norm,
                     temp=.006, 
                     perm=perm, betas=betas, select=select)
            else:
                epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                loss_clip = utils.soft_clip_loss(
                     clip_voxels_norm,
                     clip_target_norm,
                     temp=epoch_temp)
            
            loss_mse= mse(clip_voxels, clip_target)

            # loss_blurry = mse(blurry_image_enc_, blurry_image_enc) 

            loss_clip_total += loss_clip.item()
            # loss_blurry_total += loss_blurry.item()

            # loss = loss_blurry + loss_clip
            loss = 0.7 * loss_clip + 0.3 * loss_mse
            if (train_i % 10 == 0):
                print(train_i, loss)
            # print(batch_size)
            utils.check_loss(loss)
            accelerator.backward(loss)
            optimizer.step()
    
            losses.append(loss.item())
            lrs.append(optimizer.param_groups[0]['lr'])
    
            # forward and backward top 1 accuracy        
            labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) 
            fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
            bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)

            # with torch.no_grad():
            #     # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
            #     random_samps = np.random.choice(np.arange(len(voxel)), size=8, replace=False)
            #     blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
            #     blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)

            if lr_scheduler_type is not None:
                lr_scheduler.step()"""
    
    model.eval()
    for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
        with torch.cuda.amp.autocast():
            with torch.no_grad():   
                # all test samples should be loaded per batch such that test_i should never exceed 0
                if len(behav) != num_test: print("!",len(behav),num_test)
                
                ## Average same-image repeats ##
                if test_image is None:
                    voxel = voxels[behav[:,0,5].cpu().long()].to(device)
                    
                    image = behav[:,0,0].cpu().long()
                    
                    unique_image, sort_indices = torch.unique(image, return_inverse=True)
                    for im in unique_image:
                        locs = torch.where(im == image)[0]
                        if test_image is None:
                            test_image = images[im][None]
                            test_voxel = torch.mean(voxel[locs],axis=0)[None]
                        else:
                            test_image = torch.vstack((test_image, images[im][None]))
                            test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
    
                # sample of batch_size
                random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
                voxel = test_voxel[random_indices].to(device)
                image = test_image[random_indices].to(device)
                assert len(image) == batch_size
    
                # blurry_image_enc = autoenc.encode(image).latent_dist.mode()
        
                # clip_target = clip_model.embed_image(image.float())
                clip_target = embed_images_b2(image)[0].to(device) #####CHANGED
    
                voxel_ridge = model.ridge(voxel)
                
                # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
                clip_voxels = model.backbone(voxel_ridge)
                
                predicted_embeddings = clip_voxels
                break
                
                clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
                clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
        
                # loss_clip = utils.soft_clip_loss(
                #     clip_voxels_norm,
                #     clip_target_norm,
                #     temp=.006)
                
                loss_clip = mse(clip_voxels, clip_target)

                # loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
                
                # loss = loss_blurry + loss_clip
                loss = loss_clip
                
                utils.check_loss(loss)
        
                test_losses.append(loss.item())
        
                # forward and backward top 1 accuracy        
                labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) 
                test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
                test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)

                # # halving the batch size because the decoder is computationally heavy
                # blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
                # blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
                # test_blurry_pixcorr += pixcorr(image, blurry_recon_images)

                #Find captions and print next to images
                #caption1 = embed_2_caption(clip_voxels[[0]], blip2_model)
                #caption2 = embed_2_caption(clip_voxels[[1]], blip2_model)

                #true_embed1 = tensor_2_embed(image[[0]])
                #true_embed2 = tensor_2_embed(image[[1]])

                # print(clip_voxels[[0]].shape)
                # print(true_embed1.shape)
                
                #true_caption1 = embed_2_caption(true_embed1, blip2_model)
                #true_caption2 = embed_2_caption(true_embed2, blip2_model)
                
                # transform blurry recon latents to images and plot it
                #fig, axes = plt.subplots(2, 2, figsize=(8, 4))
                #axes[0,0].imshow(utils.torch_to_Image(image[[0]]))
                #axes[0,1].imshow(utils.torch_to_Image(image[[1]]))
                #axes[0,0].axis('off'); axes[0,1].axis('off'); axes[1,0].axis('off'); axes[1,1].axis('off')
                #axes[0,0].set_title(caption1)
                #axes[0,1].set_title(caption2)
                #axes[1,0].set_title(true_caption1)
                #axes[1,1].set_title(true_caption2)

                #plt.show()
                
                # # transform blurry recon latents to images and plot it
                # fig, axes = plt.subplots(1, 4, figsize=(8, 4))
                # axes[0].imshow(utils.torch_to_Image(image[[0]]))
                # axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
                # axes[2].imshow(utils.torch_to_Image(image[[1]]))
                # axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
                # axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
                # axes[0].set_title(caption1)
                # axes[3].set_title(caption2)
                # plt.show()
                
    break
    if local_rank==0:      
        # if utils.is_interactive(): clear_output(wait=True)
        assert (test_i+1) == 1
        logs = {"train/loss": np.mean(losses[-(train_i+1):]),
            "test/loss": np.mean(test_losses[-(test_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": len(losses),
            "test/num_steps": len(test_losses),
            "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
            "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
            "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
            "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
            "train/loss_clip_total": loss_clip_total / (train_i + 1),
            "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
            "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
            "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
            "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
            "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
            }
        progress_bar.set_postfix(**logs)
                          
        fig, axes = plt.subplots(1, 8, figsize=(10, 4))
        jj=-1
        for j in [0,1,2,3,4,5,6,7]:
            jj+=1
            axes[jj].imshow(utils.torch_to_Image(image[j]))
            axes[jj].axis('off')

        if wandb_log:
            generated_captions = embeds_to_captions_b2(clip_voxels[0:8])
            print(generated_captions[1])
            logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}" + "\n".join(generated_captions[1]))
            plt.close()
        # Save model checkpoint and reconstruct
        if epoch % ckpt_interval == 0:
            if not utils.is_interactive():
                save_ckpt(f'last')
                
        if wandb_log: wandb.log(logs)

    # wait for other GPUs to catch up if needed
    accelerator.wait_for_everyone()
    torch.cuda.empty_cache()
    gc.collect()

print("\n===Finished!===\n")
if ckpt_saving:
    save_ckpt(f'last')
if not utils.is_interactive():
    sys.exit(0)

captions starting with epoch 0 / 30


  0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   


===Finished!===

saving /fsx/proj-fmri/ckadirt/MindEyeV2/train_logs/captions/last.pth





In [54]:
predicted_embeddings.shape

torch.Size([128, 257, 1408])

In [55]:
generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8])
print(generated_captions[1])



['a group of people are sitting around a table\n', 'a man is holding a glass of water in front of a television\n', 'a man is riding a skateboard on a hill\n', 'a group of people standing around a bike\n', 'a building with a sign that says "the house"\n', 'a plate of food with vegetables and meat\n', 'a white cup with a small bottle of wine\n', 'a group of people playing baseball and one is holding a ball\n']


In [75]:
generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8], sample = True, temp = 0.3)
print(generated_captions[1])

['a group of people are sitting at a table with food and drinks\n', 'a man in a kitchen with a large screen\n', 'a man on a surfboard with his legs in the air\n', 'a group of people are standing on the beach in front of a boat\n', 'a building with a sign that says "home of the person"\n', 'a vegetable salad with a variety of vegetables and other ingredients\n', 'a white cup with a small amount of coffee and a bottle of wine\n', 'a group of people playing baseball and soccer\n']


In [95]:
def concatenate_lists_any_depth(list1, list2):
    """
    Concatenates two lists of potentially varying depths, forming a new list of lists.

    Args:
    list1 (list): The first list to concatenate. Elements can be of any type.
    list2 (list): The second list to concatenate. Elements can be of any type.

    Returns:
    list: A new list containing lists of elements from the original lists.
    """
    # Ensure that both lists have the same length
    if len(list1) != len(list2):
        raise ValueError("Lists must be of the same length")

    concatenated_list = []

    for a, b in zip(list1, list2):
        # If the elements are not lists, convert them to lists
        if not isinstance(a, list):
            a = [a]
        if not isinstance(b, list):
            b = [b]

        # Concatenate the lists
        concatenated_list.append(a + b)

    return concatenated_list

In [96]:
def sample_several(embeddings, num=10, temp=0.3):
    # embeddings shape = batch, 257, 1408
    results = None  # Initialize results as None

    for i in range(num):  # Iterate from 0 to num-1
        if results is None:
            # For the first iteration, assign the results directly
            results = embeds_to_captions_b2(embeddings, sample=True, temp=temp)[1]
        else:
            # For subsequent iterations, combine the new results with the existing ones
            new_results = embeds_to_captions_b2(embeddings, sample=True, temp=temp)[1]
            results = concatenate_lists_any_depth(results, new_results)

    return results  # Return the combined results


In [77]:
generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8], sample = True, temp = 0.3)
print(generated_captions[1])

['a group of people sitting on a bench in front of a building\n', 'a woman is using a computer to make a video\n', 'a man in a black shirt is sitting on a surfboard\n', 'a group of people on the beach with a bike and some other things\n', 'a large building with a sign that says "the old farmhouse"\n', 'a plate with many different types of vegetables\n', 'a white cup with a bottle of wine and a small bottle of wine\n', 'a group of people are playing baseball in a field\n']


In [99]:
several = sample_several(predicted_embeddings[0:8], num = 12, temp = 0.5)
several

[['people are sitting at a table with a bunch of chairs\n',
  'several people in the yard with some food\n',
  'people sitting on a bench near a water fountain\n',
  'a group of people are sitting around a table\n',
  'a group of people in a room with several people in the foreground\n',
  'a group of people sitting around a table with food\n',
  'the people in the background are sitting on the edge of a table\n',
  'beverages and food are served at a family picnic\n',
  'a group of people eating in a restaurant\n',
  'a group of people sitting around a table\n',
  'people are sitting at a table next to a tree\n',
  'people are sitting around a table with a lot of food\n'],
 ['a person is holding a newspaper in a restaurant\n',
  'the man is holding a cup of coffee in front of a television\n',
  'a woman is preparing to cook in a kitchen\n',
  'a man working in an office setting with a computer and a man in a chair\n',
  'a person is using a smartphone in a restaurant\n',
  'a man is h

In [100]:
several = sample_several(predicted_embeddings[0:8], num = 12, temp = 0.3)
several

[['a group of people are sitting around a table\n',
  'a group of people are sitting around a table with food\n',
  'a group of people sitting at a table with food\n',
  'a group of people are sitting on the ground in front of a table\n',
  'a group of people sitting around a table with a person and a dog\n',
  'a group of people are sitting on the ground and eating\n',
  'the group is sitting around a table with food\n',
  'people are sitting around a table with food\n',
  'a group of people sitting around a table with food\n',
  'the people are eating in front of a table\n',
  'a group of people are sitting on a bench in a field\n',
  'a group of people are sitting on a bench\n'],
 ['a man is using a computer and a phone\n',
  'a person in a kitchen with a large screen\n',
  'a man is preparing food in a kitchen\n',
  'a man is standing in front of a computer and a woman is sitting behind him\n',
  'a man is using a computer to play a game\n',
  'a man is using a computer to play a g

In [None]:
plt.plot(losses)
plt.show()
plt.plot(test_losses)
plt.show()

In [None]:
print('test')

In [None]:
def tensor_2_embed_old(tensor):
    embed_array = torch.zeros((tensor.shape[0],257, 1024)) 
    to_pil = ToPILImage()
    for sample in range(tensor.shape[0]):
        PIL_image = to_pil(tensor[sample])
        image_for_blip2 = vis_processors["eval"](PIL_image).unsqueeze(0).to(device)
        #Generate embeddings
        with blip2_model.maybe_autocast():
                blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))
                embed_array[sample] = blip2_target
                
    return embed_array