diff --git "a/src/Train-with-memory-cat.ipynb" "b/src/Train-with-memory-cat.ipynb" new file mode 100644--- /dev/null +++ "b/src/Train-with-memory-cat.ipynb" @@ -0,0 +1,1755 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", + "# from subprocess import call\n", + "# command = \"jupyter nbconvert Train-with-memory-cat.ipynb --to python\"\n", + "# call(command,shell=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b0f0f4f3", + "metadata": {}, + "source": [ + "# Import packages & functions" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5bad764b-45c1-45ce-a716-8d055e09821a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import json\n", + "import argparse\n", + "import numpy as np\n", + "import math\n", + "from einops import rearrange\n", + "import time\n", + "import random\n", + "import h5py\n", + "from tqdm import tqdm\n", + "\n", + "import webdataset as wds\n", + "import gc\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "from torchvision import transforms\n", + "\n", + "from accelerate import Accelerator, DeepSpeedPlugin\n", + "\n", + "# tf32 data type is faster than standard float32\n", + "torch.backends.cuda.matmul.allow_tf32 = True\n", + "\n", + "# custom functions #\n", + "import utils\n", + "\n", + "global_batch_size = 128 #128" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LOCAL RANK 0\n" + ] + } + ], + "source": [ + "### Multi-GPU config ###\n", + "local_rank = os.getenv('RANK')\n", + "if local_rank is None: \n", + " local_rank = 0\n", + "else:\n", + " local_rank = int(local_rank)\n", + "print(\"LOCAL RANK \", local_rank) \n", + "\n", + "num_devices = torch.cuda.device_count()\n", + "if num_devices==0: num_devices = 1\n", + "\n", + "accelerator = Accelerator(split_batches=False)\n", + "\n", + "### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above \"accelerator = \" line) ###\n", + "\n", + "# if num_devices <= 1 and utils.is_interactive():\n", + "# # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", + "# os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", + "# os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", + "# os.environ[\"RANK\"] = \"0\"\n", + "# os.environ[\"LOCAL_RANK\"] = \"0\"\n", + "# os.environ[\"WORLD_SIZE\"] = \"1\"\n", + "# os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", + "# global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n", + "\n", + "# # alter the deepspeed config according to your global and local batch size\n", + "# if local_rank == 0:\n", + "# with open('deepspeed_config_stage2.json', 'r') as file:\n", + "# config = json.load(file)\n", + "# config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", + "# config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", + "# with open('deepspeed_config_stage2.json', 'w') as file:\n", + "# json.dump(config, file)\n", + "# else:\n", + "# # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", + "# time.sleep(10)\n", + "# deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", + "# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PID of this process = 3732719\n", + "device: cuda\n", + "Distributed environment: NO\n", + "Num processes: 1\n", + "Process index: 0\n", + "Local process index: 0\n", + "Device: cuda\n", + "\n", + "Mixed precision type: no\n", + "\n", + "distributed = False num_devices = 1 local rank = 0 world size = 1\n" + ] + } + ], + "source": [ + "print(\"PID of this process =\",os.getpid())\n", + "device = accelerator.device\n", + "print(\"device:\",device)\n", + "num_workers = num_devices\n", + "print(accelerator.state)\n", + "world_size = accelerator.state.num_processes\n", + "distributed = not accelerator.state.distributed_type == 'NO'\n", + "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n", + "print = accelerator.print # only print if local_rank=0" + ] + }, + { + "cell_type": "markdown", + "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", + "metadata": { + "tags": [] + }, + "source": [ + "# Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=test', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=3e-5', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug']\n" + ] + } + ], + "source": [ + "# if running this interactively, can specify jupyter_args here for argparser to use\n", + "if utils.is_interactive():\n", + " # Example use\n", + " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", + " --model_name=test \\\n", + " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n", + " --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug\"\n", + "\n", + " jupyter_args = jupyter_args.split()\n", + " print(jupyter_args)\n", + " \n", + " from IPython.display import clear_output # function to clear print outputs in cell\n", + " %load_ext autoreload \n", + " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", + " %autoreload 2 " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "global batch_size 128\n", + "batch_size 128\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", + "parser.add_argument(\n", + " \"--model_name\", type=str, default=\"testing\",\n", + " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", + ")\n", + "parser.add_argument(\n", + " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", + " help=\"Path to where NSD data is stored / where to download it to\",\n", + ")\n", + "parser.add_argument(\n", + " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", + ")\n", + "parser.add_argument(\n", + " \"--batch_size\", type=int, default=32,\n", + " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", + ")\n", + "parser.add_argument(\n", + " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n", + " help=\"whether to log to wandb\",\n", + ")\n", + "parser.add_argument(\n", + " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", + " help=\"if not using wandb and want to resume from a ckpt\",\n", + ")\n", + "parser.add_argument(\n", + " \"--wandb_project\",type=str,default=\"stability\",\n", + " help=\"wandb project name\",\n", + ")\n", + "parser.add_argument(\n", + " \"--mixup_pct\",type=float,default=.33,\n", + " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", + ")\n", + "parser.add_argument(\n", + " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", + " help=\"whether to use image augmentation\",\n", + ")\n", + "parser.add_argument(\n", + " \"--num_epochs\",type=int,default=240,\n", + " help=\"number of epochs of training\",\n", + ")\n", + "parser.add_argument(\n", + " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", + ")\n", + "parser.add_argument(\n", + " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", + ")\n", + "parser.add_argument(\n", + " \"--ckpt_interval\",type=int,default=5,\n", + " help=\"save backup ckpt and reconstruct every x epochs\",\n", + ")\n", + "parser.add_argument(\n", + " \"--seed\",type=int,default=42,\n", + ")\n", + "parser.add_argument(\n", + " \"--max_lr\",type=float,default=3e-4,\n", + ")\n", + "parser.add_argument(\n", + " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n", + " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n", + ")\n", + "\n", + "if utils.is_interactive():\n", + " args = parser.parse_args(jupyter_args)\n", + "else:\n", + " args = parser.parse_args()\n", + "\n", + "# create global variables without the args prefix\n", + "for attribute_name in vars(args).keys():\n", + " globals()[attribute_name] = getattr(args, attribute_name)\n", + "\n", + "print(\"global batch_size\", batch_size)\n", + "batch_size = int(batch_size / num_devices)\n", + "print(\"batch_size\", batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "outdir = os.path.abspath(f'../train_mem_logs/{model_name}')\n", + "if not os.path.exists(outdir):\n", + " os.makedirs(outdir,exist_ok=True)\n", + "if use_image_aug:\n", + " import kornia\n", + " from kornia.augmentation.container import AugmentationSequential\n", + " img_augment = AugmentationSequential(\n", + " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", + " kornia.augmentation.Resize((224, 224)),\n", + " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", + " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", + " kornia.augmentation.RandomGrayscale(p=0.3),\n", + " same_on_batch=False,\n", + " data_keys=[\"input\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "42d13c25-1369-4c49-81d4-83d713586096", + "metadata": { + "tags": [] + }, + "source": [ + "# Prep data, models, and dataloaders" + ] + }, + { + "cell_type": "markdown", + "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", + "metadata": {}, + "source": [ + "## Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "81084834-035f-4465-ad59-59e6b806a2f5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", + "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" + ] + } + ], + "source": [ + "if subj==1:\n", + " num_train = 24958\n", + " num_test = 2770\n", + "test_batch_size = num_test\n", + "\n", + "def my_split_by_node(urls): return urls\n", + " \n", + "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", + "print(train_url)\n", + "\n", + "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", + " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", + " .decode(\"torch\")\\\n", + " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", + " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", + "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n", + "\n", + "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", + "print(test_url)\n", + "\n", + "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", + " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", + " .decode(\"torch\")\\\n", + " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", + " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", + "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)" + ] + }, + { + "cell_type": "markdown", + "id": "203b060a-2dd2-4c35-929b-c576be82eb52", + "metadata": {}, + "source": [ + "### check dataloaders are working" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", + "metadata": {}, + "outputs": [], + "source": [ + "# test_indices = []\n", + "# test_images = []\n", + "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", + "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n", + "# test_images = np.append(test_images, behav[:,0,0].numpy())\n", + "# test_indices = test_indices.astype(np.int16)\n", + "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n", + "# print(\"---\\n\")\n", + "\n", + "# train_indices = []\n", + "# train_images = []\n", + "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", + "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n", + "# train_images = np.append(train_images, behav[:,0,0].numpy())\n", + "# train_indices = train_indices.astype(np.int16)\n", + "# print(train_i, (train_i+1) * batch_size, len(train_indices))\n", + "\n", + "# # train_images = np.hstack((train_images, test_images))\n", + "# # print(\"WARNING: ADDED TEST IMAGES TO TRAIN IMAGES\")" + ] + }, + { + "cell_type": "markdown", + "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", + "metadata": {}, + "source": [ + "## Load data and images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "039dd330-7339-4f88-8f00-45f95e47baa0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "subj01 betas loaded into memory\n", + "voxels torch.Size([27750, 15729])\n" + ] + } + ], + "source": [ + "# load betas\n", + "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", + "voxels = f['betas'][:]\n", + "print(f\"subj0{subj} betas loaded into memory\")\n", + "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n", + "if subj==1:\n", + " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n", + "print(\"voxels\", voxels.shape)\n", + "num_voxels = voxels.shape[-1]\n", + "\n", + "# load orig images\n", + "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", + "images = f['images'][:]\n", + "images = torch.Tensor(images).to(\"cpu\").half()\n", + "print(\"images\", images.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", + "metadata": {}, + "source": [ + "## Load models" + ] + }, + { + "cell_type": "markdown", + "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", + "metadata": {}, + "source": [ + "### CLIP image embeddings model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0420dc0-199e-4c1a-857d-b1747058b467", + "metadata": {}, + "outputs": [], + "source": [ + "from models import Clipper\n", + "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", + "\n", + "clip_seq_dim = 257\n", + "clip_emb_dim = 768\n", + "hidden_dim = 4096" + ] + }, + { + "cell_type": "markdown", + "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", + "metadata": {}, + "source": [ + "### SD VAE (blurry images)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01baff79-8114-482b-b115-6f05aa8ad691", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import AutoencoderKL\n", + "autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", + "# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", + "autoenc.eval()\n", + "autoenc.requires_grad_(False)\n", + "autoenc.to(device)\n", + "utils.count_params(autoenc)" + ] + }, + { + "cell_type": "markdown", + "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", + "metadata": {}, + "source": [ + "### MindEye modules" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MindEyeModule()" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class MindEyeModule(nn.Module):\n", + " def __init__(self):\n", + " super(MindEyeModule, self).__init__()\n", + " def forward(self, x):\n", + " return x\n", + " \n", + "model = MindEyeModule()\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "64,430,080 total\n", + "64,430,080 trainable\n", + "param counts:\n", + "64,430,080 total\n", + "64,430,080 trainable\n", + "torch.Size([2, 1, 15729]) torch.Size([2, 1, 4096])\n" + ] + } + ], + "source": [ + "class RidgeRegression(torch.nn.Module):\n", + " # make sure to add weight_decay when initializing optimizer\n", + " def __init__(self, input_size, out_features): \n", + " super(RidgeRegression, self).__init__()\n", + " self.out_features = out_features\n", + " self.linear = torch.nn.Linear(input_size, out_features)\n", + " def forward(self, x):\n", + " return self.linear(x)\n", + " \n", + "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", + "utils.count_params(model.ridge)\n", + "utils.count_params(model)\n", + "\n", + "b = torch.randn((2,1,voxels.shape[1]))\n", + "print(b.shape, model.ridge(b).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "1,071,175,044 total\n", + "1,071,175,044 trainable\n", + "param counts:\n", + "1,621,688,708 total\n", + "1,621,688,708 trainable\n", + "torch.Size([2, 8192])\n", + "torch.Size([2, 257, 768]) torch.Size([2, 4, 28, 28])\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "from diffusers.models.vae import Decoder\n", + "class BrainNetwork(nn.Module):\n", + " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):\n", + " super().__init__()\n", + " self.blurry_dim = blurry_dim\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.lin0 = nn.Linear(in_dim, h)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", + " self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.clip_size = clip_size\n", + " self.clip_proj = nn.Sequential(\n", + " nn.LayerNorm(clip_size),\n", + " nn.GELU(),\n", + " nn.Linear(clip_size, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, clip_size)\n", + " )\n", + " self.upsampler = Decoder(\n", + " in_channels=64,\n", + " out_channels=4,\n", + " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", + " block_out_channels=[64, 128, 256],\n", + " layers_per_block=1,\n", + " )\n", + " \n", + " def forward(self, x):\n", + " x = self.lin0(x)\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin1(x)\n", + " b = self.blin1(x)\n", + " b = self.upsampler(b.reshape(len(b), -1, 7, 7))\n", + " c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))\n", + " return c, b\n", + "\n", + "model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) \n", + "utils.count_params(model.backbone)\n", + "utils.count_params(model)\n", + "\n", + "b = torch.randn((2,8192))\n", + "print(b.shape)\n", + "clip_, blur_ = model.backbone(b)\n", + "print(clip_.shape, blur_.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a34204d0-d268-41ee-8eea-042525262c47", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "150,481,920 total\n", + "150,481,920 trainable\n", + "param counts:\n", + "335,601,664 total\n", + "335,601,664 trainable\n", + "param counts:\n", + "1,621,688,708 total\n", + "1,621,688,708 trainable\n" + ] + } + ], + "source": [ + "# memory model\n", + "\n", + "from timm.layers.mlp import Mlp\n", + "\n", + "class MemoryEncoder(nn.Module):\n", + " def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.15):\n", + " super().__init__()\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.out_dim = out_dim\n", + " self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)\n", + " self.final_input_dim = in_dim + embedding_time_dim\n", + " self.lin0 = nn.Linear(self.final_input_dim, h)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.num_past_voxels = num_past_voxels\n", + " self.embedding_time_dim = embedding_time_dim\n", + " self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))\n", + "\n", + "\n", + " def forward(self, x, time):\n", + " time = time.long()\n", + " time = self.embedding_time(time)\n", + " x = torch.cat((x, time), dim=-1)\n", + " x = self.lin0(x)\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin1(x)\n", + " return x\n", + " \n", + "# # test the memory encoder\n", + "# memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)\n", + "\n", + "# device = torch.device(\"cpu\")\n", + "# memory_encoder.to(device)\n", + "\n", + "# # count params\n", + "# total_parameters = 0\n", + "# for parameter in memory_encoder.parameters():\n", + "# total_parameters += parameter.numel()\n", + "\n", + "# rand_input = torch.randn((2, 15279)).to(device)\n", + "# rand_time = torch.randint(0, 15, (2,)).to(device)\n", + "# print(rand_input.shape, rand_time.shape)\n", + "# memory_encoder(rand_input, rand_time).shape\n", + "\n", + "class MemoryCompressor(nn.Module):\n", + " def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15):\n", + " super().__init__()\n", + " self.num_past = num_past\n", + " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", + " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", + " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", + " self.final_input_dim = in_dim * num_past\n", + " self.lin0 = nn.Linear(self.final_input_dim, h)\n", + " self.mlp = nn.ModuleList([\n", + " nn.Sequential(\n", + " nn.Linear(h, h),\n", + " *[item() for item in act_and_norm],\n", + " nn.Dropout(drop)\n", + " ) for _ in range(n_blocks)\n", + " ])\n", + " self.lin1 = nn.Linear(h, output_dim, bias=True)\n", + " self.n_blocks = n_blocks\n", + " self.num_past = num_past\n", + " self.output_dim = output_dim\n", + "\n", + " def forward(self, x):\n", + " # x is (batch_size, num_past, in_dim)\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin0(x)\n", + " residual = x\n", + " for res_block in range(self.n_blocks):\n", + " x = self.mlp[res_block](x)\n", + " x += residual\n", + " residual = x\n", + " x = x.reshape(len(x), -1)\n", + " x = self.lin1(x)\n", + " return x\n", + " \n", + "# # test the memory compressor\n", + "# memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)\n", + "\n", + "# device = torch.device(\"cpu\")\n", + "# memory_compressor.to(device)\n", + "\n", + "# # count params\n", + "# total_parameters = 0\n", + "# for parameter in memory_compressor.parameters():\n", + "# total_parameters += parameter.numel()\n", + "\n", + "# rand_input = torch.randn((2, 15, 768)).to(device)\n", + "# print(rand_input.shape)\n", + "# memory_compressor(rand_input).shape\n", + "\n", + "model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)\n", + "model.memory_compressor = MemoryCompressor(in_dim=model.memory_encoder.out_dim, num_past=15, output_dim=4096)\n", + "\n", + "utils.count_params(model.memory_encoder)\n", + "utils.count_params(model.memory_compressor)\n", + "utils.count_params(model)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Done with model preparations!\n", + "param counts:\n", + "1,621,688,708 total\n", + "1,621,688,708 trainable\n" + ] + } + ], + "source": [ + "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", + "opt_grouped_parameters = [\n", + " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", + " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", + " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", + " {'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},\n", + " {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},\n", + "]\n", + "\n", + "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))\n", + "\n", + "if lr_scheduler_type == 'linear':\n", + " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", + " optimizer,\n", + " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n", + " last_epoch=-1\n", + " )\n", + "elif lr_scheduler_type == 'cycle':\n", + " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n", + " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer, \n", + " max_lr=max_lr,\n", + " total_steps=total_steps,\n", + " final_div_factor=1000,\n", + " last_epoch=-1, pct_start=2/num_epochs\n", + " )\n", + " \n", + "def save_ckpt(tag): \n", + " ckpt_path = outdir+f'/{tag}.pth'\n", + " print(f'saving {ckpt_path}',flush=True)\n", + " unwrapped_model = accelerator.unwrap_model(model)\n", + " try:\n", + " torch.save({\n", + " 'epoch': epoch,\n", + " 'model_state_dict': unwrapped_model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'lr_scheduler': lr_scheduler.state_dict(),\n", + " 'train_losses': losses,\n", + " 'test_losses': test_losses,\n", + " 'lrs': lrs,\n", + " }, ckpt_path)\n", + " except:\n", + " print(\"Couldn't save... moving on to prevent crashing.\")\n", + " del unwrapped_model\n", + " \n", + "print(\"\\nDone with model preparations!\")\n", + "utils.count_params(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3edca702-e148-4f2d-82b9-1c42bca5f73f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
+       " in <module>:1                                                                                    \n",
+       "                                                                                                  \n",
+       " 1 nnnn                                                                                         \n",
+       "   2                                                                                              \n",
+       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+       "NameError: name 'nnnn' is not defined\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 nnnn \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mNameError: \u001b[0mname \u001b[32m'nnnn'\u001b[0m is not defined\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [] + }, + { + "cell_type": "markdown", + "id": "983f458b-35b8-49f2-b6db-80296cece730", + "metadata": {}, + "source": [ + "# Weights and Biases" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb stability run test\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb_config:\n", + " {'model_name': 'test', 'batch_size': 128, 'num_epochs': 12, 'use_image_aug': False, 'max_lr': 3e-05, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1}\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.15.12 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231015_224404-lbkf7608" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run mem1 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://stability.wandb.io/ckadirt/stability" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://stability.wandb.io/ckadirt/stability/runs/lbkf7608" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# params for wandb\n", + "wandb_log = True\n", + "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", + " import wandb\n", + " \n", + " wandb_project = 'stability'\n", + " wandb_run = model_name\n", + " wandb_notes = ''\n", + " \n", + " print(f\"wandb {wandb_project} run {wandb_run}\")\n", + " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", + " wandb_config = {\n", + " \"model_name\": model_name,\n", + " \"batch_size\": batch_size,\n", + " \"num_epochs\": num_epochs,\n", + " \"use_image_aug\": use_image_aug,\n", + " \"max_lr\": max_lr,\n", + " \"lr_scheduler_type\": lr_scheduler_type,\n", + " \"mixup_pct\": mixup_pct,\n", + " \"num_train\": num_train,\n", + " \"num_test\": num_test,\n", + " \"seed\": seed,\n", + " \"distributed\": distributed,\n", + " \"num_devices\": num_devices,\n", + " \"world_size\": world_size,\n", + " }\n", + " print(\"wandb_config:\\n\",wandb_config)\n", + " if False: # wandb_auto_resume\n", + " print(\"wandb_id:\",model_name)\n", + " wandb.init(\n", + " id = model_name,\n", + " project=wandb_project,\n", + " name=wandb_run,\n", + " config=wandb_config,\n", + " notes=wandb_notes,\n", + " resume=\"allow\",\n", + " )\n", + " else:\n", + " wandb.init(\n", + " project=wandb_project,\n", + " name=model_name,\n", + " config=wandb_config,\n", + " notes=wandb_notes,\n", + " )\n", + "else:\n", + " wandb_log = False" + ] + }, + { + "cell_type": "markdown", + "id": "5b0ae095-3203-4eb8-8606-acc2db6ccf20", + "metadata": {}, + "source": [ + "# More custom functions" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "827ead88-7eb3-47cc-82da-31565063b927", + "metadata": {}, + "outputs": [], + "source": [ + "# using the same preprocessing as was used in MindEye + BrainDiffuser\n", + "pixcorr_preprocess = transforms.Compose([\n", + " transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n", + "])\n", + "def pixcorr(images,brains):\n", + " # Flatten images while keeping the batch dimension\n", + " all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)\n", + " all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)\n", + " corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()\n", + " return corrmean" + ] + }, + { + "cell_type": "markdown", + "id": "d5690151-2131-4918-b750-e869cbd1a8a8", + "metadata": {}, + "source": [ + "# Main" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", + "metadata": {}, + "outputs": [], + "source": [ + "epoch = 0\n", + "losses, test_losses, lrs = [], [], []\n", + "best_test_loss = 1e9\n", + "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", + "\n", + "# Optionally resume from checkpoint #\n", + "if resume_from_ckpt:\n", + " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", + " try:\n", + " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", + " except:\n", + " print('last.pth failed... trying last_backup.pth')\n", + " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", + " epoch = checkpoint['epoch']\n", + " print(\"Epoch\",epoch)\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", + " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n", + " del checkpoint\n", + "elif wandb_log:\n", + " if wandb.run.resumed:\n", + " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", + " try:\n", + " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", + " except:\n", + " print('last.pth failed... trying last_backup.pth')\n", + " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", + " epoch = checkpoint['epoch']\n", + " print(\"Epoch\",epoch)\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", + " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n", + " del checkpoint\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", + "metadata": {}, + "outputs": [], + "source": [ + "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n", + "model, optimizer, train_dl, test_dl, lr_scheduler\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "60be0d5f-3e94-4612-9373-61b53d836393", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test starting with epoch 0 / 12\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/12 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", + " in <module>:34 \n", + " \n", + " 31 │ │ │ past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_s \n", + " 32 │ │ │ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15 \n", + " 33 │ │ │ \n", + " 34 │ │ │ blurry_image_enc = autoenc.encode(image).latent_dist.mode() \n", + " 35 │ │ │ \n", + " 36 │ │ │ if use_image_aug: image = img_augment(image) \n", + " 37 \n", + " \n", + " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/utils/accelerate_utils.py:46 in wrapper \n", + " \n", + " 43 def wrapper(self, *args, **kwargs): \n", + " 44 │ │ if hasattr(self, \"_hf_hook\") and hasattr(self._hf_hook, \"pre_forward\"): \n", + " 45 │ │ │ self._hf_hook.pre_forward(self) \n", + " 46 │ │ return method(self, *args, **kwargs) \n", + " 47 \n", + " 48 return wrapper \n", + " 49 \n", + " \n", + " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/autoencoder_kl.py:258 in encode \n", + " \n", + " 255 │ │ │ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] \n", + " 256 │ │ │ h = torch.cat(encoded_slices) \n", + " 257 │ │ else: \n", + " 258 │ │ │ h = self.encoder(x) \n", + " 259 │ │ \n", + " 260 │ │ moments = self.quant_conv(h) \n", + " 261 │ │ posterior = DiagonalGaussianDistribution(moments) \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", + " .py:1501 in _call_impl \n", + " \n", + " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", + " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", + " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", + " 1501 │ │ │ return forward_call(*args, **kwargs) \n", + " 1502 │ │ # Do not call functions when jit is used \n", + " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", + " 1504 │ │ backward_pre_hooks = [] \n", + " \n", + " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/vae.py:141 in forward \n", + " \n", + " 138 │ │ else: \n", + " 139 │ │ │ # down \n", + " 140 │ │ │ for down_block in self.down_blocks: \n", + " 141 │ │ │ │ sample = down_block(sample) \n", + " 142 │ │ │ \n", + " 143 │ │ │ # middle \n", + " 144 │ │ │ sample = self.mid_block(sample) \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", + " .py:1501 in _call_impl \n", + " \n", + " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", + " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", + " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", + " 1501 │ │ │ return forward_call(*args, **kwargs) \n", + " 1502 │ │ # Do not call functions when jit is used \n", + " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", + " 1504 │ │ backward_pre_hooks = [] \n", + " \n", + " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/unet_2d_blocks.py:1247 in forward \n", + " \n", + " 1244 \n", + " 1245 def forward(self, hidden_states, scale: float = 1.0): \n", + " 1246 │ │ for resnet in self.resnets: \n", + " 1247 │ │ │ hidden_states = resnet(hidden_states, temb=None, scale=scale) \n", + " 1248 │ │ \n", + " 1249 │ │ if self.downsamplers is not None: \n", + " 1250 │ │ │ for downsampler in self.downsamplers: \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", + " .py:1501 in _call_impl \n", + " \n", + " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", + " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", + " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", + " 1501 │ │ │ return forward_call(*args, **kwargs) \n", + " 1502 │ │ # Do not call functions when jit is used \n", + " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", + " 1504 │ │ backward_pre_hooks = [] \n", + " \n", + " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/resnet.py:650 in forward \n", + " \n", + " 647 │ │ if self.time_embedding_norm == \"ada_group\" or self.time_embedding_norm == \"spati \n", + " 648 │ │ │ hidden_states = self.norm2(hidden_states, temb) \n", + " 649 │ │ else: \n", + " 650 │ │ │ hidden_states = self.norm2(hidden_states) \n", + " 651 │ │ \n", + " 652 │ │ if temb is not None and self.time_embedding_norm == \"scale_shift\": \n", + " 653 │ │ │ scale, shift = torch.chunk(temb, 2, dim=1) \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", + " .py:1501 in _call_impl \n", + " \n", + " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", + " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", + " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", + " 1501 │ │ │ return forward_call(*args, **kwargs) \n", + " 1502 │ │ # Do not call functions when jit is used \n", + " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", + " 1504 │ │ backward_pre_hooks = [] \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/normal \n", + " ization.py:273 in forward \n", + " \n", + " 270 │ │ │ init.zeros_(self.bias) \n", + " 271 \n", + " 272 def forward(self, input: Tensor) -> Tensor: \n", + " 273 │ │ return F.group_norm( \n", + " 274 │ │ │ input, self.num_groups, self.weight, self.bias, self.eps) \n", + " 275 \n", + " 276 def extra_repr(self) -> str: \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/functional.py: \n", + " 2530 in group_norm \n", + " \n", + " 2527 if input.dim() < 2: \n", + " 2528 │ │ raise RuntimeError(f\"Expected at least 2 dimensions for input tensor but receive \n", + " 2529 _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list( \n", + " 2530 return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.e \n", + " 2531 \n", + " 2532 \n", + " 2533 def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.7 \n", + "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "OutOfMemoryError: CUDA out of memory. Tried to allocate 3.06 GiB (GPU 0; 39.56 GiB total capacity; 33.04 GiB \n", + "already allocated; 752.56 MiB free; 37.34 GiB reserved in total by PyTorch) If reserved memory is >> allocated \n", + "memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", + "PYTORCH_CUDA_ALLOC_CONF\n", + "\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m34\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 31 \u001b[0m\u001b[2m│ │ │ \u001b[0mpast_15_voxels = voxels[past_behav[:,:,\u001b[94m5\u001b[0m].cpu().long()].to(device) \u001b[2m# batch_s\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 32 \u001b[0m\u001b[2m│ │ │ \u001b[0mpast_15_times = torch.Tensor([i \u001b[94mfor\u001b[0m i \u001b[95min\u001b[0m \u001b[96mrange\u001b[0m(\u001b[94m15\u001b[0m)]).to(device) \u001b[2m# 15\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 33 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 34 \u001b[2m│ │ │ \u001b[0mblurry_image_enc = autoenc.encode(image).latent_dist.mode() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 35 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 36 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m use_image_aug: image = img_augment(image) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 37 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/utils/\u001b[0m\u001b[1;33maccelerate_utils.py\u001b[0m:\u001b[94m46\u001b[0m in \u001b[92mwrapper\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m43 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mwrapper\u001b[0m(\u001b[96mself\u001b[0m, *args, **kwargs): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m44 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mhasattr\u001b[0m(\u001b[96mself\u001b[0m, \u001b[33m\"\u001b[0m\u001b[33m_hf_hook\u001b[0m\u001b[33m\"\u001b[0m) \u001b[95mand\u001b[0m \u001b[96mhasattr\u001b[0m(\u001b[96mself\u001b[0m._hf_hook, \u001b[33m\"\u001b[0m\u001b[33mpre_forward\u001b[0m\u001b[33m\"\u001b[0m): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m45 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m._hf_hook.pre_forward(\u001b[96mself\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m46 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m method(\u001b[96mself\u001b[0m, *args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m47 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m48 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m wrapper \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m49 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33mautoencoder_kl.py\u001b[0m:\u001b[94m258\u001b[0m in \u001b[92mencode\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m255 \u001b[0m\u001b[2m│ │ │ \u001b[0mencoded_slices = [\u001b[96mself\u001b[0m.encoder(x_slice) \u001b[94mfor\u001b[0m x_slice \u001b[95min\u001b[0m x.split(\u001b[94m1\u001b[0m)] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m256 \u001b[0m\u001b[2m│ │ │ \u001b[0mh = torch.cat(encoded_slices) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m257 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m258 \u001b[2m│ │ │ \u001b[0mh = \u001b[96mself\u001b[0m.encoder(x) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m259 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m260 \u001b[0m\u001b[2m│ │ \u001b[0mmoments = \u001b[96mself\u001b[0m.quant_conv(h) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m261 \u001b[0m\u001b[2m│ │ \u001b[0mposterior = DiagonalGaussianDistribution(moments) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33mvae.py\u001b[0m:\u001b[94m141\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m138 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m139 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# down\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m140 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mfor\u001b[0m down_block \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.down_blocks: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m141 \u001b[2m│ │ │ │ \u001b[0msample = down_block(sample) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m142 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m143 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# middle\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m144 \u001b[0m\u001b[2m│ │ │ \u001b[0msample = \u001b[96mself\u001b[0m.mid_block(sample) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33munet_2d_blocks.py\u001b[0m:\u001b[94m1247\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1244 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1245 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mforward\u001b[0m(\u001b[96mself\u001b[0m, hidden_states, scale: \u001b[96mfloat\u001b[0m = \u001b[94m1.0\u001b[0m): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1246 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfor\u001b[0m resnet \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.resnets: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1247 \u001b[2m│ │ │ \u001b[0mhidden_states = resnet(hidden_states, temb=\u001b[94mNone\u001b[0m, scale=scale) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1248 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1249 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.downsamplers \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1250 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mfor\u001b[0m downsampler \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.downsamplers: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33mresnet.py\u001b[0m:\u001b[94m650\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m647 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.time_embedding_norm == \u001b[33m\"\u001b[0m\u001b[33mada_group\u001b[0m\u001b[33m\"\u001b[0m \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m.time_embedding_norm == \u001b[33m\"\u001b[0m\u001b[33mspati\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m648 \u001b[0m\u001b[2m│ │ │ \u001b[0mhidden_states = \u001b[96mself\u001b[0m.norm2(hidden_states, temb) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m649 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m650 \u001b[2m│ │ │ \u001b[0mhidden_states = \u001b[96mself\u001b[0m.norm2(hidden_states) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m651 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m652 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m temb \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m \u001b[95mand\u001b[0m \u001b[96mself\u001b[0m.time_embedding_norm == \u001b[33m\"\u001b[0m\u001b[33mscale_shift\u001b[0m\u001b[33m\"\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m653 \u001b[0m\u001b[2m│ │ │ \u001b[0mscale, shift = torch.chunk(temb, \u001b[94m2\u001b[0m, dim=\u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mnormal\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33mization.py\u001b[0m:\u001b[94m273\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m270 \u001b[0m\u001b[2m│ │ │ \u001b[0minit.zeros_(\u001b[96mself\u001b[0m.bias) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m271 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m272 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mforward\u001b[0m(\u001b[96mself\u001b[0m, \u001b[96minput\u001b[0m: Tensor) -> Tensor: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m273 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m F.group_norm( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m274 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96minput\u001b[0m, \u001b[96mself\u001b[0m.num_groups, \u001b[96mself\u001b[0m.weight, \u001b[96mself\u001b[0m.bias, \u001b[96mself\u001b[0m.eps) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m275 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m276 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mextra_repr\u001b[0m(\u001b[96mself\u001b[0m) -> \u001b[96mstr\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/\u001b[0m\u001b[1;33mfunctional.py\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[94m2530\u001b[0m in \u001b[92mgroup_norm\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2527 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mif\u001b[0m \u001b[96minput\u001b[0m.dim() < \u001b[94m2\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2528 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mRuntimeError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mExpected at least 2 dimensions for input tensor but receive\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2529 \u001b[0m\u001b[2m│ \u001b[0m_verify_batch_size([\u001b[96minput\u001b[0m.size(\u001b[94m0\u001b[0m) * \u001b[96minput\u001b[0m.size(\u001b[94m1\u001b[0m) // num_groups, num_groups] + \u001b[96mlist\u001b[0m( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2530 \u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m torch.group_norm(\u001b[96minput\u001b[0m, num_groups, weight, bias, eps, torch.backends.cudnn.e \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2531 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2532 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2533 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mlocal_response_norm\u001b[0m(\u001b[96minput\u001b[0m: Tensor, size: \u001b[96mint\u001b[0m, alpha: \u001b[96mfloat\u001b[0m = \u001b[94m1e-4\u001b[0m, beta: \u001b[96mfloat\u001b[0m = \u001b[94m0.7\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰─────────────────────────────────���────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mOutOfMemoryError: \u001b[0mCUDA out of memory. Tried to allocate \u001b[1;36m3.06\u001b[0m GiB \u001b[1m(\u001b[0mGPU \u001b[1;36m0\u001b[0m; \u001b[1;36m39.56\u001b[0m GiB total capacity; \u001b[1;36m33.04\u001b[0m GiB \n", + "already allocated; \u001b[1;36m752.56\u001b[0m MiB free; \u001b[1;36m37.34\u001b[0m GiB reserved in total by PyTorch\u001b[1m)\u001b[0m If reserved memory is >> allocated \n", + "memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", + "PYTORCH_CUDA_ALLOC_CONF\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n", + "progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))\n", + "test_image, test_voxel = None, None\n", + "mse = nn.MSELoss()\n", + "for epoch in progress_bar:\n", + " model.train()\n", + " \n", + " fwd_percent_correct = 0.\n", + " bwd_percent_correct = 0.\n", + " test_fwd_percent_correct = 0.\n", + " test_bwd_percent_correct = 0.\n", + "\n", + " loss_clip_total = 0.\n", + " loss_blurry_total = 0.\n", + " test_loss_clip_total = 0.\n", + " test_loss_blurry_total = 0.\n", + "\n", + " blurry_pixcorr = 0.\n", + " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n", + " \n", + " for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", + " #if epoch == 0 or epoch == 1:\n", + " # break\n", + " with torch.cuda.amp.autocast():\n", + " optimizer.zero_grad()\n", + "\n", + " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", + " \n", + " image = images[behav[:,0,0].cpu().long()].to(device).float()\n", + "\n", + " past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279\n", + " past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15\n", + "\n", + " blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n", + " \n", + " if use_image_aug: image = img_augment(image)\n", + " \n", + " clip_target = clip_model.embed_image(image)\n", + " assert not torch.any(torch.isnan(clip_target))\n", + " \n", + " if epoch < int(mixup_pct * num_epochs):\n", + " voxel, perm, betas, select = utils.mixco(voxel)\n", + "\n", + " # reshape past voxels to be (batch_size * 15, 15279)\n", + " past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])\n", + " past_15_times = past_15_times.repeat(voxel.shape[0], 1)\n", + " past_15_times = past_15_times.reshape(-1)\n", + " \n", + " #print(past_15_voxels.shape, past_15_times.shape)\n", + "\n", + " embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)\n", + " #print(embeds_past_voxels.shape)\n", + " embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)\n", + " #print(embeds_past_voxels.shape)\n", + " information_past_voxels = model.memory_compressor(embeds_past_voxels)\n", + "\n", + "\n", + " voxel_ridge = torch.cat([model.ridge(voxel), information_past_voxels], dim=-1)\n", + " \n", + " clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n", + " \n", + " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", + " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", + "\n", + " if epoch < int(mixup_pct * num_epochs): \n", + " loss_clip = utils.mixco_nce(\n", + " clip_voxels_norm,\n", + " clip_target_norm,\n", + " temp=.006, \n", + " perm=perm, betas=betas, select=select)\n", + " else:\n", + " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n", + " loss_clip = utils.soft_clip_loss(\n", + " clip_voxels_norm,\n", + " clip_target_norm,\n", + " temp=epoch_temp)\n", + "\n", + " loss_blurry = mse(blurry_image_enc_, blurry_image_enc) \n", + "\n", + " loss_clip_total += loss_clip.item()\n", + " loss_blurry_total += loss_blurry.item()\n", + "\n", + " loss = loss_blurry + loss_clip\n", + " \n", + " utils.check_loss(loss)\n", + "\n", + " accelerator.backward(loss)\n", + " optimizer.step()\n", + " \n", + " losses.append(loss.item())\n", + " lrs.append(optimizer.param_groups[0]['lr'])\n", + " \n", + " # forward and backward top 1 accuracy \n", + " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", + " fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n", + " bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n", + "\n", + " with torch.no_grad():\n", + " # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()\n", + " random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)\n", + " blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)\n", + " blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)\n", + "\n", + " if lr_scheduler_type is not None:\n", + " lr_scheduler.step()\n", + " \n", + " model.eval()\n", + " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", + " print('test')\n", + " with torch.cuda.amp.autocast():\n", + " with torch.no_grad(): \n", + " # all test samples should be loaded per batch such that test_i should never exceed 0\n", + " if len(behav) != num_test: print(\"!\",len(behav),num_test)\n", + "\n", + " \n", + " ## Average same-image repeats ##\n", + " if test_image is None:\n", + " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", + " \n", + " image = behav[:,0,0].cpu().long()\n", + " \n", + " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n", + " for im in unique_image:\n", + " locs = torch.where(im == image)[0]\n", + " if test_image is None:\n", + " test_image = images[im][None]\n", + " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n", + " else:\n", + " test_image = torch.vstack((test_image, images[im][None]))\n", + " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n", + " \n", + " # sample of batch_size\n", + " random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]\n", + " voxel = test_voxel[random_indices].to(device)\n", + " image = test_image[random_indices].to(device)\n", + "\n", + " current_past_behav = past_behav[random_indices]\n", + "\n", + " past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279\n", + " past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15\n", + "\n", + " assert len(image) == batch_size\n", + " \n", + " blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n", + " \n", + " clip_target = clip_model.embed_image(image.float())\n", + "\n", + " past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])\n", + " past_15_times = past_15_times.repeat(voxel.shape[0], 1)\n", + " past_15_times = past_15_times.reshape(-1)\n", + " \n", + " print(past_15_voxels.shape, past_15_times.shape)\n", + "\n", + " embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)\n", + " embeds_past_voxels = embeds_past_voxels.reshape(batch_size, 15, -1)\n", + " information_past_voxels = model.memory_compressor(embeds_past_voxels)\n", + "\n", + " \n", + " voxel_ridge = torch.cat([model.ridge(voxel), information_past_voxels], dim=-1)\n", + " \n", + " clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n", + " \n", + " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", + " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", + " \n", + " loss_clip = utils.soft_clip_loss(\n", + " clip_voxels_norm,\n", + " clip_target_norm,\n", + " temp=.006)\n", + "\n", + " loss_blurry = mse(blurry_image_enc_, blurry_image_enc)\n", + " \n", + " loss = loss_blurry + loss_clip\n", + " \n", + " utils.check_loss(loss)\n", + " \n", + " test_losses.append(loss.item())\n", + " \n", + " # forward and backward top 1 accuracy \n", + " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", + " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n", + " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n", + "\n", + " # halving the batch size because the decoder is computationally heavy\n", + " blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)\n", + " blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))\n", + " test_blurry_pixcorr += pixcorr(image, blurry_recon_images)\n", + "\n", + " # transform blurry recon latents to images and plot it\n", + " fig, axes = plt.subplots(1, 4, figsize=(8, 4))\n", + " axes[0].imshow(utils.torch_to_Image(image[[0]]))\n", + " axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))\n", + " axes[2].imshow(utils.torch_to_Image(image[[1]]))\n", + " axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))\n", + " axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')\n", + " plt.show()\n", + "\n", + " if local_rank==0: \n", + " # if utils.is_interactive(): clear_output(wait=True)\n", + " assert (test_i+1) == 1\n", + " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n", + " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n", + " \"train/lr\": lrs[-1],\n", + " \"train/num_steps\": len(losses),\n", + " \"test/num_steps\": len(test_losses),\n", + " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n", + " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n", + " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n", + " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n", + " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n", + " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n", + " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n", + " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n", + " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n", + " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n", + " }\n", + " progress_bar.set_postfix(**logs)\n", + "\n", + " # Save model checkpoint and reconstruct\n", + " if epoch % ckpt_interval == 0:\n", + " if not utils.is_interactive():\n", + " save_ckpt(f'last')\n", + " \n", + " if wandb_log: wandb.log(logs)\n", + "\n", + " # wait for other GPUs to catch up if needed\n", + " accelerator.wait_for_everyone()\n", + " torch.cuda.empty_cache()\n", + " gc.collect()\n", + "\n", + "print(\"\\n===Finished!===\\n\")\n", + "if ckpt_saving:\n", + " save_ckpt(f'last')\n", + "if not utils.is_interactive():\n", + " sys.exit(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plt.plot(losses)\n", + "plt.show()\n", + "plt.plot(test_losses)\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + }, + "toc-autonumbering": true, + "vscode": { + "interpreter": { + "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}