diff --git "a/src/Train_MLPMixer.ipynb" "b/src/Train_MLPMixer.ipynb" new file mode 100644--- /dev/null +++ "b/src/Train_MLPMixer.ipynb" @@ -0,0 +1,2476 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", + "# from subprocess import call\n", + "# command = \"jupyter nbconvert Train_with_autoencoder_MLPMixer.ipynb --to python\"\n", + "# call(command,shell=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b0f0f4f3", + "metadata": {}, + "source": [ + "# Import packages & functions" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5bad764b-45c1-45ce-a716-8d055e09821a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-10-28 20:46:20,021] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "import json\n", + "import argparse\n", + "import numpy as np\n", + "import math\n", + "from einops import rearrange\n", + "import time\n", + "import random\n", + "import string\n", + "import h5py\n", + "from tqdm import tqdm\n", + "\n", + "import webdataset as wds\n", + "import gc\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import torch\n", + "import torch.nn as nn\n", + "from torchvision import transforms\n", + "\n", + "from accelerate import Accelerator, DeepSpeedPlugin\n", + "\n", + "# tf32 data type is faster than standard float32\n", + "torch.backends.cuda.matmul.allow_tf32 = True\n", + "\n", + "# custom functions #\n", + "import utils" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c0267850-3785-4be6-b134-b2a52bf55113", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "utils.mixco" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LOCAL RANK 0\n", + "Setting batch_size to 128\n", + "[2023-10-28 20:46:28,070] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented\n", + "[2023-10-28 20:46:28,071] [INFO] [comm.py:594:init_distributed] cdb=None\n", + "[2023-10-28 20:46:28,071] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n" + ] + } + ], + "source": [ + "### Multi-GPU config ###\n", + "local_rank = os.getenv('RANK')\n", + "if local_rank is None: \n", + " local_rank = 0\n", + "else:\n", + " local_rank = int(local_rank)\n", + "print(\"LOCAL RANK \", local_rank) \n", + "\n", + "num_devices = torch.cuda.device_count()\n", + "if num_devices==0: num_devices = 1\n", + "\n", + "# ## UNCOMMENT BELOW SECTION AND COMMENT OUT DEEPSPEED SECTION TO AVOID USING DEEPSPEED ###\n", + "# accelerator = Accelerator(split_batches=False, mixed_precision=\"fp16\")\n", + "# global_batch_size = batch_size = 128\n", + "# data_type = torch.float16 # change depending on your mixed_precision\n", + "\n", + "### DEEPSPEED INITIALIZATION ###\n", + "if num_devices <= 1 and utils.is_interactive():\n", + " global_batch_size = batch_size = 128\n", + " print(f\"Setting batch_size to {batch_size}\")\n", + " # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", + " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", + " os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", + " os.environ[\"RANK\"] = \"0\"\n", + " os.environ[\"LOCAL_RANK\"] = \"0\"\n", + " os.environ[\"WORLD_SIZE\"] = \"1\"\n", + " os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", + "else:\n", + " global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"] \n", + " batch_size = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", + "\n", + "# alter the deepspeed config according to your global and local batch size\n", + "if local_rank == 0:\n", + " with open('deepspeed_config_stage2.json', 'r') as file:\n", + " config = json.load(file)\n", + " config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", + " config['train_micro_batch_size_per_gpu'] = batch_size\n", + " config['bf16'] = {'enabled': False}\n", + " config['fp16'] = {'enabled': True}\n", + " with open('deepspeed_config_stage2.json', 'w') as file:\n", + " json.dump(config, file)\n", + "else:\n", + " # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", + " time.sleep(10)\n", + "deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", + "accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PID of this process = 1896724\n", + "device: cuda:0\n", + "Distributed environment: DEEPSPEED Backend: nccl\n", + "Num processes: 1\n", + "Process index: 0\n", + "Local process index: 0\n", + "Device: cuda:0\n", + "\n", + "Mixed precision type: fp16\n", + "ds_config: {'bf16': {'enabled': False}, 'fp16': {'enabled': True}, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'stage3_gather_16bit_weights_on_model_save': True, 'stage3_max_live_parameters': 1000000000.0, 'stage3_max_reuse_distance': 1000000000.0, 'stage3_prefetch_bucket_size': 10000000.0, 'stage3_param_persistence_threshold': 100000.0, 'reduce_bucket_size': 10000000.0, 'sub_group_size': 1000000000.0, 'offload_optimizer': {'device': 'none', 'nvme_path': '/scratch', 'pin_memory': True}, 'offload_param': {'device': 'none', 'nvme_path': '/scratch', 'buffer_size': 4000000000.0, 'pin_memory': True}}, 'aio': {'block_size': 26214400, 'queue_depth': 64, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}, 'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'steps_per_print': inf, 'train_batch_size': 128, 'train_micro_batch_size_per_gpu': 128, 'wall_clock_breakdown': False, 'zero_allow_untested_optimizer': True}\n", + "\n", + "distributed = True num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16\n" + ] + } + ], + "source": [ + "print(\"PID of this process =\",os.getpid())\n", + "device = accelerator.device\n", + "print(\"device:\",device)\n", + "num_workers = num_devices\n", + "print(accelerator.state)\n", + "world_size = accelerator.state.num_processes\n", + "distributed = not accelerator.state.distributed_type == 'NO'\n", + "\n", + "# set data_type to match your mixed precision (automatically set based on deepspeed config)\n", + "if accelerator.mixed_precision == \"bf16\":\n", + " data_type = torch.bfloat16\n", + "elif accelerator.mixed_precision == \"fp16\":\n", + " data_type = torch.float16\n", + "else:\n", + " data_type = torch.float32\n", + "\n", + "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size, \"data_type =\", data_type)\n", + "print = accelerator.print # only print if local_rank=0" + ] + }, + { + "cell_type": "markdown", + "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", + "metadata": { + "tags": [] + }, + "source": [ + "# Configurations" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model_name: 0qiAxQoaKN_interactive_bsl\n", + "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=0qiAxQoaKN_interactive_bsl', '--subj=1', '--batch_size=128', '--no-blurry_recon', '--no-depth_recon', '--hidden_dim=4096', '--clip_scale=1.', '--blur_scale=100.', '--depth_scale=100.', '--max_lr=3e-4', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug', '--no-ckpt_saving']\n" + ] + } + ], + "source": [ + "# if running this interactively, can specify jupyter_args here for argparser to use\n", + "if utils.is_interactive():\n", + " # create random model_name\n", + " model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))\n", + " model_name = model_name + \"_interactive_bsl\"\n", + " print(\"model_name:\", model_name)\n", + "\n", + " # global_batch_size and batch_size should already be defined in the above cells\n", + " # other variables can be specified in the following string:\n", + " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", + " --model_name={model_name} \\\n", + " --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \\\n", + " --clip_scale=1. --blur_scale=100. --depth_scale=100. \\\n", + " --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving\"\n", + "\n", + " jupyter_args = jupyter_args.split()\n", + " print(jupyter_args)\n", + " \n", + " from IPython.display import clear_output # function to clear print outputs in cell\n", + " %load_ext autoreload \n", + " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", + " %autoreload 2 " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", + "parser.add_argument(\n", + " \"--model_name\", type=str, default=\"testing\",\n", + " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", + ")\n", + "parser.add_argument(\n", + " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", + " help=\"Path to where NSD data is stored / where to download it to\",\n", + ")\n", + "parser.add_argument(\n", + " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", + ")\n", + "parser.add_argument(\n", + " \"--batch_size\", type=int, default=32,\n", + " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", + ")\n", + "parser.add_argument(\n", + " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=True,\n", + " help=\"whether to log to wandb\",\n", + ")\n", + "parser.add_argument(\n", + " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", + " help=\"if not using wandb and want to resume from a ckpt\",\n", + ")\n", + "parser.add_argument(\n", + " \"--wandb_project\",type=str,default=\"stability\",\n", + " help=\"wandb project name\",\n", + ")\n", + "parser.add_argument(\n", + " \"--mixup_pct\",type=float,default=.33,\n", + " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", + ")\n", + "parser.add_argument(\n", + " \"--blurry_recon\",action=argparse.BooleanOptionalAction,default=True,\n", + " help=\"whether to output blurry reconstructions\",\n", + ")\n", + "parser.add_argument(\n", + " \"--depth_recon\",action=argparse.BooleanOptionalAction,default=True,\n", + " help=\"whether to output depth reconstructions\",\n", + ")\n", + "parser.add_argument(\n", + " \"--blur_scale\",type=float,default=100.,\n", + " help=\"multiply loss from blurry recons by this number\",\n", + ")\n", + "parser.add_argument(\n", + " \"--depth_scale\",type=float,default=100.,\n", + " help=\"multiply loss from depth recons by this number\",\n", + ")\n", + "parser.add_argument(\n", + " \"--clip_scale\",type=float,default=1.,\n", + " help=\"multiply contrastive loss by this number\",\n", + ")\n", + "parser.add_argument(\n", + " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", + " help=\"whether to use image augmentation\",\n", + ")\n", + "parser.add_argument(\n", + " \"--num_epochs\",type=int,default=120,\n", + " help=\"number of epochs of training\",\n", + ")\n", + "parser.add_argument(\n", + " \"--hidden_dim\",type=int,default=4096,\n", + ")\n", + "parser.add_argument(\n", + " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", + ")\n", + "parser.add_argument(\n", + " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", + ")\n", + "parser.add_argument(\n", + " \"--ckpt_interval\",type=int,default=5,\n", + " help=\"save backup ckpt and reconstruct every x epochs\",\n", + ")\n", + "parser.add_argument(\n", + " \"--seed\",type=int,default=42,\n", + ")\n", + "parser.add_argument(\n", + " \"--max_lr\",type=float,default=3e-4,\n", + ")\n", + "\n", + "if utils.is_interactive():\n", + " args = parser.parse_args(jupyter_args)\n", + "else:\n", + " args = parser.parse_args()\n", + "\n", + "# create global variables without the args prefix\n", + "for attribute_name in vars(args).keys():\n", + " globals()[attribute_name] = getattr(args, attribute_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", + "if not os.path.exists(outdir) and ckpt_saving:\n", + " os.makedirs(outdir,exist_ok=True)\n", + "if use_image_aug:\n", + " import kornia\n", + " from kornia.augmentation.container import AugmentationSequential\n", + " img_augment = AugmentationSequential(\n", + " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", + " kornia.augmentation.Resize((224, 224)),\n", + " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", + " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", + " kornia.augmentation.RandomGrayscale(p=0.3),\n", + " same_on_batch=False,\n", + " data_keys=[\"input\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "42d13c25-1369-4c49-81d4-83d713586096", + "metadata": { + "tags": [] + }, + "source": [ + "# Prep data, models, and dataloaders" + ] + }, + { + "cell_type": "markdown", + "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", + "metadata": {}, + "source": [ + "## Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "81084834-035f-4465-ad59-59e6b806a2f5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", + "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" + ] + } + ], + "source": [ + "if subj==1:\n", + " num_train = 24958\n", + " num_test = 2770\n", + "test_batch_size = num_test\n", + "\n", + "def my_split_by_node(urls): return urls\n", + " \n", + "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", + "# train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..1}.tar\"\n", + "print(train_url)\n", + "\n", + "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", + " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", + " .decode(\"torch\")\\\n", + " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", + " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", + "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)\n", + "\n", + "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", + "print(test_url)\n", + "\n", + "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", + " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", + " .decode(\"torch\")\\\n", + " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", + " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", + "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)" + ] + }, + { + "cell_type": "markdown", + "id": "203b060a-2dd2-4c35-929b-c576be82eb52", + "metadata": {}, + "source": [ + "### check dataloaders are working" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 2770 2770\n", + "---\n", + "\n", + "194 24960 24960\n" + ] + } + ], + "source": [ + "test_vox_indices = []\n", + "test_73k_images = []\n", + "for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", + " test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())\n", + " test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())\n", + "test_vox_indices = test_vox_indices.astype(np.int16)\n", + "print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))\n", + "print(\"---\\n\")\n", + "\n", + "train_vox_indices = []\n", + "train_73k_images = []\n", + "for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", + " train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())\n", + " train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())\n", + "train_vox_indices = train_vox_indices.astype(np.int16)\n", + "print(train_i, (train_i+1) * batch_size, len(train_vox_indices))" + ] + }, + { + "cell_type": "markdown", + "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", + "metadata": {}, + "source": [ + "## Load data and images" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "039dd330-7339-4f88-8f00-45f95e47baa0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "subj01 betas loaded into memory\n", + "voxels torch.Size([27750, 15724])\n", + "images torch.Size([73000, 3, 224, 224])\n" + ] + } + ], + "source": [ + "# load betas\n", + "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", + "# f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')\n", + "\n", + "voxels = f['betas'][:]\n", + "print(f\"subj0{subj} betas loaded into memory\")\n", + "voxels = torch.Tensor(voxels).to(\"cpu\").to(data_type)\n", + "print(\"voxels\", voxels.shape)\n", + "num_voxels = voxels.shape[-1]\n", + "\n", + "# load orig images\n", + "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", + "images = f['images'][:]\n", + "images = torch.Tensor(images).to(\"cpu\").to(data_type)\n", + "print(\"images\", images.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", + "metadata": {}, + "source": [ + "## Load models" + ] + }, + { + "cell_type": "markdown", + "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", + "metadata": {}, + "source": [ + "### CLIP image embeddings model" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b0420dc0-199e-4c1a-857d-b1747058b467", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ViT-L/14 cuda:0\n" + ] + } + ], + "source": [ + "from models import Clipper\n", + "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", + "clip_seq_dim = 257\n", + "clip_emb_dim = 768 #1024\n", + "# hidden_dim = 4096\n", + "seq_len = 1 #2 #32 " + ] + }, + { + "cell_type": "markdown", + "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", + "metadata": {}, + "source": [ + "### SD VAE" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "01baff79-8114-482b-b115-6f05aa8ad691", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# if blurry_recon:\n", + "# from diffusers import AutoencoderKL\n", + "# autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", + "# # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", + "# autoenc.eval()\n", + "# autoenc.requires_grad_(False)\n", + "# autoenc.to(device)\n", + "# utils.count_params(autoenc)\n", + "\n", + "if blurry_recon:# or depth_recon:\n", + " from diffusers import VQModel\n", + " autoenc = VQModel.from_pretrained(\"/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae\", torch_dtype=data_type)\n", + " autoenc.eval()\n", + " autoenc.requires_grad_(False)\n", + " autoenc.to(device)\n", + " utils.count_params(autoenc)" + ] + }, + { + "cell_type": "markdown", + "id": "120c8eee-9834-437d-bb60-b38faef50138", + "metadata": {}, + "source": [ + "#### downsampled images" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6d1ba8dd-64c2-4ac9-947e-725b7f2e3e50", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if blurry_recon:\n", + " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", + "\n", + " input_batch = images[[30]].to(device)\n", + " print(input_batch.shape)\n", + "\n", + " downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)\n", + " re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')\n", + " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", + " print(re_upsampled_enc.shape)\n", + " \n", + " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" + ] + }, + { + "cell_type": "markdown", + "id": "6390a3a8-2bef-4e81-9b82-e154d26a1e1d", + "metadata": {}, + "source": [ + "#### MiDaS depth" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f35573e2-95bf-463d-8937-68ad4c2c3c20", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if depth_recon:\n", + " from controlnet_aux.midas import MidasDetector\n", + " \n", + " midas_depth = MidasDetector.from_pretrained(\n", + " \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\", cache_dir=\"/fsx/proj-fmri/shared/cache\").to(device)\n", + " midas_depth.model.eval()\n", + " midas_depth.model.requires_grad_(False)\n", + " midas_depth.model.to(device)\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ba3f9207-b98e-45da-baa6-5cfcfb2ae958", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if depth_recon:\n", + " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", + "\n", + " input_batch = images[[30,31]].float().to(device)\n", + " print(input_batch.shape)\n", + " \n", + " midas_emb = midas_depth.model(input_batch).unsqueeze(1)\n", + " print(midas_emb.shape)\n", + "\n", + " prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()\n", + " print(prediction.shape)\n", + " \n", + " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", + " midas_emb_size = prediction.flatten(1).shape[1]\n", + " print(\"midas_emb\", prediction.shape, prediction.min(), prediction.max())\n", + " print(\"midas_emb_size\", midas_emb_size)\n", + " \n", + " if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) \n", + "\n", + " if blurry_recon:\n", + " prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)\n", + " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", + " prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215\n", + " print(\"vae midas_emb\", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())\n", + " \n", + " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" + ] + }, + { + "cell_type": "markdown", + "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", + "metadata": {}, + "source": [ + "### MindEye modules" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "MindEyeModule()" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class MindEyeModule(nn.Module):\n", + " def __init__(self):\n", + " super(MindEyeModule, self).__init__()\n", + " def forward(self, x):\n", + " return x\n", + " \n", + "model = MindEyeModule()\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "64,409,600 total\n", + "64,409,600 trainable\n", + "param counts:\n", + "64,409,600 total\n", + "64,409,600 trainable\n", + "torch.Size([2, 1, 15724]) torch.Size([2, 1, 4096])\n" + ] + } + ], + "source": [ + "class RidgeRegression(torch.nn.Module):\n", + " # make sure to add weight_decay when initializing optimizer\n", + " def __init__(self, input_size, out_features): \n", + " super(RidgeRegression, self).__init__()\n", + " self.out_features = out_features\n", + " self.linear = torch.nn.Linear(input_size, out_features)\n", + " def forward(self, x):\n", + " return self.linear(x)\n", + " \n", + "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", + "utils.count_params(model.ridge)\n", + "utils.count_params(model)\n", + "\n", + "b = torch.randn((2,1,voxels.shape[1]))\n", + "print(b.shape, model.ridge(b).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "param counts:\n", + "950,287,384 total\n", + "950,287,384 trainable\n", + "param counts:\n", + "1,014,696,984 total\n", + "1,014,696,984 trainable\n", + "b.shape torch.Size([2, 1, 4096])\n", + "torch.Size([2, 257, 768]) torch.Size([1]) torch.Size([1])\n" + ] + } + ], + "source": [ + "from functools import partial\n", + "from diffusers.models.vae import Decoder\n", + "class BrainNetwork(nn.Module):\n", + " def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):\n", + " super().__init__()\n", + " self.seq_len = seq_len\n", + " self.h = h\n", + " self.clip_size = clip_size\n", + " \n", + " # Initial linear layer to match the input dimensions to hidden dimensions\n", + " # self.lin0 = nn.Linear(in_dim, seq_len * h)\n", + " \n", + " # Mixer Blocks\n", + " self.mixer_blocks1 = nn.ModuleList([\n", + " self.mixer_block1(h, drop) for _ in range(n_blocks)\n", + " ])\n", + " self.mixer_blocks2 = nn.ModuleList([\n", + " self.mixer_block2(seq_len, drop) for _ in range(n_blocks)\n", + " ])\n", + " \n", + " # Output linear layer\n", + " self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)\n", + "\n", + " # low-rank matrices\n", + " # self.rank = 500\n", + " # self.U = nn.Parameter(torch.randn(self.rank, out_dim))\n", + " # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))\n", + " # self.S = nn.Parameter(torch.randn(out_dim))\n", + "\n", + " self.clip_proj = nn.Sequential(\n", + " nn.LayerNorm(clip_size),\n", + " nn.GELU(),\n", + " nn.Linear(clip_size, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, 2048),\n", + " nn.LayerNorm(2048),\n", + " nn.GELU(),\n", + " nn.Linear(2048, clip_size)\n", + " )\n", + "\n", + " if blurry_recon:\n", + " # self.blin1 = nn.Sequential(\n", + " # nn.Linear(out_dim, 4096, bias=True),\n", + " # nn.LayerNorm(4096),\n", + " # nn.GELU(),\n", + " # nn.Linear(4096, 4096))\n", + " self.blin1 = nn.Linear(h*seq_len, 4096)\n", + " self.bgroupnorm = nn.GroupNorm(1, 256)\n", + " self.bupsampler = Decoder(\n", + " in_channels=256,\n", + " out_channels=128,\n", + " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", + " block_out_channels=[32, 64, 128],\n", + " layers_per_block=1,\n", + " )\n", + "\n", + " if depth_recon:\n", + " # self.dlin1 = nn.Sequential(\n", + " # nn.Linear(h, midas_emb_size),\n", + " # nn.Sigmoid(),\n", + " # )\n", + " self.dlin1 = nn.Linear(h*seq_len, 4096)\n", + " self.dgroupnorm = nn.GroupNorm(1, 256)\n", + " self.dupsampler = Decoder(\n", + " in_channels=256,\n", + " out_channels=1,#128,\n", + " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", + " block_out_channels=[32, 64, 128, 256],\n", + " layers_per_block=1,\n", + " )\n", + " \n", + " def mixer_block1(self, h, drop):\n", + " return nn.Sequential(\n", + " nn.LayerNorm(h),\n", + " self.mlp(h, h, drop), # Token mixing\n", + " )\n", + "\n", + " def mixer_block2(self, seq_len, drop):\n", + " return nn.Sequential(\n", + " nn.LayerNorm(seq_len),\n", + " self.mlp(seq_len, seq_len, drop) # Channel mixing\n", + " )\n", + " \n", + " def mlp(self, in_dim, out_dim, drop):\n", + " return nn.Sequential(\n", + " nn.Linear(in_dim, out_dim),\n", + " nn.GELU(),\n", + " nn.Dropout(drop),\n", + " nn.Linear(out_dim, out_dim),\n", + " )\n", + " \n", + " def forward(self, x):\n", + " # make empty tensors for blur and depth outputs\n", + " b,d = torch.Tensor([0.]), torch.Tensor([0.])\n", + " \n", + " # Initial linear layer\n", + " # x = self.lin0(x)\n", + " \n", + " # Reshape to seq_len by dim\n", + " # x = x.reshape(-1, self.seq_len, self.h)\n", + " \n", + " # Mixer blocks\n", + " residual1 = x\n", + " residual2 = x.permute(0,2,1)\n", + " for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):\n", + " x = block1(x) + residual1\n", + " residual1 = x\n", + " x = x.permute(0,2,1)\n", + " \n", + " x = block2(x) + residual2\n", + " residual2 = x\n", + " x = x.permute(0,2,1)\n", + " \n", + " # Flatten\n", + " x = x.reshape(x.size(0), -1)\n", + " \n", + " c = self.clin1(x)\n", + "\n", + " # low rank linear to out dim cuts # params by nearly half compared to full linear mapping\n", + " # c = (x @ (self.V/100) @ (self.U/100)) + self.S\n", + " \n", + " c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))\n", + "\n", + " if blurry_recon:\n", + " b = self.blin1(x)\n", + " b = b.reshape(len(b), 256, 4, 4)\n", + " b = self.bgroupnorm(b)\n", + " b = self.bupsampler(b)\n", + " \n", + " if depth_recon:\n", + " d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)\n", + " d = d.reshape(len(d), 256, 4, 4)\n", + " d = self.dgroupnorm(d)\n", + " d = self.dupsampler(d)\n", + " \n", + " return c, b, d\n", + "\n", + "model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) \n", + "utils.count_params(model.backbone)\n", + "utils.count_params(model)\n", + "\n", + "# test that the model works on some fake data\n", + "b = torch.randn((2,seq_len,hidden_dim))\n", + "print(\"b.shape\",b.shape)\n", + "with torch.no_grad():\n", + " clip_, blur_, depth_ = model.backbone(b)\n", + "print(clip_.shape, blur_.shape, depth_.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total_steps 2339\n", + "\n", + "Done with model preparations!\n", + "param counts:\n", + "1,014,696,984 total\n", + "1,014,696,984 trainable\n" + ] + } + ], + "source": [ + "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", + "opt_grouped_parameters = [\n", + " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", + " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", + " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", + "]\n", + "\n", + "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n", + "\n", + "if lr_scheduler_type == 'linear':\n", + " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", + " optimizer,\n", + " total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),\n", + " last_epoch=-1\n", + " )\n", + "elif lr_scheduler_type == 'cycle':\n", + " total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))\n", + " print(\"total_steps\", total_steps)\n", + " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", + " optimizer, \n", + " max_lr=max_lr,\n", + " total_steps=total_steps,\n", + " final_div_factor=1000,\n", + " last_epoch=-1, pct_start=2/num_epochs\n", + " )\n", + " \n", + "def save_ckpt(tag): \n", + " ckpt_path = outdir+f'/{tag}.pth'\n", + " print(f'saving {ckpt_path}',flush=True)\n", + " unwrapped_model = accelerator.unwrap_model(model)\n", + " try:\n", + " torch.save({\n", + " 'epoch': epoch,\n", + " 'model_state_dict': unwrapped_model.state_dict(),\n", + " 'optimizer_state_dict': optimizer.state_dict(),\n", + " 'lr_scheduler': lr_scheduler.state_dict(),\n", + " 'train_losses': losses,\n", + " 'test_losses': test_losses,\n", + " 'lrs': lrs,\n", + " }, ckpt_path)\n", + " except:\n", + " print(\"Couldn't save... moving on to prevent crashing.\")\n", + " del unwrapped_model\n", + " \n", + "print(\"\\nDone with model preparations!\")\n", + "utils.count_params(model)" + ] + }, + { + "cell_type": "markdown", + "id": "983f458b-35b8-49f2-b6db-80296cece730", + "metadata": {}, + "source": [ + "# Weights and Biases" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb mindeyev2 run 0qiAxQoaKN_interactive_bsl\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "wandb_config:\n", + " {'model_name': '0qiAxQoaKN_interactive_bsl', 'global_batch_size': 128, 'batch_size': 128, 'num_epochs': 12, 'clip_scale': 1.0, 'blur_scale': 100.0, 'use_image_aug': False, 'max_lr': 0.0003, 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'ckpt_interval': 999, 'ckpt_saving': False, 'seed': 42, 'distributed': True, 'num_devices': 1, 'world_size': 1, 'train_url': '/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar', 'test_url': '/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar'}\n", + "wandb_id: 0qiAxQoaKN_interactive_bsl\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.15.12 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.15.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231028_204841-0qiAxQoaKN_interactive_bsl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 0qiAxQoaKN_interactive_bsl to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://stability.wandb.io/ckadirt/mindeyev2" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://stability.wandb.io/ckadirt/mindeyev2/runs/0qiAxQoaKN_interactive_bsl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "wandb_log = True\n", + "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", + " import wandb\n", + " wandb_project = 'mindeyev2'\n", + " wandb_run = model_name\n", + " wandb_notes = ''\n", + " \n", + " print(f\"wandb {wandb_project} run {wandb_run}\")\n", + " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", + " wandb_config = {\n", + " \"model_name\": model_name,\n", + " \"global_batch_size\": global_batch_size,\n", + " \"batch_size\": batch_size,\n", + " \"num_epochs\": num_epochs,\n", + " \"clip_scale\": clip_scale,\n", + " \"blur_scale\": blur_scale,\n", + " \"use_image_aug\": use_image_aug,\n", + " \"max_lr\": max_lr,\n", + " \"mixup_pct\": mixup_pct,\n", + " \"num_train\": num_train,\n", + " \"num_test\": num_test,\n", + " \"ckpt_interval\": ckpt_interval,\n", + " \"ckpt_saving\": ckpt_saving,\n", + " \"seed\": seed,\n", + " \"distributed\": distributed,\n", + " \"num_devices\": num_devices,\n", + " \"world_size\": world_size,\n", + " \"train_url\": train_url,\n", + " \"test_url\": test_url,\n", + " }\n", + " print(\"wandb_config:\\n\",wandb_config)\n", + " if True: # wandb_auto_resume\n", + " print(\"wandb_id:\",model_name)\n", + " wandb.init(\n", + " id = model_name,\n", + " project=wandb_project,\n", + " name=wandb_run,\n", + " config=wandb_config,\n", + " notes=wandb_notes,\n", + " resume=\"allow\",\n", + " )\n", + " else:\n", + " wandb.init(\n", + " project=wandb_project,\n", + " name=wandb_run,\n", + " config=wandb_config,\n", + " notes=wandb_notes,\n", + " )\n", + "else:\n", + " wandb_log = False" + ] + }, + { + "cell_type": "markdown", + "id": "d5690151-2131-4918-b750-e869cbd1a8a8", + "metadata": {}, + "source": [ + "# Main" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", + "metadata": {}, + "outputs": [], + "source": [ + "epoch = 0\n", + "losses, test_losses, lrs = [], [], []\n", + "best_test_loss = 1e9\n", + "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", + "\n", + "# Optionally resume from checkpoint #\n", + "if resume_from_ckpt:\n", + " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", + " try:\n", + " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", + " except:\n", + " print('last.pth failed... trying last_backup.pth')\n", + " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", + " epoch = checkpoint['epoch']\n", + " print(\"Epoch\",epoch)\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", + " model.load_state_dict(checkpoint['model_state_dict'])\n", + " del checkpoint\n", + "elif wandb_log:\n", + " if wandb.run.resumed:\n", + " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", + " try:\n", + " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", + " except:\n", + " print('last.pth failed... trying last_backup.pth')\n", + " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", + " epoch = checkpoint['epoch']\n", + " print(\"Epoch\",epoch)\n", + " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", + " model.load_state_dict(checkpoint['model_state_dict'])\n", + " del checkpoint\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-10-28 20:48:51,902] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.9.5, git-hash=unknown, git-branch=unknown\n", + "[2023-10-28 20:48:53,263] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False\n", + "[2023-10-28 20:48:53,265] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer\n", + "[2023-10-28 20:48:53,266] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimizer\n", + "[2023-10-28 20:48:53,267] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW\n", + "[2023-10-28 20:48:53,268] [INFO] [utils.py:54:is_zero_supported_optimizer] Checking ZeRO support for optimizer=AdamW type=\n", + "[2023-10-28 20:48:53,269] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.float16 ZeRO stage 2 optimizer\n", + "[2023-10-28 20:48:53,269] [INFO] [stage_1_and_2.py:133:__init__] Reduce bucket size 10000000\n", + "[2023-10-28 20:48:53,270] [INFO] [stage_1_and_2.py:134:__init__] Allgather bucket size 500,000,000\n", + "[2023-10-28 20:48:53,271] [INFO] [stage_1_and_2.py:135:__init__] CPU Offload: False\n", + "[2023-10-28 20:48:53,271] [INFO] [stage_1_and_2.py:136:__init__] Round robin gradient partitioning: False\n", + "Rank: 0 partition count [1, 1, 1] and sizes[(64409600, False), (950031116, False), (256268, False)] \n", + "[2023-10-28 20:48:55,761] [INFO] [utils.py:785:see_memory_usage] Before initializing optimizer states\n", + "[2023-10-28 20:48:55,763] [INFO] [utils.py:786:see_memory_usage] MA 7.68 GB Max_MA 7.68 GB CA 7.71 GB Max_CA 8 GB \n", + "[2023-10-28 20:48:55,764] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 72.0 GB, percent = 6.4%\n", + "[2023-10-28 20:48:55,940] [INFO] [utils.py:785:see_memory_usage] After initializing optimizer states\n", + "[2023-10-28 20:48:55,941] [INFO] [utils.py:786:see_memory_usage] MA 15.24 GB Max_MA 26.1 GB CA 26.62 GB Max_CA 27 GB \n", + "[2023-10-28 20:48:55,942] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 72.0 GB, percent = 6.4%\n", + "[2023-10-28 20:48:55,943] [INFO] [stage_1_and_2.py:488:__init__] optimizer state initialized\n", + "[2023-10-28 20:48:56,073] [INFO] [utils.py:785:see_memory_usage] After initializing ZeRO optimizer\n", + "[2023-10-28 20:48:56,074] [INFO] [utils.py:786:see_memory_usage] MA 15.24 GB Max_MA 15.24 GB CA 26.62 GB Max_CA 27 GB \n", + "[2023-10-28 20:48:56,075] [INFO] [utils.py:793:see_memory_usage] CPU Virtual Memory: used = 71.99 GB, percent = 6.4%\n", + "[2023-10-28 20:48:56,078] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = AdamW\n", + "[2023-10-28 20:48:56,078] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed using client LR scheduler\n", + "[2023-10-28 20:48:56,079] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed LR Scheduler = None\n", + "[2023-10-28 20:48:56,080] [INFO] [logging.py:96:log_dist] [Rank 0] step=0, skipped=0, lr=[1.200000000000002e-05, 1.200000000000002e-05, 1.200000000000002e-05], mom=[(0.95, 0.999), (0.95, 0.999), (0.95, 0.999)]\n", + "[2023-10-28 20:48:56,081] [INFO] [config.py:960:print] DeepSpeedEngine configuration:\n", + "[2023-10-28 20:48:56,082] [INFO] [config.py:964:print] activation_checkpointing_config {\n", + " \"partition_activations\": false, \n", + " \"contiguous_memory_optimization\": false, \n", + " \"cpu_checkpointing\": false, \n", + " \"number_checkpoints\": null, \n", + " \"synchronize_checkpoint_boundary\": false, \n", + " \"profile\": false\n", + "}\n", + "[2023-10-28 20:48:56,082] [INFO] [config.py:964:print] aio_config ................... {'block_size': 26214400, 'queue_depth': 64, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}\n", + "[2023-10-28 20:48:56,083] [INFO] [config.py:964:print] amp_enabled .................. False\n", + "[2023-10-28 20:48:56,084] [INFO] [config.py:964:print] amp_params ................... False\n", + "[2023-10-28 20:48:56,085] [INFO] [config.py:964:print] autotuning_config ............ {\n", + " \"enabled\": false, \n", + " \"start_step\": null, \n", + " \"end_step\": null, \n", + " \"metric_path\": null, \n", + " \"arg_mappings\": null, \n", + " \"metric\": \"throughput\", \n", + " \"model_info\": null, \n", + " \"results_dir\": \"autotuning_results\", \n", + " \"exps_dir\": \"autotuning_exps\", \n", + " \"overwrite\": true, \n", + " \"fast\": true, \n", + " \"start_profile_step\": 3, \n", + " \"end_profile_step\": 5, \n", + " \"tuner_type\": \"gridsearch\", \n", + " \"tuner_early_stopping\": 5, \n", + " \"tuner_num_trials\": 50, \n", + " \"model_info_path\": null, \n", + " \"mp_size\": 1, \n", + " \"max_train_batch_size\": null, \n", + " \"min_train_batch_size\": 1, \n", + " \"max_train_micro_batch_size_per_gpu\": 1.024000e+03, \n", + " \"min_train_micro_batch_size_per_gpu\": 1, \n", + " \"num_tuning_micro_batch_sizes\": 3\n", + "}\n", + "[2023-10-28 20:48:56,085] [INFO] [config.py:964:print] bfloat16_enabled ............. False\n", + "[2023-10-28 20:48:56,086] [INFO] [config.py:964:print] checkpoint_parallel_write_pipeline False\n", + "[2023-10-28 20:48:56,087] [INFO] [config.py:964:print] checkpoint_tag_validation_enabled True\n", + "[2023-10-28 20:48:56,087] [INFO] [config.py:964:print] checkpoint_tag_validation_fail False\n", + "[2023-10-28 20:48:56,088] [INFO] [config.py:964:print] comms_config ................. \n", + "[2023-10-28 20:48:56,088] [INFO] [config.py:964:print] communication_data_type ...... None\n", + "[2023-10-28 20:48:56,089] [INFO] [config.py:964:print] compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}\n", + "[2023-10-28 20:48:56,090] [INFO] [config.py:964:print] curriculum_enabled_legacy .... False\n", + "[2023-10-28 20:48:56,091] [INFO] [config.py:964:print] curriculum_params_legacy ..... False\n", + "[2023-10-28 20:48:56,091] [INFO] [config.py:964:print] data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}\n", + "[2023-10-28 20:48:56,092] [INFO] [config.py:964:print] data_efficiency_enabled ...... False\n", + "[2023-10-28 20:48:56,093] [INFO] [config.py:964:print] dataloader_drop_last ......... False\n", + "[2023-10-28 20:48:56,093] [INFO] [config.py:964:print] disable_allgather ............ False\n", + "[2023-10-28 20:48:56,094] [INFO] [config.py:964:print] dump_state ................... False\n", + "[2023-10-28 20:48:56,095] [INFO] [config.py:964:print] dynamic_loss_scale_args ...... None\n", + "[2023-10-28 20:48:56,095] [INFO] [config.py:964:print] eigenvalue_enabled ........... False\n", + "[2023-10-28 20:48:56,096] [INFO] [config.py:964:print] eigenvalue_gas_boundary_resolution 1\n", + "[2023-10-28 20:48:56,097] [INFO] [config.py:964:print] eigenvalue_layer_name ........ bert.encoder.layer\n", + "[2023-10-28 20:48:56,097] [INFO] [config.py:964:print] eigenvalue_layer_num ......... 0\n", + "[2023-10-28 20:48:56,098] [INFO] [config.py:964:print] eigenvalue_max_iter .......... 100\n", + "[2023-10-28 20:48:56,099] [INFO] [config.py:964:print] eigenvalue_stability ......... 1e-06\n", + "[2023-10-28 20:48:56,099] [INFO] [config.py:964:print] eigenvalue_tol ............... 0.01\n", + "[2023-10-28 20:48:56,100] [INFO] [config.py:964:print] eigenvalue_verbose ........... False\n", + "[2023-10-28 20:48:56,100] [INFO] [config.py:964:print] elasticity_enabled ........... False\n", + "[2023-10-28 20:48:56,101] [INFO] [config.py:964:print] flops_profiler_config ........ {\n", + " \"enabled\": false, \n", + " \"recompute_fwd_factor\": 0.0, \n", + " \"profile_step\": 1, \n", + " \"module_depth\": -1, \n", + " \"top_modules\": 1, \n", + " \"detailed\": true, \n", + " \"output_file\": null\n", + "}\n", + "[2023-10-28 20:48:56,102] [INFO] [config.py:964:print] fp16_auto_cast ............... False\n", + "[2023-10-28 20:48:56,103] [INFO] [config.py:964:print] fp16_enabled ................. True\n", + "[2023-10-28 20:48:56,103] [INFO] [config.py:964:print] fp16_master_weights_and_gradients False\n", + "[2023-10-28 20:48:56,104] [INFO] [config.py:964:print] global_rank .................. 0\n", + "[2023-10-28 20:48:56,105] [INFO] [config.py:964:print] grad_accum_dtype ............. None\n", + "[2023-10-28 20:48:56,105] [INFO] [config.py:964:print] gradient_accumulation_steps .. 1\n", + "[2023-10-28 20:48:56,106] [INFO] [config.py:964:print] gradient_clipping ............ 1.0\n", + "[2023-10-28 20:48:56,107] [INFO] [config.py:964:print] gradient_predivide_factor .... 1.0\n", + "[2023-10-28 20:48:56,107] [INFO] [config.py:964:print] hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8\n", + "[2023-10-28 20:48:56,108] [INFO] [config.py:964:print] initial_dynamic_scale ........ 65536\n", + "[2023-10-28 20:48:56,109] [INFO] [config.py:964:print] load_universal_checkpoint .... False\n", + "[2023-10-28 20:48:56,109] [INFO] [config.py:964:print] loss_scale ................... 0\n", + "[2023-10-28 20:48:56,110] [INFO] [config.py:964:print] memory_breakdown ............. False\n", + "[2023-10-28 20:48:56,111] [INFO] [config.py:964:print] mics_hierarchial_params_gather False\n", + "[2023-10-28 20:48:56,111] [INFO] [config.py:964:print] mics_shard_size .............. -1\n", + "[2023-10-28 20:48:56,112] [INFO] [config.py:964:print] monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False\n", + "[2023-10-28 20:48:56,113] [INFO] [config.py:964:print] nebula_config ................ {\n", + " \"enabled\": false, \n", + " \"persistent_storage_path\": null, \n", + " \"persistent_time_interval\": 100, \n", + " \"num_of_version_in_retention\": 2, \n", + " \"enable_nebula_load\": true, \n", + " \"load_path\": null\n", + "}\n", + "[2023-10-28 20:48:56,113] [INFO] [config.py:964:print] optimizer_legacy_fusion ...... False\n", + "[2023-10-28 20:48:56,114] [INFO] [config.py:964:print] optimizer_name ............... None\n", + "[2023-10-28 20:48:56,115] [INFO] [config.py:964:print] optimizer_params ............. None\n", + "[2023-10-28 20:48:56,115] [INFO] [config.py:964:print] pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}\n", + "[2023-10-28 20:48:56,116] [INFO] [config.py:964:print] pld_enabled .................. False\n", + "[2023-10-28 20:48:56,117] [INFO] [config.py:964:print] pld_params ................... False\n", + "[2023-10-28 20:48:56,117] [INFO] [config.py:964:print] prescale_gradients ........... False\n", + "[2023-10-28 20:48:56,118] [INFO] [config.py:964:print] scheduler_name ............... None\n", + "[2023-10-28 20:48:56,119] [INFO] [config.py:964:print] scheduler_params ............. None\n", + "[2023-10-28 20:48:56,119] [INFO] [config.py:964:print] sparse_attention ............. None\n", + "[2023-10-28 20:48:56,120] [INFO] [config.py:964:print] sparse_gradients_enabled ..... False\n", + "[2023-10-28 20:48:56,121] [INFO] [config.py:964:print] steps_per_print .............. inf\n", + "[2023-10-28 20:48:56,121] [INFO] [config.py:964:print] train_batch_size ............. 128\n", + "[2023-10-28 20:48:56,122] [INFO] [config.py:964:print] train_micro_batch_size_per_gpu 128\n", + "[2023-10-28 20:48:56,123] [INFO] [config.py:964:print] use_node_local_storage ....... False\n", + "[2023-10-28 20:48:56,123] [INFO] [config.py:964:print] wall_clock_breakdown ......... False\n", + "[2023-10-28 20:48:56,124] [INFO] [config.py:964:print] world_size ................... 1\n", + "[2023-10-28 20:48:56,125] [INFO] [config.py:964:print] zero_allow_untested_optimizer True\n", + "[2023-10-28 20:48:56,125] [INFO] [config.py:964:print] zero_config .................. stage=2 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=10000000 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=False load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=5, buffer_size=4000000000, max_in_cpu=1,000,000,000, pin_memory=True) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='none', nvme_path=PosixPath('/scratch'), buffer_count=4, pin_memory=True, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1000000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=10000000 param_persistence_threshold=100000 model_persistence_threshold=sys.maxsize max_live_parameters=1000000000 max_reuse_distance=1000000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False mics_shard_size=-1 mics_hierarchical_params_gather=False memory_efficient_linear=True\n", + "[2023-10-28 20:48:56,126] [INFO] [config.py:964:print] zero_enabled ................. True\n", + "[2023-10-28 20:48:56,127] [INFO] [config.py:964:print] zero_force_ds_cpu_optimizer .. True\n", + "[2023-10-28 20:48:56,127] [INFO] [config.py:964:print] zero_optimization_stage ...... 2\n", + "[2023-10-28 20:48:56,128] [INFO] [config.py:950:print_user_config] json = {\n", + " \"bf16\": {\n", + " \"enabled\": false\n", + " }, \n", + " \"fp16\": {\n", + " \"enabled\": true\n", + " }, \n", + " \"zero_optimization\": {\n", + " \"stage\": 2, \n", + " \"contiguous_gradients\": true, \n", + " \"stage3_gather_16bit_weights_on_model_save\": true, \n", + " \"stage3_max_live_parameters\": 1.000000e+09, \n", + " \"stage3_max_reuse_distance\": 1.000000e+09, \n", + " \"stage3_prefetch_bucket_size\": 1.000000e+07, \n", + " \"stage3_param_persistence_threshold\": 1.000000e+05, \n", + " \"reduce_bucket_size\": 1.000000e+07, \n", + " \"sub_group_size\": 1.000000e+09, \n", + " \"offload_optimizer\": {\n", + " \"device\": \"none\", \n", + " \"nvme_path\": \"/scratch\", \n", + " \"pin_memory\": true\n", + " }, \n", + " \"offload_param\": {\n", + " \"device\": \"none\", \n", + " \"nvme_path\": \"/scratch\", \n", + " \"buffer_size\": 4.000000e+09, \n", + " \"pin_memory\": true\n", + " }\n", + " }, \n", + " \"aio\": {\n", + " \"block_size\": 2.621440e+07, \n", + " \"queue_depth\": 64, \n", + " \"thread_count\": 1, \n", + " \"single_submit\": false, \n", + " \"overlap_events\": true\n", + " }, \n", + " \"gradient_accumulation_steps\": 1, \n", + " \"gradient_clipping\": 1.0, \n", + " \"steps_per_print\": inf, \n", + " \"train_batch_size\": 128, \n", + " \"train_micro_batch_size_per_gpu\": 128, \n", + " \"wall_clock_breakdown\": false, \n", + " \"zero_allow_untested_optimizer\": true\n", + "}\n" + ] + } + ], + "source": [ + "model, optimizer, train_dl, lr_scheduler = accelerator.prepare(\n", + "model, optimizer, train_dl, lr_scheduler\n", + ")\n", + "# leaving out test_dl since we will only have local_rank 0 device do evals" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "469e6313-425f-45ed-875a-ecd5df343e31", + "metadata": {}, + "outputs": [], + "source": [ + "def add_saturation(image, alpha=2):\n", + " gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]\n", + " gray_image = gray_image.unsqueeze(1).expand_as(image)\n", + " saturated_image = alpha * image + (1 - alpha) * gray_image\n", + " return torch.clamp(saturated_image, 0, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "60be0d5f-3e94-4612-9373-61b53d836393", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0qiAxQoaKN_interactive_bsl starting with epoch 0 / 12\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/12 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", + " in <module>:119 \n", + " \n", + " 116 │ │ │ │ │ # utils.check_loss(pixcorr) \n", + " 117 │ │ │ \n", + " 118 │ │ │ utils.check_loss(loss) \n", + " 119 │ │ │ accelerator.backward(loss) \n", + " 120 │ │ │ optimizer.step() \n", + " 121 │ │ │ \n", + " 122 │ │ │ losses.append(loss.item()) \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/accelerator. \n", + " py:1815 in backward \n", + " \n", + " 1812 │ │ │ # deepspeed handles loss scaling by gradient_accumulation_steps in its `back \n", + " 1813 │ │ │ loss = loss / self.gradient_accumulation_steps \n", + " 1814 │ │ if self.distributed_type == DistributedType.DEEPSPEED: \n", + " 1815 │ │ │ self.deepspeed_engine_wrapped.backward(loss, **kwargs) \n", + " 1816 │ │ elif self.distributed_type == DistributedType.MEGATRON_LM: \n", + " 1817 │ │ │ return \n", + " 1818 │ │ elif self.scaler is not None: \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/utils/deepsp \n", + " eed.py:176 in backward \n", + " \n", + " 173 │ │ # - zero grad \n", + " 174 │ │ # - checking overflow \n", + " 175 │ │ # - lr_scheduler step (only if engine.lr_scheduler is not None) \n", + " 176 │ │ self.engine.step() \n", + " 177 │ │ # and this plugin overrides the above calls with no-ops when Accelerate runs und \n", + " 178 │ │ # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabli \n", + " 179 │ │ # training loop that works transparently under many training regimes. \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/engin \n", + " e.py:2053 in step \n", + " \n", + " 2050 │ │ │ │ │ and self.quantizer.any_precision_switch()): \n", + " 2051 │ │ │ │ self._take_model_step(lr_kwargs, self.block_eigenvalue) \n", + " 2052 │ │ │ else: \n", + " 2053 │ │ │ │ self._take_model_step(lr_kwargs) \n", + " 2054 │ │ │ \n", + " 2055 │ │ │ report_progress = self.global_rank == 0 if self.global_rank else True \n", + " 2056 \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/engin \n", + " e.py:1960 in _take_model_step \n", + " \n", + " 1957 │ │ │ │ # https://nvidia.github.io/apex/advanced.html#gradient-clipping \n", + " 1958 │ │ │ │ master_params = amp.master_params(self.optimizer) \n", + " 1959 │ │ │ │ clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clippin \n", + " 1960 │ │ self.optimizer.step() \n", + " 1961 │ │ \n", + " 1962 │ │ if hasattr(self.optimizer, '_global_grad_norm'): \n", + " 1963 │ │ │ self._global_grad_norm = self.optimizer._global_grad_norm \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/ \n", + " stage_1_and_2.py:1733 in step \n", + " \n", + " 1730 │ │ │ │ \n", + " 1731 │ │ │ │ # Step 3:- run the optimizer if no offloading \n", + " 1732 │ │ │ │ self.start_timers([OPTIMIZER_STEP]) \n", + " 1733 │ │ │ │ self._optimizer_step(i) \n", + " 1734 │ │ │ │ # Step 4:- get rid of the fp32 gradients. Not needed anymore \n", + " 1735 │ │ │ │ self.single_partition_of_fp32_groups[i].grad = None \n", + " 1736 │ │ │ │ del single_grad_partition \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/ \n", + " stage_1_and_2.py:1638 in _optimizer_step \n", + " \n", + " 1635 │ │ # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no) \n", + " 1636 │ │ #else: \n", + " 1637 │ │ # self.optimizer.step() \n", + " 1638 │ │ self.optimizer.step() \n", + " 1639 │ │ self.optimizer.param_groups = original_param_groups \n", + " 1640 \n", + " 1641 def step(self, closure=None): \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/lr_schedule \n", + " r.py:69 in wrapper \n", + " \n", + " 66 │ │ │ │ instance = instance_ref() \n", + " 67 │ │ │ │ instance._step_count += 1 \n", + " 68 │ │ │ │ wrapped = func.__get__(instance, cls) \n", + " 69 │ │ │ │ return wrapped(*args, **kwargs) \n", + " 70 │ │ �� \n", + " 71 │ │ │ # Note that the returned function here is no longer a bound method, \n", + " 72 │ │ │ # so attributes like `__func__` and `__self__` no longer exist. \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/optimizer.p \n", + " y:280 in wrapper \n", + " \n", + " 277 │ │ │ │ │ │ │ raise RuntimeError(f\"{func} must return None or a tuple of ( \n", + " 278 │ │ │ │ │ │ │ │ │ │ │ f\"but got {result}.\") \n", + " 279 │ │ │ │ \n", + " 280 │ │ │ │ out = func(*args, **kwargs) \n", + " 281 │ │ │ │ self._optimizer_step_code() \n", + " 282 │ │ │ │ \n", + " 283 │ │ │ │ # call optimizer step post hooks \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/optimizer.p \n", + " y:33 in _use_grad \n", + " \n", + " 30 │ │ prev_grad = torch.is_grad_enabled() \n", + " 31 │ │ try: \n", + " 32 │ │ │ torch.set_grad_enabled(self.defaults['differentiable']) \n", + " 33 │ │ │ ret = func(self, *args, **kwargs) \n", + " 34 │ │ finally: \n", + " 35 │ │ │ torch.set_grad_enabled(prev_grad) \n", + " 36 │ │ return ret \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/adamw.py:17 \n", + " 1 in step \n", + " \n", + " 168 │ │ │ │ state_steps, \n", + " 169 │ │ │ ) \n", + " 170 │ │ │ \n", + " 171 │ │ │ adamw( \n", + " 172 │ │ │ │ params_with_grad, \n", + " 173 │ │ │ │ grads, \n", + " 174 │ │ │ │ exp_avgs, \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/adamw.py:32 \n", + " 1 in adamw \n", + " \n", + " 318 else: \n", + " 319 │ │ func = _single_tensor_adamw \n", + " 320 \n", + " 321 func( \n", + " 322 │ │ params, \n", + " 323 │ │ grads, \n", + " 324 │ │ exp_avgs, \n", + " \n", + " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/adamw.py:56 \n", + " 6 in _multi_tensor_adamw \n", + " \n", + " 563 │ │ │ else: \n", + " 564 │ │ │ │ exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) \n", + " 565 │ │ │ │ torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) \n", + " 566 │ │ │ │ denom = torch._foreach_add(exp_avg_sq_sqrt, eps) \n", + " 567 │ │ │ \n", + " 568 │ │ │ torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size) \n", + " 569 \n", + "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "OutOfMemoryError: CUDA out of memory. Tried to allocate 3.54 GiB (GPU 0; 39.56 GiB total capacity; 26.52 GiB \n", + "already allocated; 2.00 GiB free; 35.94 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory\n", + "try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", + "PYTORCH_CUDA_ALLOC_CONF\n", + "\n" + ], + "text/plain": [ + "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", + "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m119\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m116 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[2m# utils.check_loss(pixcorr)\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m117 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m118 \u001b[0m\u001b[2m│ │ │ \u001b[0mutils.check_loss(loss) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m119 \u001b[2m│ │ │ \u001b[0maccelerator.backward(loss) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m120 \u001b[0m\u001b[2m│ │ │ \u001b[0moptimizer.step() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m121 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m122 \u001b[0m\u001b[2m│ │ │ \u001b[0mlosses.append(loss.item()) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/\u001b[0m\u001b[1;33maccelerator.\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33mpy\u001b[0m:\u001b[94m1815\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1812 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# deepspeed handles loss scaling by gradient_accumulation_steps in its `back\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1813 \u001b[0m\u001b[2m│ │ │ \u001b[0mloss = loss / \u001b[96mself\u001b[0m.gradient_accumulation_steps \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1814 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.distributed_type == DistributedType.DEEPSPEED: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1815 \u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m.deepspeed_engine_wrapped.backward(loss, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1816 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melif\u001b[0m \u001b[96mself\u001b[0m.distributed_type == DistributedType.MEGATRON_LM: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1817 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1818 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melif\u001b[0m \u001b[96mself\u001b[0m.scaler \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/utils/\u001b[0m\u001b[1;33mdeepsp\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33meed.py\u001b[0m:\u001b[94m176\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m173 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# - zero grad\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# - checking overflow\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m175 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# - lr_scheduler step (only if engine.lr_scheduler is not None)\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m176 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.engine.step() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m177 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# and this plugin overrides the above calls with no-ops when Accelerate runs und\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m178 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabli\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m179 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# training loop that works transparently under many training regimes.\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/\u001b[0m\u001b[1;33mengin\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33me.py\u001b[0m:\u001b[94m2053\u001b[0m in \u001b[92mstep\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2050 \u001b[0m\u001b[2m│ │ │ │ │ \u001b[0m\u001b[95mand\u001b[0m \u001b[96mself\u001b[0m.quantizer.any_precision_switch()): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2051 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._take_model_step(lr_kwargs, \u001b[96mself\u001b[0m.block_eigenvalue) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2052 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2053 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._take_model_step(lr_kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2054 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2055 \u001b[0m\u001b[2m│ │ │ \u001b[0mreport_progress = \u001b[96mself\u001b[0m.global_rank == \u001b[94m0\u001b[0m \u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.global_rank \u001b[94melse\u001b[0m \u001b[94mTrue\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m2056 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/\u001b[0m\u001b[1;33mengin\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33me.py\u001b[0m:\u001b[94m1960\u001b[0m in \u001b[92m_take_model_step\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1957 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# https://nvidia.github.io/apex/advanced.html#gradient-clipping\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1958 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mmaster_params = amp.master_params(\u001b[96mself\u001b[0m.optimizer) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1959 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mclip_grad_norm_(parameters=master_params, max_norm=\u001b[96mself\u001b[0m.gradient_clippin \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1960 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.step() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1961 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1962 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mhasattr\u001b[0m(\u001b[96mself\u001b[0m.optimizer, \u001b[33m'\u001b[0m\u001b[33m_global_grad_norm\u001b[0m\u001b[33m'\u001b[0m): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1963 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m._global_grad_norm = \u001b[96mself\u001b[0m.optimizer._global_grad_norm \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33mstage_1_and_2.py\u001b[0m:\u001b[94m1733\u001b[0m in \u001b[92mstep\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1730 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1731 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# Step 3:- run the optimizer if no offloading\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1732 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.start_timers([OPTIMIZER_STEP]) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1733 \u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._optimizer_step(i) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1734 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# Step 4:- get rid of the fp32 gradients. Not needed anymore\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1735 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m.single_partition_of_fp32_groups[i].grad = \u001b[94mNone\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1736 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[94mdel\u001b[0m single_grad_partition \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/zero/\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33mstage_1_and_2.py\u001b[0m:\u001b[94m1638\u001b[0m in \u001b[92m_optimizer_step\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1635 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1636 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m#else:\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1637 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# self.optimizer.step()\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1638 \u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.step() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1639 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.param_groups = original_param_groups \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1640 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m1641 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mstep\u001b[0m(\u001b[96mself\u001b[0m, closure=\u001b[94mNone\u001b[0m): \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33mlr_schedule\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33mr.py\u001b[0m:\u001b[94m69\u001b[0m in \u001b[92mwrapper\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 66 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minstance = instance_ref() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 67 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minstance._step_count += \u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 68 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mwrapped = func.\u001b[92m__get__\u001b[0m(instance, \u001b[96mcls\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 69 \u001b[2m│ │ │ │ \u001b[0m\u001b[94mreturn\u001b[0m wrapped(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 70 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 71 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# Note that the returned function here is no longer a bound method,\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 72 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# so attributes like `__func__` and `__self__` no longer exist.\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33moptimizer.p\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33my\u001b[0m:\u001b[94m280\u001b[0m in \u001b[92mwrapper\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m277 \u001b[0m\u001b[2m│ │ │ │ │ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mRuntimeError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33m{\u001b[0mfunc\u001b[33m}\u001b[0m\u001b[33m must return None or a tuple of (\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m278 \u001b[0m\u001b[2m│ │ │ │ │ │ │ │ │ │ │ \u001b[0m\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mbut got \u001b[0m\u001b[33m{\u001b[0mresult\u001b[33m}\u001b[0m\u001b[33m.\u001b[0m\u001b[33m\"\u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m279 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m280 \u001b[2m│ │ │ │ \u001b[0mout = func(*args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m281 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[96mself\u001b[0m._optimizer_step_code() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m282 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m283 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[2m# call optimizer step post hooks\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33moptimizer.p\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[1;33my\u001b[0m:\u001b[94m33\u001b[0m in \u001b[92m_use_grad\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 30 \u001b[0m\u001b[2m│ │ \u001b[0mprev_grad = torch.is_grad_enabled() \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 31 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mtry\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 32 \u001b[0m\u001b[2m│ │ │ \u001b[0mtorch.set_grad_enabled(\u001b[96mself\u001b[0m.defaults[\u001b[33m'\u001b[0m\u001b[33mdifferentiable\u001b[0m\u001b[33m'\u001b[0m]) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 33 \u001b[2m│ │ │ \u001b[0mret = func(\u001b[96mself\u001b[0m, *args, **kwargs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 34 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfinally\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 35 \u001b[0m\u001b[2m│ │ │ \u001b[0mtorch.set_grad_enabled(prev_grad) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m 36 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m ret \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33madamw.py\u001b[0m:\u001b[94m17\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[94m1\u001b[0m in \u001b[92mstep\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m168 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mstate_steps, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m169 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m170 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m171 \u001b[2m│ │ │ \u001b[0madamw( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m172 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mparams_with_grad, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m173 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mgrads, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m174 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mexp_avgs, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33madamw.py\u001b[0m:\u001b[94m32\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[94m1\u001b[0m in \u001b[92madamw\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m318 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m319 \u001b[0m\u001b[2m│ │ \u001b[0mfunc = _single_tensor_adamw \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m320 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m321 \u001b[2m│ \u001b[0mfunc( \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m322 \u001b[0m\u001b[2m│ │ \u001b[0mparams, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m323 \u001b[0m\u001b[2m│ │ \u001b[0mgrads, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m324 \u001b[0m\u001b[2m│ │ \u001b[0mexp_avgs, \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/optim/\u001b[0m\u001b[1;33madamw.py\u001b[0m:\u001b[94m56\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[94m6\u001b[0m in \u001b[92m_multi_tensor_adamw\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m563 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m564 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mexp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m565 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mtorch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m566 \u001b[2m│ │ │ │ \u001b[0mdenom = torch._foreach_add(exp_avg_sq_sqrt, eps) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m567 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m568 \u001b[0m\u001b[2m│ │ │ \u001b[0mtorch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size) \u001b[31m│\u001b[0m\n", + "\u001b[31m│\u001b[0m \u001b[2m569 \u001b[0m \u001b[31m│\u001b[0m\n", + "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", + "\u001b[1;91mOutOfMemoryError: \u001b[0mCUDA out of memory. Tried to allocate \u001b[1;36m3.54\u001b[0m GiB \u001b[1m(\u001b[0mGPU \u001b[1;36m0\u001b[0m; \u001b[1;36m39.56\u001b[0m GiB total capacity; \u001b[1;36m26.52\u001b[0m GiB \n", + "already allocated; \u001b[1;36m2.00\u001b[0m GiB free; \u001b[1;36m35.94\u001b[0m GiB reserved in total by PyTorch\u001b[1m)\u001b[0m If reserved memory is >> allocated memory\n", + "try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", + "PYTORCH_CUDA_ALLOC_CONF\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n", + "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n", + "test_image, test_voxel = None, None\n", + "mse = nn.MSELoss()\n", + "l1 = nn.L1Loss()\n", + "\n", + "for epoch in progress_bar:\n", + " model.train()\n", + " \n", + " fwd_percent_correct = 0.\n", + " bwd_percent_correct = 0.\n", + " test_fwd_percent_correct = 0.\n", + " test_bwd_percent_correct = 0.\n", + "\n", + " loss_clip_total = 0.\n", + " loss_blurry_total = 0.\n", + " loss_depth_total = 0.\n", + " test_loss_clip_total = 0.\n", + " test_loss_blurry_total = 0.\n", + " test_loss_depth_total = 0.\n", + "\n", + " blurry_pixcorr = 0.\n", + " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n", + " \n", + " for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", + " with torch.cuda.amp.autocast(dtype=data_type):\n", + " optimizer.zero_grad()\n", + " \n", + " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", + " image = images[behav[:,0,0].cpu().long()].to(device).float()\n", + "\n", + " for past in range(1):\n", + " past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)\n", + " \n", + " if blurry_recon:\n", + " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", + " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", + "\n", + " if depth_recon:\n", + " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", + " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", + " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", + " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", + " \n", + " if use_image_aug: \n", + " image = img_augment(image)\n", + " \n", + " clip_target = clip_model.embed_image(image)\n", + " assert not torch.any(torch.isnan(clip_target))\n", + " \n", + " if epoch < int(mixup_pct * num_epochs):\n", + " voxel, perm, betas, select = utils.mixco(voxel)\n", + " past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)\n", + " \n", + " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", + " \n", + " # past_voxel_ridge = model.ridge(past_voxel)\n", + " # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)\n", + " \n", + " clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)\n", + " \n", + " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", + " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", + " \n", + " if epoch < int(mixup_pct * num_epochs): \n", + " loss_clip = utils.mixco_nce(\n", + " clip_voxels_norm,\n", + " clip_target_norm,\n", + " temp=.006, \n", + " perm=perm, betas=betas, select=select)\n", + " else:\n", + " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n", + " loss_clip = utils.soft_clip_loss(\n", + " clip_voxels_norm,\n", + " clip_target_norm,\n", + " temp=epoch_temp)\n", + "\n", + " loss_clip_total += loss_clip.item()\n", + " loss_clip *= clip_scale\n", + " loss = loss_clip\n", + " \n", + " if blurry_recon:\n", + " downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)\n", + " re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))\n", + " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", + " \n", + " loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))\n", + " loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))\n", + " loss_blurry_total += loss_blurry.item()\n", + " loss_blurry *= blur_scale\n", + " loss += loss_blurry\n", + "\n", + " if depth_recon:\n", + " loss_depth = l1(depth_image_enc_, depth_image_enc)\n", + " # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))\n", + " loss_depth_total += loss_depth.item()\n", + " loss_depth *= depth_scale\n", + " loss += loss_depth\n", + " \n", + " # forward and backward top 1 accuracy \n", + " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", + " fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()\n", + " bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()\n", + " \n", + " if blurry_recon:\n", + " with torch.no_grad():\n", + " # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()\n", + " random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)\n", + " # random_samps = np.arange(batch_size//5)\n", + " blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)\n", + " # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean\n", + " pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)\n", + " # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)\n", + " # loss += (1 - pixcorr)\n", + " blurry_pixcorr += pixcorr.item()\n", + " # utils.check_loss(pixcorr)\n", + "\n", + " utils.check_loss(loss)\n", + " accelerator.backward(loss)\n", + " optimizer.step()\n", + " \n", + " losses.append(loss.item())\n", + " lrs.append(optimizer.param_groups[0]['lr'])\n", + " \n", + " if lr_scheduler_type is not None:\n", + " lr_scheduler.step()\n", + "\n", + " model.eval()\n", + " if local_rank==0:\n", + " with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type): \n", + " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl): \n", + " # all test samples should be loaded per batch such that test_i should never exceed 0\n", + " assert len(behav) == num_test\n", + " \n", + " ## Average same-image repeats ##\n", + " if test_image is None:\n", + " voxel = voxels[behav[:,0,5].cpu().long()]\n", + " image = behav[:,0,0].cpu().long()\n", + " \n", + " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n", + " for im in unique_image:\n", + " locs = torch.where(im == image)[0]\n", + " if test_image is None:\n", + " test_image = images[im][None]\n", + " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n", + " else:\n", + " test_image = torch.vstack((test_image, images[im][None]))\n", + " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n", + " \n", + " # random sample of 300\n", + " random_indices = torch.arange(len(test_voxel))[:300]\n", + " voxel = test_voxel[random_indices].to(device)\n", + " image = test_image[random_indices].to(device)\n", + " assert len(image) == 300\n", + "\n", + " if blurry_recon:\n", + " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", + " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", + "\n", + " if depth_recon:\n", + " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", + " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", + " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", + " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", + " \n", + " clip_target = clip_model.embed_image(image.float())\n", + " \n", + " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", + "\n", + " # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)\n", + " \n", + " clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)\n", + " \n", + " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", + " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", + " \n", + " loss_clip = utils.soft_clip_loss(\n", + " clip_voxels_norm,\n", + " clip_target_norm,\n", + " temp=.006)\n", + " test_loss_clip_total += loss_clip.item()\n", + " loss_clip = loss_clip * clip_scale\n", + " loss = loss_clip\n", + "\n", + " if blurry_recon:\n", + " downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)\n", + " re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))\n", + " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", + " \n", + " loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))\n", + " loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))\n", + " test_loss_blurry_total += loss_blurry.item()\n", + " loss_blurry *= blur_scale\n", + " loss += loss_blurry\n", + " \n", + " # halving the batch size because the decoder is computationally heavy\n", + " blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)\n", + " blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", + " pixcorr = utils.pixcorr(image, blurry_recon_images)\n", + " loss += (1 - pixcorr)\n", + " test_blurry_pixcorr += pixcorr.item()\n", + "\n", + " if depth_recon:\n", + " loss_depth = l1(depth_image_enc_, depth_image_enc)\n", + " # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))\n", + " test_loss_depth_total += loss_depth.item()\n", + " loss_depth *= depth_scale\n", + " loss += loss_depth\n", + " \n", + " # forward and backward top 1 accuracy \n", + " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", + " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()\n", + " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()\n", + "\n", + " utils.check_loss(loss) \n", + " test_losses.append(loss.item())\n", + "\n", + " # if utils.is_interactive(): clear_output(wait=True)\n", + " print(\"---\")\n", + " \n", + " assert (test_i+1) == 1\n", + " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n", + " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n", + " \"train/lr\": lrs[-1],\n", + " \"train/num_steps\": len(losses),\n", + " \"test/num_steps\": len(test_losses),\n", + " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n", + " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n", + " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n", + " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n", + " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n", + " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n", + " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n", + " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n", + " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n", + " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n", + " \"train/loss_depth_total\": loss_depth_total / (train_i + 1),\n", + " \"test/loss_depth_total\": test_loss_depth_total / (test_i + 1),\n", + " }\n", + " \n", + " if blurry_recon: \n", + " # transform blurry recon latents to images and plot it\n", + " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", + " jj=-1\n", + " for j in [0,1,2,3]:\n", + " jj+=1\n", + " axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", + " axes[jj].axis('off')\n", + " jj+=1\n", + " axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", + " axes[jj].axis('off')\n", + " \n", + " if wandb_log:\n", + " logs[f\"test/recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\")\n", + " plt.close()\n", + " else:\n", + " plt.show()\n", + "\n", + " if depth_recon:\n", + " # transform blurry recon latents to images and plot it\n", + " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", + " # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", + " # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))\n", + " jj=-1\n", + " for j in [0,1,2,3]:\n", + " jj+=1\n", + " axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))\n", + " axes[jj].axis('off')\n", + " jj+=1\n", + " axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))\n", + " axes[jj].axis('off')\n", + " if wandb_log:\n", + " logs[f\"test/depth_recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\")\n", + " plt.close()\n", + " else:\n", + " plt.show()\n", + " \n", + " progress_bar.set_postfix(**logs)\n", + " \n", + " # Save model checkpoint and reconstruct\n", + " if epoch % ckpt_interval == 0:\n", + " if not utils.is_interactive():\n", + " save_ckpt(f'last')\n", + " \n", + " if wandb_log: wandb.log(logs)\n", + "\n", + " # wait for other GPUs to catch up if needed\n", + " accelerator.wait_for_everyone()\n", + " torch.cuda.empty_cache()\n", + " gc.collect()\n", + "\n", + "print(\"\\n===Finished!===\\n\")\n", + "if ckpt_saving:\n", + " save_ckpt(f'last')\n", + "if not utils.is_interactive():\n", + " sys.exit(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "35cc1be7-bf76-4ad1-8c6a-de52bd013bf4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sat Oct 28 21:13:17 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 33C P0 50W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 30C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 34C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 30C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 36C P0 53W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 35C P0 72W / 400W | 38467MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 33C P0 50W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 31C P0 51W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 5 N/A N/A 1896724 C ...3/envs/mindeye/bin/python 38464MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a", + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(losses)\n", + "plt.show()\n", + "plt.plot(test_losses)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "f2690877-a431-44e8-a2ca-61f4b7397070", + "metadata": {}, + "source": [ + "# Retrieve nearest neighbor in the training set using test set data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b6b8feb-391d-437e-a5d9-a2088f1b1149", + "metadata": {}, + "outputs": [], + "source": [ + "annots = np.load(\"/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "612ac5aa-6f0f-45ad-809e-03df905d184c", + "metadata": {}, + "outputs": [], + "source": [ + "ii=2\n", + "all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))\n", + "with torch.no_grad(), torch.cuda.amp.autocast():\n", + " for batch in tqdm(range(0,len(all_indices),512)):\n", + " if batch==0:\n", + " clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()\n", + " else:\n", + " target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()\n", + " clip_target = torch.vstack((clip_target,target))\n", + " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", + "\n", + " voxel = test_voxel[[ii]].to(device)\n", + " image = test_image[[ii]].to(device)\n", + "\n", + " print(\"Original Image (test set)\")\n", + " display(utils.torch_to_Image(image))\n", + " \n", + " clip_target = clip_model.embed_image(image).cpu()\n", + " # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))\n", + " \n", + " voxel_ridge = model.ridge(voxel).unsqueeze(1)\n", + " clip_voxels, _, _ = model.backbone(voxel_ridge) \n", + " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", + " clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", + "\n", + " print(\"clip_voxels_norm\", clip_voxels_norm.shape)\n", + " print(\"clip_target_norm\", clip_target_norm.shape)\n", + " \n", + " sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(), \n", + " clip_target_norm).flatten()).flip(0)\n", + " picks = all_indices[sortt[:5]]\n", + "\n", + " print(\"\\nNearest neighbors in training set\")\n", + " for ip,p in enumerate(picks):\n", + " display(utils.torch_to_Image(images[[p]]))\n", + " # print(utils.select_annotations([annots[int(p)]]))\n", + " if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]\n", + "\n", + "print(\"\\n=====\\npredicted_caption:\\n\", predicted_caption)" + ] + }, + { + "cell_type": "markdown", + "id": "1473ddaa-5f2b-4448-9194-c7b0801d05db", + "metadata": {}, + "source": [ + "# Feed into Stable Diffusion XL for reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70e50e0d-c44f-4d56-939a-2943535e1747", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import StableDiffusionXLPipeline\n", + "pipe = StableDiffusionXLPipeline.from_pretrained(\n", + " \"/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n", + ")\n", + "pipe.to(\"cuda\")\n", + "pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "479e6994-3eaa-47d2-89a3-422c464fab36", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = predicted_caption\n", + "recon = pipe(prompt=prompt).images[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9dc48e1b-5842-4a29-963a-6469d943a72c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(\"Seen image\")\n", + "display(utils.torch_to_Image(image))\n", + "\n", + "print(\"Reconstruction\")\n", + "utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + }, + "toc-autonumbering": true, + "vscode": { + "interpreter": { + "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}