{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", "#from subprocess import call\n", "#command = \"jupyter nbconvert Train_MLPMixer-Copy1.ipynb --to python\"\n", "#call(command,shell=True)" ] }, { "cell_type": "markdown", "id": "b0f0f4f3", "metadata": {}, "source": [ "# Import packages & functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bad764b-45c1-45ce-a716-8d055e09821a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2023-11-03 11:37:15,856] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import sys\n", "import json\n", "import argparse\n", "import numpy as np\n", "import math\n", "from einops import rearrange\n", "import time\n", "import random\n", "import string\n", "import h5py\n", "from tqdm import tqdm\n", "\n", "import webdataset as wds\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms\n", "\n", "from accelerate import Accelerator, DeepSpeedPlugin\n", "\n", "# tf32 data type is faster than standard float32\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", "# custom functions #\n", "import utils" ] }, { "cell_type": "code", "execution_count": 3, "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCAL RANK 0\n", "Setting batch_size to 128\n", "[2023-11-03 11:37:18,062] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented\n", "[2023-11-03 11:37:18,062] [INFO] [comm.py:594:init_distributed] cdb=None\n", "[2023-11-03 11:37:18,063] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl\n" ] } ], "source": [ "### Multi-GPU config ###\n", "local_rank = os.getenv('RANK')\n", "if local_rank is None: \n", " local_rank = 0\n", "else:\n", " local_rank = int(local_rank)\n", "print(\"LOCAL RANK \", local_rank) \n", "\n", "num_devices = torch.cuda.device_count()\n", "if num_devices==0: num_devices = 1\n", "\n", "# ## UNCOMMENT BELOW SECTION AND COMMENT OUT DEEPSPEED SECTION TO AVOID USING DEEPSPEED ###\n", "# accelerator = Accelerator(split_batches=False, mixed_precision=\"fp16\")\n", "# global_batch_size = batch_size = 128\n", "# data_type = torch.float16 # change depending on your mixed_precision\n", "\n", "### DEEPSPEED INITIALIZATION ###\n", "if num_devices <= 1 and utils.is_interactive():\n", " global_batch_size = batch_size = 128\n", " print(f\"Setting batch_size to {batch_size}\")\n", " # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", " os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", " os.environ[\"RANK\"] = \"0\"\n", " os.environ[\"LOCAL_RANK\"] = \"0\"\n", " os.environ[\"WORLD_SIZE\"] = \"1\"\n", " os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", "else:\n", " global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"] \n", " batch_size = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", "\n", "# alter the deepspeed config according to your global and local batch size\n", "if local_rank == 0:\n", " with open('deepspeed_config_stage2.json', 'r') as file:\n", " config = json.load(file)\n", " config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", " config['train_micro_batch_size_per_gpu'] = batch_size\n", " config['bf16'] = {'enabled': False}\n", " config['fp16'] = {'enabled': True}\n", " with open('deepspeed_config_stage2.json', 'w') as file:\n", " json.dump(config, file)\n", "else:\n", " # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", " time.sleep(10)\n", "deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2_cpuoffload.json\")\n", "accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PID of this process = 418110\n", "device: cuda:0\n", "Distributed environment: DEEPSPEED Backend: nccl\n", "Num processes: 1\n", "Process index: 0\n", "Local process index: 0\n", "Device: cuda:0\n", "\n", "Mixed precision type: fp16\n", "ds_config: {'bf16': {'enabled': False}, 'fp16': {'enabled': True}, 'zero_optimization': {'stage': 2, 'contiguous_gradients': True, 'stage3_gather_16bit_weights_on_model_save': True, 'stage3_max_live_parameters': 1000000000.0, 'stage3_max_reuse_distance': 1000000000.0, 'stage3_prefetch_bucket_size': 10000000.0, 'stage3_param_persistence_threshold': 100000.0, 'reduce_bucket_size': 10000000.0, 'sub_group_size': 1000000000.0, 'offload_optimizer': {'device': 'cpu', 'nvme_path': '/scratch', 'pin_memory': True}, 'offload_param': {'device': 'none', 'nvme_path': '/scratch', 'buffer_size': 4000000000.0, 'pin_memory': True}}, 'aio': {'block_size': 26214400, 'queue_depth': 32, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}, 'gradient_accumulation_steps': 1, 'gradient_clipping': 1.0, 'steps_per_print': inf, 'train_batch_size': 256, 'train_micro_batch_size_per_gpu': 32, 'wall_clock_breakdown': False, 'zero_allow_untested_optimizer': True}\n", "\n", "distributed = True num_devices = 1 local rank = 0 world size = 1 data_type = torch.float16\n" ] } ], "source": [ "print(\"PID of this process =\",os.getpid())\n", "device = accelerator.device\n", "print(\"device:\",device)\n", "num_workers = num_devices\n", "print(accelerator.state)\n", "world_size = accelerator.state.num_processes\n", "distributed = not accelerator.state.distributed_type == 'NO'\n", "\n", "# set data_type to match your mixed precision (automatically set based on deepspeed config)\n", "if accelerator.mixed_precision == \"bf16\":\n", " data_type = torch.bfloat16\n", "elif accelerator.mixed_precision == \"fp16\":\n", " data_type = torch.float16\n", "else:\n", " data_type = torch.float32\n", "\n", "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size, \"data_type =\", data_type)\n", "print = accelerator.print # only print if local_rank=0" ] }, { "cell_type": "markdown", "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", "metadata": { "tags": [] }, "source": [ "# Configurations" ] }, { "cell_type": "code", "execution_count": 5, "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model_name: CUIEjK9AY8_interactive\n", "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=CUIEjK9AY8_interactive', '--subj=1', '--batch_size=128', '--no-blurry_recon', '--no-depth_recon', '--hidden_dim=4096', '--clip_scale=1.', '--blur_scale=100.', '--depth_scale=100.', '--max_lr=3e-4', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug', '--no-ckpt_saving']\n" ] } ], "source": [ "# if running this interactively, can specify jupyter_args here for argparser to use\n", "if utils.is_interactive():\n", " # create random model_name\n", " model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))\n", " model_name = model_name + \"_interactive\"\n", " print(\"model_name:\", model_name)\n", "\n", " # global_batch_size and batch_size should already be defined in the above cells\n", " # other variables can be specified in the following string:\n", " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", " --model_name={model_name} \\\n", " --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \\\n", " --clip_scale=1. --blur_scale=100. --depth_scale=100. \\\n", " --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving\"\n", "\n", " jupyter_args = jupyter_args.split()\n", " print(jupyter_args)\n", " \n", " from IPython.display import clear_output # function to clear print outputs in cell\n", " %load_ext autoreload \n", " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", " %autoreload 2 " ] }, { "cell_type": "code", "execution_count": 6, "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", "metadata": { "tags": [] }, "outputs": [], "source": [ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", "parser.add_argument(\n", " \"--model_name\", type=str, default=\"testing\",\n", " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", ")\n", "parser.add_argument(\n", " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", " help=\"Path to where NSD data is stored / where to download it to\",\n", ")\n", "parser.add_argument(\n", " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", ")\n", "parser.add_argument(\n", " \"--batch_size\", type=int, default=32,\n", " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"whether to log to wandb\",\n", ")\n", "parser.add_argument(\n", " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"if not using wandb and want to resume from a ckpt\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_project\",type=str,default=\"stability\",\n", " help=\"wandb project name\",\n", ")\n", "parser.add_argument(\n", " \"--mixup_pct\",type=float,default=.33,\n", " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", ")\n", "parser.add_argument(\n", " \"--blurry_recon\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to output blurry reconstructions\",\n", ")\n", "parser.add_argument(\n", " \"--depth_recon\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to output depth reconstructions\",\n", ")\n", "parser.add_argument(\n", " \"--blur_scale\",type=float,default=100.,\n", " help=\"multiply loss from blurry recons by this number\",\n", ")\n", "parser.add_argument(\n", " \"--depth_scale\",type=float,default=100.,\n", " help=\"multiply loss from depth recons by this number\",\n", ")\n", "parser.add_argument(\n", " \"--clip_scale\",type=float,default=1.,\n", " help=\"multiply contrastive loss by this number\",\n", ")\n", "parser.add_argument(\n", " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to use image augmentation\",\n", ")\n", "parser.add_argument(\n", " \"--num_epochs\",type=int,default=120,\n", " help=\"number of epochs of training\",\n", ")\n", "parser.add_argument(\n", " \"--hidden_dim\",type=int,default=4096,\n", ")\n", "parser.add_argument(\n", " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", ")\n", "parser.add_argument(\n", " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", ")\n", "parser.add_argument(\n", " \"--ckpt_interval\",type=int,default=5,\n", " help=\"save backup ckpt and reconstruct every x epochs\",\n", ")\n", "parser.add_argument(\n", " \"--seed\",type=int,default=42,\n", ")\n", "parser.add_argument(\n", " \"--max_lr\",type=float,default=3e-4,\n", ")\n", "\n", "if utils.is_interactive():\n", " args = parser.parse_args(jupyter_args)\n", "else:\n", " args = parser.parse_args()\n", "\n", "# create global variables without the args prefix\n", "for attribute_name in vars(args).keys():\n", " globals()[attribute_name] = getattr(args, attribute_name)" ] }, { "cell_type": "code", "execution_count": 7, "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", "if not os.path.exists(outdir) and ckpt_saving:\n", " os.makedirs(outdir,exist_ok=True)\n", "if use_image_aug:\n", " import kornia\n", " from kornia.augmentation.container import AugmentationSequential\n", " img_augment = AugmentationSequential(\n", " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", " kornia.augmentation.Resize((224, 224)),\n", " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", " kornia.augmentation.RandomGrayscale(p=0.3),\n", " same_on_batch=False,\n", " data_keys=[\"input\"],\n", " )" ] }, { "cell_type": "markdown", "id": "42d13c25-1369-4c49-81d4-83d713586096", "metadata": { "tags": [] }, "source": [ "# Prep data, models, and dataloaders" ] }, { "cell_type": "markdown", "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", "metadata": {}, "source": [ "## Dataloader" ] }, { "cell_type": "code", "execution_count": 8, "id": "81084834-035f-4465-ad59-59e6b806a2f5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" ] } ], "source": [ "if subj==1:\n", " num_train = 24958\n", " num_test = 2770\n", "test_batch_size = num_test\n", "\n", "def my_split_by_node(urls): return urls\n", " \n", "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", "# train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..1}.tar\"\n", "print(train_url)\n", "\n", "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)\n", "\n", "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", "print(test_url)\n", "\n", "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)" ] }, { "cell_type": "markdown", "id": "203b060a-2dd2-4c35-929b-c576be82eb52", "metadata": {}, "source": [ "### check dataloaders are working" ] }, { "cell_type": "code", "execution_count": 9, "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 2770 2770\n", "---\n", "\n", "194 24960 24960\n" ] } ], "source": [ "test_vox_indices = []\n", "test_73k_images = []\n", "for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", " test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())\n", " test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())\n", "test_vox_indices = test_vox_indices.astype(np.int16)\n", "print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))\n", "print(\"---\\n\")\n", "\n", "train_vox_indices = []\n", "train_73k_images = []\n", "for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())\n", " train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())\n", "train_vox_indices = train_vox_indices.astype(np.int16)\n", "print(train_i, (train_i+1) * batch_size, len(train_vox_indices))" ] }, { "cell_type": "markdown", "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", "metadata": {}, "source": [ "## Load data and images" ] }, { "cell_type": "code", "execution_count": 10, "id": "039dd330-7339-4f88-8f00-45f95e47baa0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subj01 betas loaded into memory\n", "voxels torch.Size([27750, 15724])\n", "images torch.Size([73000, 3, 224, 224])\n" ] } ], "source": [ "# load betas\n", "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", "# f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')\n", "\n", "voxels = f['betas'][:]\n", "print(f\"subj0{subj} betas loaded into memory\")\n", "voxels = torch.Tensor(voxels).to(\"cpu\").to(data_type)\n", "print(\"voxels\", voxels.shape)\n", "num_voxels = voxels.shape[-1]\n", "\n", "# load orig images\n", "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", "images = f['images'][:]\n", "images = torch.Tensor(images).to(\"cpu\").to(data_type)\n", "print(\"images\", images.shape)" ] }, { "cell_type": "markdown", "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", "metadata": {}, "source": [ "## Load models" ] }, { "cell_type": "markdown", "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", "metadata": {}, "source": [ "### CLIP image embeddings model" ] }, { "cell_type": "code", "execution_count": 11, "id": "b0420dc0-199e-4c1a-857d-b1747058b467", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "from models import Clipper\n", "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", "clip_seq_dim = 257\n", "clip_emb_dim = 768 #1024\n", "# hidden_dim = 4096\n", "seq_len = 1 #2 #32 " ] }, { "cell_type": "markdown", "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", "metadata": {}, "source": [ "### SD VAE" ] }, { "cell_type": "code", "execution_count": 12, "id": "01baff79-8114-482b-b115-6f05aa8ad691", "metadata": { "tags": [] }, "outputs": [], "source": [ "# if blurry_recon:\n", "# from diffusers import AutoencoderKL\n", "# autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", "# # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", "# autoenc.eval()\n", "# autoenc.requires_grad_(False)\n", "# autoenc.to(device)\n", "# utils.count_params(autoenc)\n", "\n", "if blurry_recon:# or depth_recon:\n", " from diffusers import VQModel\n", " autoenc = VQModel.from_pretrained(\"/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae\", torch_dtype=data_type)\n", " autoenc.eval()\n", " autoenc.requires_grad_(False)\n", " autoenc.to(device)\n", " utils.count_params(autoenc)" ] }, { "cell_type": "markdown", "id": "120c8eee-9834-437d-bb60-b38faef50138", "metadata": {}, "source": [ "#### downsampled images" ] }, { "cell_type": "code", "execution_count": 13, "id": "6d1ba8dd-64c2-4ac9-947e-725b7f2e3e50", "metadata": { "tags": [] }, "outputs": [], "source": [ "if blurry_recon:\n", " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", "\n", " input_batch = images[[30]].to(device)\n", " print(input_batch.shape)\n", "\n", " downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)\n", " re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')\n", " re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215\n", " print(re_upsampled_enc.shape)\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" ] }, { "cell_type": "markdown", "id": "6390a3a8-2bef-4e81-9b82-e154d26a1e1d", "metadata": {}, "source": [ "#### MiDaS depth" ] }, { "cell_type": "code", "execution_count": 14, "id": "f35573e2-95bf-463d-8937-68ad4c2c3c20", "metadata": { "tags": [] }, "outputs": [], "source": [ "if depth_recon:\n", " from controlnet_aux.midas import MidasDetector\n", " \n", " midas_depth = MidasDetector.from_pretrained(\n", " \"valhalla/t2iadapter-aux-models\", filename=\"dpt_large_384.pt\", model_type=\"dpt_large\", cache_dir=\"/fsx/proj-fmri/shared/cache\").to(device)\n", " midas_depth.model.eval()\n", " midas_depth.model.requires_grad_(False)\n", " midas_depth.model.to(device)\n", " pass" ] }, { "cell_type": "code", "execution_count": 15, "id": "ba3f9207-b98e-45da-baa6-5cfcfb2ae958", "metadata": { "tags": [] }, "outputs": [], "source": [ "if depth_recon:\n", " if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))\n", "\n", " input_batch = images[[30,31]].float().to(device)\n", " print(input_batch.shape)\n", " \n", " midas_emb = midas_depth.model(input_batch).unsqueeze(1)\n", " print(midas_emb.shape)\n", "\n", " prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()\n", " print(prediction.shape)\n", " \n", " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", " midas_emb_size = prediction.flatten(1).shape[1]\n", " print(\"midas_emb\", prediction.shape, prediction.min(), prediction.max())\n", " print(\"midas_emb_size\", midas_emb_size)\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224))) \n", "\n", " if blurry_recon:\n", " prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)\n", " prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()\n", " prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215\n", " print(\"vae midas_emb\", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())\n", " \n", " if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))" ] }, { "cell_type": "markdown", "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", "metadata": {}, "source": [ "### MindEye modules" ] }, { "cell_type": "code", "execution_count": 16, "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MindEyeModule()" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MindEyeModule(nn.Module):\n", " def __init__(self):\n", " super(MindEyeModule, self).__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "model = MindEyeModule()\n", "model" ] }, { "cell_type": "code", "execution_count": 17, "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "66,506,752 total\n", "66,506,752 trainable\n", "param counts:\n", "66,506,752 total\n", "66,506,752 trainable\n", "torch.Size([2, 1, 15724]) torch.Size([2, 1, 4096])\n" ] } ], "source": [ "time_embedding_dim = 512\n", "\n", "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.out_features = out_features\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)\n", "utils.count_params(model.ridge)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,1,voxels.shape[1]))\n", "time_emb_test = torch.randn((2,1,time_embedding_dim))\n", "print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)" ] }, { "cell_type": "code", "execution_count": 59, "id": "ee763515-4853-4c5b-9265-ec5aa1bde971", "metadata": { "tags": [] }, "outputs": [], "source": [ "num_past_voxels = 15\n", "seq_len = 1 + 1" ] }, { "cell_type": "code", "execution_count": 73, "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "824,409,280 total\n", "824,409,280 trainable\n", "param counts:\n", "890,923,712 total\n", "890,923,712 trainable\n", "b.shape torch.Size([256, 4, 1024])\n", "torch.Size([256, 257, 768]) torch.Size([1]) torch.Size([1])\n" ] } ], "source": [ "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):\n", " super().__init__()\n", " self.seq_len = seq_len\n", " self.h = h\n", " self.clip_size = clip_size\n", " \n", " # Initial linear layer to match the input dimensions to hidden dimensions\n", " # self.lin0 = nn.Linear(in_dim, seq_len * h)\n", " \n", " # Mixer Blocks\n", " self.mixer_blocks1 = nn.ModuleList([\n", " self.mixer_block1(h, drop) for _ in range(n_blocks)\n", " ])\n", " self.mixer_blocks2 = nn.ModuleList([\n", " self.mixer_block2(seq_len, drop) for _ in range(n_blocks)\n", " ])\n", " \n", " # Output linear layer\n", " self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)\n", "\n", " # low-rank matrices\n", " # self.rank = 500\n", " # self.U = nn.Parameter(torch.randn(self.rank, out_dim))\n", " # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))\n", " # self.S = nn.Parameter(torch.randn(out_dim))\n", "\n", " self.clip_proj = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", "\n", " if blurry_recon:\n", " # self.blin1 = nn.Sequential(\n", " # nn.Linear(out_dim, 4096, bias=True),\n", " # nn.LayerNorm(4096),\n", " # nn.GELU(),\n", " # nn.Linear(4096, 4096))\n", " self.blin1 = nn.Linear(h*seq_len, 4096)\n", " self.bgroupnorm = nn.GroupNorm(1, 256)\n", " self.bupsampler = Decoder(\n", " in_channels=256,\n", " out_channels=128,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[32, 64, 128],\n", " layers_per_block=1,\n", " )\n", "\n", " if depth_recon:\n", " # self.dlin1 = nn.Sequential(\n", " # nn.Linear(h, midas_emb_size),\n", " # nn.Sigmoid(),\n", " # )\n", " self.dlin1 = nn.Linear(h*seq_len, 4096)\n", " self.dgroupnorm = nn.GroupNorm(1, 256)\n", " self.dupsampler = Decoder(\n", " in_channels=256,\n", " out_channels=1,#128,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[32, 64, 128, 256],\n", " layers_per_block=1,\n", " )\n", " \n", " def mixer_block1(self, h, drop):\n", " return nn.Sequential(\n", " nn.LayerNorm(h),\n", " self.mlp(h, h, drop), # Token mixing\n", " )\n", "\n", " def mixer_block2(self, seq_len, drop):\n", " return nn.Sequential(\n", " nn.LayerNorm(seq_len),\n", " self.mlp(seq_len, seq_len, drop) # Channel mixing\n", " )\n", " \n", " def mlp(self, in_dim, out_dim, drop):\n", " return nn.Sequential(\n", " nn.Linear(in_dim, out_dim),\n", " nn.GELU(),\n", " nn.Dropout(drop),\n", " nn.Linear(out_dim, out_dim),\n", " )\n", " \n", " def forward(self, x):\n", " # make empty tensors for blur and depth outputs\n", " b,d = torch.Tensor([0.]), torch.Tensor([0.])\n", " \n", " # Initial linear layer\n", " # x = self.lin0(x)\n", " \n", " # Reshape to seq_len by dim\n", " # x = x.reshape(-1, self.seq_len, self.h)\n", " \n", " # Mixer blocks\n", " residual1 = x\n", " residual2 = x.permute(0,2,1)\n", " for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):\n", " x = block1(x) + residual1\n", " residual1 = x\n", " x = x.permute(0,2,1)\n", " \n", " x = block2(x) + residual2\n", " residual2 = x\n", " x = x.permute(0,2,1)\n", " \n", " # Flatten\n", " x = x.reshape(x.size(0), -1)\n", " \n", " c = self.clin1(x)\n", "\n", " # low rank linear to out dim cuts # params by nearly half compared to full linear mapping\n", " # c = (x @ (self.V/100) @ (self.U/100)) + self.S\n", " \n", " c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))\n", "\n", " if blurry_recon:\n", " b = self.blin1(x)\n", " b = b.reshape(len(b), 256, 4, 4)\n", " b = self.bgroupnorm(b)\n", " b = self.bupsampler(b)\n", " \n", " if depth_recon:\n", " d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)\n", " d = d.reshape(len(d), 256, 4, 4)\n", " d = self.dgroupnorm(d)\n", " d = self.dupsampler(d)\n", " \n", " return c, b, d\n", "\n", "\n", "class TimeEmbedding(nn.Module):\n", " def __init__(self, embedding_time_dim=512, num_past_voxels=15):\n", " super().__init__()\n", " self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)\n", " self.num_past_voxels = num_past_voxels\n", " self.embedding_time_dim = embedding_time_dim\n", "\n", " def forward(self, time):\n", " # time is (batch_size,)\n", " time = time.long()\n", " time = self.embedding_time(time)\n", " return time # (batch_size, embedding_time_dim)\n", " \n", "\n", "#model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)\n", "model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)\n", "\n", "model.backbone = BrainNetwork(h=1024, in_dim=1024, seq_len=4, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim) \n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "\n", "# test that the model works on some fake data\n", "b = torch.randn((256,4,1024))\n", "print(\"b.shape\",b.shape)\n", "with torch.no_grad():\n", " clip_, blur_, depth_ = model.backbone(b)\n", "print(clip_.shape, blur_.shape, depth_.shape)\n" ] }, { "cell_type": "code", "execution_count": 70, "id": "fa7aa4ab-64af-45d4-8ed8-259043c16c29", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "b.shape torch.Size([256, 2, 4096])\n", "torch.Size([256, 257, 768]) torch.Size([1]) torch.Size([1])\n" ] } ], "source": [ "\n", "voxel_ridge = torch.randn(512,4096)\n", "voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)\n", "print(\"b.shape\",voxel_ridge.shape)\n", "with torch.no_grad():\n", " clip_, blur_, depth_ = model.backbone(voxel_ridge)\n", "print(clip_.shape, blur_.shape, depth_.shape)" ] }, { "cell_type": "code", "execution_count": 64, "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total_steps 2339\n", "\n", "Done with model preparations!\n", "param counts:\n", "1,825,253,952 total\n", "1,825,253,952 trainable\n" ] } ], "source": [ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", "opt_grouped_parameters = [\n", " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", "]\n", "\n", "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)\n", "\n", "if lr_scheduler_type == 'linear':\n", " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", " optimizer,\n", " total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),\n", " last_epoch=-1\n", " )\n", "elif lr_scheduler_type == 'cycle':\n", " total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))\n", " print(\"total_steps\", total_steps)\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer, \n", " max_lr=max_lr,\n", " total_steps=total_steps,\n", " final_div_factor=1000,\n", " last_epoch=-1, pct_start=2/num_epochs\n", " )\n", " \n", "def save_ckpt(tag): \n", " ckpt_path = outdir+f'/{tag}.pth'\n", " print(f'saving {ckpt_path}',flush=True)\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " try:\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': unwrapped_model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'lr_scheduler': lr_scheduler.state_dict(),\n", " 'train_losses': losses,\n", " 'test_losses': test_losses,\n", " 'lrs': lrs,\n", " }, ckpt_path)\n", " except:\n", " print(\"Couldn't save... moving on to prevent crashing.\")\n", " del unwrapped_model\n", " \n", "print(\"\\nDone with model preparations!\")\n", "utils.count_params(model)" ] }, { "cell_type": "code", "execution_count": 49, "id": "5d600d8d-6a0d-4584-840e-c89521ff6364", "metadata": {}, "outputs": [], "source": [ "seq_len = 4" ] }, { "cell_type": "code", "execution_count": 57, "id": "7f8eef14-4806-4d11-af12-a36e4acad4df", "metadata": { "tags": [] }, "outputs": [], "source": [ "voxel_ridge = torch.randn(512,4096)\n", "voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)" ] }, { "cell_type": "code", "execution_count": 58, "id": "c578b861-2eb8-4b1c-8b2a-61268b9e9bef", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 4, 4096])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "voxel_ridge.shape" ] }, { "cell_type": "code", "execution_count": 55, "id": "3c937d01-eea5-4bad-906a-36d07b3c30a4", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([0., 1., 2., 3.])\n", "torch.Size([128, 15, 17])\n", "tensor([ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., -1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0.], dtype=torch.float64)\n", "torch.Size([128, 3, 15724])\n", "tensor([ True, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, True, False, False, False, False, True, False,\n", " False, False, False, False, True, True, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, True, False, False, False])\n", "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.3796, 2.2207, 1.7715, ..., 3.0859, 1.5938, 0.6274],\n", " [ 0.3513, 1.1621, 0.0634, ..., 0.7300, 0.2314, 0.1820]],\n", "\n", " [[ 0.6987, -1.0283, 0.8228, ..., -0.0037, -1.2510, -0.5884],\n", " [-0.6133, 0.8877, 1.2207, ..., 0.7803, 1.0908, 0.8496],\n", " [ 0.2534, -0.0533, 0.3386, ..., -0.4102, -0.5864, 0.3982]],\n", "\n", " [[ 0.2139, -0.6875, -0.3225, ..., 0.8423, -0.5718, -0.0623],\n", " [-0.3418, -1.8799, -1.8477, ..., -1.2676, -2.4707, -1.8398],\n", " [-0.3765, 1.1113, -1.2715, ..., 0.5498, 0.0262, -0.4839]],\n", "\n", " ...,\n", "\n", " [[-0.9604, 0.2109, -1.0596, ..., 0.2092, -0.7017, -0.7466],\n", " [ 0.8960, -1.1387, -1.4111, ..., -0.3269, -0.2957, -0.5659],\n", " [ 2.4707, 0.8105, 0.7910, ..., 0.6099, 1.0049, 0.3572]],\n", "\n", " [[ 0.8110, -0.8374, -0.8813, ..., 0.2411, 0.5176, 1.0039],\n", " [-0.0218, -0.6675, 0.0044, ..., 0.2050, 0.2045, 0.7485],\n", " [-0.3379, -1.2100, -0.0176, ..., -0.4167, -0.3860, -0.1342]],\n", "\n", " [[-0.2690, 0.8086, 0.2878, ..., 1.0840, 0.8159, -0.7021],\n", " [ 0.3433, 0.2678, -0.9961, ..., -0.2805, -1.2490, -0.7988],\n", " [ 0.0687, 0.5952, 0.0823, ..., -0.0443, 0.6401, -0.4092]]],\n", " dtype=torch.float16)\n", "torch.Size([128, 15, 17])\n", "tensor([ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n", " 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., -1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0.], dtype=torch.float64)\n", "torch.Size([128, 3, 15724])\n", "tensor([False, False, False, False, True, False, False, False, False, False,\n", " True, False, False, False, False, False, False, True, False, False,\n", " True, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, True, False,\n", " False, False, False, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, True, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False])\n", "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.3796, 2.2207, 1.7715, ..., 3.0859, 1.5938, 0.6274],\n", " [ 0.3513, 1.1621, 0.0634, ..., 0.7300, 0.2314, 0.1820]],\n", "\n", " [[ 0.6987, -1.0283, 0.8228, ..., -0.0037, -1.2510, -0.5884],\n", " [-0.6133, 0.8877, 1.2207, ..., 0.7803, 1.0908, 0.8496],\n", " [ 0.2534, -0.0533, 0.3386, ..., -0.4102, -0.5864, 0.3982]],\n", "\n", " [[ 0.2139, -0.6875, -0.3225, ..., 0.8423, -0.5718, -0.0623],\n", " [-0.3418, -1.8799, -1.8477, ..., -1.2676, -2.4707, -1.8398],\n", " [-0.3765, 1.1113, -1.2715, ..., 0.5498, 0.0262, -0.4839]],\n", "\n", " ...,\n", "\n", " [[-0.9604, 0.2109, -1.0596, ..., 0.2092, -0.7017, -0.7466],\n", " [ 0.8960, -1.1387, -1.4111, ..., -0.3269, -0.2957, -0.5659],\n", " [ 2.4707, 0.8105, 0.7910, ..., 0.6099, 1.0049, 0.3572]],\n", "\n", " [[ 0.8110, -0.8374, -0.8813, ..., 0.2411, 0.5176, 1.0039],\n", " [-0.0218, -0.6675, 0.0044, ..., 0.2050, 0.2045, 0.7485],\n", " [-0.3379, -1.2100, -0.0176, ..., -0.4167, -0.3860, -0.1342]],\n", "\n", " [[-0.2690, 0.8086, 0.2878, ..., 1.0840, 0.8159, -0.7021],\n", " [ 0.3433, 0.2678, -0.9961, ..., -0.2805, -1.2490, -0.7988],\n", " [ 0.0687, 0.5952, 0.0823, ..., -0.0443, 0.6401, -0.4092]]],\n", " dtype=torch.float16)\n", "torch.Size([128, 15, 17])\n", "tensor([ 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., -1., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n", " 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,\n", " 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,\n", " 0., -1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0.], dtype=torch.float64)\n", "torch.Size([128, 3, 15724])\n", "tensor([False, False, False, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, True, False, False, False, False, False, False, True,\n", " False, False, False, False, False, False, True, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, True,\n", " False, False, False, False, True, False, False, False, False, False,\n", " True, False, False, False, True, False, True, False, False, True,\n", " True, False, False, False, False, False, False, False, False, False,\n", " False, False, False, True, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False, False, False,\n", " False, False, False, False, False, False, False, False])\n", "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", " [ 0.3796, 2.2207, 1.7715, ..., 3.0859, 1.5938, 0.6274],\n", " [ 0.3513, 1.1621, 0.0634, ..., 0.7300, 0.2314, 0.1820]],\n", "\n", " [[ 0.6987, -1.0283, 0.8228, ..., -0.0037, -1.2510, -0.5884],\n", " [-0.6133, 0.8877, 1.2207, ..., 0.7803, 1.0908, 0.8496],\n", " [ 0.2534, -0.0533, 0.3386, ..., -0.4102, -0.5864, 0.3982]],\n", "\n", " [[ 0.2139, -0.6875, -0.3225, ..., 0.8423, -0.5718, -0.0623],\n", " [-0.3418, -1.8799, -1.8477, ..., -1.2676, -2.4707, -1.8398],\n", " [-0.3765, 1.1113, -1.2715, ..., 0.5498, 0.0262, -0.4839]],\n", "\n", " ...,\n", "\n", " [[-0.9604, 0.2109, -1.0596, ..., 0.2092, -0.7017, -0.7466],\n", " [ 0.8960, -1.1387, -1.4111, ..., -0.3269, -0.2957, -0.5659],\n", " [ 2.4707, 0.8105, 0.7910, ..., 0.6099, 1.0049, 0.3572]],\n", "\n", " [[ 0.8110, -0.8374, -0.8813, ..., 0.2411, 0.5176, 1.0039],\n", " [-0.0218, -0.6675, 0.0044, ..., 0.2050, 0.2045, 0.7485],\n", " [-0.3379, -1.2100, -0.0176, ..., -0.4167, -0.3860, -0.1342]],\n", "\n", " [[-0.2690, 0.8086, 0.2878, ..., 1.0840, 0.8159, -0.7021],\n", " [ 0.3433, 0.2678, -0.9961, ..., -0.2805, -1.2490, -0.7988],\n", " [ 0.0687, 0.5952, 0.0823, ..., -0.0443, 0.6401, -0.4092]]],\n", " dtype=torch.float16)\n" ] } ], "source": [ "pp = None\n", "for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " with torch.cuda.amp.autocast(dtype=data_type):\n", " #optimizer.zero_grad()\n", "\n", " voxel = voxels[behav[:,0,5].cpu().long()]#.to(device)\n", " image = images[behav[:,0,0].cpu().long()].float()#.to(device).float()\n", "\n", " past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()]#.to(device) # batch_size, 15, 15279\n", " past_15_times = torch.Tensor([i for i in range(seq_len)])#.to(device) # 15\n", " \n", " print(past_15_times)\n", " #for past in range(1):\n", " # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)\n", " \"\"\"\n", " #if blurry_recon:\n", " # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215\n", " blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215\n", "\n", " if depth_recon:\n", " # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)\n", " depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)\n", " depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()\n", " depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215\n", "\n", " if use_image_aug: \n", " image = img_augment(image)\n", "\n", " clip_target = clip_model.embed_image(image)\n", " assert not torch.any(torch.isnan(clip_target))\n", "\n", " if epoch < int(mixup_pct * num_epochs):\n", " voxel, perm, betas, select = utils.mixco(voxel)\n", " past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)\n", " \"\"\"\n", " for p in range(seq_len-1):\n", " print(past_behav.shape) #128, 15, 17\n", " print(past_behav[:,p,-1])\n", " print(past_15_voxels.shape) # 128, 1, 15724\n", " mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])\n", " print(mask) # 128\n", " past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])\n", " print(past_15_voxels)\n", " pp = past_15_voxels\n", " \n", " break" ] }, { "cell_type": "code", "execution_count": 54, "id": "5c7eb009-9fcf-4e35-9e1b-c0ca6418b6b8", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([0., 0., 0., ..., 0., 0., 0.], dtype=torch.float16)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pp[20, 0, :]" ] }, { "cell_type": "markdown", "id": "983f458b-35b8-49f2-b6db-80296cece730", "metadata": {}, "source": [ "# Weights and Biases" ] }, { "cell_type": "code", "execution_count": 66, "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", "metadata": {}, "outputs": [], "source": [ "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", " import wandb\n", " wandb_project = 'mindeyev2'\n", " wandb_run = model_name\n", " wandb_notes = ''\n", " \n", " print(f\"wandb {wandb_project} run {wandb_run}\")\n", " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", " wandb_config = {\n", " \"model_name\": model_name,\n", " \"global_batch_size\": global_batch_size,\n", " \"batch_size\": batch_size,\n", " \"num_epochs\": num_epochs,\n", " \"clip_scale\": clip_scale,\n", " \"blur_scale\": blur_scale,\n", " \"use_image_aug\": use_image_aug,\n", " \"max_lr\": max_lr,\n", " \"mixup_pct\": mixup_pct,\n", " \"num_train\": num_train,\n", " \"num_test\": num_test,\n", " \"ckpt_interval\": ckpt_interval,\n", " \"ckpt_saving\": ckpt_saving,\n", " \"seed\": seed,\n", " \"distributed\": distributed,\n", " \"num_devices\": num_devices,\n", " \"world_size\": world_size,\n", " \"train_url\": train_url,\n", " \"test_url\": test_url,\n", " }\n", " print(\"wandb_config:\\n\",wandb_config)\n", " if True: # wandb_auto_resume\n", " print(\"wandb_id:\",model_name)\n", " wandb.init(\n", " id = model_name,\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " resume=\"allow\",\n", " )\n", " else:\n", " wandb.init(\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " )\n", "else:\n", " wandb_log = False" ] }, { "cell_type": "markdown", "id": "d5690151-2131-4918-b750-e869cbd1a8a8", "metadata": {}, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 67, "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", "metadata": {}, "outputs": [], "source": [ "epoch = 0\n", "losses, test_losses, lrs = [], [], []\n", "best_test_loss = 1e9\n", "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", "\n", "# Optionally resume from checkpoint #\n", "if resume_from_ckpt:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "elif wandb_log:\n", " if wandb.run.resumed:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 68, "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[93m [WARNING] \u001b[0m cpu_adam cuda is missing or is incompatible with installed torch, only cpu ops can be compiled!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using /admin/home-ckadirt/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", "Emitting ninja build file /admin/home-ckadirt/.cache/torch_extensions/py310_cu117/cpu_adam/build.ninja...\n", "Building extension module cpu_adam...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ninja: no work to do.\n", "Time to load cpu_adam op: 3.0333330631256104 seconds\n", "[2023-11-03 13:59:08,078] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed info: version=0.9.5, git-hash=unknown, git-branch=unknown\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading extension module cpu_adam...\n" ] }, { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:1 │\n", "│ │\n", "│ ❱ 1 model, optimizer, train_dl, lr_scheduler = accelerator.prepare( │\n", "│ 2 model, optimizer, train_dl, lr_scheduler │\n", "│ 3 ) │\n", "│ 4 # leaving out test_dl since we will only have local_rank 0 device do evals │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/accelerator. │\n", "│ py:1178 in prepare │\n", "│ │\n", "│ 1175 │ │ │ elif self.device.type == \"xpu\" and is_xpu_available(): │\n", "│ 1176 │ │ │ │ args = self._prepare_ipex(*args) │\n", "│ 1177 │ │ if self.distributed_type == DistributedType.DEEPSPEED: │\n", "│ ❱ 1178 │ │ │ result = self._prepare_deepspeed(*args) │\n", "│ 1179 │ │ elif self.distributed_type == DistributedType.MEGATRON_LM: │\n", "│ 1180 │ │ │ result = self._prepare_megatron_lm(*args) │\n", "│ 1181 │ │ else: │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/accelerate/accelerator. │\n", "│ py:1505 in _prepare_deepspeed │\n", "│ │\n", "│ 1502 │ │ │ │ │ │ if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VA │\n", "│ 1503 │ │ │ │ │ │ │ kwargs[\"lr_scheduler\"] = scheduler │\n", "│ 1504 │ │ │ │\n", "│ ❱ 1505 │ │ │ engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) │\n", "│ 1506 │ │ │ if optimizer is not None: │\n", "│ 1507 │ │ │ │ optimizer = DeepSpeedOptimizerWrapper(optimizer) │\n", "│ 1508 │ │ │ if scheduler is not None: │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/__init__.py:1 │\n", "│ 51 in initialize │\n", "│ │\n", "│ 148 │ assert config != None, \"DeepSpeed requires --deepspeed_config to specify configurati │\n", "│ 149 │ │\n", "│ 150 │ if not isinstance(model, PipelineModule): │\n", "│ ❱ 151 │ │ config_class = DeepSpeedConfig(config, mpu) │\n", "│ 152 │ │ if config_class.hybrid_engine.enabled: │\n", "│ 153 │ │ │ engine = DeepSpeedHybridEngine(args=args, │\n", "│ 154 │ │ │ │ │ │ │ │ │ │ model=model, │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/confi │\n", "│ g.py:769 in __init__ │\n", "│ │\n", "│ 766 │ │ │\n", "│ 767 │ │ # Pass a copy so that user json is unmodified, e.g. for logging │\n", "│ 768 │ │ self._initialize_params(copy.copy(self._param_dict)) │\n", "│ ❱ 769 │ │ self._configure_train_batch_size() │\n", "│ 770 │ │ self._do_sanity_check() │\n", "│ 771 │ │\n", "│ 772 │ def _initialize_params(self, param_dict): │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/confi │\n", "│ g.py:942 in _configure_train_batch_size │\n", "│ │\n", "│ 939 │ │\n", "│ 940 │ def _configure_train_batch_size(self): │\n", "│ 941 │ │ self._set_batch_related_parameters() │\n", "│ ❱ 942 │ │ self._batch_assertion() │\n", "│ 943 │ │\n", "│ 944 │ def _do_sanity_check(self): │\n", "│ 945 │ │ self._do_error_check() │\n", "│ │\n", "│ /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/deepspeed/runtime/confi │\n", "│ g.py:890 in _batch_assertion │\n", "│ │\n", "│ 887 │ │ │\n", "│ 888 │ │ assert (grad_acc > 0), f\"Gradient accumulation steps: {grad_acc} has to be great │\n", "│ 889 │ │ │\n", "│ ❱ 890 │ │ assert train_batch == micro_batch * grad_acc * self.world_size, ( │\n", "│ 891 │ │ │ f\"Check batch related parameters. train_batch_size is not equal \" │\n", "│ 892 │ │ │ \"to micro_batch_per_gpu * gradient_acc_step * world_size \" │\n", "│ 893 │ │ │ f\"{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}\") │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "AssertionError: Check batch related parameters. train_batch_size is not equal to micro_batch_per_gpu * \n", "gradient_acc_step * world_size 256 != 32 * 1 * 1\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:1 │\n", "│ │\n", "│ ❱ 1 print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\") │\n", "│ 2 progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0)) │\n", "│ 3 test_image, test_voxel = None, None │\n", "│ 4 mse = nn.MSELoss() │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "NameError: name 'epoch' is not defined\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m