{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[NbConvertApp] Converting notebook Train_MLPMixer-img.ipynb to python\n", "[NbConvertApp] Writing 53671 bytes to Train_MLPMixer-img.py\n" ] }, { "data": { "text/plain": [ "0" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", "from subprocess import call\n", "command = \"jupyter nbconvert Train_MLPMixer-img.ipynb --to python\"\n", "call(command,shell=True)" ] }, { "cell_type": "markdown", "id": "b0f0f4f3", "metadata": {}, "source": [ "# Import packages & functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bad764b-45c1-45ce-a716-8d055e09821a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-11-06 01:14:35,497] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "\n", "import os\n", "import sys\n", "import json\n", "import argparse\n", "import numpy as np\n", "import math\n", "from einops import rearrange\n", "import time\n", "import random\n", "import string\n", "import h5py\n", "from tqdm import tqdm\n", "\n", "import webdataset as wds\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms\n", "\n", "from accelerate import Accelerator, DeepSpeedPlugin\n", "\n", "# tf32 data type is faster than standard float32\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", "# custom functions #\n", "import utils\n", "\n", "global_batch_size = 16 #128\n", "\n", "import os\n", "os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "b8da928c-ea86-4959-86ee-dde55c07e58f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCAL RANK 0\n" ] } ], "source": [ "### Multi-GPU config ###\n", "local_rank = os.getenv('RANK')\n", "if local_rank is None: \n", " local_rank = 0\n", "else:\n", " local_rank = int(local_rank)\n", "print(\"LOCAL RANK \", local_rank) \n", "\n", "num_devices = torch.cuda.device_count()\n", "if num_devices==0: num_devices = 1\n", "\n", "accelerator = Accelerator(split_batches=False)\n", "\n", "### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above \"accelerator = \" line) ###\n", "\n", "# if num_devices <= 1 and utils.is_interactive():\n", "# # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", "# os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", "# os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", "# os.environ[\"RANK\"] = \"0\"\n", "# os.environ[\"LOCAL_RANK\"] = \"0\"\n", "# os.environ[\"WORLD_SIZE\"] = \"1\"\n", "# os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", "# global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n", "\n", "# # alter the deepspeed config according to your global and local batch size\n", "# if local_rank == 0:\n", "# with open('deepspeed_config_stage2.json', 'r') as file:\n", "# config = json.load(file)\n", "# config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", "# config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", "# with open('deepspeed_config_stage2.json', 'w') as file:\n", "# json.dump(config, file)\n", "# else:\n", "# # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", "# time.sleep(10)\n", "# deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", "# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PID of this process = 1849460\n", "device: cuda\n", "Distributed environment: NO\n", "Num processes: 1\n", "Process index: 0\n", "Local process index: 0\n", "Device: cuda\n", "\n", "Mixed precision type: no\n", "\n", "distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float32\n" ] } ], "source": [ "print(\"PID of this process =\",os.getpid())\n", "device = accelerator.device\n", "print(\"device:\",device)\n", "num_workers = num_devices\n", "print(accelerator.state)\n", "world_size = accelerator.state.num_processes\n", "distributed = not accelerator.state.distributed_type == 'NO'\n", "\n", "# set data_type to match your mixed precision (automatically set based on deepspeed config)\n", "if accelerator.mixed_precision == \"bf16\":\n", " data_type = torch.bfloat16\n", "elif accelerator.mixed_precision == \"fp16\":\n", " data_type = torch.float16\n", "else:\n", " data_type = torch.float32\n", "\n", "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size, \"data_type =\", data_type)\n", "print = accelerator.print # only print if local_rank=0" ] }, { "cell_type": "code", "execution_count": 5, "id": "a7b0548c-fc95-43d9-94a7-d3e53176c736", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accelerator.state.distributed_type" ] }, { "cell_type": "markdown", "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", "metadata": { "tags": [] }, "source": [ "# Configurations" ] }, { "cell_type": "code", "execution_count": 6, "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model_name: mPo0n5r22f_interactive\n", "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=mPo0n5r22f_interactive', '--subj=1', '--batch_size=16', '--no-blurry_recon', '--no-depth_recon', '--hidden_dim=1024', '--clip_scale=1.', '--blur_scale=100.', '--depth_scale=100.', '--max_lr=3e-4', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug', '--no-ckpt_saving']\n" ] } ], "source": [ "# if running this interactively, can specify jupyter_args here for argparser to use\n", "if utils.is_interactive():\n", " # create random model_name\n", " model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))\n", " model_name = model_name + \"_interactive\"\n", " print(\"model_name:\", model_name)\n", "\n", " # global_batch_size and batch_size should already be defined in the above cells\n", " # other variables can be specified in the following string:\n", " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", " --model_name={model_name} \\\n", " --subj=1 --batch_size={global_batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=1024 \\\n", " --clip_scale=1. --blur_scale=100. --depth_scale=100. \\\n", " --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving\"\n", "\n", " jupyter_args = jupyter_args.split()\n", " print(jupyter_args)\n", " \n", " from IPython.display import clear_output # function to clear print outputs in cell\n", " %load_ext autoreload \n", " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", " %autoreload 2 " ] }, { "cell_type": "code", "execution_count": 7, "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", "metadata": { "tags": [] }, "outputs": [], "source": [ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", "parser.add_argument(\n", " \"--model_name\", type=str, default=\"testing\",\n", " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", ")\n", "parser.add_argument(\n", " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", " help=\"Path to where NSD data is stored / where to download it to\",\n", ")\n", "parser.add_argument(\n", " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", ")\n", "parser.add_argument(\n", " \"--batch_size\", type=int, default=32,\n", " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to log to wandb\",\n", ")\n", "parser.add_argument(\n", " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"if not using wandb and want to resume from a ckpt\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_project\",type=str,default=\"stability\",\n", " help=\"wandb project name\",\n", ")\n", "parser.add_argument(\n", " \"--mixup_pct\",type=float,default=.33,\n", " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", ")\n", "parser.add_argument(\n", " \"--blurry_recon\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to output blurry reconstructions\",\n", ")\n", "parser.add_argument(\n", " \"--depth_recon\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to output depth reconstructions\",\n", ")\n", "parser.add_argument(\n", " \"--blur_scale\",type=float,default=100.,\n", " help=\"multiply loss from blurry recons by this number\",\n", ")\n", "parser.add_argument(\n", " \"--depth_scale\",type=float,default=100.,\n", " help=\"multiply loss from depth recons by this number\",\n", ")\n", "parser.add_argument(\n", " \"--clip_scale\",type=float,default=1.,\n", " help=\"multiply contrastive loss by this number\",\n", ")\n", "parser.add_argument(\n", " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to use image augmentation\",\n", ")\n", "parser.add_argument(\n", " \"--num_epochs\",type=int,default=120,\n", " help=\"number of epochs of training\",\n", ")\n", "parser.add_argument(\n", " \"--hidden_dim\",type=int,default=4096,\n", ")\n", "parser.add_argument(\n", " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", ")\n", "parser.add_argument(\n", " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", ")\n", "parser.add_argument(\n", " \"--ckpt_interval\",type=int,default=5,\n", " help=\"save backup ckpt and reconstruct every x epochs\",\n", ")\n", "parser.add_argument(\n", " \"--seed\",type=int,default=42,\n", ")\n", "parser.add_argument(\n", " \"--max_lr\",type=float,default=3e-4,\n", ")\n", "parser.add_argument(\n", " \"--seq_len\",type=int,default=2,\n", ")\n", "\n", "if utils.is_interactive():\n", " args = parser.parse_args(jupyter_args)\n", "else:\n", " args = parser.parse_args()\n", "\n", "# create global variables without the args prefix\n", "for attribute_name in vars(args).keys():\n", " globals()[attribute_name] = getattr(args, attribute_name)" ] }, { "cell_type": "code", "execution_count": 8, "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", "if not os.path.exists(outdir) and ckpt_saving:\n", " os.makedirs(outdir,exist_ok=True)\n", "if use_image_aug:\n", " import kornia\n", " from kornia.augmentation.container import AugmentationSequential\n", " img_augment = AugmentationSequential(\n", " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", " kornia.augmentation.Resize((224, 224)),\n", " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", " kornia.augmentation.RandomGrayscale(p=0.3),\n", " same_on_batch=False,\n", " data_keys=[\"input\"],\n", " )" ] }, { "cell_type": "markdown", "id": "42d13c25-1369-4c49-81d4-83d713586096", "metadata": { "tags": [] }, "source": [ "# Prep data, models, and dataloaders" ] }, { "cell_type": "markdown", "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", "metadata": {}, "source": [ "## Dataloader" ] }, { "cell_type": "code", "execution_count": 9, "id": "81084834-035f-4465-ad59-59e6b806a2f5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" ] } ], "source": [ "if subj==1:\n", " num_train = 24958\n", " num_test = 2770\n", "test_batch_size = num_test\n", "\n", "def my_split_by_node(urls): return urls\n", " \n", "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", "# train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..1}.tar\"\n", "print(train_url)\n", "\n", "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)\n", "\n", "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", "print(test_url)\n", "\n", "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)" ] }, { "cell_type": "markdown", "id": "203b060a-2dd2-4c35-929b-c576be82eb52", "metadata": {}, "source": [ "### check dataloaders are working" ] }, { "cell_type": "code", "execution_count": 10, "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2770 2770\n", "---\n", "\n", "1560 24976 24976\n" ] } ], "source": [ "test_vox_indices = []\n", "test_73k_images = []\n", "for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", " test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())\n", " test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())\n", "test_vox_indices = test_vox_indices.astype(np.int16)\n", "print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))\n", "print(\"---\\n\")\n", "\n", "train_vox_indices = []\n", "train_73k_images = []\n", "for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())\n", " train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())\n", "train_vox_indices = train_vox_indices.astype(np.int16)\n", "print(train_i, (train_i+1) * batch_size, len(train_vox_indices))" ] }, { "cell_type": "markdown", "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", "metadata": {}, "source": [ "## Load data and images" ] }, { "cell_type": "code", "execution_count": 11, "id": "039dd330-7339-4f88-8f00-45f95e47baa0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subj01 betas loaded into memory\n", "voxels torch.Size([27750, 15724])\n", "images torch.Size([73000, 3, 224, 224])\n" ] } ], "source": [ "# load betas\n", "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", "# f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')\n", "\n", "voxels = f['betas'][:]\n", "print(f\"subj0{subj} betas loaded into memory\")\n", "voxels = torch.Tensor(voxels).to(\"cpu\").to(data_type)\n", "print(\"voxels\", voxels.shape)\n", "num_voxels = voxels.shape[-1]\n", "\n", "# load orig images\n", "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", "images = f['images'][:]\n", "images = torch.Tensor(images).to(\"cpu\").to(data_type)\n", "print(\"images\", images.shape)" ] }, { "cell_type": "markdown", "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", "metadata": {}, "source": [ "## Load models" ] }, { "cell_type": "markdown", "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", "metadata": {}, "source": [ "### CLIP image embeddings model" ] }, { "cell_type": "code", "execution_count": 12, "id": "b0420dc0-199e-4c1a-857d-b1747058b467", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "from models import Clipper\n", "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", "clip_seq_dim = 257\n", "clip_emb_dim = 768 #1024\n", "# hidden_dim = 4096\n", "#seq_len = 1 #2 #32 " ] }, { "cell_type": "code", "execution_count": 13, "id": "e1c6467c-fd57-4a4c-b5bb-f13fcb36d5b7", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "clip_model2 = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=False, norm_embs=True)" ] }, { "cell_type": "code", "execution_count": 14, "id": "50211450-8b0c-4a7a-87c4-64c109a4242b", "metadata": { "tags": [] }, "outputs": [], "source": [ "#out2t = clip_model2.embed_image(torch.randn(32,3,224,224))" ] }, { "cell_type": "code", "execution_count": 15, "id": "d7c01109-4bb9-4763-9477-460beeb6a949", "metadata": { "tags": [] }, "outputs": [], "source": [ "#out2t.shape" ] }, { "cell_type": "markdown", "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", "metadata": {}, "source": [ "### SD VAE" ] }, { "cell_type": "code", "execution_count": 16, "id": "01baff79-8114-482b-b115-6f05aa8ad691", "metadata": { "tags": [] }, "outputs": [], "source": [ "# if blurry_recon:\n", "# from diffusers import AutoencoderKL\n", "# autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", "# # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", "# autoenc.eval()\n", "# autoenc.requires_grad_(False)\n", "# autoenc.to(device)\n", "# utils.count_params(autoenc)\n", "\n", "if blurry_recon:# or depth_recon:\n", " from diffusers import VQModel\n", " autoenc = VQModel.from_pretrained(\"/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae\", torch_dtype=data_type)\n", " autoenc.eval()\n", " autoenc.requires_grad_(False)\n", " autoenc.to(device)\n", " utils.count_params(autoenc)" ] }, { "cell_type": "markdown", "id": "120c8eee-9834-437d-bb60-b38faef50138", "metadata": {}, "source": [ "#### downsampled images" ] }, { "cell_type": "code", "execution_count": 17, "id": "6d1ba8dd-64c2-4ac9-947e-725b7f2e3e50", "metadata": { "tags": [] }, "outputs": [], "source": [ "if blurry_recon:\n", " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", "\n", " input_batch = images[[30]].to(device)\n", " print(input_batch.shape)\n", "\n", " downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " print(re_upsampled_enc.shape)\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" ] }, { "cell_type": "markdown", "id": "6390a3a8-2bef-4e81-9b82-e154d26a1e1d", "metadata": {}, "source": [ "#### MiDaS depth" ] }, { "cell_type": "code", "execution_count": 18, "id": "f35573e2-95bf-463d-8937-68ad4c2c3c20", "metadata": { "tags": [] }, "outputs": [], "source": [ "if depth_recon:\n", " from controlnet_aux.midas import MidasDetector\n", " \n", " midas_depth = MidasDetector.from_pretrained(\n", " \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\", cache_dir=\"/fsx/proj-fmri/shared/cache\").to(device)\n", " midas_depth.model.eval()\n", " midas_depth.model.requires_grad_(False)\n", " midas_depth.model.to(device)\n", " pass" ] }, { "cell_type": "code", "execution_count": 19, "id": "ba3f9207-b98e-45da-baa6-5cfcfb2ae958", "metadata": { "tags": [] }, "outputs": [], "source": [ "if depth_recon:\n", " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", "\n", " input_batch = images[[30,31]].float().to(device)\n", " print(input_batch.shape)\n", " \n", " midas_emb = midas_depth.model(input_batch).unsqueeze(1)\n", " print(midas_emb.shape)\n", "\n", " prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()\n", " print(prediction.shape)\n", " \n", " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", " midas_emb_size = prediction.flatten(1).shape[1]\n", " print(\"midas_emb\", prediction.shape, prediction.min(), prediction.max())\n", " print(\"midas_emb_size\", midas_emb_size)\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) \n", "\n", " if blurry_recon:\n", " prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)\n", " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", " prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215\n", " print(\"vae midas_emb\", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" ] }, { "cell_type": "markdown", "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", "metadata": {}, "source": [ "### MindEye modules" ] }, { "cell_type": "code", "execution_count": 20, "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MindEyeModule()" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MindEyeModule(nn.Module):\n", " def __init__(self):\n", " super(MindEyeModule, self).__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "model = MindEyeModule()\n", "model" ] }, { "cell_type": "code", "execution_count": 21, "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "16,626,688 total\n", "16,626,688 trainable\n", "param counts:\n", "16,626,688 total\n", "16,626,688 trainable\n", "torch.Size([2, 1, 15724]) torch.Size([2, 1, 1024])\n" ] } ], "source": [ "time_embedding_dim = 512\n", "\n", "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.out_features = out_features\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)\n", "utils.count_params(model.ridge)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,1,voxels.shape[1]))\n", "time_emb_test = torch.randn((2,1,time_embedding_dim))\n", "print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)" ] }, { "cell_type": "code", "execution_count": 22, "id": "ee763515-4853-4c5b-9265-ec5aa1bde971", "metadata": { "tags": [] }, "outputs": [], "source": [ "num_past_voxels = 15\n", "#seq_len = 1 + 1" ] }, { "cell_type": "code", "execution_count": 23, "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "740,666,432 total\n", "740,666,432 trainable\n", "param counts:\n", "757,300,800 total\n", "757,300,800 trainable\n", "b.shape torch.Size([1, 2, 1792])\n", "torch.Size([1, 257, 768]) torch.Size([1]) torch.Size([1])\n" ] } ], "source": [ "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):\n", " super().__init__()\n", " self.seq_len = seq_len\n", " self.h = h\n", " self.clip_size = clip_size\n", " \n", " # Initial linear layer to match the input dimensions to hidden dimensions\n", " # self.lin0 = nn.Linear(in_dim, seq_len * h)\n", " \n", " # Mixer Blocks\n", " self.mixer_blocks1 = nn.ModuleList([\n", " self.mixer_block1(h, drop) for _ in range(n_blocks)\n", " ])\n", " self.mixer_blocks2 = nn.ModuleList([\n", " self.mixer_block2(seq_len, drop) for _ in range(n_blocks)\n", " ])\n", " \n", " # Output linear layer\n", " self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)\n", "\n", " # low-rank matrices\n", " # self.rank = 500\n", " # self.U = nn.Parameter(torch.randn(self.rank, out_dim))\n", " # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))\n", " # self.S = nn.Parameter(torch.randn(out_dim))\n", "\n", " self.clip_proj = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", "\n", " if blurry_recon:\n", " # self.blin1 = nn.Sequential(\n", " # nn.Linear(out_dim, 4096, bias=True),\n", " # nn.LayerNorm(4096),\n", " # nn.GELU(),\n", " # nn.Linear(4096, 4096))\n", " self.blin1 = nn.Linear(h*seq_len, 4096)\n", " self.bgroupnorm = nn.GroupNorm(1, 256)\n", " self.bupsampler = Decoder(\n", " in_channels=256,\n", " out_channels=128,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[32, 64, 128],\n", " layers_per_block=1,\n", " )\n", "\n", " if depth_recon:\n", " # self.dlin1 = nn.Sequential(\n", " # nn.Linear(h, midas_emb_size),\n", " # nn.Sigmoid(),\n", " # )\n", " self.dlin1 = nn.Linear(h*seq_len, 4096)\n", " self.dgroupnorm = nn.GroupNorm(1, 256)\n", " self.dupsampler = Decoder(\n", " in_channels=256,\n", " out_channels=1,#128,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[32, 64, 128, 256],\n", " layers_per_block=1,\n", " )\n", " \n", " def mixer_block1(self, h, drop):\n", " return nn.Sequential(\n", " nn.LayerNorm(h),\n", " self.mlp(h, h, drop), # Token mixing\n", " )\n", "\n", " def mixer_block2(self, seq_len, drop):\n", " return nn.Sequential(\n", " nn.LayerNorm(seq_len),\n", " self.mlp(seq_len, seq_len, drop) # Channel mixing\n", " )\n", " \n", " def mlp(self, in_dim, out_dim, drop):\n", " return nn.Sequential(\n", " nn.Linear(in_dim, out_dim),\n", " nn.GELU(),\n", " nn.Dropout(drop),\n", " nn.Linear(out_dim, out_dim),\n", " )\n", " \n", " def forward(self, x):\n", " # make empty tensors for blur and depth outputs\n", " b,d = torch.Tensor([0.]), torch.Tensor([0.])\n", " \n", " # Initial linear layer\n", " # x = self.lin0(x)\n", " \n", " # Reshape to seq_len by dim\n", " # x = x.reshape(-1, self.seq_len, self.h)\n", " \n", " # Mixer blocks\n", " residual1 = x\n", " residual2 = x.permute(0,2,1)\n", " for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):\n", " x = block1(x) + residual1\n", " residual1 = x\n", " x = x.permute(0,2,1)\n", " \n", " x = block2(x) + residual2\n", " residual2 = x\n", " x = x.permute(0,2,1)\n", " \n", " # Flatten\n", " x = x.reshape(x.size(0), -1)\n", " \n", " c = self.clin1(x)\n", "\n", " # low rank linear to out dim cuts # params by nearly half compared to full linear mapping\n", " # c = (x @ (self.V/100) @ (self.U/100)) + self.S\n", " \n", " c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))\n", "\n", " if blurry_recon:\n", " b = self.blin1(x)\n", " b = b.reshape(len(b), 256, 4, 4)\n", " b = self.bgroupnorm(b)\n", " b = self.bupsampler(b)\n", " \n", " if depth_recon:\n", " d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)\n", " d = d.reshape(len(d), 256, 4, 4)\n", " d = self.dgroupnorm(d)\n", " d = self.dupsampler(d)\n", " \n", " return c, b, d\n", "\n", "\n", "class TimeEmbedding(nn.Module):\n", " def __init__(self, embedding_time_dim=512, num_past_voxels=15):\n", " super().__init__()\n", " self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)\n", " self.num_past_voxels = num_past_voxels\n", " self.embedding_time_dim = embedding_time_dim\n", "\n", " def forward(self, time):\n", " # time is (batch_size,)\n", " time = time.long()\n", " time = self.embedding_time(time)\n", " return time # (batch_size, embedding_time_dim)\n", " \n", "\n", "#model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)\n", "model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)\n", "\n", "model.backbone = BrainNetwork(h=hidden_dim + clip_emb_dim, in_dim=hidden_dim + clip_emb_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) \n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "# test that the model works on some fake data\n", "b = torch.randn((1,seq_len,hidden_dim + clip_emb_dim))\n", "print(\"b.shape\",b.shape)\n", "with torch.no_grad():\n", " clip_, blur_, depth_ = model.backbone(b)\n", "print(clip_.shape, blur_.shape, depth_.shape)\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "fa7aa4ab-64af-45d4-8ed8-259043c16c29", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'\\nvoxel_ridge = torch.randn(512,4096)\\nvoxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)\\nprint(\"b.shape\",voxel_ridge.shape)\\nwith torch.no_grad():\\n clip_, blur_, depth_ = model.backbone(voxel_ridge)\\nprint(clip_.shape, blur_.shape, depth_.shape)'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "voxel_ridge = torch.randn(512,4096)\n", "voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)\n", "print(\"b.shape\",voxel_ridge.shape)\n", "with torch.no_grad():\n", " clip_, blur_, depth_ = model.backbone(voxel_ridge)\n", "print(clip_.shape, blur_.shape, depth_.shape)\"\"\"" ] }, { "cell_type": "code", "execution_count": 25, "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total_steps 18718\n", "\n", "Done with model preparations!\n", "param counts:\n", "757,300,800 total\n", "757,300,800 trainable\n" ] } ], "source": [ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", "opt_grouped_parameters = [\n", " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", "]\n", "\n", "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n", "\n", "if lr_scheduler_type == 'linear':\n", " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", " optimizer,\n", " total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),\n", " last_epoch=-1\n", " )\n", "elif lr_scheduler_type == 'cycle':\n", " total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))\n", " print(\"total_steps\", total_steps)\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer, \n", " max_lr=max_lr,\n", " total_steps=total_steps,\n", " final_div_factor=1000,\n", " last_epoch=-1, pct_start=2/num_epochs\n", " )\n", " \n", "def save_ckpt(tag): \n", " ckpt_path = outdir+f'/{tag}.pth'\n", " print(f'saving {ckpt_path}',flush=True)\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " try:\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': unwrapped_model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'lr_scheduler': lr_scheduler.state_dict(),\n", " 'train_losses': losses,\n", " 'test_losses': test_losses,\n", " 'lrs': lrs,\n", " }, ckpt_path)\n", " except:\n", " print(\"Couldn't save... moving on to prevent crashing.\")\n", " del unwrapped_model\n", " \n", "print(\"\\nDone with model preparations!\")\n", "utils.count_params(model)" ] }, { "cell_type": "code", "execution_count": 26, "id": "5d600d8d-6a0d-4584-840e-c89521ff6364", "metadata": {}, "outputs": [], "source": [ "#nn++" ] }, { "cell_type": "code", "execution_count": 27, "id": "3c937d01-eea5-4bad-906a-36d07b3c30a4", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'pp = None\\nfor train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\\n #with torch.cuda.amp.autocast(dtype=data_type):\\n #optimizer.zero_grad()\\n\\n voxel = voxels[behav[:,0,5].cpu().long()]#.to(device)\\n image = images[behav[:,0,0].cpu().long()].float()#.to(device).float()\\n\\n past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()]#.to(device) # batch_size, 15, 15279\\n past_15_times = torch.Tensor([i for i in range(seq_len)])#.to(device) # 15\\n print(past_behav[:,:seq_len-1,0].cpu().long())\\n past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()]\\n \\n break\\n \\n print(past_15_times)\\n #for past in range(1):\\n # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)\\n \\n #if blurry_recon:\\n # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\\n blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\\n\\n if depth_recon:\\n # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\\n depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\\n depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\\n depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\\n\\n if use_image_aug: \\n image = img_augment(image)\\n\\n clip_target = clip_model.embed_image(image)\\n assert not torch.any(torch.isnan(clip_target))\\n\\n if epoch < int(mixup_pct * num_epochs):\\n voxel, perm, betas, select = utils.mixco(voxel)\\n past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)\\n \\n for p in range(seq_len-1):\\n print(past_behav.shape) #128, 15, 17\\n print(past_behav[:,p,-1])\\n print(past_15_voxels.shape) # 128, 1, 15724\\n mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])\\n print(mask) # 128\\n past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])\\n print(past_15_voxels)\\n pp = past_15_voxels\\n \\n break'" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"pp = None\n", "for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " #with torch.cuda.amp.autocast(dtype=data_type):\n", " #optimizer.zero_grad()\n", "\n", " voxel = voxels[behav[:,0,5].cpu().long()]#.to(device)\n", " image = images[behav[:,0,0].cpu().long()].float()#.to(device).float()\n", "\n", " past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()]#.to(device) # batch_size, 15, 15279\n", " past_15_times = torch.Tensor([i for i in range(seq_len)])#.to(device) # 15\n", " print(past_behav[:,:seq_len-1,0].cpu().long())\n", " past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()]\n", " \n", " break\n", " \n", " print(past_15_times)\n", " #for past in range(1):\n", " # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)\n", " \n", " #if blurry_recon:\n", " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", "\n", " if depth_recon:\n", " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", "\n", " if use_image_aug: \n", " image = img_augment(image)\n", "\n", " clip_target = clip_model.embed_image(image)\n", " assert not torch.any(torch.isnan(clip_target))\n", "\n", " if epoch < int(mixup_pct * num_epochs):\n", " voxel, perm, betas, select = utils.mixco(voxel)\n", " past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)\n", " \n", " for p in range(seq_len-1):\n", " print(past_behav.shape) #128, 15, 17\n", " print(past_behav[:,p,-1])\n", " print(past_15_voxels.shape) # 128, 1, 15724\n", " mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])\n", " print(mask) # 128\n", " past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])\n", " print(past_15_voxels)\n", " pp = past_15_voxels\n", " \n", " break\"\"\"" ] }, { "cell_type": "code", "execution_count": 28, "id": "5c7eb009-9fcf-4e35-9e1b-c0ca6418b6b8", "metadata": { "tags": [] }, "outputs": [], "source": [ "#pp[20, 0, :]" ] }, { "cell_type": "markdown", "id": "983f458b-35b8-49f2-b6db-80296cece730", "metadata": {}, "source": [ "# Weights and Biases" ] }, { "cell_type": "code", "execution_count": 29, "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "wandb mindeyev2 run mPo0n5r22f_interactive\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "wandb_config:\n", " {'model_name': 'mPo0n5r22f_interactive', 'global_batch_size': 16, 'batch_size': 16, 'num_epochs': 12, 'clip_scale': 1.0, 'blur_scale': 100.0, 'use_image_aug': False, 'max_lr': 0.0003, 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'ckpt_interval': 999, 'ckpt_saving': False, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1, 'train_url': '/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar', 'test_url': '/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar'}\n" ] }, { "data": { "text/html": [ "wandb version 0.15.12 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231106_011703-qm896vre" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run mPo0n5r22f_interactive to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://stability.wandb.io/ckadirt/mindeyev2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://stability.wandb.io/ckadirt/mindeyev2/runs/qm896vre" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", " import wandb\n", " wandb_project = 'mindeyev2'\n", " wandb_run = model_name\n", " wandb_notes = ''\n", " \n", " print(f\"wandb {wandb_project} run {wandb_run}\")\n", " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", " wandb_config = {\n", " \"model_name\": model_name,\n", " \"global_batch_size\": global_batch_size,\n", " \"batch_size\": batch_size,\n", " \"num_epochs\": num_epochs,\n", " \"clip_scale\": clip_scale,\n", " \"blur_scale\": blur_scale,\n", " \"use_image_aug\": use_image_aug,\n", " \"max_lr\": max_lr,\n", " \"mixup_pct\": mixup_pct,\n", " \"num_train\": num_train,\n", " \"num_test\": num_test,\n", " \"ckpt_interval\": ckpt_interval,\n", " \"ckpt_saving\": ckpt_saving,\n", " \"seed\": seed,\n", " \"distributed\": distributed,\n", " \"num_devices\": num_devices,\n", " \"world_size\": world_size,\n", " \"train_url\": train_url,\n", " \"test_url\": test_url,\n", " }\n", " print(\"wandb_config:\\n\",wandb_config)\n", " if False: # wandb_auto_resume\n", " print(\"wandb_id:\",model_name)\n", " wandb.init(\n", " id = model_name,\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " resume=\"allow\",\n", " )\n", " else:\n", " wandb.init(\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " )\n", "else:\n", " wandb_log = False" ] }, { "cell_type": "markdown", "id": "d5690151-2131-4918-b750-e869cbd1a8a8", "metadata": {}, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 30, "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", "metadata": {}, "outputs": [], "source": [ "epoch = 0\n", "losses, test_losses, lrs = [], [], []\n", "best_test_loss = 1e9\n", "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", "\n", "# Optionally resume from checkpoint #\n", "if resume_from_ckpt:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "elif wandb_log:\n", " if wandb.run.resumed:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 31, "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", "metadata": {}, "outputs": [], "source": [ "model, optimizer, train_dl, lr_scheduler = accelerator.prepare(\n", "model, optimizer, train_dl, lr_scheduler\n", ")\n", "# leaving out test_dl since we will only have local_rank 0 device do evals" ] }, { "cell_type": "code", "execution_count": 32, "id": "469e6313-425f-45ed-875a-ecd5df343e31", "metadata": {}, "outputs": [], "source": [ "def add_saturation(image, alpha=2):\n", " gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]\n", " gray_image = gray_image.unsqueeze(1).expand_as(image)\n", " saturated_image = alpha * image + (1 - alpha) * gray_image\n", " return torch.clamp(saturated_image, 0, 1)" ] }, { "cell_type": "code", "execution_count": 33, "id": "43e126bd-5320-4512-91c1-cdcb67a622f0", "metadata": { "tags": [] }, "outputs": [], "source": [ "#b = torch.randn(1,2)\n", "#b.to(device)" ] }, { "cell_type": "code", "execution_count": 34, "id": "264b34ed-8214-48d2-871f-0a7413526ac4", "metadata": { "tags": [] }, "outputs": [], "source": [ "#device" ] }, { "cell_type": "code", "execution_count": 35, "id": "ff64dc04-6082-4b87-979a-224868db57e1", "metadata": { "tags": [] }, "outputs": [], "source": [ "#past_15_times = torch.Tensor([i for i in range(seq_len-1)]).long() # 15\n", "#past_15_times.to(device)" ] }, { "cell_type": "code", "execution_count": 36, "id": "00d5e5e1-0e22-409a-9f5f-5eeeb6b00620", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭──────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
       " nn++                                                                                             \n",
       "                                                                                                 \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "SyntaxError: invalid syntax\n",
       "
\n" ], "text/plain": [ "\u001b[91m╭──────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", "\u001b[91m│\u001b[0m nn++ \u001b[91m│\u001b[0m\n", "\u001b[91m│\u001b[0m \u001b[1;91m▲\u001b[0m \u001b[91m│\u001b[0m\n", "\u001b[91m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mSyntaxError: \u001b[0minvalid syntax\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "nn++" ] }, { "cell_type": "code", "execution_count": null, "id": "85af68da-dd37-4d79-aca1-49f67e5bbd98", "metadata": { "tags": [] }, "outputs": [], "source": [ "#images.shape" ] }, { "cell_type": "code", "execution_count": 41, "id": "60be0d5f-3e94-4612-9373-61b53d836393", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mPo0n5r22f_interactive starting with epoch 0 / 12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/12 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", " in <module>:307 \n", " \n", " 304 │ │ │ assert (test_i+1) == 1 \n", " 305 │ │ │ logs = {\"train/loss\": np.mean(losses[-(train_i+1):]), \n", " 306 │ │ │ │ \"test/loss\": np.mean(test_losses[-(test_i+1):]), \n", " 307 │ │ │ │ \"train/lr\": lrs[-1], \n", " 308 │ │ │ │ \"train/num_steps\": len(losses), \n", " 309 │ │ │ │ \"test/num_steps\": len(test_losses), \n", " 310 │ │ │ │ \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1), \n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "IndexError: list index out of range\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m307\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m304 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94massert\u001b[0m (test_i+\u001b[94m1\u001b[0m) == \u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m305 \u001b[0m\u001b[2m│ │ │ \u001b[0mlogs = {\u001b[33m\"\u001b[0m\u001b[33mtrain/loss\u001b[0m\u001b[33m\"\u001b[0m: np.mean(losses[-(train_i+\u001b[94m1\u001b[0m):]), \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m306 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mtest/loss\u001b[0m\u001b[33m\"\u001b[0m: np.mean(test_losses[-(test_i+\u001b[94m1\u001b[0m):]), \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m307 \u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mtrain/lr\u001b[0m\u001b[33m\"\u001b[0m: lrs[-\u001b[94m1\u001b[0m], \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m308 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mtrain/num_steps\u001b[0m\u001b[33m\"\u001b[0m: \u001b[96mlen\u001b[0m(losses), \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m309 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mtest/num_steps\u001b[0m\u001b[33m\"\u001b[0m: \u001b[96mlen\u001b[0m(test_losses), \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m310 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[33m\"\u001b[0m\u001b[33mtrain/fwd_pct_correct\u001b[0m\u001b[33m\"\u001b[0m: fwd_percent_correct / (train_i + \u001b[94m1\u001b[0m), \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mIndexError: \u001b[0mlist index out of range\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n", "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n", "test_image, test_voxel = None, None\n", "mse = nn.MSELoss()\n", "l1 = nn.L1Loss()\n", "\n", "for epoch in progress_bar:\n", " model.train()\n", " \n", " fwd_percent_correct = 0.\n", " bwd_percent_correct = 0.\n", " test_fwd_percent_correct = 0.\n", " test_bwd_percent_correct = 0.\n", "\n", " loss_clip_total = 0.\n", " loss_blurry_total = 0.\n", " loss_depth_total = 0.\n", " test_loss_clip_total = 0.\n", " test_loss_blurry_total = 0.\n", " test_loss_depth_total = 0.\n", "\n", " blurry_pixcorr = 0.\n", " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n", " \n", " for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " break\n", " with torch.cuda.amp.autocast():\n", " optimizer.zero_grad()\n", " \n", " #voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " #image = images[behav[:,0,0].cpu().long()].to(device).float()\n", " \n", " #past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279\n", " #past_15_times = torch.Tensor([i for i in range(seq_len - 1)]).to(device) # 15\n", " \n", " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " image = images[behav[:,0,0].cpu().long()].to(device).float()\n", "\n", " past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279\n", " #print(past_behav[:,:seq_len-1,0].cpu().long(), behav[:,0,0].cpu().long(), past_behav[:,:seq_len-1,0].cpu().long()[0])\n", " past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()].to(device).float()\n", " past_array = [i for i in range(seq_len-1)]\n", " past_15_times = torch.Tensor(past_array) # 15\n", " #print(past_15_times)\n", " #print(past_15_voxels.shape, past_behav[:,:seq_len-1,5].cpu().long())\n", " past_15_times = past_15_times.to(device)\n", " #for past in range(1):\n", " # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)\n", " \n", " if blurry_recon:\n", " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", "\n", " if depth_recon:\n", " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", " \n", " if use_image_aug: \n", " image = img_augment(image)\n", " \n", " clip_target = clip_model.embed_image(image)\n", " assert not torch.any(torch.isnan(clip_target))\n", " \n", " if epoch < int(mixup_pct * num_epochs):\n", " voxel, perm, betas, select = utils.mixco(voxel)\n", " past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)\n", " \n", " #print(past_15_images.shape)\n", " \n", " for p in range(seq_len-1):\n", " #print(past_behav.shape) #128, 15, 17\n", " #print(past_behav[:,p,-1])\n", " #print(past_15_voxels.shape) # 128, 1, 15724\n", " mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])\n", " #print(mask) # 128\n", " past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])\n", " past_15_images[mask, p, :] = torch.zeros_like(past_15_images[0, p, :])\n", " #print(past_15_voxels)\n", " \n", " past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])\n", " past_15_images = past_15_images.reshape(-1, past_15_images.shape[-3], past_15_images.shape[-2], past_15_images.shape[-1])\n", " #print(past_15_images.shape)\n", " past_15_embeddings = clip_model2.embed_image(past_15_images)\n", " #print(past_15_embeddings.shape, 'uteho')\n", " past_15_embeddings = torch.cat([torch.zeros(batch_size, past_15_embeddings.shape[-1]).to(past_15_embeddings.device), past_15_embeddings], dim = 0)\n", " #print('tuhet', past_15_embeddings.shape)\n", " #print('yepe', past_15_embeddings[0,:])\n", " #print('yepe', past_15_embeddings[17,:])\n", " past_15_times = past_15_times.repeat(voxel.shape[0], 1)\n", " past_15_times = past_15_times.reshape(-1)\n", " time_embeddings = model.time_embedding(past_15_times)\n", " \n", " past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)\n", " \n", " positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)\n", " voxel = torch.cat((voxel, positional_current_voxel), dim=-1)\n", " voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))\n", " voxel_ridge = voxel_ridge.view(seq_len,int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2)\n", " #past_15_embeddings = torch.split(past_15_embeddings, seq_len)\n", " #print(past_15_embeddings, 'ttt')\n", " past_15_embeddings = past_15_embeddings.reshape(seq_len, int(past_15_embeddings.shape[0]/seq_len), clip_emb_dim).permute(1,0,2)\n", " #unsqueeze(1) # bz * 2, 1, 4096\n", " #print(voxel_ridge.shape, past_15_embeddings.shape)\n", " #print('yepe', past_15_embeddings[10,0,:])\n", " #print('yepe', past_15_embeddings[10,1,:])\n", " voxel_ridge = torch.cat((voxel_ridge, past_15_embeddings), dim=-1)\n", " #print(voxel_ridge[:,0,-10:-1])\n", " #print(voxel_ridge[:,0,10:20])\n", " #raise(\"uehot\")\n", " # past_voxel_ridge = model.ridge(past_voxel)\n", " # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)\n", " #print(voxel_ridge.shape)\n", " \n", " clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", " \n", " if epoch < int(mixup_pct * num_epochs): \n", " loss_clip = utils.mixco_nce(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006, \n", " perm=perm, betas=betas, select=select)\n", " else:\n", " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=epoch_temp)\n", "\n", " loss_clip_total += loss_clip.item()\n", " loss_clip *= clip_scale\n", " loss = loss_clip\n", " \n", " if blurry_recon:\n", " downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " \n", " loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))\n", " loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))\n", " loss_blurry_total += loss_blurry.item()\n", " loss_blurry *= blur_scale\n", " loss += loss_blurry\n", "\n", " if depth_recon:\n", " loss_depth = l1(depth_image_enc_, depth_image_enc)\n", " # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))\n", " loss_depth_total += loss_depth.item()\n", " loss_depth *= depth_scale\n", " loss += loss_depth\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()\n", " bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()\n", " \n", " if blurry_recon:\n", " with torch.no_grad():\n", " # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()\n", " random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)\n", " # random_samps = np.arange(batch_size//5)\n", " blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)\n", " # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean\n", " pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)\n", " # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)\n", " # loss += (1 - pixcorr)\n", " blurry_pixcorr += pixcorr.item()\n", " # utils.check_loss(pixcorr)\n", "\n", " utils.check_loss(loss)\n", " accelerator.backward(loss)\n", " optimizer.step()\n", " \n", " losses.append(loss.item())\n", " lrs.append(optimizer.param_groups[0]['lr'])\n", " \n", " if lr_scheduler_type is not None:\n", " lr_scheduler.step()\n", "\n", " model.eval()\n", " if local_rank==0:\n", " with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): \n", " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): \n", " # all test samples should be loaded per batch such that test_i should never exceed 0\n", " assert len(behav) == num_test\n", " \n", " ## Average same-image repeats ##\n", " if test_image is None:\n", " voxel = voxels[behav[:,0,5].cpu().long()]\n", " image = behav[:,0,0].cpu().long()\n", " \n", " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n", " for im in unique_image:\n", " locs = torch.where(im == image)[0]\n", " if test_image is None:\n", " test_image = images[im][None]\n", " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n", " else:\n", " test_image = torch.vstack((test_image, images[im][None]))\n", " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n", " \n", " # random sample of 300\n", " random_indices = torch.arange(len(test_voxel))[:300]\n", " voxel = test_voxel[random_indices].to(device)\n", " image = test_image[random_indices].to(device)\n", " assert len(image) == 300\n", " \n", " current_past_behav = past_behav[random_indices]\n", "\n", " past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279\n", " past_15_images = images[current_past_behav[:,:seq_len-1,0].cpu().long()].to(device).float()\n", " past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15\n", "\n", " if blurry_recon:\n", " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", "\n", " if depth_recon:\n", " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", " \n", " clip_target = clip_model.embed_image(image.float())\n", " \n", "\n", " past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])\n", " past_15_images = past_15_images.reshape(-1, past_15_images.shape[-3], past_15_images.shape[-2], past_15_images.shape[-1])\n", " print(past_15_images.shape)\n", " past_15_embeddings = clip_model2.embed_image(past_15_images)\n", " print(past_15_embeddings.shape)\n", " past_15_embeddings = torch.cat([torch.zeros(image.shape[0], past_15_embeddings.shape[-1]).to(past_15_embeddings.device), past_15_embeddings], dim = 0)\n", " print(past_15_embeddings.shape)\n", " past_15_times = past_15_times.repeat(voxel.shape[0], 1)\n", " past_15_times = past_15_times.reshape(-1)\n", " time_embeddings = model.time_embedding(past_15_times)\n", " past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)\n", "\n", " positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)\n", " voxel = torch.cat((voxel, positional_current_voxel), dim=-1)\n", " voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))\n", " voxel_ridge = voxel_ridge.view(seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2)\n", " past_15_embeddings = past_15_embeddings.view(seq_len, int(past_15_embeddings.shape[0]/seq_len), clip_emb_dim).permute(1,0,2)\n", " print(past_15_embeddings.shape, voxel_ridge.shape)\n", " voxel_ridge = torch.cat((voxel_ridge, past_15_embeddings), dim=-1)\n", " \n", " #voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", "\n", " # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)\n", " \n", " clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", " \n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006)\n", " test_loss_clip_total += loss_clip.item()\n", " loss_clip = loss_clip * clip_scale\n", " loss = loss_clip\n", "\n", " if blurry_recon:\n", " downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " \n", " loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))\n", " loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))\n", " test_loss_blurry_total += loss_blurry.item()\n", " loss_blurry *= blur_scale\n", " loss += loss_blurry\n", " \n", " # halving the batch size because the decoder is computationally heavy\n", " blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)\n", " blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " pixcorr = utils.pixcorr(image, blurry_recon_images)\n", " loss += (1 - pixcorr)\n", " test_blurry_pixcorr += pixcorr.item()\n", "\n", " if depth_recon:\n", " loss_depth = l1(depth_image_enc_, depth_image_enc)\n", " # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))\n", " test_loss_depth_total += loss_depth.item()\n", " loss_depth *= depth_scale\n", " loss += loss_depth\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()\n", " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()\n", "\n", " utils.check_loss(loss) \n", " test_losses.append(loss.item())\n", "\n", " # if utils.is_interactive(): clear_output(wait=True)\n", " print(\"---\")\n", " \n", " assert (test_i+1) == 1\n", " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n", " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n", " \"train/lr\": lrs[-1],\n", " \"train/num_steps\": len(losses),\n", " \"test/num_steps\": len(test_losses),\n", " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n", " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n", " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n", " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n", " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n", " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n", " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n", " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n", " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n", " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n", " \"train/loss_depth_total\": loss_depth_total / (train_i + 1),\n", " \"test/loss_depth_total\": test_loss_depth_total / (test_i + 1),\n", " }\n", " \n", " if blurry_recon: \n", " # transform blurry recon latents to images and plot it\n", " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", " jj=-1\n", " for j in [0,1,2,3]:\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " axes[jj].axis('off')\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " axes[jj].axis('off')\n", " \n", " if wandb_log:\n", " logs[f\"test/recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\")\n", " plt.close()\n", " else:\n", " plt.show()\n", "\n", " if depth_recon:\n", " # transform blurry recon latents to images and plot it\n", " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", " # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", " jj=-1\n", " for j in [0,1,2,3]:\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))\n", " axes[jj].axis('off')\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))\n", " axes[jj].axis('off')\n", " if wandb_log:\n", " logs[f\"test/depth_recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\")\n", " plt.close()\n", " else:\n", " plt.show()\n", " \n", " progress_bar.set_postfix(**logs)\n", " \n", " # Save model checkpoint and reconstruct\n", " if epoch % ckpt_interval == 0:\n", " if not utils.is_interactive():\n", " save_ckpt(f'last')\n", " \n", " if wandb_log: wandb.log(logs)\n", "\n", " # wait for other GPUs to catch up if needed\n", " accelerator.wait_for_everyone()\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", "\n", "print(\"\\n===Finished!===\\n\")\n", "if ckpt_saving:\n", " save_ckpt(f'last')\n", "if not utils.is_interactive():\n", " sys.exit(0)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a", "metadata": {}, "outputs": [], "source": [ "plt.plot(losses)\n", "plt.show()\n", "plt.plot(test_losses)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "f2690877-a431-44e8-a2ca-61f4b7397070", "metadata": {}, "source": [ "# Retrieve nearest neighbor in the training set using test set data" ] }, { "cell_type": "code", "execution_count": null, "id": "5b6b8feb-391d-437e-a5d9-a2088f1b1149", "metadata": {}, "outputs": [], "source": [ "annots = np.load(\"/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy\")" ] }, { "cell_type": "code", "execution_count": null, "id": "612ac5aa-6f0f-45ad-809e-03df905d184c", "metadata": {}, "outputs": [], "source": [ "ii=2\n", "all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))\n", "with torch.no_grad(), torch.cuda.amp.autocast():\n", " for batch in tqdm(range(0,len(all_indices),512)):\n", " if batch==0:\n", " clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()\n", " else:\n", " target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()\n", " clip_target = torch.vstack((clip_target,target))\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", "\n", " voxel = test_voxel[[ii]].to(device)\n", " image = test_image[[ii]].to(device)\n", "\n", " print(\"Original Image (test set)\")\n", " display(utils.torch_to_Image(image))\n", " \n", " clip_target = clip_model.embed_image(image).cpu()\n", " # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))\n", " \n", " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", " clip_voxels, _, _ = model.backbone(voxel_ridge) \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", "\n", " print(\"clip_voxels_norm\", clip_voxels_norm.shape)\n", " print(\"clip_target_norm\", clip_target_norm.shape)\n", " \n", " sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(), \n", " clip_target_norm).flatten()).flip(0)\n", " picks = all_indices[sortt[:5]]\n", "\n", " print(\"\\nNearest neighbors in training set\")\n", " for ip,p in enumerate(picks):\n", " display(utils.torch_to_Image(images[[p]]))\n", " # print(utils.select_annotations([annots[int(p)]]))\n", " if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]\n", "\n", "print(\"\\n=====\\npredicted_caption:\\n\", predicted_caption)" ] }, { "cell_type": "markdown", "id": "1473ddaa-5f2b-4448-9194-c7b0801d05db", "metadata": {}, "source": [ "# Feed into Stable Diffusion XL for reconstructions" ] }, { "cell_type": "code", "execution_count": null, "id": "70e50e0d-c44f-4d56-939a-2943535e1747", "metadata": {}, "outputs": [], "source": [ "from diffusers import StableDiffusionXLPipeline\n", "pipe = StableDiffusionXLPipeline.from_pretrained(\n", " \"/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n", ")\n", "pipe.to(\"cuda\")\n", "pass" ] }, { "cell_type": "code", "execution_count": null, "id": "479e6994-3eaa-47d2-89a3-422c464fab36", "metadata": {}, "outputs": [], "source": [ "prompt = predicted_caption\n", "recon = pipe(prompt=prompt).images[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "9dc48e1b-5842-4a29-963a-6469d943a72c", "metadata": { "tags": [] }, "outputs": [], "source": [ "print(\"Seen image\")\n", "display(utils.torch_to_Image(image))\n", "\n", "print(\"Reconstruction\")\n", "utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": true }, "toc-autonumbering": true, "vscode": { "interpreter": { "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376" } } }, "nbformat": 4, "nbformat_minor": 5 }