{ "cells": [ { "cell_type": "code", "execution_count": 30, "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "This application is used to convert notebook files (*.ipynb)\n", " to various other formats.\n", "\n", " WARNING: THE COMMANDLINE INTERFACE MAY CHANGE IN FUTURE RELEASES.\n", "\n", "Options\n", "=======\n", "The options below are convenience aliases to configurable class-options,\n", "as listed in the \"Equivalent to\" description-line of the aliases.\n", "To see all configurable class-options for some , use:\n", " --help-all\n", "\n", "--debug\n", " set log level to logging.DEBUG (maximize logging output)\n", " Equivalent to: [--Application.log_level=10]\n", "--show-config\n", " Show the application's configuration (human-readable format)\n", " Equivalent to: [--Application.show_config=True]\n", "--show-config-json\n", " Show the application's configuration (json format)\n", " Equivalent to: [--Application.show_config_json=True]\n", "--generate-config\n", " generate default config file\n", " Equivalent to: [--JupyterApp.generate_config=True]\n", "-y\n", " Answer yes to any questions instead of prompting.\n", " Equivalent to: [--JupyterApp.answer_yes=True]\n", "--execute\n", " Execute the notebook prior to export.\n", " Equivalent to: [--ExecutePreprocessor.enabled=True]\n", "--allow-errors\n", " Continue notebook execution even if one of the cells throws an error and include the error message in the cell output (the default behaviour is to abort conversion). This flag is only relevant if '--execute' was specified, too.\n", " Equivalent to: [--ExecutePreprocessor.allow_errors=True]\n", "--stdin\n", " read a single notebook file from stdin. Write the resulting notebook with default basename 'notebook.*'\n", " Equivalent to: [--NbConvertApp.from_stdin=True]\n", "--stdout\n", " Write notebook output to stdout instead of files.\n", " Equivalent to: [--NbConvertApp.writer_class=StdoutWriter]\n", "--inplace\n", " Run nbconvert in place, overwriting the existing notebook (only\n", " relevant when converting to notebook format)\n", " Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory=]\n", "--clear-output\n", " Clear output of current file and save in place,\n", " overwriting the existing notebook.\n", " Equivalent to: [--NbConvertApp.use_output_suffix=False --NbConvertApp.export_format=notebook --FilesWriter.build_directory= --ClearOutputPreprocessor.enabled=True]\n", "--no-prompt\n", " Exclude input and output prompts from converted document.\n", " Equivalent to: [--TemplateExporter.exclude_input_prompt=True --TemplateExporter.exclude_output_prompt=True]\n", "--no-input\n", " Exclude input cells and output prompts from converted document.\n", " This mode is ideal for generating code-free reports.\n", " Equivalent to: [--TemplateExporter.exclude_output_prompt=True --TemplateExporter.exclude_input=True --TemplateExporter.exclude_input_prompt=True]\n", "--allow-chromium-download\n", " Whether to allow downloading chromium if no suitable version is found on the system.\n", " Equivalent to: [--WebPDFExporter.allow_chromium_download=True]\n", "--disable-chromium-sandbox\n", " Disable chromium security sandbox when converting to PDF..\n", " Equivalent to: [--WebPDFExporter.disable_sandbox=True]\n", "--show-input\n", " Shows code input. This flag is only useful for dejavu users.\n", " Equivalent to: [--TemplateExporter.exclude_input=False]\n", "--embed-images\n", " Embed the images as base64 dataurls in the output. This flag is only useful for the HTML/WebPDF/Slides exports.\n", " Equivalent to: [--HTMLExporter.embed_images=True]\n", "--sanitize-html\n", " Whether the HTML in Markdown cells and cell outputs should be sanitized..\n", " Equivalent to: [--HTMLExporter.sanitize_html=True]\n", "--log-level=\n", " Set the log level by value or name.\n", " Choices: any of [0, 10, 20, 30, 40, 50, 'DEBUG', 'INFO', 'WARN', 'ERROR', 'CRITICAL']\n", " Default: 30\n", " Equivalent to: [--Application.log_level]\n", "--config=\n", " Full path of a config file.\n", " Default: ''\n", " Equivalent to: [--JupyterApp.config_file]\n", "--to=\n", " The export format to be used, either one of the built-in formats\n", " ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'rst', 'script', 'slides', 'webpdf']\n", " or a dotted object name that represents the import path for an\n", " ``Exporter`` class\n", " Default: ''\n", " Equivalent to: [--NbConvertApp.export_format]\n", "--template=\n", " Name of the template to use\n", " Default: ''\n", " Equivalent to: [--TemplateExporter.template_name]\n", "--template-file=\n", " Name of the template file to use\n", " Default: None\n", " Equivalent to: [--TemplateExporter.template_file]\n", "--theme=\n", " Template specific theme(e.g. the name of a JupyterLab CSS theme distributed\n", " as prebuilt extension for the lab template)\n", " Default: 'light'\n", " Equivalent to: [--HTMLExporter.theme]\n", "--sanitize_html=\n", " Whether the HTML in Markdown cells and cell outputs should be sanitized.This\n", " should be set to True by nbviewer or similar tools.\n", " Default: False\n", " Equivalent to: [--HTMLExporter.sanitize_html]\n", "--writer=\n", " Writer class used to write the\n", " results of the conversion\n", " Default: 'FilesWriter'\n", " Equivalent to: [--NbConvertApp.writer_class]\n", "--post=\n", " PostProcessor class used to write the\n", " results of the conversion\n", " Default: ''\n", " Equivalent to: [--NbConvertApp.postprocessor_class]\n", "--output=\n", " overwrite base name use for output files.\n", " can only be used when converting one notebook at a time.\n", " Default: ''\n", " Equivalent to: [--NbConvertApp.output_base]\n", "--output-dir=\n", " Directory to write output(s) to. Defaults\n", " to output to the directory of each notebook. To recover\n", " previous default behaviour (outputting to the current\n", " working directory) use . as the flag value.\n", " Default: ''\n", " Equivalent to: [--FilesWriter.build_directory]\n", "--reveal-prefix=\n", " The URL prefix for reveal.js (version 3.x).\n", " This defaults to the reveal CDN, but can be any url pointing to a copy\n", " of reveal.js.\n", " For speaker notes to work, this must be a relative path to a local\n", " copy of reveal.js: e.g., \"reveal.js\".\n", " If a relative path is given, it must be a subdirectory of the\n", " current directory (from which the server is run).\n", " See the usage documentation\n", " (https://nbconvert.readthedocs.io/en/latest/usage.html#reveal-js-html-slideshow)\n", " for more details.\n", " Default: ''\n", " Equivalent to: [--SlidesExporter.reveal_url_prefix]\n", "--nbformat=\n", " The nbformat version to write.\n", " Use this to downgrade notebooks.\n", " Choices: any of [1, 2, 3, 4]\n", " Default: 4\n", " Equivalent to: [--NotebookExporter.nbformat_version]\n", "\n", "Examples\n", "--------\n", "\n", " The simplest way to use nbconvert is\n", "\n", " > jupyter nbconvert mynotebook.ipynb --to html\n", "\n", " Options include ['asciidoc', 'custom', 'html', 'latex', 'markdown', 'notebook', 'pdf', 'python', 'rst', 'script', 'slides', 'webpdf'].\n", "\n", " > jupyter nbconvert --to latex mynotebook.ipynb\n", "\n", " Both HTML and LaTeX support multiple output templates. LaTeX includes\n", " 'base', 'article' and 'report'. HTML includes 'basic', 'lab' and\n", " 'classic'. You can specify the flavor of the format used.\n", "\n", " > jupyter nbconvert --to html --template lab mynotebook.ipynb\n", "\n", " You can also pipe the output to stdout, rather than a file\n", "\n", " > jupyter nbconvert mynotebook.ipynb --stdout\n", "\n", " PDF is generated via latex\n", "\n", " > jupyter nbconvert mynotebook.ipynb --to pdf\n", "\n", " You can get (and serve) a Reveal.js-powered slideshow\n", "\n", " > jupyter nbconvert myslides.ipynb --to slides --post serve\n", "\n", " Multiple notebooks can be given at the command line in a couple of\n", " different ways:\n", "\n", " > jupyter nbconvert notebook*.ipynb\n", " > jupyter nbconvert notebook1.ipynb notebook2.ipynb\n", "\n", " or you can specify the notebooks list in a config file, containing::\n", "\n", " c.NbConvertApp.notebooks = [\"my_notebook.ipynb\"]\n", "\n", " > jupyter nbconvert --config mycfg.py\n", "\n", "To see all available configurables, use `--help-all`.\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[NbConvertApp] WARNING | pattern 'Train-with-memory_cat.ipynb' matched no files\n" ] }, { "data": { "text/plain": [ "255" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", "from subprocess import call\n", "command = \"jupyter nbconvert Train-with-memory_cat.ipynb --to python\"\n", "call(command,shell=True)" ] }, { "cell_type": "markdown", "id": "b0f0f4f3", "metadata": {}, "source": [ "# Import packages & functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bad764b-45c1-45ce-a716-8d055e09821a", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2023-10-15 21:57:10,361] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "import os\n", "import sys\n", "import json\n", "import argparse\n", "import numpy as np\n", "import math\n", "from einops import rearrange\n", "import time\n", "import random\n", "import h5py\n", "from tqdm import tqdm\n", "\n", "import webdataset as wds\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms\n", "\n", "from accelerate import Accelerator, DeepSpeedPlugin\n", "\n", "# tf32 data type is faster than standard float32\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", "# custom functions #\n", "import utils\n", "\n", "global_batch_size = 128 #128" ] }, { "cell_type": "code", "execution_count": 3, "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCAL RANK 0\n" ] } ], "source": [ "### Multi-GPU config ###\n", "local_rank = os.getenv('RANK')\n", "if local_rank is None: \n", " local_rank = 0\n", "else:\n", " local_rank = int(local_rank)\n", "print(\"LOCAL RANK \", local_rank) \n", "\n", "num_devices = torch.cuda.device_count()\n", "if num_devices==0: num_devices = 1\n", "\n", "accelerator = Accelerator(split_batches=False)\n", "\n", "### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above \"accelerator = \" line) ###\n", "\n", "# if num_devices <= 1 and utils.is_interactive():\n", "# # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", "# os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", "# os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", "# os.environ[\"RANK\"] = \"0\"\n", "# os.environ[\"LOCAL_RANK\"] = \"0\"\n", "# os.environ[\"WORLD_SIZE\"] = \"1\"\n", "# os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", "# global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n", "\n", "# # alter the deepspeed config according to your global and local batch size\n", "# if local_rank == 0:\n", "# with open('deepspeed_config_stage2.json', 'r') as file:\n", "# config = json.load(file)\n", "# config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", "# config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", "# with open('deepspeed_config_stage2.json', 'w') as file:\n", "# json.dump(config, file)\n", "# else:\n", "# # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", "# time.sleep(10)\n", "# deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", "# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PID of this process = 2525880\n", "device: cuda\n", "Distributed environment: NO\n", "Num processes: 1\n", "Process index: 0\n", "Local process index: 0\n", "Device: cuda\n", "\n", "Mixed precision type: no\n", "\n", "distributed = False num_devices = 1 local rank = 0 world size = 1\n" ] } ], "source": [ "print(\"PID of this process =\",os.getpid())\n", "device = accelerator.device\n", "print(\"device:\",device)\n", "num_workers = num_devices\n", "print(accelerator.state)\n", "world_size = accelerator.state.num_processes\n", "distributed = not accelerator.state.distributed_type == 'NO'\n", "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n", "print = accelerator.print # only print if local_rank=0" ] }, { "cell_type": "markdown", "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", "metadata": { "tags": [] }, "source": [ "# Configurations" ] }, { "cell_type": "code", "execution_count": 5, "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=test', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=3e-5', '--mixup_pct=.66', '--num_epochs=12', '--ckpt_interval=999', '--no-use_image_aug']\n" ] } ], "source": [ "# if running this interactively, can specify jupyter_args here for argparser to use\n", "if utils.is_interactive():\n", " # Example use\n", " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", " --model_name=test \\\n", " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n", " --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug\"\n", "\n", " jupyter_args = jupyter_args.split()\n", " print(jupyter_args)\n", " \n", " from IPython.display import clear_output # function to clear print outputs in cell\n", " %load_ext autoreload \n", " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", " %autoreload 2 " ] }, { "cell_type": "code", "execution_count": 6, "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global batch_size 128\n", "batch_size 128\n" ] } ], "source": [ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", "parser.add_argument(\n", " \"--model_name\", type=str, default=\"testing\",\n", " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", ")\n", "parser.add_argument(\n", " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", " help=\"Path to where NSD data is stored / where to download it to\",\n", ")\n", "parser.add_argument(\n", " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", ")\n", "parser.add_argument(\n", " \"--batch_size\", type=int, default=32,\n", " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"whether to log to wandb\",\n", ")\n", "parser.add_argument(\n", " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"if not using wandb and want to resume from a ckpt\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_project\",type=str,default=\"stability\",\n", " help=\"wandb project name\",\n", ")\n", "parser.add_argument(\n", " \"--mixup_pct\",type=float,default=.33,\n", " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", ")\n", "parser.add_argument(\n", " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to use image augmentation\",\n", ")\n", "parser.add_argument(\n", " \"--num_epochs\",type=int,default=240,\n", " help=\"number of epochs of training\",\n", ")\n", "parser.add_argument(\n", " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", ")\n", "parser.add_argument(\n", " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", ")\n", "parser.add_argument(\n", " \"--ckpt_interval\",type=int,default=5,\n", " help=\"save backup ckpt and reconstruct every x epochs\",\n", ")\n", "parser.add_argument(\n", " \"--seed\",type=int,default=42,\n", ")\n", "parser.add_argument(\n", " \"--max_lr\",type=float,default=3e-4,\n", ")\n", "parser.add_argument(\n", " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n", " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n", ")\n", "\n", "if utils.is_interactive():\n", " args = parser.parse_args(jupyter_args)\n", "else:\n", " args = parser.parse_args()\n", "\n", "# create global variables without the args prefix\n", "for attribute_name in vars(args).keys():\n", " globals()[attribute_name] = getattr(args, attribute_name)\n", "\n", "print(\"global batch_size\", batch_size)\n", "batch_size = int(batch_size / num_devices)\n", "print(\"batch_size\", batch_size)" ] }, { "cell_type": "code", "execution_count": 7, "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "outdir = os.path.abspath(f'../train_mem_logs/{model_name}')\n", "if not os.path.exists(outdir):\n", " os.makedirs(outdir,exist_ok=True)\n", "if use_image_aug:\n", " import kornia\n", " from kornia.augmentation.container import AugmentationSequential\n", " img_augment = AugmentationSequential(\n", " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", " kornia.augmentation.Resize((224, 224)),\n", " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", " kornia.augmentation.RandomGrayscale(p=0.3),\n", " same_on_batch=False,\n", " data_keys=[\"input\"],\n", " )" ] }, { "cell_type": "markdown", "id": "42d13c25-1369-4c49-81d4-83d713586096", "metadata": { "tags": [] }, "source": [ "# Prep data, models, and dataloaders" ] }, { "cell_type": "markdown", "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", "metadata": {}, "source": [ "## Dataloader" ] }, { "cell_type": "code", "execution_count": 8, "id": "81084834-035f-4465-ad59-59e6b806a2f5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" ] } ], "source": [ "if subj==1:\n", " num_train = 24958\n", " num_test = 2770\n", "test_batch_size = num_test\n", "\n", "def my_split_by_node(urls): return urls\n", " \n", "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", "print(train_url)\n", "\n", "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n", "\n", "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", "print(test_url)\n", "\n", "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)" ] }, { "cell_type": "markdown", "id": "203b060a-2dd2-4c35-929b-c576be82eb52", "metadata": {}, "source": [ "### check dataloaders are working" ] }, { "cell_type": "code", "execution_count": 9, "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", "metadata": {}, "outputs": [], "source": [ "# test_indices = []\n", "# test_images = []\n", "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n", "# test_images = np.append(test_images, behav[:,0,0].numpy())\n", "# test_indices = test_indices.astype(np.int16)\n", "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n", "# print(\"---\\n\")\n", "\n", "# train_indices = []\n", "# train_images = []\n", "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n", "# train_images = np.append(train_images, behav[:,0,0].numpy())\n", "# train_indices = train_indices.astype(np.int16)\n", "# print(train_i, (train_i+1) * batch_size, len(train_indices))\n", "\n", "# # train_images = np.hstack((train_images, test_images))\n", "# # print(\"WARNING: ADDED TEST IMAGES TO TRAIN IMAGES\")" ] }, { "cell_type": "markdown", "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", "metadata": {}, "source": [ "## Load data and images" ] }, { "cell_type": "code", "execution_count": 10, "id": "039dd330-7339-4f88-8f00-45f95e47baa0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subj01 betas loaded into memory\n", "voxels torch.Size([27750, 15729])\n", "images torch.Size([73000, 3, 224, 224])\n" ] } ], "source": [ "# load betas\n", "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", "voxels = f['betas'][:]\n", "print(f\"subj0{subj} betas loaded into memory\")\n", "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n", "if subj==1:\n", " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n", "print(\"voxels\", voxels.shape)\n", "num_voxels = voxels.shape[-1]\n", "\n", "# load orig images\n", "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", "images = f['images'][:]\n", "images = torch.Tensor(images).to(\"cpu\").half()\n", "print(\"images\", images.shape)" ] }, { "cell_type": "markdown", "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", "metadata": {}, "source": [ "## Load models" ] }, { "cell_type": "markdown", "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", "metadata": {}, "source": [ "### CLIP image embeddings model" ] }, { "cell_type": "code", "execution_count": 11, "id": "b0420dc0-199e-4c1a-857d-b1747058b467", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "from models import Clipper\n", "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)\n", "\n", "clip_seq_dim = 257\n", "clip_emb_dim = 768\n", "hidden_dim = 4096" ] }, { "cell_type": "markdown", "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", "metadata": {}, "source": [ "### SD VAE (blurry images)" ] }, { "cell_type": "code", "execution_count": 12, "id": "01baff79-8114-482b-b115-6f05aa8ad691", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "83,653,863 total\n", "0 trainable\n" ] } ], "source": [ "from diffusers import AutoencoderKL\n", "autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", "# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", "autoenc.eval()\n", "autoenc.requires_grad_(False)\n", "autoenc.to(device)\n", "utils.count_params(autoenc)" ] }, { "cell_type": "markdown", "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", "metadata": {}, "source": [ "### MindEye modules" ] }, { "cell_type": "code", "execution_count": 13, "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MindEyeModule()" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MindEyeModule(nn.Module):\n", " def __init__(self):\n", " super(MindEyeModule, self).__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "model = MindEyeModule()\n", "model" ] }, { "cell_type": "code", "execution_count": 14, "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "64,430,080 total\n", "64,430,080 trainable\n", "param counts:\n", "64,430,080 total\n", "64,430,080 trainable\n", "torch.Size([2, 1, 15729]) torch.Size([2, 1, 4096])\n" ] } ], "source": [ "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.out_features = out_features\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", "utils.count_params(model.ridge)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,1,voxels.shape[1]))\n", "print(b.shape, model.ridge(b).shape)" ] }, { "cell_type": "code", "execution_count": 22, "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "1,071,175,044 total\n", "1,071,175,044 trainable\n", "param counts:\n", "1,621,688,708 total\n", "1,621,688,708 trainable\n", "torch.Size([2, 8192])\n", "torch.Size([2, 257, 768]) torch.Size([2, 4, 28, 28])\n" ] } ], "source": [ "from functools import partial\n", "from diffusers.models.vae import Decoder\n", "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):\n", " super().__init__()\n", " self.blurry_dim = blurry_dim\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.lin0 = nn.Linear(in_dim, h)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", " self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.clip_size = clip_size\n", " self.clip_proj = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", " self.upsampler = Decoder(\n", " in_channels=64,\n", " out_channels=4,\n", " up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " block_out_channels=[64, 128, 256],\n", " layers_per_block=1,\n", " )\n", " \n", " def forward(self, x):\n", " x = self.lin0(x)\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " x = x.reshape(len(x), -1)\n", " x = self.lin1(x)\n", " b = self.blin1(x)\n", " b = self.upsampler(b.reshape(len(b), -1, 7, 7))\n", " c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))\n", " return c, b\n", "\n", "model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) \n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,8192))\n", "print(b.shape)\n", "clip_, blur_ = model.backbone(b)\n", "print(clip_.shape, blur_.shape)" ] }, { "cell_type": "code", "execution_count": 23, "id": "a34204d0-d268-41ee-8eea-042525262c47", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "150,481,920 total\n", "150,481,920 trainable\n", "param counts:\n", "335,601,664 total\n", "335,601,664 trainable\n", "param counts:\n", "1,621,688,708 total\n", "1,621,688,708 trainable\n" ] } ], "source": [ "# memory model\n", "\n", "from timm.layers.mlp import Mlp\n", "\n", "class MemoryEncoder(nn.Module):\n", " def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.15):\n", " super().__init__()\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.out_dim = out_dim\n", " self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)\n", " self.final_input_dim = in_dim + embedding_time_dim\n", " self.lin0 = nn.Linear(self.final_input_dim, h)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.num_past_voxels = num_past_voxels\n", " self.embedding_time_dim = embedding_time_dim\n", " self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))\n", "\n", "\n", " def forward(self, x, time):\n", " time = time.long()\n", " time = self.embedding_time(time)\n", " x = torch.cat((x, time), dim=-1)\n", " x = self.lin0(x)\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " x = x.reshape(len(x), -1)\n", " x = self.lin1(x)\n", " return x\n", " \n", "# # test the memory encoder\n", "# memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)\n", "\n", "# device = torch.device(\"cpu\")\n", "# memory_encoder.to(device)\n", "\n", "# # count params\n", "# total_parameters = 0\n", "# for parameter in memory_encoder.parameters():\n", "# total_parameters += parameter.numel()\n", "\n", "# rand_input = torch.randn((2, 15279)).to(device)\n", "# rand_time = torch.randint(0, 15, (2,)).to(device)\n", "# print(rand_input.shape, rand_time.shape)\n", "# memory_encoder(rand_input, rand_time).shape\n", "\n", "class MemoryCompressor(nn.Module):\n", " def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15):\n", " super().__init__()\n", " self.num_past = num_past\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.final_input_dim = in_dim * num_past\n", " self.lin0 = nn.Linear(self.final_input_dim, h)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, output_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.num_past = num_past\n", " self.output_dim = output_dim\n", "\n", " def forward(self, x):\n", " # x is (batch_size, num_past, in_dim)\n", " x = x.reshape(len(x), -1)\n", " x = self.lin0(x)\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " x = x.reshape(len(x), -1)\n", " x = self.lin1(x)\n", " return x\n", " \n", "# # test the memory compressor\n", "# memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)\n", "\n", "# device = torch.device(\"cpu\")\n", "# memory_compressor.to(device)\n", "\n", "# # count params\n", "# total_parameters = 0\n", "# for parameter in memory_compressor.parameters():\n", "# total_parameters += parameter.numel()\n", "\n", "# rand_input = torch.randn((2, 15, 768)).to(device)\n", "# print(rand_input.shape)\n", "# memory_compressor(rand_input).shape\n", "\n", "model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)\n", "model.memory_compressor = MemoryCompressor(in_dim=model.memory_encoder.out_dim, num_past=15, output_dim=4096)\n", "\n", "utils.count_params(model.memory_encoder)\n", "utils.count_params(model.memory_compressor)\n", "utils.count_params(model)\n", "\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Done with model preparations!\n", "param counts:\n", "1,621,688,708 total\n", "1,621,688,708 trainable\n" ] } ], "source": [ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", "opt_grouped_parameters = [\n", " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", " {'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},\n", "]\n", "\n", "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))\n", "\n", "if lr_scheduler_type == 'linear':\n", " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", " optimizer,\n", " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n", " last_epoch=-1\n", " )\n", "elif lr_scheduler_type == 'cycle':\n", " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer, \n", " max_lr=max_lr,\n", " total_steps=total_steps,\n", " final_div_factor=1000,\n", " last_epoch=-1, pct_start=2/num_epochs\n", " )\n", " \n", "def save_ckpt(tag): \n", " ckpt_path = outdir+f'/{tag}.pth'\n", " print(f'saving {ckpt_path}',flush=True)\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " try:\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': unwrapped_model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'lr_scheduler': lr_scheduler.state_dict(),\n", " 'train_losses': losses,\n", " 'test_losses': test_losses,\n", " 'lrs': lrs,\n", " }, ckpt_path)\n", " except:\n", " print(\"Couldn't save... moving on to prevent crashing.\")\n", " del unwrapped_model\n", " \n", "print(\"\\nDone with model preparations!\")\n", "utils.count_params(model)" ] }, { "cell_type": "code", "execution_count": 18, "id": "3edca702-e148-4f2d-82b9-1c42bca5f73f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n",
       " in <module>:1                                                                                    \n",
       "                                                                                                  \n",
       " 1 nnnn                                                                                         \n",
       "   2                                                                                              \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "NameError: name 'nnnn' is not defined\n",
       "
\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m1\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1 nnnn \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mNameError: \u001b[0mname \u001b[32m'nnnn'\u001b[0m is not defined\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [] }, { "cell_type": "markdown", "id": "983f458b-35b8-49f2-b6db-80296cece730", "metadata": {}, "source": [ "# Weights and Biases" ] }, { "cell_type": "code", "execution_count": 25, "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "wandb stability run test\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "wandb_config:\n", " {'model_name': 'test', 'batch_size': 128, 'num_epochs': 12, 'use_image_aug': False, 'max_lr': 3e-05, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1}\n" ] }, { "data": { "text/html": [ "wandb version 0.15.12 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231015_224404-lbkf7608" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run mem1 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://stability.wandb.io/ckadirt/stability" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://stability.wandb.io/ckadirt/stability/runs/lbkf7608" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# params for wandb\n", "wandb_log = True\n", "if local_rank==0 and wandb_log: # only use main process for wandb logging\n", " import wandb\n", " \n", " wandb_project = 'stability'\n", " wandb_run = model_name\n", " wandb_notes = ''\n", " \n", " print(f\"wandb {wandb_project} run {wandb_run}\")\n", " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", " wandb_config = {\n", " \"model_name\": model_name,\n", " \"batch_size\": batch_size,\n", " \"num_epochs\": num_epochs,\n", " \"use_image_aug\": use_image_aug,\n", " \"max_lr\": max_lr,\n", " \"lr_scheduler_type\": lr_scheduler_type,\n", " \"mixup_pct\": mixup_pct,\n", " \"num_train\": num_train,\n", " \"num_test\": num_test,\n", " \"seed\": seed,\n", " \"distributed\": distributed,\n", " \"num_devices\": num_devices,\n", " \"world_size\": world_size,\n", " }\n", " print(\"wandb_config:\\n\",wandb_config)\n", " if False: # wandb_auto_resume\n", " print(\"wandb_id:\",model_name)\n", " wandb.init(\n", " id = model_name,\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " resume=\"allow\",\n", " )\n", " else:\n", " wandb.init(\n", " project=wandb_project,\n", " name=model_name,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " )\n", "else:\n", " wandb_log = False" ] }, { "cell_type": "markdown", "id": "5b0ae095-3203-4eb8-8606-acc2db6ccf20", "metadata": {}, "source": [ "# More custom functions" ] }, { "cell_type": "code", "execution_count": 26, "id": "827ead88-7eb3-47cc-82da-31565063b927", "metadata": {}, "outputs": [], "source": [ "# using the same preprocessing as was used in MindEye + BrainDiffuser\n", "pixcorr_preprocess = transforms.Compose([\n", " transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n", "])\n", "def pixcorr(images,brains):\n", " # Flatten images while keeping the batch dimension\n", " all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)\n", " all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)\n", " corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()\n", " return corrmean" ] }, { "cell_type": "markdown", "id": "d5690151-2131-4918-b750-e869cbd1a8a8", "metadata": {}, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 27, "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", "metadata": {}, "outputs": [], "source": [ "epoch = 0\n", "losses, test_losses, lrs = [], [], []\n", "best_test_loss = 1e9\n", "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", "\n", "# Optionally resume from checkpoint #\n", "if resume_from_ckpt:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "elif wandb_log:\n", " if wandb.run.resumed:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 28, "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", "metadata": {}, "outputs": [], "source": [ "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n", "model, optimizer, train_dl, test_dl, lr_scheduler\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "60be0d5f-3e94-4612-9373-61b53d836393", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test starting with epoch 0 / 12\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/12 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", " in <module>:34 \n", " \n", " 31 │ │ │ past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_s \n", " 32 │ │ │ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15 \n", " 33 │ │ │ \n", " 34 │ │ │ blurry_image_enc = autoenc.encode(image).latent_dist.mode() \n", " 35 │ │ │ \n", " 36 │ │ │ if use_image_aug: image = img_augment(image) \n", " 37 \n", " \n", " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/utils/accelerate_utils.py:46 in wrapper \n", " \n", " 43 def wrapper(self, *args, **kwargs): \n", " 44 │ │ if hasattr(self, \"_hf_hook\") and hasattr(self._hf_hook, \"pre_forward\"): \n", " 45 │ │ │ self._hf_hook.pre_forward(self) \n", " 46 │ │ return method(self, *args, **kwargs) \n", " 47 \n", " 48 return wrapper \n", " 49 \n", " \n", " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/autoencoder_kl.py:258 in encode \n", " \n", " 255 │ │ │ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] \n", " 256 │ │ │ h = torch.cat(encoded_slices) \n", " 257 │ │ else: \n", " 258 │ │ │ h = self.encoder(x) \n", " 259 │ │ \n", " 260 │ │ moments = self.quant_conv(h) \n", " 261 │ │ posterior = DiagonalGaussianDistribution(moments) \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", " .py:1501 in _call_impl \n", " \n", " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", " 1501 │ │ │ return forward_call(*args, **kwargs) \n", " 1502 │ │ # Do not call functions when jit is used \n", " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", " 1504 │ │ backward_pre_hooks = [] \n", " \n", " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/vae.py:141 in forward \n", " \n", " 138 │ │ else: \n", " 139 │ │ │ # down \n", " 140 │ │ │ for down_block in self.down_blocks: \n", " 141 │ │ │ │ sample = down_block(sample) \n", " 142 │ │ │ \n", " 143 │ │ │ # middle \n", " 144 │ │ │ sample = self.mid_block(sample) \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", " .py:1501 in _call_impl \n", " \n", " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", " 1501 │ │ │ return forward_call(*args, **kwargs) \n", " 1502 │ │ # Do not call functions when jit is used \n", " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", " 1504 │ │ backward_pre_hooks = [] \n", " \n", " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/unet_2d_blocks.py:1247 in forward \n", " \n", " 1244 \n", " 1245 def forward(self, hidden_states, scale: float = 1.0): \n", " 1246 │ │ for resnet in self.resnets: \n", " 1247 │ │ │ hidden_states = resnet(hidden_states, temb=None, scale=scale) \n", " 1248 │ │ \n", " 1249 │ │ if self.downsamplers is not None: \n", " 1250 │ │ │ for downsampler in self.downsamplers: \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", " .py:1501 in _call_impl \n", " \n", " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", " 1501 │ │ │ return forward_call(*args, **kwargs) \n", " 1502 │ │ # Do not call functions when jit is used \n", " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", " 1504 │ │ backward_pre_hooks = [] \n", " \n", " /fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/resnet.py:650 in forward \n", " \n", " 647 │ │ if self.time_embedding_norm == \"ada_group\" or self.time_embedding_norm == \"spati \n", " 648 │ │ │ hidden_states = self.norm2(hidden_states, temb) \n", " 649 │ │ else: \n", " 650 │ │ │ hidden_states = self.norm2(hidden_states) \n", " 651 │ │ \n", " 652 │ │ if temb is not None and self.time_embedding_norm == \"scale_shift\": \n", " 653 │ │ │ scale, shift = torch.chunk(temb, 2, dim=1) \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module \n", " .py:1501 in _call_impl \n", " \n", " 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks \n", " 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks \n", " 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): \n", " 1501 │ │ │ return forward_call(*args, **kwargs) \n", " 1502 │ │ # Do not call functions when jit is used \n", " 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] \n", " 1504 │ │ backward_pre_hooks = [] \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/normal \n", " ization.py:273 in forward \n", " \n", " 270 │ │ │ init.zeros_(self.bias) \n", " 271 \n", " 272 def forward(self, input: Tensor) -> Tensor: \n", " 273 │ │ return F.group_norm( \n", " 274 │ │ │ input, self.num_groups, self.weight, self.bias, self.eps) \n", " 275 \n", " 276 def extra_repr(self) -> str: \n", " \n", " /admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/functional.py: \n", " 2530 in group_norm \n", " \n", " 2527 if input.dim() < 2: \n", " 2528 │ │ raise RuntimeError(f\"Expected at least 2 dimensions for input tensor but receive \n", " 2529 _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list( \n", " 2530 return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.e \n", " 2531 \n", " 2532 \n", " 2533 def local_response_norm(input: Tensor, size: int, alpha: float = 1e-4, beta: float = 0.7 \n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "OutOfMemoryError: CUDA out of memory. Tried to allocate 3.06 GiB (GPU 0; 39.56 GiB total capacity; 33.04 GiB \n", "already allocated; 752.56 MiB free; 37.34 GiB reserved in total by PyTorch) If reserved memory is >> allocated \n", "memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", "PYTORCH_CUDA_ALLOC_CONF\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m34\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 31 \u001b[0m\u001b[2m│ │ │ \u001b[0mpast_15_voxels = voxels[past_behav[:,:,\u001b[94m5\u001b[0m].cpu().long()].to(device) \u001b[2m# batch_s\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 32 \u001b[0m\u001b[2m│ │ │ \u001b[0mpast_15_times = torch.Tensor([i \u001b[94mfor\u001b[0m i \u001b[95min\u001b[0m \u001b[96mrange\u001b[0m(\u001b[94m15\u001b[0m)]).to(device) \u001b[2m# 15\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 33 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 34 \u001b[2m│ │ │ \u001b[0mblurry_image_enc = autoenc.encode(image).latent_dist.mode() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 35 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 36 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mif\u001b[0m use_image_aug: image = img_augment(image) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 37 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/utils/\u001b[0m\u001b[1;33maccelerate_utils.py\u001b[0m:\u001b[94m46\u001b[0m in \u001b[92mwrapper\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m43 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mwrapper\u001b[0m(\u001b[96mself\u001b[0m, *args, **kwargs): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m44 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mhasattr\u001b[0m(\u001b[96mself\u001b[0m, \u001b[33m\"\u001b[0m\u001b[33m_hf_hook\u001b[0m\u001b[33m\"\u001b[0m) \u001b[95mand\u001b[0m \u001b[96mhasattr\u001b[0m(\u001b[96mself\u001b[0m._hf_hook, \u001b[33m\"\u001b[0m\u001b[33mpre_forward\u001b[0m\u001b[33m\"\u001b[0m): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m45 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m._hf_hook.pre_forward(\u001b[96mself\u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m46 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m method(\u001b[96mself\u001b[0m, *args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m47 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m48 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m wrapper \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m49 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33mautoencoder_kl.py\u001b[0m:\u001b[94m258\u001b[0m in \u001b[92mencode\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m255 \u001b[0m\u001b[2m│ │ │ \u001b[0mencoded_slices = [\u001b[96mself\u001b[0m.encoder(x_slice) \u001b[94mfor\u001b[0m x_slice \u001b[95min\u001b[0m x.split(\u001b[94m1\u001b[0m)] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m256 \u001b[0m\u001b[2m│ │ │ \u001b[0mh = torch.cat(encoded_slices) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m257 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m258 \u001b[2m│ │ │ \u001b[0mh = \u001b[96mself\u001b[0m.encoder(x) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m259 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m260 \u001b[0m\u001b[2m│ │ \u001b[0mmoments = \u001b[96mself\u001b[0m.quant_conv(h) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m261 \u001b[0m\u001b[2m│ │ \u001b[0mposterior = DiagonalGaussianDistribution(moments) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33mvae.py\u001b[0m:\u001b[94m141\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m138 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m139 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# down\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m140 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mfor\u001b[0m down_block \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.down_blocks: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m141 \u001b[2m│ │ │ │ \u001b[0msample = down_block(sample) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m142 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m143 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# middle\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m144 \u001b[0m\u001b[2m│ │ │ \u001b[0msample = \u001b[96mself\u001b[0m.mid_block(sample) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33munet_2d_blocks.py\u001b[0m:\u001b[94m1247\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1244 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1245 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mforward\u001b[0m(\u001b[96mself\u001b[0m, hidden_states, scale: \u001b[96mfloat\u001b[0m = \u001b[94m1.0\u001b[0m): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1246 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mfor\u001b[0m resnet \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.resnets: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1247 \u001b[2m│ │ │ \u001b[0mhidden_states = resnet(hidden_states, temb=\u001b[94mNone\u001b[0m, scale=scale) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1248 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1249 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.downsamplers \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1250 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mfor\u001b[0m downsampler \u001b[95min\u001b[0m \u001b[96mself\u001b[0m.downsamplers: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/fsx/proj-fmri/ckadirt/diffusers/src/diffusers/models/\u001b[0m\u001b[1;33mresnet.py\u001b[0m:\u001b[94m650\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m647 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mself\u001b[0m.time_embedding_norm == \u001b[33m\"\u001b[0m\u001b[33mada_group\u001b[0m\u001b[33m\"\u001b[0m \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m.time_embedding_norm == \u001b[33m\"\u001b[0m\u001b[33mspati\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m648 \u001b[0m\u001b[2m│ │ │ \u001b[0mhidden_states = \u001b[96mself\u001b[0m.norm2(hidden_states, temb) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m649 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94melse\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m650 \u001b[2m│ │ │ \u001b[0mhidden_states = \u001b[96mself\u001b[0m.norm2(hidden_states) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m651 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m652 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m temb \u001b[95mis\u001b[0m \u001b[95mnot\u001b[0m \u001b[94mNone\u001b[0m \u001b[95mand\u001b[0m \u001b[96mself\u001b[0m.time_embedding_norm == \u001b[33m\"\u001b[0m\u001b[33mscale_shift\u001b[0m\u001b[33m\"\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m653 \u001b[0m\u001b[2m│ │ │ \u001b[0mscale, shift = torch.chunk(temb, \u001b[94m2\u001b[0m, dim=\u001b[94m1\u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mmodule\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33m.py\u001b[0m:\u001b[94m1501\u001b[0m in \u001b[92m_call_impl\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1498 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[95mnot\u001b[0m (\u001b[96mself\u001b[0m._backward_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._backward_pre_hooks \u001b[95mor\u001b[0m \u001b[96mself\u001b[0m._forward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1499 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_backward_pre_hooks \u001b[95mor\u001b[0m _global_backward_hooks \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1500 \u001b[0m\u001b[2m│ │ │ │ \u001b[0m\u001b[95mor\u001b[0m _global_forward_hooks \u001b[95mor\u001b[0m _global_forward_pre_hooks): \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m1501 \u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m forward_call(*args, **kwargs) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1502 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[2m# Do not call functions when jit is used\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1503 \u001b[0m\u001b[2m│ │ \u001b[0mfull_backward_hooks, non_full_backward_hooks = [], [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m1504 \u001b[0m\u001b[2m│ │ \u001b[0mbackward_pre_hooks = [] \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/\u001b[0m\u001b[1;33mnormal\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[1;33mization.py\u001b[0m:\u001b[94m273\u001b[0m in \u001b[92mforward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m270 \u001b[0m\u001b[2m│ │ │ \u001b[0minit.zeros_(\u001b[96mself\u001b[0m.bias) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m271 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m272 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mforward\u001b[0m(\u001b[96mself\u001b[0m, \u001b[96minput\u001b[0m: Tensor) -> Tensor: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m273 \u001b[2m│ │ \u001b[0m\u001b[94mreturn\u001b[0m F.group_norm( \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m274 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96minput\u001b[0m, \u001b[96mself\u001b[0m.num_groups, \u001b[96mself\u001b[0m.weight, \u001b[96mself\u001b[0m.bias, \u001b[96mself\u001b[0m.eps) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m275 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m276 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mextra_repr\u001b[0m(\u001b[96mself\u001b[0m) -> \u001b[96mstr\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/\u001b[0m\u001b[1;33mfunctional.py\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[94m2530\u001b[0m in \u001b[92mgroup_norm\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2527 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mif\u001b[0m \u001b[96minput\u001b[0m.dim() < \u001b[94m2\u001b[0m: \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2528 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mRuntimeError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m\"\u001b[0m\u001b[33mExpected at least 2 dimensions for input tensor but receive\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2529 \u001b[0m\u001b[2m│ \u001b[0m_verify_batch_size([\u001b[96minput\u001b[0m.size(\u001b[94m0\u001b[0m) * \u001b[96minput\u001b[0m.size(\u001b[94m1\u001b[0m) // num_groups, num_groups] + \u001b[96mlist\u001b[0m( \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m2530 \u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m torch.group_norm(\u001b[96minput\u001b[0m, num_groups, weight, bias, eps, torch.backends.cudnn.e \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2531 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2532 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m2533 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92mlocal_response_norm\u001b[0m(\u001b[96minput\u001b[0m: Tensor, size: \u001b[96mint\u001b[0m, alpha: \u001b[96mfloat\u001b[0m = \u001b[94m1e-4\u001b[0m, beta: \u001b[96mfloat\u001b[0m = \u001b[94m0.7\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mOutOfMemoryError: \u001b[0mCUDA out of memory. Tried to allocate \u001b[1;36m3.06\u001b[0m GiB \u001b[1m(\u001b[0mGPU \u001b[1;36m0\u001b[0m; \u001b[1;36m39.56\u001b[0m GiB total capacity; \u001b[1;36m33.04\u001b[0m GiB \n", "already allocated; \u001b[1;36m752.56\u001b[0m MiB free; \u001b[1;36m37.34\u001b[0m GiB reserved in total by PyTorch\u001b[1m)\u001b[0m If reserved memory is >> allocated \n", "memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and \n", "PYTORCH_CUDA_ALLOC_CONF\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n", "progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))\n", "test_image, test_voxel = None, None\n", "mse = nn.MSELoss()\n", "for epoch in progress_bar:\n", " model.train()\n", " \n", " fwd_percent_correct = 0.\n", " bwd_percent_correct = 0.\n", " test_fwd_percent_correct = 0.\n", " test_bwd_percent_correct = 0.\n", "\n", " loss_clip_total = 0.\n", " loss_blurry_total = 0.\n", " test_loss_clip_total = 0.\n", " test_loss_blurry_total = 0.\n", "\n", " blurry_pixcorr = 0.\n", " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n", " \n", " for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " #if epoch == 0 or epoch == 1:\n", " # break\n", " with torch.cuda.amp.autocast():\n", " optimizer.zero_grad()\n", "\n", " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " \n", " image = images[behav[:,0,0].cpu().long()].to(device).float()\n", "\n", " past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279\n", " past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15\n", "\n", " blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n", " \n", " if use_image_aug: image = img_augment(image)\n", " \n", " clip_target = clip_model.embed_image(image)\n", " assert not torch.any(torch.isnan(clip_target))\n", " \n", " if epoch < int(mixup_pct * num_epochs):\n", " voxel, perm, betas, select = utils.mixco(voxel)\n", "\n", " # reshape past voxels to be (batch_size * 15, 15279)\n", " past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])\n", " past_15_times = past_15_times.repeat(voxel.shape[0], 1)\n", " past_15_times = past_15_times.reshape(-1)\n", " \n", " #print(past_15_voxels.shape, past_15_times.shape)\n", "\n", " embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)\n", " #print(embeds_past_voxels.shape)\n", " embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)\n", " #print(embeds_past_voxels.shape)\n", " information_past_voxels = model.memory_compressor(embeds_past_voxels)\n", "\n", "\n", " voxel_ridge = torch.cat([model.ridge(voxel), information_past_voxels], dim=-1)\n", " \n", " clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", "\n", " if epoch < int(mixup_pct * num_epochs): \n", " loss_clip = utils.mixco_nce(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006, \n", " perm=perm, betas=betas, select=select)\n", " else:\n", " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=epoch_temp)\n", "\n", " loss_blurry = mse(blurry_image_enc_, blurry_image_enc) \n", "\n", " loss_clip_total += loss_clip.item()\n", " loss_blurry_total += loss_blurry.item()\n", "\n", " loss = loss_blurry + loss_clip\n", " \n", " utils.check_loss(loss)\n", "\n", " accelerator.backward(loss)\n", " optimizer.step()\n", " \n", " losses.append(loss.item())\n", " lrs.append(optimizer.param_groups[0]['lr'])\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n", " bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n", "\n", " with torch.no_grad():\n", " # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()\n", " random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)\n", " blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)\n", " blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)\n", "\n", " if lr_scheduler_type is not None:\n", " lr_scheduler.step()\n", " \n", " model.eval()\n", " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", " print('test')\n", " with torch.cuda.amp.autocast():\n", " with torch.no_grad(): \n", " # all test samples should be loaded per batch such that test_i should never exceed 0\n", " if len(behav) != num_test: print(\"!\",len(behav),num_test)\n", "\n", " \n", " ## Average same-image repeats ##\n", " if test_image is None:\n", " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " \n", " image = behav[:,0,0].cpu().long()\n", " \n", " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n", " for im in unique_image:\n", " locs = torch.where(im == image)[0]\n", " if test_image is None:\n", " test_image = images[im][None]\n", " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n", " else:\n", " test_image = torch.vstack((test_image, images[im][None]))\n", " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n", " \n", " # sample of batch_size\n", " random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]\n", " voxel = test_voxel[random_indices].to(device)\n", " image = test_image[random_indices].to(device)\n", "\n", " current_past_behav = past_behav[random_indices]\n", "\n", " past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279\n", " past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15\n", "\n", " assert len(image) == batch_size\n", " \n", " blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n", " \n", " clip_target = clip_model.embed_image(image.float())\n", "\n", " past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])\n", " past_15_times = past_15_times.repeat(voxel.shape[0], 1)\n", " past_15_times = past_15_times.reshape(-1)\n", " \n", " print(past_15_voxels.shape, past_15_times.shape)\n", "\n", " embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)\n", " embeds_past_voxels = embeds_past_voxels.reshape(batch_size, 15, -1)\n", " information_past_voxels = model.memory_compressor(embeds_past_voxels)\n", "\n", " \n", " voxel_ridge = torch.cat([model.ridge(voxel), information_past_voxels], dim=-1)\n", " \n", " clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", " \n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006)\n", "\n", " loss_blurry = mse(blurry_image_enc_, blurry_image_enc)\n", " \n", " loss = loss_blurry + loss_clip\n", " \n", " utils.check_loss(loss)\n", " \n", " test_losses.append(loss.item())\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n", " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n", "\n", " # halving the batch size because the decoder is computationally heavy\n", " blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)\n", " blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))\n", " test_blurry_pixcorr += pixcorr(image, blurry_recon_images)\n", "\n", " # transform blurry recon latents to images and plot it\n", " fig, axes = plt.subplots(1, 4, figsize=(8, 4))\n", " axes[0].imshow(utils.torch_to_Image(image[[0]]))\n", " axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))\n", " axes[2].imshow(utils.torch_to_Image(image[[1]]))\n", " axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))\n", " axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')\n", " plt.show()\n", "\n", " if local_rank==0: \n", " # if utils.is_interactive(): clear_output(wait=True)\n", " assert (test_i+1) == 1\n", " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n", " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n", " \"train/lr\": lrs[-1],\n", " \"train/num_steps\": len(losses),\n", " \"test/num_steps\": len(test_losses),\n", " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n", " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n", " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n", " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n", " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n", " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n", " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n", " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n", " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n", " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n", " }\n", " progress_bar.set_postfix(**logs)\n", "\n", " # Save model checkpoint and reconstruct\n", " if epoch % ckpt_interval == 0:\n", " if not utils.is_interactive():\n", " save_ckpt(f'last')\n", " \n", " if wandb_log: wandb.log(logs)\n", "\n", " # wait for other GPUs to catch up if needed\n", " accelerator.wait_for_everyone()\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", "\n", "print(\"\\n===Finished!===\\n\")\n", "if ckpt_saving:\n", " save_ckpt(f'last')\n", "if not utils.is_interactive():\n", " sys.exit(0)" ] }, { "cell_type": "code", "execution_count": null, "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a", "metadata": { "tags": [] }, "outputs": [], "source": [ "plt.plot(losses)\n", "plt.show()\n", "plt.plot(test_losses)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": true }, "toc-autonumbering": true, "vscode": { "interpreter": { "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376" } } }, "nbformat": 4, "nbformat_minor": 5 }