ckadirt commited on
Commit
b8ea2b2
·
verified ·
1 Parent(s): 9a4be91

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +164 -0
  2. LICENSE +21 -0
  3. README.md +12 -0
  4. src/Train-with-memory-Copy1.ipynb +0 -0
  5. src/Train-with-memory-cat-trial.ipynb +0 -0
  6. src/Train-with-memory-cat.ipynb +0 -0
  7. src/Train-with-memory-cat.py +1056 -0
  8. src/Train-with-memory-rr-dropout.py +1040 -0
  9. src/Train-with-memory-rr-mlpmix.ipynb +0 -0
  10. src/Train-with-memory-rr.ipynb +0 -0
  11. src/Train-with-memory-rr.py +1018 -0
  12. src/Train-with-memory.ipynb +0 -0
  13. src/Train-with-memory.py +978 -0
  14. src/Train.ipynb +0 -0
  15. src/Train.py +761 -0
  16. src/Train_MLPMixer-Copy1.ipynb +0 -0
  17. src/Train_MLPMixer-Copy1.py +1352 -0
  18. src/Train_MLPMixer-Copy2.py +1275 -0
  19. src/Train_MLPMixer-img.ipynb +0 -0
  20. src/Train_MLPMixer-img.py +1444 -0
  21. src/Train_MLPMixer.ipynb +0 -0
  22. src/Train_MLPMixer.py +1275 -0
  23. src/Train_diffusion.ipynb +0 -0
  24. src/accel.slurm +38 -0
  25. src/accel2.slurm +40 -0
  26. src/accel3.slurm +40 -0
  27. src/accel4.slurm +40 -0
  28. src/accel5.slurm +40 -0
  29. src/accel6.slurm +40 -0
  30. src/accel7.slurm +41 -0
  31. src/accel8.slurm +41 -0
  32. src/accel9.slurm +44 -0
  33. src/blip2_captions.py +71 -0
  34. src/blip_tryal.ipynb +0 -0
  35. src/checking_models.ipynb +1526 -0
  36. src/deepspeed_config_stage1.json +1 -0
  37. src/deepspeed_config_stage2.json +1 -0
  38. src/deepspeed_config_stage2_cpuoffload.json +44 -0
  39. src/deepspeed_config_stage3.json +1 -0
  40. src/huggingface_to_s3.ipynb +422 -0
  41. src/models.py +210 -0
  42. src/setup.sh +15 -0
  43. src/train2-tryal.ipynb +2409 -0
  44. src/train2.ipynb +1856 -0
  45. src/train2.py +1141 -0
  46. src/utils.py +368 -0
  47. train_mem_logs/error.pth +3 -0
  48. train_mem_logs/error_tensors.pth +3 -0
  49. train_mem_logs/test/last.pth +3 -0
  50. train_mem_logs/test_mem/last.pth +3 -0
.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb/
2
+ train_logs/
3
+ slurms/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 MedARC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MindEyeV2
2
+
3
+ In-progress -- this repo is under active development in the MedARC discord server (feel free to join us and help develop MindEyeV2!)
4
+
5
+ 1. Download all of https://huggingface.co/datasets/pscotti/mindeyev2 and place them in a folder. You will need to specify the path to this folder as "data_path" variable.
6
+
7
+ 2. Run setup.sh to install a new "fmri" conda environment.
8
+
9
+ 3. Activate the conda environment with "conda activate fmri"
10
+
11
+ 4. Run Train.ipynb or Train.py (they are the same code)
12
+
src/Train-with-memory-Copy1.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train-with-memory-cat-trial.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train-with-memory-cat.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train-with-memory-cat.py ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[3]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ # from subprocess import call
9
+ # command = "jupyter nbconvert Train-with-memory-cat.ipynb --to python"
10
+ # call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[4]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import h5py
28
+ from tqdm import tqdm
29
+
30
+ import webdataset as wds
31
+ import gc
32
+
33
+ import matplotlib.pyplot as plt
34
+ import torch
35
+ import torch.nn as nn
36
+ from torchvision import transforms
37
+
38
+ from accelerate import Accelerator, DeepSpeedPlugin
39
+
40
+ # tf32 data type is faster than standard float32
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+
43
+ # custom functions #
44
+ import utils
45
+
46
+ global_batch_size = 128 #128
47
+
48
+
49
+ # In[5]:
50
+
51
+
52
+ ### Multi-GPU config ###
53
+ local_rank = os.getenv('RANK')
54
+ if local_rank is None:
55
+ local_rank = 0
56
+ else:
57
+ local_rank = int(local_rank)
58
+ print("LOCAL RANK ", local_rank)
59
+
60
+ num_devices = torch.cuda.device_count()
61
+ if num_devices==0: num_devices = 1
62
+
63
+ accelerator = Accelerator(split_batches=False)
64
+
65
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
66
+
67
+ # if num_devices <= 1 and utils.is_interactive():
68
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
69
+ # os.environ["MASTER_ADDR"] = "localhost"
70
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
71
+ # os.environ["RANK"] = "0"
72
+ # os.environ["LOCAL_RANK"] = "0"
73
+ # os.environ["WORLD_SIZE"] = "1"
74
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
75
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
76
+
77
+ # # alter the deepspeed config according to your global and local batch size
78
+ # if local_rank == 0:
79
+ # with open('deepspeed_config_stage2.json', 'r') as file:
80
+ # config = json.load(file)
81
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
82
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
83
+ # with open('deepspeed_config_stage2.json', 'w') as file:
84
+ # json.dump(config, file)
85
+ # else:
86
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
87
+ # time.sleep(10)
88
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
89
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
90
+
91
+
92
+ # In[6]:
93
+
94
+
95
+ print("PID of this process =",os.getpid())
96
+ device = accelerator.device
97
+ print("device:",device)
98
+ num_workers = num_devices
99
+ print(accelerator.state)
100
+ world_size = accelerator.state.num_processes
101
+ distributed = not accelerator.state.distributed_type == 'NO'
102
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
103
+ print = accelerator.print # only print if local_rank=0
104
+
105
+
106
+ # # Configurations
107
+
108
+ # In[7]:
109
+
110
+
111
+ # if running this interactively, can specify jupyter_args here for argparser to use
112
+ if utils.is_interactive():
113
+ # Example use
114
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
115
+ --model_name=test \
116
+ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
117
+ --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
118
+
119
+ jupyter_args = jupyter_args.split()
120
+ print(jupyter_args)
121
+
122
+ from IPython.display import clear_output # function to clear print outputs in cell
123
+ get_ipython().run_line_magic('load_ext', 'autoreload')
124
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
125
+ get_ipython().run_line_magic('autoreload', '2')
126
+
127
+
128
+ # In[8]:
129
+
130
+
131
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
132
+ parser.add_argument(
133
+ "--model_name", type=str, default="testing",
134
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
135
+ )
136
+ parser.add_argument(
137
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
138
+ help="Path to where NSD data is stored / where to download it to",
139
+ )
140
+ parser.add_argument(
141
+ "--subj",type=int, default=1, choices=[1,2,5,7],
142
+ )
143
+ parser.add_argument(
144
+ "--batch_size", type=int, default=32,
145
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
146
+ )
147
+ parser.add_argument(
148
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
149
+ help="whether to log to wandb",
150
+ )
151
+ parser.add_argument(
152
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
153
+ help="if not using wandb and want to resume from a ckpt",
154
+ )
155
+ parser.add_argument(
156
+ "--wandb_project",type=str,default="stability",
157
+ help="wandb project name",
158
+ )
159
+ parser.add_argument(
160
+ "--mixup_pct",type=float,default=.33,
161
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
162
+ )
163
+ parser.add_argument(
164
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
165
+ help="whether to use image augmentation",
166
+ )
167
+ parser.add_argument(
168
+ "--num_epochs",type=int,default=240,
169
+ help="number of epochs of training",
170
+ )
171
+ parser.add_argument(
172
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
173
+ )
174
+ parser.add_argument(
175
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
176
+ )
177
+ parser.add_argument(
178
+ "--ckpt_interval",type=int,default=5,
179
+ help="save backup ckpt and reconstruct every x epochs",
180
+ )
181
+ parser.add_argument(
182
+ "--seed",type=int,default=42,
183
+ )
184
+ parser.add_argument(
185
+ "--max_lr",type=float,default=3e-4,
186
+ )
187
+ parser.add_argument(
188
+ "--n_samples_save",type=int,default=0,choices=[0,1],
189
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
190
+ )
191
+
192
+ if utils.is_interactive():
193
+ args = parser.parse_args(jupyter_args)
194
+ else:
195
+ args = parser.parse_args()
196
+
197
+ # create global variables without the args prefix
198
+ for attribute_name in vars(args).keys():
199
+ globals()[attribute_name] = getattr(args, attribute_name)
200
+
201
+ print("global batch_size", batch_size)
202
+ batch_size = int(batch_size / num_devices)
203
+ print("batch_size", batch_size)
204
+
205
+
206
+ # In[9]:
207
+
208
+
209
+ outdir = os.path.abspath(f'../train_mem_logs/{model_name}')
210
+ if not os.path.exists(outdir):
211
+ os.makedirs(outdir,exist_ok=True)
212
+ if use_image_aug:
213
+ import kornia
214
+ from kornia.augmentation.container import AugmentationSequential
215
+ img_augment = AugmentationSequential(
216
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
217
+ kornia.augmentation.Resize((224, 224)),
218
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
219
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
220
+ kornia.augmentation.RandomGrayscale(p=0.3),
221
+ same_on_batch=False,
222
+ data_keys=["input"],
223
+ )
224
+
225
+
226
+ # # Prep data, models, and dataloaders
227
+
228
+ # ## Dataloader
229
+
230
+ # In[10]:
231
+
232
+
233
+ if subj==1:
234
+ num_train = 24958
235
+ num_test = 2770
236
+ test_batch_size = num_test
237
+
238
+ def my_split_by_node(urls): return urls
239
+
240
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
241
+ print(train_url)
242
+
243
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
244
+ .shuffle(750, initial=1500, rng=random.Random(42))\
245
+ .decode("torch")\
246
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
247
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
248
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
249
+
250
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
251
+ print(test_url)
252
+
253
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
254
+ .shuffle(750, initial=1500, rng=random.Random(42))\
255
+ .decode("torch")\
256
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
257
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
258
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
259
+
260
+
261
+ # ### check dataloaders are working
262
+
263
+ # In[9]:
264
+
265
+
266
+ # test_indices = []
267
+ # test_images = []
268
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
269
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
270
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
271
+ # test_indices = test_indices.astype(np.int16)
272
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
273
+ # print("---\n")
274
+
275
+ # train_indices = []
276
+ # train_images = []
277
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
278
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
279
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
280
+ # train_indices = train_indices.astype(np.int16)
281
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
282
+
283
+ # # train_images = np.hstack((train_images, test_images))
284
+ # # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")
285
+
286
+
287
+ # ## Load data and images
288
+
289
+ # In[ ]:
290
+
291
+
292
+ # load betas
293
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
294
+ voxels = f['betas'][:]
295
+ print(f"subj0{subj} betas loaded into memory")
296
+ voxels = torch.Tensor(voxels).to("cpu").half()
297
+ if subj==1:
298
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
299
+ print("voxels", voxels.shape)
300
+ num_voxels = voxels.shape[-1]
301
+
302
+ # load orig images
303
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
304
+ images = f['images'][:]
305
+ images = torch.Tensor(images).to("cpu").half()
306
+ print("images", images.shape)
307
+
308
+
309
+ # ## Load models
310
+
311
+ # ### CLIP image embeddings model
312
+
313
+ # In[ ]:
314
+
315
+
316
+ from models import Clipper
317
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
318
+
319
+ clip_seq_dim = 257
320
+ clip_emb_dim = 768
321
+ hidden_dim = 4096
322
+
323
+
324
+ # ### SD VAE (blurry images)
325
+
326
+ # In[ ]:
327
+
328
+
329
+ from diffusers import AutoencoderKL
330
+ autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
331
+ # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
332
+ autoenc.eval()
333
+ autoenc.requires_grad_(False)
334
+ autoenc.to(device)
335
+ utils.count_params(autoenc)
336
+
337
+
338
+ # ### MindEye modules
339
+
340
+ # In[13]:
341
+
342
+
343
+ class MindEyeModule(nn.Module):
344
+ def __init__(self):
345
+ super(MindEyeModule, self).__init__()
346
+ def forward(self, x):
347
+ return x
348
+
349
+ model = MindEyeModule()
350
+ model
351
+
352
+
353
+ # In[14]:
354
+
355
+
356
+ class RidgeRegression(torch.nn.Module):
357
+ # make sure to add weight_decay when initializing optimizer
358
+ def __init__(self, input_size, out_features):
359
+ super(RidgeRegression, self).__init__()
360
+ self.out_features = out_features
361
+ self.linear = torch.nn.Linear(input_size, out_features)
362
+ def forward(self, x):
363
+ return self.linear(x)
364
+
365
+ model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
366
+ utils.count_params(model.ridge)
367
+ utils.count_params(model)
368
+
369
+ b = torch.randn((2,1,voxels.shape[1]))
370
+ print(b.shape, model.ridge(b).shape)
371
+
372
+
373
+ # In[22]:
374
+
375
+
376
+ from functools import partial
377
+ from diffusers.models.vae import Decoder
378
+ class BrainNetwork(nn.Module):
379
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):
380
+ super().__init__()
381
+ self.blurry_dim = blurry_dim
382
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
383
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
384
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
385
+ self.lin0 = nn.Linear(in_dim, h)
386
+ self.mlp = nn.ModuleList([
387
+ nn.Sequential(
388
+ nn.Linear(h, h),
389
+ *[item() for item in act_and_norm],
390
+ nn.Dropout(drop)
391
+ ) for _ in range(n_blocks)
392
+ ])
393
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
394
+ self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
395
+ self.n_blocks = n_blocks
396
+ self.clip_size = clip_size
397
+ self.clip_proj = nn.Sequential(
398
+ nn.LayerNorm(clip_size),
399
+ nn.GELU(),
400
+ nn.Linear(clip_size, 2048),
401
+ nn.LayerNorm(2048),
402
+ nn.GELU(),
403
+ nn.Linear(2048, 2048),
404
+ nn.LayerNorm(2048),
405
+ nn.GELU(),
406
+ nn.Linear(2048, clip_size)
407
+ )
408
+ self.upsampler = Decoder(
409
+ in_channels=64,
410
+ out_channels=4,
411
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
412
+ block_out_channels=[64, 128, 256],
413
+ layers_per_block=1,
414
+ )
415
+
416
+ def forward(self, x):
417
+ x = self.lin0(x)
418
+ residual = x
419
+ for res_block in range(self.n_blocks):
420
+ x = self.mlp[res_block](x)
421
+ x += residual
422
+ residual = x
423
+ x = x.reshape(len(x), -1)
424
+ x = self.lin1(x)
425
+ b = self.blin1(x)
426
+ b = self.upsampler(b.reshape(len(b), -1, 7, 7))
427
+ c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
428
+ return c, b
429
+
430
+ model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
431
+ utils.count_params(model.backbone)
432
+ utils.count_params(model)
433
+
434
+ b = torch.randn((2,8192))
435
+ print(b.shape)
436
+ clip_, blur_ = model.backbone(b)
437
+ print(clip_.shape, blur_.shape)
438
+
439
+
440
+ # In[23]:
441
+
442
+
443
+ # memory model
444
+
445
+ from timm.layers.mlp import Mlp
446
+
447
+ class MemoryEncoder(nn.Module):
448
+ def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.15):
449
+ super().__init__()
450
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
451
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
452
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
453
+ self.out_dim = out_dim
454
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
455
+ self.final_input_dim = in_dim + embedding_time_dim
456
+ self.lin0 = nn.Linear(self.final_input_dim, h)
457
+ self.mlp = nn.ModuleList([
458
+ nn.Sequential(
459
+ nn.Linear(h, h),
460
+ *[item() for item in act_and_norm],
461
+ nn.Dropout(drop)
462
+ ) for _ in range(n_blocks)
463
+ ])
464
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
465
+ self.n_blocks = n_blocks
466
+ self.num_past_voxels = num_past_voxels
467
+ self.embedding_time_dim = embedding_time_dim
468
+ self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))
469
+
470
+
471
+ def forward(self, x, time):
472
+ time = time.long()
473
+ time = self.embedding_time(time)
474
+ x = torch.cat((x, time), dim=-1)
475
+ x = self.lin0(x)
476
+ residual = x
477
+ for res_block in range(self.n_blocks):
478
+ x = self.mlp[res_block](x)
479
+ x += residual
480
+ residual = x
481
+ x = x.reshape(len(x), -1)
482
+ x = self.lin1(x)
483
+ return x
484
+
485
+ # # test the memory encoder
486
+ # memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)
487
+
488
+ # device = torch.device("cpu")
489
+ # memory_encoder.to(device)
490
+
491
+ # # count params
492
+ # total_parameters = 0
493
+ # for parameter in memory_encoder.parameters():
494
+ # total_parameters += parameter.numel()
495
+
496
+ # rand_input = torch.randn((2, 15279)).to(device)
497
+ # rand_time = torch.randint(0, 15, (2,)).to(device)
498
+ # print(rand_input.shape, rand_time.shape)
499
+ # memory_encoder(rand_input, rand_time).shape
500
+
501
+ class MemoryCompressor(nn.Module):
502
+ def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15):
503
+ super().__init__()
504
+ self.num_past = num_past
505
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
506
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
507
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
508
+ self.final_input_dim = in_dim * num_past
509
+ self.lin0 = nn.Linear(self.final_input_dim, h)
510
+ self.mlp = nn.ModuleList([
511
+ nn.Sequential(
512
+ nn.Linear(h, h),
513
+ *[item() for item in act_and_norm],
514
+ nn.Dropout(drop)
515
+ ) for _ in range(n_blocks)
516
+ ])
517
+ self.lin1 = nn.Linear(h, output_dim, bias=True)
518
+ self.n_blocks = n_blocks
519
+ self.num_past = num_past
520
+ self.output_dim = output_dim
521
+
522
+ def forward(self, x):
523
+ # x is (batch_size, num_past, in_dim)
524
+ x = x.reshape(len(x), -1)
525
+ x = self.lin0(x)
526
+ residual = x
527
+ for res_block in range(self.n_blocks):
528
+ x = self.mlp[res_block](x)
529
+ x += residual
530
+ residual = x
531
+ x = x.reshape(len(x), -1)
532
+ x = self.lin1(x)
533
+ return x
534
+
535
+ # # test the memory compressor
536
+ # memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)
537
+
538
+ # device = torch.device("cpu")
539
+ # memory_compressor.to(device)
540
+
541
+ # # count params
542
+ # total_parameters = 0
543
+ # for parameter in memory_compressor.parameters():
544
+ # total_parameters += parameter.numel()
545
+
546
+ # rand_input = torch.randn((2, 15, 768)).to(device)
547
+ # print(rand_input.shape)
548
+ # memory_compressor(rand_input).shape
549
+
550
+ model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
551
+ model.memory_compressor = MemoryCompressor(in_dim=model.memory_encoder.out_dim, num_past=15, output_dim=4096)
552
+
553
+ utils.count_params(model.memory_encoder)
554
+ utils.count_params(model.memory_compressor)
555
+ utils.count_params(model)
556
+
557
+
558
+
559
+ # In[24]:
560
+
561
+
562
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
563
+ opt_grouped_parameters = [
564
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
565
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
566
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
567
+ {'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},
568
+ {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},
569
+ ]
570
+
571
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
572
+
573
+ if lr_scheduler_type == 'linear':
574
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
575
+ optimizer,
576
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
577
+ last_epoch=-1
578
+ )
579
+ elif lr_scheduler_type == 'cycle':
580
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
581
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
582
+ optimizer,
583
+ max_lr=max_lr,
584
+ total_steps=total_steps,
585
+ final_div_factor=1000,
586
+ last_epoch=-1, pct_start=2/num_epochs
587
+ )
588
+
589
+ def save_ckpt(tag):
590
+ ckpt_path = outdir+f'/{tag}.pth'
591
+ print(f'saving {ckpt_path}',flush=True)
592
+ unwrapped_model = accelerator.unwrap_model(model)
593
+ try:
594
+ torch.save({
595
+ 'epoch': epoch,
596
+ 'model_state_dict': unwrapped_model.state_dict(),
597
+ 'optimizer_state_dict': optimizer.state_dict(),
598
+ 'lr_scheduler': lr_scheduler.state_dict(),
599
+ 'train_losses': losses,
600
+ 'test_losses': test_losses,
601
+ 'lrs': lrs,
602
+ }, ckpt_path)
603
+ except:
604
+ print("Couldn't save... moving on to prevent crashing.")
605
+ del unwrapped_model
606
+
607
+ print("\nDone with model preparations!")
608
+ utils.count_params(model)
609
+
610
+
611
+ # In[18]:
612
+
613
+
614
+
615
+
616
+
617
+ # # Weights and Biases
618
+
619
+ # In[25]:
620
+
621
+
622
+ # params for wandb
623
+ wandb_log = True
624
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
625
+ import wandb
626
+
627
+ wandb_project = 'stability'
628
+ wandb_run = model_name
629
+ wandb_notes = ''
630
+
631
+ print(f"wandb {wandb_project} run {wandb_run}")
632
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
633
+ wandb_config = {
634
+ "model_name": model_name,
635
+ "batch_size": batch_size,
636
+ "num_epochs": num_epochs,
637
+ "use_image_aug": use_image_aug,
638
+ "max_lr": max_lr,
639
+ "lr_scheduler_type": lr_scheduler_type,
640
+ "mixup_pct": mixup_pct,
641
+ "num_train": num_train,
642
+ "num_test": num_test,
643
+ "seed": seed,
644
+ "distributed": distributed,
645
+ "num_devices": num_devices,
646
+ "world_size": world_size,
647
+ }
648
+ print("wandb_config:\n",wandb_config)
649
+ if False: # wandb_auto_resume
650
+ print("wandb_id:",model_name)
651
+ wandb.init(
652
+ id = model_name,
653
+ project=wandb_project,
654
+ name=wandb_run,
655
+ config=wandb_config,
656
+ notes=wandb_notes,
657
+ resume="allow",
658
+ )
659
+ else:
660
+ wandb.init(
661
+ project=wandb_project,
662
+ name=model_name,
663
+ config=wandb_config,
664
+ notes=wandb_notes,
665
+ )
666
+ else:
667
+ wandb_log = False
668
+
669
+
670
+ # # More custom functions
671
+
672
+ # In[26]:
673
+
674
+
675
+ # using the same preprocessing as was used in MindEye + BrainDiffuser
676
+ pixcorr_preprocess = transforms.Compose([
677
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
678
+ ])
679
+ def pixcorr(images,brains):
680
+ # Flatten images while keeping the batch dimension
681
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
682
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
683
+ corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
684
+ return corrmean
685
+
686
+
687
+ # # Main
688
+
689
+ # In[27]:
690
+
691
+
692
+ epoch = 0
693
+ losses, test_losses, lrs = [], [], []
694
+ best_test_loss = 1e9
695
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
696
+
697
+ # Optionally resume from checkpoint #
698
+ if resume_from_ckpt:
699
+ print("\n---resuming from last.pth ckpt---\n")
700
+ try:
701
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
702
+ except:
703
+ print('last.pth failed... trying last_backup.pth')
704
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
705
+ epoch = checkpoint['epoch']
706
+ print("Epoch",epoch)
707
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
708
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
709
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
710
+ del checkpoint
711
+ elif wandb_log:
712
+ if wandb.run.resumed:
713
+ print("\n---resuming from last.pth ckpt---\n")
714
+ try:
715
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
716
+ except:
717
+ print('last.pth failed... trying last_backup.pth')
718
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
719
+ epoch = checkpoint['epoch']
720
+ print("Epoch",epoch)
721
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
722
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
723
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
724
+ del checkpoint
725
+ torch.cuda.empty_cache()
726
+
727
+
728
+ # In[28]:
729
+
730
+
731
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
732
+ model, optimizer, train_dl, test_dl, lr_scheduler
733
+ )
734
+
735
+
736
+ # In[29]:
737
+ no_more = False
738
+
739
+
740
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
741
+ progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))
742
+ test_image, test_voxel = None, None
743
+ mse = nn.MSELoss()
744
+ for epoch in progress_bar:
745
+ model.train()
746
+
747
+ fwd_percent_correct = 0.
748
+ bwd_percent_correct = 0.
749
+ test_fwd_percent_correct = 0.
750
+ test_bwd_percent_correct = 0.
751
+
752
+ loss_clip_total = 0.
753
+ loss_blurry_total = 0.
754
+ test_loss_clip_total = 0.
755
+ test_loss_blurry_total = 0.
756
+
757
+ blurry_pixcorr = 0.
758
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
759
+
760
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
761
+ #if epoch == 0 or epoch == 1:
762
+ # break
763
+ with torch.cuda.amp.autocast():
764
+ optimizer.zero_grad()
765
+
766
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
767
+
768
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
769
+
770
+ past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
771
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
772
+
773
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
774
+
775
+ if use_image_aug: image = img_augment(image)
776
+
777
+ clip_target = clip_model.embed_image(image)
778
+ assert not torch.any(torch.isnan(clip_target))
779
+
780
+ if epoch < int(mixup_pct * num_epochs):
781
+ voxel, perm, betas, select = utils.mixco(voxel)
782
+
783
+ # reshape past voxels to be (batch_size * 15, 15279)
784
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
785
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
786
+ past_15_times = past_15_times.reshape(-1)
787
+
788
+ #print(past_15_voxels.shape, past_15_times.shape)
789
+
790
+ embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)
791
+ #print(embeds_past_voxels.shape)
792
+ embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
793
+ #print(embeds_past_voxels.shape)
794
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
795
+
796
+
797
+ voxel_ridge = torch.cat([model.ridge(voxel), information_past_voxels], dim=-1)
798
+
799
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
800
+
801
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
802
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
803
+
804
+ if epoch < int(mixup_pct * num_epochs):
805
+ loss_clip = utils.mixco_nce(
806
+ clip_voxels_norm,
807
+ clip_target_norm,
808
+ temp=.006,
809
+ perm=perm, betas=betas, select=select)
810
+ else:
811
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
812
+ loss_clip = utils.soft_clip_loss(
813
+ clip_voxels_norm,
814
+ clip_target_norm,
815
+ temp=epoch_temp)
816
+
817
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
818
+
819
+ loss_clip_total += loss_clip.item()
820
+ loss_blurry_total += loss_blurry.item()
821
+
822
+ loss = loss_blurry + loss_clip
823
+
824
+
825
+ print('voxel', voxel.shape, voxel)
826
+ print('image', image.shape, image)
827
+ print('past_15_voxels', past_15_voxels.shape)
828
+ print('past_15_times', past_15_times.shape)
829
+ print('embeds_past_voxels', embeds_past_voxels.shape)
830
+ print('information_past_voxels', information_past_voxels.shape)
831
+ print('voxel_ridge', voxel_ridge.shape)
832
+ print('clip_target', clip_target.shape)
833
+ print('clip_voxels', clip_voxels.shape)
834
+ print('clip_voxels_norm', clip_voxels_norm.shape)
835
+ print('clip_target_norm', clip_target_norm.shape)
836
+ print('loss_clip_total', loss_clip_total)
837
+ print('loss_blurry_total', loss_blurry_total)
838
+ print('loss_clip', loss_clip)
839
+ print('loss_blurry', loss_blurry)
840
+ print('train_i', train_i)
841
+ print('epoch', epoch)
842
+
843
+ e = utils.check_loss(loss)
844
+ if e == 'NaN loss':
845
+ no_more = True
846
+ print("saving ckpt")
847
+ save_ckpt(f'error')
848
+ # save all the tensors
849
+ torch.save({
850
+ 'voxel': voxel,
851
+ 'image': image,
852
+ 'past_15_voxels': past_15_voxels,
853
+ 'past_15_times': past_15_times,
854
+ 'embeds_past_voxels': embeds_past_voxels,
855
+ 'information_past_voxels': information_past_voxels,
856
+ 'voxel_ridge': voxel_ridge,
857
+ 'blurry_image_enc': blurry_image_enc,
858
+ 'clip_target': clip_target,
859
+ 'clip_voxels': clip_voxels,
860
+ 'blurry_image_enc_': blurry_image_enc_,
861
+ 'clip_voxels_norm': clip_voxels_norm,
862
+ 'clip_target_norm': clip_target_norm,
863
+ 'loss': loss,
864
+ 'loss_clip': loss_clip,
865
+ 'loss_blurry': loss_blurry,
866
+ 'loss_clip_total': loss_clip_total,
867
+ 'loss_blurry_total': loss_blurry_total,
868
+ 'train_i': train_i,
869
+ 'epoch': epoch,
870
+ 'model_state_dict': model.state_dict(),
871
+ }, outdir+f'/error_tensors.pth')
872
+
873
+ print("Error with loss here")
874
+ print('voxel', voxel.shape, voxel)
875
+ print('image', image.shape, image)
876
+ print('past_15_voxels', past_15_voxels.shape, past_15_voxels)
877
+ print('past_15_times', past_15_times.shape, past_15_times)
878
+ print('embeds_past_voxels', embeds_past_voxels.shape, embeds_past_voxels)
879
+ print('information_past_voxels', information_past_voxels.shape, information_past_voxels)
880
+ print('voxel_ridge', voxel_ridge.shape, voxel_ridge)
881
+ print('clip_target', clip_target.shape, clip_target)
882
+ print('clip_voxels', clip_voxels.shape, clip_voxels)
883
+ print('clip_voxels_norm', clip_voxels_norm.shape, clip_voxels_norm)
884
+ print('clip_target_norm', clip_target_norm.shape, clip_target_norm)
885
+ print('loss_clip_total', loss_clip_total)
886
+ print('loss_blurry_total', loss_blurry_total)
887
+ print('loss_clip', loss_clip)
888
+ print('loss_blurry', loss_blurry)
889
+ print('train_i', train_i)
890
+ print('epoch', epoch)
891
+
892
+
893
+
894
+
895
+
896
+ accelerator.backward(loss)
897
+ optimizer.step()
898
+
899
+ losses.append(loss.item())
900
+ lrs.append(optimizer.param_groups[0]['lr'])
901
+
902
+ # forward and backward top 1 accuracy
903
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
904
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
905
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
906
+
907
+ with torch.no_grad():
908
+ # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
909
+ random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)
910
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
911
+ blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
912
+
913
+ if lr_scheduler_type is not None:
914
+ lr_scheduler.step()
915
+
916
+ model.eval()
917
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
918
+ print('test')
919
+ with torch.cuda.amp.autocast():
920
+ with torch.no_grad():
921
+ # all test samples should be loaded per batch such that test_i should never exceed 0
922
+ if len(behav) != num_test: print("!",len(behav),num_test)
923
+
924
+
925
+ ## Average same-image repeats ##
926
+ if test_image is None:
927
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
928
+
929
+ image = behav[:,0,0].cpu().long()
930
+
931
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
932
+ for im in unique_image:
933
+ locs = torch.where(im == image)[0]
934
+ if test_image is None:
935
+ test_image = images[im][None]
936
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
937
+ else:
938
+ test_image = torch.vstack((test_image, images[im][None]))
939
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
940
+
941
+ # sample of batch_size
942
+ random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
943
+ voxel = test_voxel[random_indices].to(device)
944
+ image = test_image[random_indices].to(device)
945
+
946
+ current_past_behav = past_behav[random_indices]
947
+
948
+ past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
949
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
950
+
951
+ assert len(image) == batch_size
952
+
953
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
954
+
955
+ clip_target = clip_model.embed_image(image.float())
956
+
957
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
958
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
959
+ past_15_times = past_15_times.reshape(-1)
960
+
961
+ print(past_15_voxels.shape, past_15_times.shape)
962
+
963
+ embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)
964
+ embeds_past_voxels = embeds_past_voxels.reshape(batch_size, 15, -1)
965
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
966
+
967
+
968
+ voxel_ridge = torch.cat([model.ridge(voxel), information_past_voxels], dim=-1)
969
+
970
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
971
+
972
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
973
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
974
+
975
+ loss_clip = utils.soft_clip_loss(
976
+ clip_voxels_norm,
977
+ clip_target_norm,
978
+ temp=.006)
979
+
980
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
981
+
982
+ loss = loss_blurry + loss_clip
983
+
984
+ utils.check_loss(loss)
985
+
986
+ test_losses.append(loss.item())
987
+
988
+ # forward and backward top 1 accuracy
989
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
990
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
991
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
992
+
993
+ # halving the batch size because the decoder is computationally heavy
994
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
995
+ blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
996
+ test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
997
+
998
+ # transform blurry recon latents to images and plot it
999
+ fig, axes = plt.subplots(1, 4, figsize=(8, 4))
1000
+ axes[0].imshow(utils.torch_to_Image(image[[0]]))
1001
+ axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
1002
+ axes[2].imshow(utils.torch_to_Image(image[[1]]))
1003
+ axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
1004
+ axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
1005
+ plt.show()
1006
+
1007
+ if local_rank==0:
1008
+ # if utils.is_interactive(): clear_output(wait=True)
1009
+ assert (test_i+1) == 1
1010
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
1011
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
1012
+ "train/lr": lrs[-1],
1013
+ "train/num_steps": len(losses),
1014
+ "test/num_steps": len(test_losses),
1015
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1016
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1017
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1018
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1019
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1020
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1021
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1022
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1023
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1024
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1025
+ }
1026
+ progress_bar.set_postfix(**logs)
1027
+
1028
+ # Save model checkpoint and reconstruct
1029
+ if no_more:
1030
+ print("not writing more")
1031
+ if ((epoch % ckpt_interval == 0) and (not no_more)):
1032
+ if not utils.is_interactive():
1033
+ save_ckpt(f'last')
1034
+
1035
+ if wandb_log: wandb.log(logs)
1036
+
1037
+ # wait for other GPUs to catch up if needed
1038
+ accelerator.wait_for_everyone()
1039
+ torch.cuda.empty_cache()
1040
+ gc.collect()
1041
+
1042
+ print("\n===Finished!===\n")
1043
+ if ckpt_saving:
1044
+ save_ckpt(f'last')
1045
+ if not utils.is_interactive():
1046
+ sys.exit(0)
1047
+
1048
+
1049
+ # In[ ]:
1050
+
1051
+
1052
+ plt.plot(losses)
1053
+ plt.show()
1054
+ plt.plot(test_losses)
1055
+ plt.show()
1056
+
src/Train-with-memory-rr-dropout.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[ ]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ #from subprocess import call
9
+ #command = "jupyter nbconvert Train-with-memory-rr.ipynb --to python"
10
+ #call(command,shell=True)
11
+
12
+
13
+ # In[24]:
14
+
15
+
16
+ #get_ipython().system('nvidia-smi')
17
+
18
+
19
+ # # Import packages & functions
20
+
21
+ # In[3]:
22
+
23
+
24
+ import os
25
+ import sys
26
+ import json
27
+ import argparse
28
+ import numpy as np
29
+ import math
30
+ from einops import rearrange
31
+ import time
32
+ import random
33
+ import h5py
34
+ from tqdm import tqdm
35
+
36
+ import webdataset as wds
37
+ import gc
38
+
39
+ import matplotlib.pyplot as plt
40
+ import torch
41
+ import torch.nn as nn
42
+ from torchvision import transforms
43
+
44
+ from accelerate import Accelerator, DeepSpeedPlugin
45
+
46
+ # tf32 data type is faster than standard float32
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+
49
+ # custom functions #
50
+ import utils
51
+
52
+ global_batch_size = 512 #128
53
+
54
+
55
+ # In[4]:
56
+
57
+
58
+ ### Multi-GPU config ###
59
+ local_rank = os.getenv('RANK')
60
+ if local_rank is None:
61
+ local_rank = 0
62
+ else:
63
+ local_rank = int(local_rank)
64
+ print("LOCAL RANK ", local_rank)
65
+
66
+ num_devices = torch.cuda.device_count()
67
+ if num_devices==0: num_devices = 1
68
+
69
+ accelerator = Accelerator(split_batches=False)
70
+
71
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
72
+
73
+ # if num_devices <= 1 and utils.is_interactive():
74
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
75
+ # os.environ["MASTER_ADDR"] = "localhost"
76
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
77
+ # os.environ["RANK"] = "0"
78
+ # os.environ["LOCAL_RANK"] = "0"
79
+ # os.environ["WORLD_SIZE"] = "1"
80
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
81
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
82
+
83
+ # # alter the deepspeed config according to your global and local batch size
84
+ # if local_rank == 0:
85
+ # with open('deepspeed_config_stage2.json', 'r') as file:
86
+ # config = json.load(file)
87
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
88
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
89
+ # with open('deepspeed_config_stage2.json', 'w') as file:
90
+ # json.dump(config, file)
91
+ # else:
92
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
93
+ # time.sleep(10)
94
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
95
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
96
+
97
+
98
+ # In[5]:
99
+
100
+
101
+ print("PID of this process =",os.getpid())
102
+ device = accelerator.device
103
+ print("device:",device)
104
+ num_workers = num_devices
105
+ print(accelerator.state)
106
+ world_size = accelerator.state.num_processes
107
+ distributed = not accelerator.state.distributed_type == 'NO'
108
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
109
+ print = accelerator.print # only print if local_rank=0
110
+
111
+
112
+ # # Configurations
113
+
114
+ # In[6]:
115
+
116
+
117
+ # if running this interactively, can specify jupyter_args here for argparser to use
118
+ if utils.is_interactive():
119
+ # Example use
120
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
121
+ --model_name=test \
122
+ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
123
+ --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
124
+
125
+ jupyter_args = jupyter_args.split()
126
+ print(jupyter_args)
127
+
128
+ from IPython.display import clear_output # function to clear print outputs in cell
129
+ get_ipython().run_line_magic('load_ext', 'autoreload')
130
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
131
+ get_ipython().run_line_magic('autoreload', '2')
132
+
133
+
134
+ # In[7]:
135
+
136
+
137
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
138
+ parser.add_argument(
139
+ "--model_name", type=str, default="memory_cat_rr",
140
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
141
+ )
142
+ parser.add_argument(
143
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
144
+ help="Path to where NSD data is stored / where to download it to",
145
+ )
146
+ parser.add_argument(
147
+ "--subj",type=int, default=1, choices=[1,2,5,7],
148
+ )
149
+ parser.add_argument(
150
+ "--batch_size", type=int, default=32,
151
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
152
+ )
153
+ parser.add_argument(
154
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
155
+ help="whether to log to wandb",
156
+ )
157
+ parser.add_argument(
158
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
159
+ help="if not using wandb and want to resume from a ckpt",
160
+ )
161
+ parser.add_argument(
162
+ "--wandb_project",type=str,default="stability",
163
+ help="wandb project name",
164
+ )
165
+ parser.add_argument(
166
+ "--mixup_pct",type=float,default=.33,
167
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
168
+ )
169
+ parser.add_argument(
170
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
171
+ help="whether to use image augmentation",
172
+ )
173
+ parser.add_argument(
174
+ "--num_epochs",type=int,default=240,
175
+ help="number of epochs of training",
176
+ )
177
+ parser.add_argument(
178
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
179
+ )
180
+ parser.add_argument(
181
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
182
+ )
183
+ parser.add_argument(
184
+ "--ckpt_interval",type=int,default=5,
185
+ help="save backup ckpt and reconstruct every x epochs",
186
+ )
187
+ parser.add_argument(
188
+ "--seed",type=int,default=42,
189
+ )
190
+ parser.add_argument(
191
+ "--max_lr",type=float,default=3e-4,
192
+ )
193
+ parser.add_argument(
194
+ "--n_samples_save",type=int,default=0,choices=[0,1],
195
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
196
+ )
197
+
198
+ if utils.is_interactive():
199
+ args = parser.parse_args(jupyter_args)
200
+ else:
201
+ args = parser.parse_args()
202
+
203
+ # create global variables without the args prefix
204
+ for attribute_name in vars(args).keys():
205
+ globals()[attribute_name] = getattr(args, attribute_name)
206
+
207
+ print("global batch_size", batch_size)
208
+ batch_size = int(batch_size / num_devices)
209
+ print("batch_size", batch_size)
210
+
211
+
212
+ # In[8]:
213
+
214
+
215
+ outdir = os.path.abspath(f'../train_mem_logs/{model_name}')
216
+ if not os.path.exists(outdir):
217
+ os.makedirs(outdir,exist_ok=True)
218
+ if use_image_aug:
219
+ import kornia
220
+ from kornia.augmentation.container import AugmentationSequential
221
+ img_augment = AugmentationSequential(
222
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
223
+ kornia.augmentation.Resize((224, 224)),
224
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
225
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
226
+ kornia.augmentation.RandomGrayscale(p=0.3),
227
+ same_on_batch=False,
228
+ data_keys=["input"],
229
+ )
230
+
231
+
232
+ # # Prep data, models, and dataloaders
233
+
234
+ # ## Dataloader
235
+
236
+ # In[9]:
237
+
238
+
239
+ if subj==1:
240
+ num_train = 24958
241
+ num_test = 2770
242
+ test_batch_size = num_test
243
+
244
+ def my_split_by_node(urls): return urls
245
+
246
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
247
+ print(train_url)
248
+
249
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
250
+ .shuffle(750, initial=1500, rng=random.Random(42))\
251
+ .decode("torch")\
252
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
253
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
254
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
255
+
256
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
257
+ print(test_url)
258
+
259
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
260
+ .shuffle(750, initial=1500, rng=random.Random(42))\
261
+ .decode("torch")\
262
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
263
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
264
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
265
+
266
+
267
+ # ### check dataloaders are working
268
+
269
+ # In[10]:
270
+
271
+
272
+ # test_indices = []
273
+ # test_images = []
274
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
275
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
276
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
277
+ # test_indices = test_indices.astype(np.int16)
278
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
279
+ # print("---\n")
280
+
281
+ # train_indices = []
282
+ # train_images = []
283
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
284
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
285
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
286
+ # train_indices = train_indices.astype(np.int16)
287
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
288
+
289
+ # # train_images = np.hstack((train_images, test_images))
290
+ # # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")
291
+
292
+
293
+ # ## Load data and images
294
+
295
+ # In[11]:
296
+
297
+
298
+ # load betas
299
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
300
+ voxels = f['betas'][:]
301
+ print(f"subj0{subj} betas loaded into memory")
302
+ voxels = torch.Tensor(voxels).to("cpu").half()
303
+ if subj==1:
304
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
305
+ print("voxels", voxels.shape)
306
+ num_voxels = voxels.shape[-1]
307
+
308
+ # load orig images
309
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
310
+ images = f['images'][:]
311
+ images = torch.Tensor(images).to("cpu").half()
312
+ print("images", images.shape)
313
+
314
+
315
+ # ## Load models
316
+
317
+ # ### CLIP image embeddings model
318
+
319
+ # In[12]:
320
+
321
+
322
+ from models import Clipper
323
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
324
+
325
+ clip_seq_dim = 257
326
+ clip_emb_dim = 768
327
+ hidden_dim = 4096
328
+
329
+
330
+ # ### SD VAE (blurry images)
331
+
332
+ # In[13]:
333
+
334
+
335
+ from diffusers import AutoencoderKL
336
+ autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
337
+ # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
338
+ autoenc.eval()
339
+ autoenc.requires_grad_(False)
340
+ autoenc.to(device)
341
+ utils.count_params(autoenc)
342
+
343
+
344
+ # ### MindEye modules
345
+
346
+ # In[14]:
347
+
348
+
349
+ class MindEyeModule(nn.Module):
350
+ def __init__(self):
351
+ super(MindEyeModule, self).__init__()
352
+ def forward(self, x):
353
+ return x
354
+
355
+ model = MindEyeModule()
356
+ model
357
+
358
+
359
+ # In[15]:
360
+
361
+
362
+ time_embedding_dim = 512
363
+
364
+ class RidgeRegression(torch.nn.Module):
365
+ # make sure to add weight_decay when initializing optimizer
366
+ def __init__(self, input_size, out_features):
367
+ super(RidgeRegression, self).__init__()
368
+ self.out_features = out_features
369
+ self.linear = torch.nn.Linear(input_size, out_features)
370
+ def forward(self, x):
371
+ return self.linear(x)
372
+
373
+ model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
374
+ utils.count_params(model.ridge)
375
+ utils.count_params(model)
376
+
377
+ b = torch.randn((2,1,voxels.shape[1]))
378
+ time_emb_test = torch.randn((2,1,time_embedding_dim))
379
+ print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
380
+
381
+
382
+ # In[16]:
383
+
384
+
385
+ from functools import partial
386
+ from diffusers.models.vae import Decoder
387
+ class BrainNetwork(nn.Module):
388
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.75, blurry_dim=16):
389
+ super().__init__()
390
+ self.blurry_dim = blurry_dim
391
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
392
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
393
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
394
+ self.lin0 = nn.Linear(in_dim, h)
395
+ self.mlp = nn.ModuleList([
396
+ nn.Sequential(
397
+ nn.Linear(h, h),
398
+ *[item() for item in act_and_norm],
399
+ nn.Dropout(drop)
400
+ ) for _ in range(n_blocks)
401
+ ])
402
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
403
+ self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
404
+ self.n_blocks = n_blocks
405
+ self.clip_size = clip_size
406
+ self.clip_proj = nn.Sequential(
407
+ nn.LayerNorm(clip_size),
408
+ nn.GELU(),
409
+ nn.Linear(clip_size, 2048),
410
+ nn.LayerNorm(2048),
411
+ nn.GELU(),
412
+ nn.Linear(2048, 2048),
413
+ nn.LayerNorm(2048),
414
+ nn.GELU(),
415
+ nn.Linear(2048, clip_size)
416
+ )
417
+ self.upsampler = Decoder(
418
+ in_channels=64,
419
+ out_channels=4,
420
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
421
+ block_out_channels=[64, 128, 256],
422
+ layers_per_block=1,
423
+ )
424
+
425
+ def forward(self, x):
426
+ x = self.lin0(x)
427
+ residual = x
428
+ for res_block in range(self.n_blocks):
429
+ x = self.mlp[res_block](x)
430
+ x += residual
431
+ residual = x
432
+ x = x.reshape(len(x), -1)
433
+ x = self.lin1(x)
434
+ b = self.blin1(x)
435
+ b = self.upsampler(b.reshape(len(b), -1, 7, 7))
436
+ c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
437
+ return c, b
438
+
439
+ model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
440
+ utils.count_params(model.backbone)
441
+ utils.count_params(model)
442
+
443
+ b = torch.randn((2,8192))
444
+ print(b.shape)
445
+ clip_, blur_ = model.backbone(b)
446
+ print(clip_.shape, blur_.shape)
447
+
448
+
449
+ # In[17]:
450
+
451
+
452
+ # memory model
453
+
454
+ from timm.layers.mlp import Mlp
455
+
456
+
457
+
458
+ class MemoryDropout(nn.Module):
459
+ def __init__(self, p):
460
+ super(MemoryDropout, self).__init__()
461
+ self.p = p
462
+
463
+ def forward(self, x):
464
+ if self.training:
465
+ mask = torch.zeros(x.size(0), x.size(1)).bernoulli_(1 - self.p).unsqueeze(2).expand_as(x)
466
+ mask = mask.to(x.device)
467
+ x = x * mask / (1 - self.p)
468
+ return x
469
+
470
+ memory_dropout_percentage = 0.75
471
+ memory_dropout = MemoryDropout(memory_dropout_percentage)
472
+
473
+ class MemoryEncoder(nn.Module):
474
+ def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.85):
475
+ super().__init__()
476
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
477
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
478
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
479
+ self.out_dim = out_dim
480
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
481
+ self.final_input_dim = in_dim + embedding_time_dim
482
+ self.lin0 = nn.Linear(self.final_input_dim, h)
483
+ self.mlp = nn.ModuleList([
484
+ nn.Sequential(
485
+ nn.Linear(h, h),
486
+ *[item() for item in act_and_norm],
487
+ nn.Dropout(drop)
488
+ ) for _ in range(n_blocks)
489
+ ])
490
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
491
+ self.n_blocks = n_blocks
492
+ self.num_past_voxels = num_past_voxels
493
+ self.embedding_time_dim = embedding_time_dim
494
+ self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))
495
+
496
+
497
+ def forward(self, x, time):
498
+ time = time.long()
499
+ time = self.embedding_time(time)
500
+ x = torch.cat((x, time), dim=-1)
501
+ x = self.lin0(x)
502
+ residual = x
503
+ for res_block in range(self.n_blocks):
504
+ x = self.mlp[res_block](x)
505
+ x += residual
506
+ residual = x
507
+ x = x.reshape(len(x), -1)
508
+ x = self.lin1(x)
509
+ return x
510
+
511
+ # # test the memory encoder
512
+ # memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)
513
+
514
+ # device = torch.device("cpu")
515
+ # memory_encoder.to(device)
516
+
517
+ # # count params
518
+ # total_parameters = 0
519
+ # for parameter in memory_encoder.parameters():
520
+ # total_parameters += parameter.numel()
521
+
522
+ # rand_input = torch.randn((2, 15279)).to(device)
523
+ # rand_time = torch.randint(0, 15, (2,)).to(device)
524
+ # print(rand_input.shape, rand_time.shape)
525
+ # memory_encoder(rand_input, rand_time).shape
526
+
527
+ class MemoryCompressor(nn.Module):
528
+ def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.75):
529
+ super().__init__()
530
+ self.num_past = num_past
531
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
532
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
533
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
534
+ self.final_input_dim = in_dim * num_past
535
+ self.lin0 = nn.Linear(self.final_input_dim, h)
536
+ self.mlp = nn.ModuleList([
537
+ nn.Sequential(
538
+ nn.Linear(h, h),
539
+ *[item() for item in act_and_norm],
540
+ nn.Dropout(drop)
541
+ ) for _ in range(n_blocks)
542
+ ])
543
+ self.lin1 = nn.Linear(h, output_dim, bias=True)
544
+ self.n_blocks = n_blocks
545
+ self.num_past = num_past
546
+ self.output_dim = output_dim
547
+
548
+ def forward(self, x):
549
+ # x is (batch_size, num_past, in_dim)
550
+ x = x.reshape(len(x), -1)
551
+ x = self.lin0(x)
552
+ residual = x
553
+ for res_block in range(self.n_blocks):
554
+ x = self.mlp[res_block](x)
555
+ x += residual
556
+ residual = x
557
+ x = x.reshape(len(x), -1)
558
+ x = self.lin1(x)
559
+ return x
560
+
561
+ # # test the memory compressor
562
+ # memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)
563
+
564
+ # device = torch.device("cpu")
565
+ # memory_compressor.to(device)
566
+
567
+ # # count params
568
+ # total_parameters = 0
569
+ # for parameter in memory_compressor.parameters():
570
+ # total_parameters += parameter.numel()
571
+
572
+ # rand_input = torch.randn((2, 15, 768)).to(device)
573
+ # print(rand_input.shape)
574
+ # memory_compressor(rand_input).shape
575
+
576
+ class TimeEmbedding(nn.Module):
577
+ def __init__(self, embedding_time_dim=512, num_past_voxels=15):
578
+ super().__init__()
579
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
580
+ self.num_past_voxels = num_past_voxels
581
+ self.embedding_time_dim = embedding_time_dim
582
+
583
+ def forward(self, time):
584
+ # time is (batch_size,)
585
+ time = time.long()
586
+ time = self.embedding_time(time)
587
+ return time # (batch_size, embedding_time_dim)
588
+
589
+
590
+ #model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
591
+ model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
592
+ model.memory_compressor = MemoryCompressor(in_dim=model.ridge.out_features, num_past=15, output_dim=4096)
593
+
594
+ #utils.count_params(model.memory_encoder)
595
+ utils.count_params(model.memory_compressor)
596
+ utils.count_params(model)
597
+
598
+
599
+
600
+ # In[18]:
601
+
602
+
603
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
604
+ opt_grouped_parameters = [
605
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
606
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
607
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
608
+ #{'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},
609
+ {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},
610
+ {'params': [p for n, p in model.time_embedding.named_parameters()], 'weight_decay': 0.0},
611
+ ]
612
+
613
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
614
+
615
+ if lr_scheduler_type == 'linear':
616
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
617
+ optimizer,
618
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
619
+ last_epoch=-1
620
+ )
621
+ elif lr_scheduler_type == 'cycle':
622
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
623
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
624
+ optimizer,
625
+ max_lr=max_lr,
626
+ total_steps=total_steps,
627
+ final_div_factor=1000,
628
+ last_epoch=-1, pct_start=2/num_epochs
629
+ )
630
+
631
+ def save_ckpt(tag):
632
+ ckpt_path = outdir+f'/{tag}.pth'
633
+ print(f'saving {ckpt_path}',flush=True)
634
+ unwrapped_model = accelerator.unwrap_model(model)
635
+ try:
636
+ torch.save({
637
+ 'epoch': epoch,
638
+ 'model_state_dict': unwrapped_model.state_dict(),
639
+ 'optimizer_state_dict': optimizer.state_dict(),
640
+ 'lr_scheduler': lr_scheduler.state_dict(),
641
+ 'train_losses': losses,
642
+ 'test_losses': test_losses,
643
+ 'lrs': lrs,
644
+ }, ckpt_path)
645
+ except:
646
+ print("Couldn't save... moving on to prevent crashing.")
647
+ del unwrapped_model
648
+
649
+ print("\nDone with model preparations!")
650
+ utils.count_params(model)
651
+
652
+
653
+ # In[ ]:
654
+
655
+
656
+
657
+
658
+
659
+ # # Weights and Biases
660
+
661
+ # In[19]:
662
+
663
+
664
+ # params for wandb
665
+ wandb_log = True
666
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
667
+ import wandb
668
+
669
+ wandb_project = 'stability'
670
+ wandb_run = model_name
671
+ wandb_notes = ''
672
+
673
+ print(f"wandb {wandb_project} run {wandb_run}")
674
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
675
+ wandb_config = {
676
+ "model_name": model_name,
677
+ "batch_size": batch_size,
678
+ "num_epochs": num_epochs,
679
+ "use_image_aug": use_image_aug,
680
+ "max_lr": max_lr,
681
+ "lr_scheduler_type": lr_scheduler_type,
682
+ "mixup_pct": mixup_pct,
683
+ "num_train": num_train,
684
+ "num_test": num_test,
685
+ "seed": seed,
686
+ "distributed": distributed,
687
+ "num_devices": num_devices,
688
+ "world_size": world_size,
689
+ }
690
+ print("wandb_config:\n",wandb_config)
691
+ if False: # wandb_auto_resume
692
+ print("wandb_id:",model_name)
693
+ wandb.init(
694
+ id = model_name,
695
+ project=wandb_project,
696
+ name=wandb_run,
697
+ config=wandb_config,
698
+ notes=wandb_notes,
699
+ resume="allow",
700
+ )
701
+ else:
702
+ wandb.init(
703
+ project=wandb_project,
704
+ name=model_name,
705
+ config=wandb_config,
706
+ notes=wandb_notes,
707
+ )
708
+ else:
709
+ wandb_log = False
710
+
711
+
712
+ # # More custom functions
713
+
714
+ # In[20]:
715
+
716
+
717
+ # using the same preprocessing as was used in MindEye + BrainDiffuser
718
+ pixcorr_preprocess = transforms.Compose([
719
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
720
+ ])
721
+ def pixcorr(images,brains):
722
+ # Flatten images while keeping the batch dimension
723
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
724
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
725
+ corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
726
+ return corrmean
727
+
728
+
729
+ # # Main
730
+
731
+ # In[21]:
732
+
733
+
734
+ epoch = 0
735
+ losses, test_losses, lrs = [], [], []
736
+ best_test_loss = 1e9
737
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
738
+
739
+ # Optionally resume from checkpoint #
740
+ if resume_from_ckpt:
741
+ print("\n---resuming from last.pth ckpt---\n")
742
+ try:
743
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
744
+ except:
745
+ print('last.pth failed... trying last_backup.pth')
746
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
747
+ epoch = checkpoint['epoch']
748
+ print("Epoch",epoch)
749
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
750
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
751
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
752
+ del checkpoint
753
+ elif wandb_log:
754
+ if wandb.run.resumed:
755
+ print("\n---resuming from last.pth ckpt---\n")
756
+ try:
757
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
758
+ except:
759
+ print('last.pth failed... trying last_backup.pth')
760
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
761
+ epoch = checkpoint['epoch']
762
+ print("Epoch",epoch)
763
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
764
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
765
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
766
+ del checkpoint
767
+ torch.cuda.empty_cache()
768
+
769
+
770
+ # In[22]:
771
+
772
+
773
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
774
+ model, optimizer, train_dl, test_dl, lr_scheduler
775
+ )
776
+
777
+
778
+ # In[23]:
779
+
780
+
781
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
782
+ progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))
783
+ test_image, test_voxel = None, None
784
+ mse = nn.MSELoss()
785
+ for epoch in progress_bar:
786
+ model.train()
787
+
788
+ fwd_percent_correct = 0.
789
+ bwd_percent_correct = 0.
790
+ test_fwd_percent_correct = 0.
791
+ test_bwd_percent_correct = 0.
792
+
793
+ loss_clip_total = 0.
794
+ loss_blurry_total = 0.
795
+ test_loss_clip_total = 0.
796
+ test_loss_blurry_total = 0.
797
+
798
+ blurry_pixcorr = 0.
799
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
800
+
801
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
802
+ #if epoch == 0 or epoch == 1:
803
+ # break
804
+ with torch.cuda.amp.autocast():
805
+ optimizer.zero_grad()
806
+
807
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
808
+
809
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
810
+
811
+ past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
812
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
813
+
814
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
815
+
816
+ if use_image_aug: image = img_augment(image)
817
+
818
+ clip_target = clip_model.embed_image(image)
819
+ assert not torch.any(torch.isnan(clip_target))
820
+
821
+ if epoch < int(mixup_pct * num_epochs):
822
+ voxel, perm, betas, select = utils.mixco(voxel)
823
+
824
+ # reshape past voxels to be (batch_size * 15, 15279)
825
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
826
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
827
+ past_15_times = past_15_times.reshape(-1)
828
+
829
+ #print(past_15_voxels.shape, past_15_times.shape)
830
+ time_embeddings = model.time_embedding(past_15_times)
831
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
832
+ embeds_past_voxels = model.ridge(past_info_full)
833
+ #print(embeds_past_voxels.shape)
834
+ embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
835
+ embeds_past_voxels = memory_dropout(embeds_past_voxels)
836
+ #print(embeds_past_voxels.shape)
837
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
838
+
839
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
840
+ #print(torch.cat((voxel, positional_current_voxel), dim=-1).shape, positional_current_voxel.shape, voxel.shape)
841
+ voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1)
842
+
843
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
844
+
845
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
846
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
847
+
848
+ if epoch < int(mixup_pct * num_epochs):
849
+ loss_clip = utils.mixco_nce(
850
+ clip_voxels_norm,
851
+ clip_target_norm,
852
+ temp=.006,
853
+ perm=perm, betas=betas, select=select)
854
+ else:
855
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
856
+ loss_clip = utils.soft_clip_loss(
857
+ clip_voxels_norm,
858
+ clip_target_norm,
859
+ temp=epoch_temp)
860
+
861
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
862
+
863
+ loss_clip_total += loss_clip.item()
864
+ loss_blurry_total += loss_blurry.item()
865
+
866
+ loss = loss_blurry + loss_clip
867
+
868
+ utils.check_loss(loss)
869
+
870
+ accelerator.backward(loss)
871
+ optimizer.step()
872
+
873
+ losses.append(loss.item())
874
+ lrs.append(optimizer.param_groups[0]['lr'])
875
+
876
+ # forward and backward top 1 accuracy
877
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
878
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
879
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
880
+
881
+ with torch.no_grad():
882
+ # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
883
+ random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)
884
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
885
+ blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
886
+
887
+ if lr_scheduler_type is not None:
888
+ lr_scheduler.step()
889
+
890
+ model.eval()
891
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
892
+ print('test')
893
+ with torch.cuda.amp.autocast():
894
+ with torch.no_grad():
895
+ # all test samples should be loaded per batch such that test_i should never exceed 0
896
+ if len(behav) != num_test: print("!",len(behav),num_test)
897
+
898
+
899
+ ## Average same-image repeats ##
900
+ if test_image is None:
901
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
902
+
903
+ image = behav[:,0,0].cpu().long()
904
+
905
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
906
+ for im in unique_image:
907
+ locs = torch.where(im == image)[0]
908
+ if test_image is None:
909
+ test_image = images[im][None]
910
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
911
+ else:
912
+ test_image = torch.vstack((test_image, images[im][None]))
913
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
914
+
915
+ # sample of batch_size
916
+ random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
917
+ voxel = test_voxel[random_indices].to(device)
918
+ image = test_image[random_indices].to(device)
919
+
920
+ current_past_behav = past_behav[random_indices]
921
+
922
+ past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
923
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
924
+
925
+ assert len(image) == batch_size
926
+
927
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
928
+
929
+ clip_target = clip_model.embed_image(image.float())
930
+
931
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
932
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
933
+ past_15_times = past_15_times.reshape(-1)
934
+
935
+ print(past_15_voxels.shape, past_15_times.shape)
936
+
937
+ #print(past_15_voxels.shape, past_15_times.shape)
938
+ time_embeddings = model.time_embedding(past_15_times)
939
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
940
+ embeds_past_voxels = model.ridge(past_info_full)
941
+ #print(embeds_past_voxels.shape)
942
+ embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
943
+ #print(embeds_past_voxels.shape)
944
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
945
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
946
+
947
+ voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1)
948
+
949
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
950
+
951
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
952
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
953
+
954
+ loss_clip = utils.soft_clip_loss(
955
+ clip_voxels_norm,
956
+ clip_target_norm,
957
+ temp=.006)
958
+
959
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
960
+
961
+ loss = loss_blurry + loss_clip
962
+
963
+ utils.check_loss(loss)
964
+
965
+ test_losses.append(loss.item())
966
+
967
+ # forward and backward top 1 accuracy
968
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
969
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
970
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
971
+
972
+ # halving the batch size because the decoder is computationally heavy
973
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
974
+ blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
975
+ test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
976
+
977
+ # transform blurry recon latents to images and plot it
978
+ #fig, axes = plt.subplots(1, 4, figsize=(8, 4))
979
+ #axes[0].imshow(utils.torch_to_Image(image[[0]]))
980
+ #axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
981
+ #axes[2].imshow(utils.torch_to_Image(image[[1]]))
982
+ #axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
983
+ #axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
984
+ #plt.show()
985
+
986
+
987
+ if local_rank==0:
988
+ # if utils.is_interactive(): clear_output(wait=True)
989
+ with torch.cuda.amp.autocast():
990
+ with torch.no_grad():
991
+ wandb.log({"gt": [wandb.Image(utils.torch_to_Image(image[[0]])), wandb.Image(utils.torch_to_Image(image[[1]])) ]})
992
+ wandb.log({"preds": [wandb.Image(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1))), wandb.Image(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1))) ]})
993
+
994
+ assert (test_i+1) == 1
995
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
996
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
997
+ "train/lr": lrs[-1],
998
+ "train/num_steps": len(losses),
999
+ "test/num_steps": len(test_losses),
1000
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1001
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1002
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1003
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1004
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1005
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1006
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1007
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1008
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1009
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1010
+ }
1011
+ progress_bar.set_postfix(**logs)
1012
+
1013
+ # Save model checkpoint and reconstruct
1014
+ if epoch % ckpt_interval == 0:
1015
+ if not utils.is_interactive():
1016
+ save_ckpt(f'last')
1017
+
1018
+ if wandb_log: wandb.log(logs)
1019
+
1020
+ # wait for other GPUs to catch up if needed
1021
+ accelerator.wait_for_everyone()
1022
+ torch.cuda.empty_cache()
1023
+ gc.collect()
1024
+
1025
+ print("\n===Finished!===\n")
1026
+ if ckpt_saving:
1027
+ save_ckpt(f'last')
1028
+ if not utils.is_interactive():
1029
+ sys.exit(0)
1030
+
1031
+
1032
+
1033
+ # In[ ]:
1034
+
1035
+
1036
+ plt.plot(losses)
1037
+ plt.show()
1038
+ plt.plot(test_losses)
1039
+ plt.show()
1040
+
src/Train-with-memory-rr-mlpmix.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train-with-memory-rr.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train-with-memory-rr.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[ ]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ #from subprocess import call
9
+ #command = "jupyter nbconvert Train-with-memory-rr.ipynb --to python"
10
+ #call(command,shell=True)
11
+
12
+
13
+ # In[24]:
14
+
15
+
16
+ #get_ipython().system('nvidia-smi')
17
+
18
+
19
+ # # Import packages & functions
20
+
21
+ # In[3]:
22
+
23
+
24
+ import os
25
+ import sys
26
+ import json
27
+ import argparse
28
+ import numpy as np
29
+ import math
30
+ from einops import rearrange
31
+ import time
32
+ import random
33
+ import h5py
34
+ from tqdm import tqdm
35
+
36
+ import webdataset as wds
37
+ import gc
38
+
39
+ import matplotlib.pyplot as plt
40
+ import torch
41
+ import torch.nn as nn
42
+ from torchvision import transforms
43
+
44
+ from accelerate import Accelerator, DeepSpeedPlugin
45
+
46
+ # tf32 data type is faster than standard float32
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+
49
+ # custom functions #
50
+ import utils
51
+
52
+ global_batch_size = 512 #128
53
+
54
+
55
+ # In[4]:
56
+
57
+
58
+ ### Multi-GPU config ###
59
+ local_rank = os.getenv('RANK')
60
+ if local_rank is None:
61
+ local_rank = 0
62
+ else:
63
+ local_rank = int(local_rank)
64
+ print("LOCAL RANK ", local_rank)
65
+
66
+ num_devices = torch.cuda.device_count()
67
+ if num_devices==0: num_devices = 1
68
+
69
+ accelerator = Accelerator(split_batches=False)
70
+
71
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
72
+
73
+ # if num_devices <= 1 and utils.is_interactive():
74
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
75
+ # os.environ["MASTER_ADDR"] = "localhost"
76
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
77
+ # os.environ["RANK"] = "0"
78
+ # os.environ["LOCAL_RANK"] = "0"
79
+ # os.environ["WORLD_SIZE"] = "1"
80
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
81
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
82
+
83
+ # # alter the deepspeed config according to your global and local batch size
84
+ # if local_rank == 0:
85
+ # with open('deepspeed_config_stage2.json', 'r') as file:
86
+ # config = json.load(file)
87
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
88
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
89
+ # with open('deepspeed_config_stage2.json', 'w') as file:
90
+ # json.dump(config, file)
91
+ # else:
92
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
93
+ # time.sleep(10)
94
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
95
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
96
+
97
+
98
+ # In[5]:
99
+
100
+
101
+ print("PID of this process =",os.getpid())
102
+ device = accelerator.device
103
+ print("device:",device)
104
+ num_workers = num_devices
105
+ print(accelerator.state)
106
+ world_size = accelerator.state.num_processes
107
+ distributed = not accelerator.state.distributed_type == 'NO'
108
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
109
+ print = accelerator.print # only print if local_rank=0
110
+
111
+
112
+ # # Configurations
113
+
114
+ # In[6]:
115
+
116
+
117
+ # if running this interactively, can specify jupyter_args here for argparser to use
118
+ if utils.is_interactive():
119
+ # Example use
120
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
121
+ --model_name=test \
122
+ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
123
+ --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
124
+
125
+ jupyter_args = jupyter_args.split()
126
+ print(jupyter_args)
127
+
128
+ from IPython.display import clear_output # function to clear print outputs in cell
129
+ get_ipython().run_line_magic('load_ext', 'autoreload')
130
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
131
+ get_ipython().run_line_magic('autoreload', '2')
132
+
133
+
134
+ # In[7]:
135
+
136
+
137
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
138
+ parser.add_argument(
139
+ "--model_name", type=str, default="memory_cat_rr",
140
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
141
+ )
142
+ parser.add_argument(
143
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
144
+ help="Path to where NSD data is stored / where to download it to",
145
+ )
146
+ parser.add_argument(
147
+ "--subj",type=int, default=1, choices=[1,2,5,7],
148
+ )
149
+ parser.add_argument(
150
+ "--batch_size", type=int, default=32,
151
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
152
+ )
153
+ parser.add_argument(
154
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
155
+ help="whether to log to wandb",
156
+ )
157
+ parser.add_argument(
158
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
159
+ help="if not using wandb and want to resume from a ckpt",
160
+ )
161
+ parser.add_argument(
162
+ "--wandb_project",type=str,default="stability",
163
+ help="wandb project name",
164
+ )
165
+ parser.add_argument(
166
+ "--mixup_pct",type=float,default=.33,
167
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
168
+ )
169
+ parser.add_argument(
170
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
171
+ help="whether to use image augmentation",
172
+ )
173
+ parser.add_argument(
174
+ "--num_epochs",type=int,default=240,
175
+ help="number of epochs of training",
176
+ )
177
+ parser.add_argument(
178
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
179
+ )
180
+ parser.add_argument(
181
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
182
+ )
183
+ parser.add_argument(
184
+ "--ckpt_interval",type=int,default=5,
185
+ help="save backup ckpt and reconstruct every x epochs",
186
+ )
187
+ parser.add_argument(
188
+ "--seed",type=int,default=42,
189
+ )
190
+ parser.add_argument(
191
+ "--max_lr",type=float,default=3e-4,
192
+ )
193
+ parser.add_argument(
194
+ "--n_samples_save",type=int,default=0,choices=[0,1],
195
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
196
+ )
197
+
198
+ if utils.is_interactive():
199
+ args = parser.parse_args(jupyter_args)
200
+ else:
201
+ args = parser.parse_args()
202
+
203
+ # create global variables without the args prefix
204
+ for attribute_name in vars(args).keys():
205
+ globals()[attribute_name] = getattr(args, attribute_name)
206
+
207
+ print("global batch_size", batch_size)
208
+ batch_size = int(batch_size / num_devices)
209
+ print("batch_size", batch_size)
210
+
211
+
212
+ # In[8]:
213
+
214
+
215
+ outdir = os.path.abspath(f'../train_mem_logs/{model_name}')
216
+ if not os.path.exists(outdir):
217
+ os.makedirs(outdir,exist_ok=True)
218
+ if use_image_aug:
219
+ import kornia
220
+ from kornia.augmentation.container import AugmentationSequential
221
+ img_augment = AugmentationSequential(
222
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
223
+ kornia.augmentation.Resize((224, 224)),
224
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
225
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
226
+ kornia.augmentation.RandomGrayscale(p=0.3),
227
+ same_on_batch=False,
228
+ data_keys=["input"],
229
+ )
230
+
231
+
232
+ # # Prep data, models, and dataloaders
233
+
234
+ # ## Dataloader
235
+
236
+ # In[9]:
237
+
238
+
239
+ if subj==1:
240
+ num_train = 24958
241
+ num_test = 2770
242
+ test_batch_size = num_test
243
+
244
+ def my_split_by_node(urls): return urls
245
+
246
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
247
+ print(train_url)
248
+
249
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
250
+ .shuffle(750, initial=1500, rng=random.Random(42))\
251
+ .decode("torch")\
252
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
253
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
254
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
255
+
256
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
257
+ print(test_url)
258
+
259
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
260
+ .shuffle(750, initial=1500, rng=random.Random(42))\
261
+ .decode("torch")\
262
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
263
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
264
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
265
+
266
+
267
+ # ### check dataloaders are working
268
+
269
+ # In[10]:
270
+
271
+
272
+ # test_indices = []
273
+ # test_images = []
274
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
275
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
276
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
277
+ # test_indices = test_indices.astype(np.int16)
278
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
279
+ # print("---\n")
280
+
281
+ # train_indices = []
282
+ # train_images = []
283
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
284
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
285
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
286
+ # train_indices = train_indices.astype(np.int16)
287
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
288
+
289
+ # # train_images = np.hstack((train_images, test_images))
290
+ # # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")
291
+
292
+
293
+ # ## Load data and images
294
+
295
+ # In[11]:
296
+
297
+
298
+ # load betas
299
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
300
+ voxels = f['betas'][:]
301
+ print(f"subj0{subj} betas loaded into memory")
302
+ voxels = torch.Tensor(voxels).to("cpu").half()
303
+ if subj==1:
304
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
305
+ print("voxels", voxels.shape)
306
+ num_voxels = voxels.shape[-1]
307
+
308
+ # load orig images
309
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
310
+ images = f['images'][:]
311
+ images = torch.Tensor(images).to("cpu").half()
312
+ print("images", images.shape)
313
+
314
+
315
+ # ## Load models
316
+
317
+ # ### CLIP image embeddings model
318
+
319
+ # In[12]:
320
+
321
+
322
+ from models import Clipper
323
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
324
+
325
+ clip_seq_dim = 257
326
+ clip_emb_dim = 768
327
+ hidden_dim = 4096
328
+
329
+
330
+ # ### SD VAE (blurry images)
331
+
332
+ # In[13]:
333
+
334
+
335
+ from diffusers import AutoencoderKL
336
+ autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
337
+ # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
338
+ autoenc.eval()
339
+ autoenc.requires_grad_(False)
340
+ autoenc.to(device)
341
+ utils.count_params(autoenc)
342
+
343
+
344
+ # ### MindEye modules
345
+
346
+ # In[14]:
347
+
348
+
349
+ class MindEyeModule(nn.Module):
350
+ def __init__(self):
351
+ super(MindEyeModule, self).__init__()
352
+ def forward(self, x):
353
+ return x
354
+
355
+ model = MindEyeModule()
356
+ model
357
+
358
+
359
+ # In[15]:
360
+
361
+
362
+ time_embedding_dim = 512
363
+
364
+ class RidgeRegression(torch.nn.Module):
365
+ # make sure to add weight_decay when initializing optimizer
366
+ def __init__(self, input_size, out_features):
367
+ super(RidgeRegression, self).__init__()
368
+ self.out_features = out_features
369
+ self.linear = torch.nn.Linear(input_size, out_features)
370
+ def forward(self, x):
371
+ return self.linear(x)
372
+
373
+ model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
374
+ utils.count_params(model.ridge)
375
+ utils.count_params(model)
376
+
377
+ b = torch.randn((2,1,voxels.shape[1]))
378
+ time_emb_test = torch.randn((2,1,time_embedding_dim))
379
+ print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
380
+
381
+
382
+ # In[16]:
383
+
384
+
385
+ from functools import partial
386
+ from diffusers.models.vae import Decoder
387
+ class BrainNetwork(nn.Module):
388
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.25, blurry_dim=16):
389
+ super().__init__()
390
+ self.blurry_dim = blurry_dim
391
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
392
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
393
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
394
+ self.lin0 = nn.Linear(in_dim, h)
395
+ self.mlp = nn.ModuleList([
396
+ nn.Sequential(
397
+ nn.Linear(h, h),
398
+ *[item() for item in act_and_norm],
399
+ nn.Dropout(drop)
400
+ ) for _ in range(n_blocks)
401
+ ])
402
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
403
+ self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
404
+ self.n_blocks = n_blocks
405
+ self.clip_size = clip_size
406
+ self.clip_proj = nn.Sequential(
407
+ nn.LayerNorm(clip_size),
408
+ nn.GELU(),
409
+ nn.Linear(clip_size, 2048),
410
+ nn.LayerNorm(2048),
411
+ nn.GELU(),
412
+ nn.Linear(2048, 2048),
413
+ nn.LayerNorm(2048),
414
+ nn.GELU(),
415
+ nn.Linear(2048, clip_size)
416
+ )
417
+ self.upsampler = Decoder(
418
+ in_channels=64,
419
+ out_channels=4,
420
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
421
+ block_out_channels=[64, 128, 256],
422
+ layers_per_block=1,
423
+ )
424
+
425
+ def forward(self, x):
426
+ x = self.lin0(x)
427
+ residual = x
428
+ for res_block in range(self.n_blocks):
429
+ x = self.mlp[res_block](x)
430
+ x += residual
431
+ residual = x
432
+ x = x.reshape(len(x), -1)
433
+ x = self.lin1(x)
434
+ b = self.blin1(x)
435
+ b = self.upsampler(b.reshape(len(b), -1, 7, 7))
436
+ c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
437
+ return c, b
438
+
439
+ model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim*2, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
440
+ utils.count_params(model.backbone)
441
+ utils.count_params(model)
442
+
443
+ b = torch.randn((2,8192))
444
+ print(b.shape)
445
+ clip_, blur_ = model.backbone(b)
446
+ print(clip_.shape, blur_.shape)
447
+
448
+
449
+ # In[17]:
450
+
451
+
452
+ # memory model
453
+
454
+ from timm.layers.mlp import Mlp
455
+
456
+ class MemoryEncoder(nn.Module):
457
+ def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.25):
458
+ super().__init__()
459
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
460
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
461
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
462
+ self.out_dim = out_dim
463
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
464
+ self.final_input_dim = in_dim + embedding_time_dim
465
+ self.lin0 = nn.Linear(self.final_input_dim, h)
466
+ self.mlp = nn.ModuleList([
467
+ nn.Sequential(
468
+ nn.Linear(h, h),
469
+ *[item() for item in act_and_norm],
470
+ nn.Dropout(drop)
471
+ ) for _ in range(n_blocks)
472
+ ])
473
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
474
+ self.n_blocks = n_blocks
475
+ self.num_past_voxels = num_past_voxels
476
+ self.embedding_time_dim = embedding_time_dim
477
+ self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))
478
+
479
+
480
+ def forward(self, x, time):
481
+ time = time.long()
482
+ time = self.embedding_time(time)
483
+ x = torch.cat((x, time), dim=-1)
484
+ x = self.lin0(x)
485
+ residual = x
486
+ for res_block in range(self.n_blocks):
487
+ x = self.mlp[res_block](x)
488
+ x += residual
489
+ residual = x
490
+ x = x.reshape(len(x), -1)
491
+ x = self.lin1(x)
492
+ return x
493
+
494
+ # # test the memory encoder
495
+ # memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)
496
+
497
+ # device = torch.device("cpu")
498
+ # memory_encoder.to(device)
499
+
500
+ # # count params
501
+ # total_parameters = 0
502
+ # for parameter in memory_encoder.parameters():
503
+ # total_parameters += parameter.numel()
504
+
505
+ # rand_input = torch.randn((2, 15279)).to(device)
506
+ # rand_time = torch.randint(0, 15, (2,)).to(device)
507
+ # print(rand_input.shape, rand_time.shape)
508
+ # memory_encoder(rand_input, rand_time).shape
509
+
510
+ class MemoryCompressor(nn.Module):
511
+ def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.25):
512
+ super().__init__()
513
+ self.num_past = num_past
514
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
515
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
516
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
517
+ self.final_input_dim = in_dim * num_past
518
+ self.lin0 = nn.Linear(self.final_input_dim, h)
519
+ self.mlp = nn.ModuleList([
520
+ nn.Sequential(
521
+ nn.Linear(h, h),
522
+ *[item() for item in act_and_norm],
523
+ nn.Dropout(drop)
524
+ ) for _ in range(n_blocks)
525
+ ])
526
+ self.lin1 = nn.Linear(h, output_dim, bias=True)
527
+ self.n_blocks = n_blocks
528
+ self.num_past = num_past
529
+ self.output_dim = output_dim
530
+
531
+ def forward(self, x):
532
+ # x is (batch_size, num_past, in_dim)
533
+ x = x.reshape(len(x), -1)
534
+ x = self.lin0(x)
535
+ residual = x
536
+ for res_block in range(self.n_blocks):
537
+ x = self.mlp[res_block](x)
538
+ x += residual
539
+ residual = x
540
+ x = x.reshape(len(x), -1)
541
+ x = self.lin1(x)
542
+ return x
543
+
544
+ # # test the memory compressor
545
+ # memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)
546
+
547
+ # device = torch.device("cpu")
548
+ # memory_compressor.to(device)
549
+
550
+ # # count params
551
+ # total_parameters = 0
552
+ # for parameter in memory_compressor.parameters():
553
+ # total_parameters += parameter.numel()
554
+
555
+ # rand_input = torch.randn((2, 15, 768)).to(device)
556
+ # print(rand_input.shape)
557
+ # memory_compressor(rand_input).shape
558
+
559
+ class TimeEmbedding(nn.Module):
560
+ def __init__(self, embedding_time_dim=512, num_past_voxels=15):
561
+ super().__init__()
562
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
563
+ self.num_past_voxels = num_past_voxels
564
+ self.embedding_time_dim = embedding_time_dim
565
+
566
+ def forward(self, time):
567
+ # time is (batch_size,)
568
+ time = time.long()
569
+ time = self.embedding_time(time)
570
+ return time # (batch_size, embedding_time_dim)
571
+
572
+
573
+ #model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
574
+ model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
575
+ model.memory_compressor = MemoryCompressor(in_dim=model.ridge.out_features, num_past=15, output_dim=4096)
576
+
577
+ #utils.count_params(model.memory_encoder)
578
+ utils.count_params(model.memory_compressor)
579
+ utils.count_params(model)
580
+
581
+
582
+
583
+ # In[18]:
584
+
585
+
586
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
587
+ opt_grouped_parameters = [
588
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
589
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
590
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
591
+ #{'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},
592
+ {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},
593
+ {'params': [p for n, p in model.time_embedding.named_parameters()], 'weight_decay': 0.0},
594
+ ]
595
+
596
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
597
+
598
+ if lr_scheduler_type == 'linear':
599
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
600
+ optimizer,
601
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
602
+ last_epoch=-1
603
+ )
604
+ elif lr_scheduler_type == 'cycle':
605
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
606
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
607
+ optimizer,
608
+ max_lr=max_lr,
609
+ total_steps=total_steps,
610
+ final_div_factor=1000,
611
+ last_epoch=-1, pct_start=2/num_epochs
612
+ )
613
+
614
+ def save_ckpt(tag):
615
+ ckpt_path = outdir+f'/{tag}.pth'
616
+ print(f'saving {ckpt_path}',flush=True)
617
+ unwrapped_model = accelerator.unwrap_model(model)
618
+ try:
619
+ torch.save({
620
+ 'epoch': epoch,
621
+ 'model_state_dict': unwrapped_model.state_dict(),
622
+ 'optimizer_state_dict': optimizer.state_dict(),
623
+ 'lr_scheduler': lr_scheduler.state_dict(),
624
+ 'train_losses': losses,
625
+ 'test_losses': test_losses,
626
+ 'lrs': lrs,
627
+ }, ckpt_path)
628
+ except:
629
+ print("Couldn't save... moving on to prevent crashing.")
630
+ del unwrapped_model
631
+
632
+ print("\nDone with model preparations!")
633
+ utils.count_params(model)
634
+
635
+
636
+ # In[ ]:
637
+
638
+
639
+
640
+
641
+
642
+ # # Weights and Biases
643
+
644
+ # In[19]:
645
+
646
+
647
+ # params for wandb
648
+ wandb_log = True
649
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
650
+ import wandb
651
+
652
+ wandb_project = 'stability'
653
+ wandb_run = model_name
654
+ wandb_notes = ''
655
+
656
+ print(f"wandb {wandb_project} run {wandb_run}")
657
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
658
+ wandb_config = {
659
+ "model_name": model_name,
660
+ "batch_size": batch_size,
661
+ "num_epochs": num_epochs,
662
+ "use_image_aug": use_image_aug,
663
+ "max_lr": max_lr,
664
+ "lr_scheduler_type": lr_scheduler_type,
665
+ "mixup_pct": mixup_pct,
666
+ "num_train": num_train,
667
+ "num_test": num_test,
668
+ "seed": seed,
669
+ "distributed": distributed,
670
+ "num_devices": num_devices,
671
+ "world_size": world_size,
672
+ }
673
+ print("wandb_config:\n",wandb_config)
674
+ if False: # wandb_auto_resume
675
+ print("wandb_id:",model_name)
676
+ wandb.init(
677
+ id = model_name,
678
+ project=wandb_project,
679
+ name=wandb_run,
680
+ config=wandb_config,
681
+ notes=wandb_notes,
682
+ resume="allow",
683
+ )
684
+ else:
685
+ wandb.init(
686
+ project=wandb_project,
687
+ name=model_name,
688
+ config=wandb_config,
689
+ notes=wandb_notes,
690
+ )
691
+ else:
692
+ wandb_log = False
693
+
694
+
695
+ # # More custom functions
696
+
697
+ # In[20]:
698
+
699
+
700
+ # using the same preprocessing as was used in MindEye + BrainDiffuser
701
+ pixcorr_preprocess = transforms.Compose([
702
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
703
+ ])
704
+ def pixcorr(images,brains):
705
+ # Flatten images while keeping the batch dimension
706
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
707
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
708
+ corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
709
+ return corrmean
710
+
711
+
712
+ # # Main
713
+
714
+ # In[21]:
715
+
716
+
717
+ epoch = 0
718
+ losses, test_losses, lrs = [], [], []
719
+ best_test_loss = 1e9
720
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
721
+
722
+ # Optionally resume from checkpoint #
723
+ if resume_from_ckpt:
724
+ print("\n---resuming from last.pth ckpt---\n")
725
+ try:
726
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
727
+ except:
728
+ print('last.pth failed... trying last_backup.pth')
729
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
730
+ epoch = checkpoint['epoch']
731
+ print("Epoch",epoch)
732
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
733
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
734
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
735
+ del checkpoint
736
+ elif wandb_log:
737
+ if wandb.run.resumed:
738
+ print("\n---resuming from last.pth ckpt---\n")
739
+ try:
740
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
741
+ except:
742
+ print('last.pth failed... trying last_backup.pth')
743
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
744
+ epoch = checkpoint['epoch']
745
+ print("Epoch",epoch)
746
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
747
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
748
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
749
+ del checkpoint
750
+ torch.cuda.empty_cache()
751
+
752
+
753
+ # In[22]:
754
+
755
+
756
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
757
+ model, optimizer, train_dl, test_dl, lr_scheduler
758
+ )
759
+
760
+
761
+ # In[23]:
762
+
763
+
764
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
765
+ progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))
766
+ test_image, test_voxel = None, None
767
+ mse = nn.MSELoss()
768
+ for epoch in progress_bar:
769
+ model.train()
770
+
771
+ fwd_percent_correct = 0.
772
+ bwd_percent_correct = 0.
773
+ test_fwd_percent_correct = 0.
774
+ test_bwd_percent_correct = 0.
775
+
776
+ loss_clip_total = 0.
777
+ loss_blurry_total = 0.
778
+ test_loss_clip_total = 0.
779
+ test_loss_blurry_total = 0.
780
+
781
+ blurry_pixcorr = 0.
782
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
783
+
784
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
785
+ #if epoch == 0 or epoch == 1:
786
+ # break
787
+ with torch.cuda.amp.autocast():
788
+ optimizer.zero_grad()
789
+
790
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
791
+
792
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
793
+
794
+ past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
795
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
796
+
797
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
798
+
799
+ if use_image_aug: image = img_augment(image)
800
+
801
+ clip_target = clip_model.embed_image(image)
802
+ assert not torch.any(torch.isnan(clip_target))
803
+
804
+ if epoch < int(mixup_pct * num_epochs):
805
+ voxel, perm, betas, select = utils.mixco(voxel)
806
+
807
+ # reshape past voxels to be (batch_size * 15, 15279)
808
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
809
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
810
+ past_15_times = past_15_times.reshape(-1)
811
+
812
+ #print(past_15_voxels.shape, past_15_times.shape)
813
+ time_embeddings = model.time_embedding(past_15_times)
814
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
815
+ embeds_past_voxels = model.ridge(past_info_full)
816
+ #print(embeds_past_voxels.shape)
817
+ embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
818
+ #print(embeds_past_voxels.shape)
819
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
820
+
821
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
822
+ #print(torch.cat((voxel, positional_current_voxel), dim=-1).shape, positional_current_voxel.shape, voxel.shape)
823
+ voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1)
824
+
825
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
826
+
827
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
828
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
829
+
830
+ if epoch < int(mixup_pct * num_epochs):
831
+ loss_clip = utils.mixco_nce(
832
+ clip_voxels_norm,
833
+ clip_target_norm,
834
+ temp=.006,
835
+ perm=perm, betas=betas, select=select)
836
+ else:
837
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
838
+ loss_clip = utils.soft_clip_loss(
839
+ clip_voxels_norm,
840
+ clip_target_norm,
841
+ temp=epoch_temp)
842
+
843
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
844
+
845
+ loss_clip_total += loss_clip.item()
846
+ loss_blurry_total += loss_blurry.item()
847
+
848
+ loss = loss_blurry + loss_clip
849
+
850
+ utils.check_loss(loss)
851
+
852
+ accelerator.backward(loss)
853
+ optimizer.step()
854
+
855
+ losses.append(loss.item())
856
+ lrs.append(optimizer.param_groups[0]['lr'])
857
+
858
+ # forward and backward top 1 accuracy
859
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
860
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
861
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
862
+
863
+ with torch.no_grad():
864
+ # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
865
+ random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)
866
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
867
+ blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
868
+
869
+ if lr_scheduler_type is not None:
870
+ lr_scheduler.step()
871
+
872
+ model.eval()
873
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
874
+ print('test')
875
+ with torch.cuda.amp.autocast():
876
+ with torch.no_grad():
877
+ # all test samples should be loaded per batch such that test_i should never exceed 0
878
+ if len(behav) != num_test: print("!",len(behav),num_test)
879
+
880
+
881
+ ## Average same-image repeats ##
882
+ if test_image is None:
883
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
884
+
885
+ image = behav[:,0,0].cpu().long()
886
+
887
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
888
+ for im in unique_image:
889
+ locs = torch.where(im == image)[0]
890
+ if test_image is None:
891
+ test_image = images[im][None]
892
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
893
+ else:
894
+ test_image = torch.vstack((test_image, images[im][None]))
895
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
896
+
897
+ # sample of batch_size
898
+ random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
899
+ voxel = test_voxel[random_indices].to(device)
900
+ image = test_image[random_indices].to(device)
901
+
902
+ current_past_behav = past_behav[random_indices]
903
+
904
+ past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
905
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
906
+
907
+ assert len(image) == batch_size
908
+
909
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
910
+
911
+ clip_target = clip_model.embed_image(image.float())
912
+
913
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
914
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
915
+ past_15_times = past_15_times.reshape(-1)
916
+
917
+ print(past_15_voxels.shape, past_15_times.shape)
918
+
919
+ #print(past_15_voxels.shape, past_15_times.shape)
920
+ time_embeddings = model.time_embedding(past_15_times)
921
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
922
+ embeds_past_voxels = model.ridge(past_info_full)
923
+ #print(embeds_past_voxels.shape)
924
+ embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
925
+ #print(embeds_past_voxels.shape)
926
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
927
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
928
+
929
+ voxel_ridge = torch.cat([model.ridge(torch.cat((voxel, positional_current_voxel), dim=-1)), information_past_voxels], dim=-1)
930
+
931
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
932
+
933
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
934
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
935
+
936
+ loss_clip = utils.soft_clip_loss(
937
+ clip_voxels_norm,
938
+ clip_target_norm,
939
+ temp=.006)
940
+
941
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
942
+
943
+ loss = loss_blurry + loss_clip
944
+
945
+ utils.check_loss(loss)
946
+
947
+ test_losses.append(loss.item())
948
+
949
+ # forward and backward top 1 accuracy
950
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
951
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
952
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
953
+
954
+ # halving the batch size because the decoder is computationally heavy
955
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
956
+ blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
957
+ test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
958
+
959
+ # transform blurry recon latents to images and plot it
960
+ #fig, axes = plt.subplots(1, 4, figsize=(8, 4))
961
+ #axes[0].imshow(utils.torch_to_Image(image[[0]]))
962
+ #axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
963
+ #axes[2].imshow(utils.torch_to_Image(image[[1]]))
964
+ #axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
965
+ #axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
966
+ #plt.show()
967
+ wandb.log({"gt": [wandb.Image(utils.torch_to_Image(image[[0]])), wandb.Image(utils.torch_to_Image(image[[1]])) ]}
968
+ wandb.log({"preds": [utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)), utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)) ]}
969
+
970
+ if local_rank==0:
971
+ # if utils.is_interactive(): clear_output(wait=True)
972
+ assert (test_i+1) == 1
973
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
974
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
975
+ "train/lr": lrs[-1],
976
+ "train/num_steps": len(losses),
977
+ "test/num_steps": len(test_losses),
978
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
979
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
980
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
981
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
982
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
983
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
984
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
985
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
986
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
987
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
988
+ }
989
+ progress_bar.set_postfix(**logs)
990
+
991
+ # Save model checkpoint and reconstruct
992
+ if epoch % ckpt_interval == 0:
993
+ if not utils.is_interactive():
994
+ save_ckpt(f'last')
995
+
996
+ if wandb_log: wandb.log(logs)
997
+
998
+ # wait for other GPUs to catch up if needed
999
+ accelerator.wait_for_everyone()
1000
+ torch.cuda.empty_cache()
1001
+ gc.collect()
1002
+
1003
+ print("\n===Finished!===\n")
1004
+ if ckpt_saving:
1005
+ save_ckpt(f'last')
1006
+ if not utils.is_interactive():
1007
+ sys.exit(0)
1008
+
1009
+
1010
+
1011
+ # In[ ]:
1012
+
1013
+
1014
+ plt.plot(losses)
1015
+ plt.show()
1016
+ plt.plot(test_losses)
1017
+ plt.show()
1018
+
src/Train-with-memory.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train-with-memory.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ # from subprocess import call
9
+ # command = "jupyter nbconvert Train.ipynb --to python"
10
+ # call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ #from einops import rearrange
25
+ import time
26
+ import random
27
+ import h5py
28
+ from tqdm import tqdm
29
+
30
+ import webdataset as wds
31
+ import gc
32
+
33
+ import matplotlib.pyplot as plt
34
+ import torch
35
+ import torch.nn as nn
36
+ from torchvision import transforms
37
+
38
+ from accelerate import Accelerator, DeepSpeedPlugin
39
+
40
+ # tf32 data type is faster than standard float32
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+
43
+ # custom functions #
44
+ import utils
45
+
46
+ global_batch_size = 128 #128
47
+
48
+
49
+ # In[3]:
50
+
51
+
52
+ ### Multi-GPU config ###
53
+ local_rank = os.getenv('RANK')
54
+ if local_rank is None:
55
+ local_rank = 0
56
+ else:
57
+ local_rank = int(local_rank)
58
+ print("LOCAL RANK ", local_rank)
59
+
60
+ num_devices = torch.cuda.device_count()
61
+ if num_devices==0: num_devices = 1
62
+
63
+ accelerator = Accelerator(split_batches=False)
64
+
65
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
66
+
67
+ # if num_devices <= 1 and utils.is_interactive():
68
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
69
+ # os.environ["MASTER_ADDR"] = "localhost"
70
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
71
+ # os.environ["RANK"] = "0"
72
+ # os.environ["LOCAL_RANK"] = "0"
73
+ # os.environ["WORLD_SIZE"] = "1"
74
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
75
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
76
+
77
+ # # alter the deepspeed config according to your global and local batch size
78
+ # if local_rank == 0:
79
+ # with open('deepspeed_config_stage2.json', 'r') as file:
80
+ # config = json.load(file)
81
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
82
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
83
+ # with open('deepspeed_config_stage2.json', 'w') as file:
84
+ # json.dump(config, file)
85
+ # else:
86
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
87
+ # time.sleep(10)
88
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
89
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
90
+
91
+
92
+ # In[4]:
93
+
94
+
95
+ print("PID of this process =",os.getpid())
96
+ device = accelerator.device
97
+ print("device:",device)
98
+ num_workers = num_devices
99
+ print(accelerator.state)
100
+ world_size = accelerator.state.num_processes
101
+ distributed = not accelerator.state.distributed_type == 'NO'
102
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
103
+ print = accelerator.print # only print if local_rank=0
104
+
105
+
106
+ # # Configurations
107
+
108
+ # In[5]:
109
+
110
+
111
+ # if running this interactively, can specify jupyter_args here for argparser to use
112
+ if utils.is_interactive():
113
+ # Example use
114
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
115
+ --model_name=test \
116
+ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
117
+ --max_lr=3e-5 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug"
118
+
119
+ jupyter_args = jupyter_args.split()
120
+ print(jupyter_args)
121
+
122
+ from IPython.display import clear_output # function to clear print outputs in cell
123
+ get_ipython().run_line_magic('load_ext', 'autoreload')
124
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
125
+ get_ipython().run_line_magic('autoreload', '2')
126
+
127
+
128
+ # In[6]:
129
+
130
+
131
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
132
+ parser.add_argument(
133
+ "--model_name", type=str, default="testing",
134
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
135
+ )
136
+ parser.add_argument(
137
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
138
+ help="Path to where NSD data is stored / where to download it to",
139
+ )
140
+ parser.add_argument(
141
+ "--subj",type=int, default=1, choices=[1,2,5,7],
142
+ )
143
+ parser.add_argument(
144
+ "--batch_size", type=int, default=32,
145
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
146
+ )
147
+ parser.add_argument(
148
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
149
+ help="whether to log to wandb",
150
+ )
151
+ parser.add_argument(
152
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
153
+ help="if not using wandb and want to resume from a ckpt",
154
+ )
155
+ parser.add_argument(
156
+ "--wandb_project",type=str,default="stability",
157
+ help="wandb project name",
158
+ )
159
+ parser.add_argument(
160
+ "--mixup_pct",type=float,default=.33,
161
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
162
+ )
163
+ parser.add_argument(
164
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
165
+ help="whether to use image augmentation",
166
+ )
167
+ parser.add_argument(
168
+ "--num_epochs",type=int,default=240,
169
+ help="number of epochs of training",
170
+ )
171
+ parser.add_argument(
172
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
173
+ )
174
+ parser.add_argument(
175
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
176
+ )
177
+ parser.add_argument(
178
+ "--ckpt_interval",type=int,default=5,
179
+ help="save backup ckpt and reconstruct every x epochs",
180
+ )
181
+ parser.add_argument(
182
+ "--seed",type=int,default=42,
183
+ )
184
+ parser.add_argument(
185
+ "--max_lr",type=float,default=3e-4,
186
+ )
187
+ parser.add_argument(
188
+ "--n_samples_save",type=int,default=0,choices=[0,1],
189
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
190
+ )
191
+
192
+ if utils.is_interactive():
193
+ args = parser.parse_args(jupyter_args)
194
+ else:
195
+ args = parser.parse_args()
196
+
197
+ # create global variables without the args prefix
198
+ for attribute_name in vars(args).keys():
199
+ globals()[attribute_name] = getattr(args, attribute_name)
200
+
201
+ print("global batch_size", batch_size)
202
+ batch_size = int(batch_size / num_devices)
203
+ print("batch_size", batch_size)
204
+
205
+
206
+ # In[7]:
207
+
208
+
209
+ outdir = os.path.abspath(f'../train_mem_logs/{model_name}')
210
+ if not os.path.exists(outdir):
211
+ os.makedirs(outdir,exist_ok=True)
212
+ if use_image_aug:
213
+ import kornia
214
+ from kornia.augmentation.container import AugmentationSequential
215
+ img_augment = AugmentationSequential(
216
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
217
+ kornia.augmentation.Resize((224, 224)),
218
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
219
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
220
+ kornia.augmentation.RandomGrayscale(p=0.3),
221
+ same_on_batch=False,
222
+ data_keys=["input"],
223
+ )
224
+
225
+
226
+ # # Prep data, models, and dataloaders
227
+
228
+ # ## Dataloader
229
+
230
+ # In[8]:
231
+
232
+
233
+ if subj==1:
234
+ num_train = 24958
235
+ num_test = 2770
236
+ test_batch_size = num_test
237
+
238
+ def my_split_by_node(urls): return urls
239
+
240
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
241
+ print(train_url)
242
+
243
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
244
+ .shuffle(750, initial=1500, rng=random.Random(42))\
245
+ .decode("torch")\
246
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
247
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
248
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
249
+
250
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
251
+ print(test_url)
252
+
253
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
254
+ .shuffle(750, initial=1500, rng=random.Random(42))\
255
+ .decode("torch")\
256
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
257
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
258
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
259
+
260
+
261
+ # ### check dataloaders are working
262
+
263
+ # In[9]:
264
+
265
+
266
+ # test_indices = []
267
+ # test_images = []
268
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
269
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
270
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
271
+ # test_indices = test_indices.astype(np.int16)
272
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
273
+ # print("---\n")
274
+
275
+ # train_indices = []
276
+ # train_images = []
277
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
278
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
279
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
280
+ # train_indices = train_indices.astype(np.int16)
281
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
282
+
283
+ # # train_images = np.hstack((train_images, test_images))
284
+ # # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")
285
+
286
+
287
+ # ## Load data and images
288
+
289
+ # In[10]:
290
+
291
+
292
+ # load betas
293
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
294
+ voxels = f['betas'][:]
295
+ print(f"subj0{subj} betas loaded into memory")
296
+ voxels = torch.Tensor(voxels).to("cpu").half()
297
+ if subj==1:
298
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
299
+ print("voxels", voxels.shape)
300
+ num_voxels = voxels.shape[-1]
301
+
302
+ # load orig images
303
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
304
+ images = f['images'][:]
305
+ images = torch.Tensor(images).to("cpu").half()
306
+ print("images", images.shape)
307
+
308
+
309
+ # ## Load models
310
+
311
+ # ### CLIP image embeddings model
312
+
313
+ # In[11]:
314
+
315
+
316
+ from models import Clipper
317
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
318
+
319
+ clip_seq_dim = 257
320
+ clip_emb_dim = 768
321
+ hidden_dim = 4096
322
+
323
+
324
+ # ### SD VAE (blurry images)
325
+
326
+ # In[12]:
327
+
328
+
329
+ from diffusers import AutoencoderKL
330
+ autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
331
+ # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
332
+ autoenc.eval()
333
+ autoenc.requires_grad_(False)
334
+ autoenc.to(device)
335
+ utils.count_params(autoenc)
336
+
337
+
338
+ # ### MindEye modules
339
+
340
+ # In[13]:
341
+
342
+
343
+ class MindEyeModule(nn.Module):
344
+ def __init__(self):
345
+ super(MindEyeModule, self).__init__()
346
+ def forward(self, x):
347
+ return x
348
+
349
+ model = MindEyeModule()
350
+ model
351
+
352
+
353
+ # In[14]:
354
+
355
+
356
+ class RidgeRegression(torch.nn.Module):
357
+ # make sure to add weight_decay when initializing optimizer
358
+ def __init__(self, input_size, out_features):
359
+ super(RidgeRegression, self).__init__()
360
+ self.out_features = out_features
361
+ self.linear = torch.nn.Linear(input_size, out_features)
362
+ def forward(self, x):
363
+ return self.linear(x)
364
+
365
+ model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
366
+ utils.count_params(model.ridge)
367
+ utils.count_params(model)
368
+
369
+ b = torch.randn((2,1,voxels.shape[1]))
370
+ print(b.shape, model.ridge(b).shape)
371
+
372
+
373
+ # In[15]:
374
+
375
+
376
+ from functools import partial
377
+ from diffusers.models.vae import Decoder
378
+ class BrainNetwork(nn.Module):
379
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):
380
+ super().__init__()
381
+ self.blurry_dim = blurry_dim
382
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
383
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
384
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
385
+ self.lin0 = nn.Linear(in_dim, h)
386
+ self.mlp = nn.ModuleList([
387
+ nn.Sequential(
388
+ nn.Linear(h, h),
389
+ *[item() for item in act_and_norm],
390
+ nn.Dropout(drop)
391
+ ) for _ in range(n_blocks)
392
+ ])
393
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
394
+ self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
395
+ self.n_blocks = n_blocks
396
+ self.clip_size = clip_size
397
+ self.clip_proj = nn.Sequential(
398
+ nn.LayerNorm(clip_size),
399
+ nn.GELU(),
400
+ nn.Linear(clip_size, 2048),
401
+ nn.LayerNorm(2048),
402
+ nn.GELU(),
403
+ nn.Linear(2048, 2048),
404
+ nn.LayerNorm(2048),
405
+ nn.GELU(),
406
+ nn.Linear(2048, clip_size)
407
+ )
408
+ self.upsampler = Decoder(
409
+ in_channels=64,
410
+ out_channels=4,
411
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
412
+ block_out_channels=[64, 128, 256],
413
+ layers_per_block=1,
414
+ )
415
+
416
+ def forward(self, x):
417
+ x = self.lin0(x)
418
+ residual = x
419
+ for res_block in range(self.n_blocks):
420
+ x = self.mlp[res_block](x)
421
+ x += residual
422
+ residual = x
423
+ x = x.reshape(len(x), -1)
424
+ x = self.lin1(x)
425
+ b = self.blin1(x)
426
+ b = self.upsampler(b.reshape(len(b), -1, 7, 7))
427
+ c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
428
+ return c, b
429
+
430
+ model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
431
+ utils.count_params(model.backbone)
432
+ utils.count_params(model)
433
+
434
+ b = torch.randn((2,hidden_dim))
435
+ print(b.shape)
436
+ clip_, blur_ = model.backbone(b)
437
+ print(clip_.shape, blur_.shape)
438
+
439
+
440
+ # In[19]:
441
+
442
+
443
+ # memory model
444
+
445
+ from timm.layers.mlp import Mlp
446
+
447
+ class MemoryEncoder(nn.Module):
448
+ def __init__(self, in_dim=15279, out_dim=768, h=4096, num_past_voxels=15, embedding_time_dim = 512, n_blocks=4, norm_type='ln', act_first=False, drop=.15):
449
+ super().__init__()
450
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
451
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
452
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
453
+ self.out_dim = out_dim
454
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
455
+ self.final_input_dim = in_dim + embedding_time_dim
456
+ self.lin0 = nn.Linear(self.final_input_dim, h)
457
+ self.mlp = nn.ModuleList([
458
+ nn.Sequential(
459
+ nn.Linear(h, h),
460
+ *[item() for item in act_and_norm],
461
+ nn.Dropout(drop)
462
+ ) for _ in range(n_blocks)
463
+ ])
464
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
465
+ self.n_blocks = n_blocks
466
+ self.num_past_voxels = num_past_voxels
467
+ self.embedding_time_dim = embedding_time_dim
468
+ self.memory = nn.Parameter(torch.randn((self.num_past_voxels, self.embedding_time_dim)))
469
+
470
+
471
+ def forward(self, x, time):
472
+ time = time.long()
473
+ time = self.embedding_time(time)
474
+ x = torch.cat((x, time), dim=-1)
475
+ x = self.lin0(x)
476
+ residual = x
477
+ for res_block in range(self.n_blocks):
478
+ x = self.mlp[res_block](x)
479
+ x += residual
480
+ residual = x
481
+ x = x.reshape(len(x), -1)
482
+ x = self.lin1(x)
483
+ return x
484
+
485
+ # # test the memory encoder
486
+ # memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=hidden_dim, num_past_voxels=15, embedding_time_dim=512)
487
+
488
+ # device = torch.device("cpu")
489
+ # memory_encoder.to(device)
490
+
491
+ # # count params
492
+ # total_parameters = 0
493
+ # for parameter in memory_encoder.parameters():
494
+ # total_parameters += parameter.numel()
495
+
496
+ # rand_input = torch.randn((2, 15279)).to(device)
497
+ # rand_time = torch.randint(0, 15, (2,)).to(device)
498
+ # print(rand_input.shape, rand_time.shape)
499
+ # memory_encoder(rand_input, rand_time).shape
500
+
501
+ class MemoryCompressor(nn.Module):
502
+ def __init__(self, in_dim=768, num_past = 15, output_dim=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15):
503
+ super().__init__()
504
+ self.num_past = num_past
505
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
506
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
507
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
508
+ self.final_input_dim = in_dim * num_past
509
+ self.lin0 = nn.Linear(self.final_input_dim, h)
510
+ self.mlp = nn.ModuleList([
511
+ nn.Sequential(
512
+ nn.Linear(h, h),
513
+ *[item() for item in act_and_norm],
514
+ nn.Dropout(drop)
515
+ ) for _ in range(n_blocks)
516
+ ])
517
+ self.lin1 = nn.Linear(h, output_dim, bias=True)
518
+ self.n_blocks = n_blocks
519
+ self.num_past = num_past
520
+ self.output_dim = output_dim
521
+
522
+ def forward(self, x):
523
+ # x is (batch_size, num_past, in_dim)
524
+ x = x.reshape(len(x), -1)
525
+ x = self.lin0(x)
526
+ residual = x
527
+ for res_block in range(self.n_blocks):
528
+ x = self.mlp[res_block](x)
529
+ x += residual
530
+ residual = x
531
+ x = x.reshape(len(x), -1)
532
+ x = self.lin1(x)
533
+ return x
534
+
535
+ # # test the memory compressor
536
+ # memory_compressor = MemoryCompressor(in_dim=768, num_past=15, output_dim=768)
537
+
538
+ # device = torch.device("cpu")
539
+ # memory_compressor.to(device)
540
+
541
+ # # count params
542
+ # total_parameters = 0
543
+ # for parameter in memory_compressor.parameters():
544
+ # total_parameters += parameter.numel()
545
+
546
+ # rand_input = torch.randn((2, 15, 768)).to(device)
547
+ # print(rand_input.shape)
548
+ # memory_compressor(rand_input).shape
549
+
550
+ model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
551
+ model.memory_compressor = MemoryCompressor(in_dim=model.memory_encoder.out_dim, num_past=15, output_dim=4096)
552
+
553
+ utils.count_params(model.memory_encoder)
554
+ utils.count_params(model.memory_compressor)
555
+ utils.count_params(model)
556
+
557
+
558
+
559
+ # In[17]:
560
+
561
+
562
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
563
+ opt_grouped_parameters = [
564
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
565
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
566
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
567
+ {'params': [p for n, p in model.memory_encoder.named_parameters()], 'weight_decay': 1e-2},
568
+ {'params': [p for n, p in model.memory_compressor.named_parameters()], 'weight_decay': 1e-2},
569
+ ]
570
+
571
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
572
+
573
+ if lr_scheduler_type == 'linear':
574
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
575
+ optimizer,
576
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
577
+ last_epoch=-1
578
+ )
579
+ elif lr_scheduler_type == 'cycle':
580
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
581
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
582
+ optimizer,
583
+ max_lr=max_lr,
584
+ total_steps=total_steps,
585
+ final_div_factor=1000,
586
+ last_epoch=-1, pct_start=2/num_epochs
587
+ )
588
+
589
+ def save_ckpt(tag):
590
+ ckpt_path = outdir+f'/{tag}.pth'
591
+ print(f'saving {ckpt_path}',flush=True)
592
+ unwrapped_model = accelerator.unwrap_model(model)
593
+ try:
594
+ torch.save({
595
+ 'epoch': epoch,
596
+ 'model_state_dict': unwrapped_model.state_dict(),
597
+ 'optimizer_state_dict': optimizer.state_dict(),
598
+ 'lr_scheduler': lr_scheduler.state_dict(),
599
+ 'train_losses': losses,
600
+ 'test_losses': test_losses,
601
+ 'lrs': lrs,
602
+ }, ckpt_path)
603
+ except:
604
+ print("Couldn't save... moving on to prevent crashing.")
605
+ del unwrapped_model
606
+
607
+ print("\nDone with model preparations!")
608
+ utils.count_params(model)
609
+
610
+
611
+
612
+ # # Weights and Biases
613
+
614
+ # In[ ]:
615
+
616
+
617
+ # params for wandb
618
+ wandb_log = True
619
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
620
+ import wandb
621
+
622
+ wandb_project = 'stability'
623
+ wandb_run = model_name
624
+ wandb_notes = ''
625
+
626
+ print(f"wandb {wandb_project} run {wandb_run}")
627
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
628
+ wandb_config = {
629
+ "model_name": model_name,
630
+ "batch_size": batch_size,
631
+ "num_epochs": num_epochs,
632
+ "use_image_aug": use_image_aug,
633
+ "max_lr": max_lr,
634
+ "lr_scheduler_type": lr_scheduler_type,
635
+ "mixup_pct": mixup_pct,
636
+ "num_train": num_train,
637
+ "num_test": num_test,
638
+ "seed": seed,
639
+ "distributed": distributed,
640
+ "num_devices": num_devices,
641
+ "world_size": world_size,
642
+ }
643
+ print("wandb_config:\n",wandb_config)
644
+ if False: # wandb_auto_resume
645
+ print("wandb_id:",model_name)
646
+ wandb.init(
647
+ id = model_name,
648
+ project=wandb_project,
649
+ name=wandb_run,
650
+ config=wandb_config,
651
+ notes=wandb_notes,
652
+ resume="allow",
653
+ )
654
+ else:
655
+ wandb.init(
656
+ project=wandb_project,
657
+ name=model_name,
658
+ config=wandb_config,
659
+ notes=wandb_notes,
660
+ )
661
+ else:
662
+ wandb_log = False
663
+
664
+
665
+ # # More custom functions
666
+
667
+ # In[ ]:
668
+
669
+
670
+ # using the same preprocessing as was used in MindEye + BrainDiffuser
671
+ pixcorr_preprocess = transforms.Compose([
672
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
673
+ ])
674
+ def pixcorr(images,brains):
675
+ # Flatten images while keeping the batch dimension
676
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
677
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
678
+ corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
679
+ return corrmean
680
+
681
+
682
+ # # Main
683
+
684
+ # In[ ]:
685
+
686
+
687
+ epoch = 0
688
+ losses, test_losses, lrs = [], [], []
689
+ best_test_loss = 1e9
690
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
691
+
692
+ # Optionally resume from checkpoint #
693
+ if resume_from_ckpt:
694
+ print("\n---resuming from last.pth ckpt---\n")
695
+ try:
696
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
697
+ except:
698
+ print('last.pth failed... trying last_backup.pth')
699
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
700
+ epoch = checkpoint['epoch']
701
+ print("Epoch",epoch)
702
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
703
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
704
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
705
+ del checkpoint
706
+ elif wandb_log:
707
+ if wandb.run.resumed:
708
+ print("\n---resuming from last.pth ckpt---\n")
709
+ try:
710
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
711
+ except:
712
+ print('last.pth failed... trying last_backup.pth')
713
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
714
+ epoch = checkpoint['epoch']
715
+ print("Epoch",epoch)
716
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
717
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
718
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
719
+ del checkpoint
720
+ torch.cuda.empty_cache()
721
+
722
+
723
+ # In[ ]:
724
+
725
+
726
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
727
+ model, optimizer, train_dl, test_dl, lr_scheduler
728
+ )
729
+
730
+
731
+ # In[ ]:
732
+
733
+
734
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
735
+ progress_bar = tqdm(range(0,num_epochs), ncols=1200, disable=(local_rank!=0))
736
+ test_image, test_voxel = None, None
737
+ mse = nn.MSELoss()
738
+ for epoch in progress_bar:
739
+ model.train()
740
+
741
+ fwd_percent_correct = 0.
742
+ bwd_percent_correct = 0.
743
+ test_fwd_percent_correct = 0.
744
+ test_bwd_percent_correct = 0.
745
+
746
+ loss_clip_total = 0.
747
+ loss_blurry_total = 0.
748
+ test_loss_clip_total = 0.
749
+ test_loss_blurry_total = 0.
750
+
751
+ blurry_pixcorr = 0.
752
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
753
+
754
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
755
+ #if epoch == 0 or epoch == 1:
756
+ # break
757
+ with torch.cuda.amp.autocast():
758
+ optimizer.zero_grad()
759
+
760
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
761
+
762
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
763
+
764
+ past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
765
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
766
+
767
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
768
+
769
+ if use_image_aug: image = img_augment(image)
770
+
771
+ clip_target = clip_model.embed_image(image)
772
+ assert not torch.any(torch.isnan(clip_target))
773
+
774
+ if epoch < int(mixup_pct * num_epochs):
775
+ voxel, perm, betas, select = utils.mixco(voxel)
776
+
777
+ # reshape past voxels to be (batch_size * 15, 15279)
778
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
779
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
780
+ past_15_times = past_15_times.reshape(-1)
781
+
782
+ #print(past_15_voxels.shape, past_15_times.shape)
783
+
784
+ embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)
785
+ #print(embeds_past_voxels.shape)
786
+ embeds_past_voxels = embeds_past_voxels.reshape(voxel.shape[0], 15, -1)
787
+ #print(embeds_past_voxels.shape)
788
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
789
+
790
+
791
+ voxel_ridge = model.ridge(voxel) + information_past_voxels
792
+
793
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
794
+
795
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
796
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
797
+
798
+ if epoch < int(mixup_pct * num_epochs):
799
+ loss_clip = utils.mixco_nce(
800
+ clip_voxels_norm,
801
+ clip_target_norm,
802
+ temp=.006,
803
+ perm=perm, betas=betas, select=select)
804
+ else:
805
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
806
+ loss_clip = utils.soft_clip_loss(
807
+ clip_voxels_norm,
808
+ clip_target_norm,
809
+ temp=epoch_temp)
810
+
811
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
812
+
813
+ loss_clip_total += loss_clip.item()
814
+ loss_blurry_total += loss_blurry.item()
815
+
816
+ loss = loss_blurry + loss_clip
817
+
818
+ utils.check_loss(loss)
819
+
820
+ accelerator.backward(loss)
821
+ optimizer.step()
822
+
823
+ losses.append(loss.item())
824
+ lrs.append(optimizer.param_groups[0]['lr'])
825
+
826
+ # forward and backward top 1 accuracy
827
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
828
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
829
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
830
+
831
+ with torch.no_grad():
832
+ # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
833
+ random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)
834
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
835
+ blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
836
+
837
+ if lr_scheduler_type is not None:
838
+ lr_scheduler.step()
839
+
840
+ model.eval()
841
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
842
+ print('test')
843
+ with torch.cuda.amp.autocast():
844
+ with torch.no_grad():
845
+ # all test samples should be loaded per batch such that test_i should never exceed 0
846
+ if len(behav) != num_test: print("!",len(behav),num_test)
847
+
848
+
849
+ ## Average same-image repeats ##
850
+ if test_image is None:
851
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
852
+
853
+ image = behav[:,0,0].cpu().long()
854
+
855
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
856
+ for im in unique_image:
857
+ locs = torch.where(im == image)[0]
858
+ if test_image is None:
859
+ test_image = images[im][None]
860
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
861
+ else:
862
+ test_image = torch.vstack((test_image, images[im][None]))
863
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
864
+
865
+ # sample of batch_size
866
+ random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
867
+ voxel = test_voxel[random_indices].to(device)
868
+ image = test_image[random_indices].to(device)
869
+
870
+ current_past_behav = past_behav[random_indices]
871
+
872
+ past_15_voxels = voxels[current_past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
873
+ past_15_times = torch.Tensor([i for i in range(15)]).to(device) # 15
874
+
875
+ assert len(image) == batch_size
876
+
877
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
878
+
879
+ clip_target = clip_model.embed_image(image.float())
880
+
881
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
882
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
883
+ past_15_times = past_15_times.reshape(-1)
884
+
885
+ print(past_15_voxels.shape, past_15_times.shape)
886
+
887
+ embeds_past_voxels = model.memory_encoder(past_15_voxels, past_15_times)
888
+ embeds_past_voxels = embeds_past_voxels.reshape(batch_size, 15, -1)
889
+ information_past_voxels = model.memory_compressor(embeds_past_voxels)
890
+
891
+
892
+ voxel_ridge = model.ridge(voxel) + information_past_voxels
893
+
894
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
895
+
896
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
897
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
898
+
899
+ loss_clip = utils.soft_clip_loss(
900
+ clip_voxels_norm,
901
+ clip_target_norm,
902
+ temp=.006)
903
+
904
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
905
+
906
+ loss = loss_blurry + loss_clip
907
+
908
+ utils.check_loss(loss)
909
+
910
+ test_losses.append(loss.item())
911
+
912
+ # forward and backward top 1 accuracy
913
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
914
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
915
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
916
+
917
+ # halving the batch size because the decoder is computationally heavy
918
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
919
+ blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
920
+ test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
921
+
922
+ # transform blurry recon latents to images and plot it
923
+ fig, axes = plt.subplots(1, 4, figsize=(8, 4))
924
+ axes[0].imshow(utils.torch_to_Image(image[[0]]))
925
+ axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
926
+ axes[2].imshow(utils.torch_to_Image(image[[1]]))
927
+ axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
928
+ axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
929
+ plt.show()
930
+
931
+ if local_rank==0:
932
+ # if utils.is_interactive(): clear_output(wait=True)
933
+ assert (test_i+1) == 1
934
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
935
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
936
+ "train/lr": lrs[-1],
937
+ "train/num_steps": len(losses),
938
+ "test/num_steps": len(test_losses),
939
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
940
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
941
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
942
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
943
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
944
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
945
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
946
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
947
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
948
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
949
+ }
950
+ progress_bar.set_postfix(**logs)
951
+
952
+ # Save model checkpoint and reconstruct
953
+ if epoch % ckpt_interval == 0:
954
+ if not utils.is_interactive():
955
+ save_ckpt(f'last')
956
+
957
+ if wandb_log: wandb.log(logs)
958
+
959
+ # wait for other GPUs to catch up if needed
960
+ accelerator.wait_for_everyone()
961
+ torch.cuda.empty_cache()
962
+ gc.collect()
963
+
964
+ print("\n===Finished!===\n")
965
+ if ckpt_saving:
966
+ save_ckpt(f'last')
967
+ if not utils.is_interactive():
968
+ sys.exit(0)
969
+
970
+
971
+ # In[ ]:
972
+
973
+
974
+ plt.plot(losses)
975
+ plt.show()
976
+ plt.plot(test_losses)
977
+ plt.show()
978
+
src/Train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ # from subprocess import call
9
+ # command = "jupyter nbconvert Train.ipynb --to python"
10
+ # call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import h5py
28
+ from tqdm import tqdm
29
+
30
+ import webdataset as wds
31
+ import gc
32
+
33
+ import matplotlib.pyplot as plt
34
+ import torch
35
+ import torch.nn as nn
36
+ from torchvision import transforms
37
+
38
+ from accelerate import Accelerator, DeepSpeedPlugin
39
+
40
+ # tf32 data type is faster than standard float32
41
+ torch.backends.cuda.matmul.allow_tf32 = True
42
+
43
+ # custom functions #
44
+ import utils
45
+
46
+ global_batch_size = 128 #128
47
+
48
+
49
+ # In[ ]:
50
+
51
+
52
+ ### Multi-GPU config ###
53
+ local_rank = os.getenv('RANK')
54
+ if local_rank is None:
55
+ local_rank = 0
56
+ else:
57
+ local_rank = int(local_rank)
58
+ print("LOCAL RANK ", local_rank)
59
+
60
+ num_devices = torch.cuda.device_count()
61
+ if num_devices==0: num_devices = 1
62
+
63
+ accelerator = Accelerator(split_batches=False)
64
+
65
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
66
+
67
+ # if num_devices <= 1 and utils.is_interactive():
68
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
69
+ # os.environ["MASTER_ADDR"] = "localhost"
70
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
71
+ # os.environ["RANK"] = "0"
72
+ # os.environ["LOCAL_RANK"] = "0"
73
+ # os.environ["WORLD_SIZE"] = "1"
74
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
75
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
76
+
77
+ # # alter the deepspeed config according to your global and local batch size
78
+ # if local_rank == 0:
79
+ # with open('deepspeed_config_stage2.json', 'r') as file:
80
+ # config = json.load(file)
81
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
82
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
83
+ # with open('deepspeed_config_stage2.json', 'w') as file:
84
+ # json.dump(config, file)
85
+ # else:
86
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
87
+ # time.sleep(10)
88
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
89
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
90
+
91
+ ### Multi-GPU config ###
92
+ print("PID of this process =",os.getpid())
93
+ device = accelerator.device
94
+ print("device:",device)
95
+ num_workers = num_devices
96
+ print(accelerator.state)
97
+ world_size = accelerator.state.num_processes
98
+ distributed = not accelerator.state.distributed_type == 'NO'
99
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
100
+ print = accelerator.print # only print if local_rank=0
101
+
102
+ # In[ ]:
103
+
104
+
105
+
106
+
107
+
108
+ # # Configurations
109
+
110
+ # In[3]:
111
+
112
+
113
+
114
+ # In[4]:
115
+
116
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
117
+ parser.add_argument(
118
+ "--model_name", type=str, default="testing",
119
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
120
+ )
121
+ parser.add_argument(
122
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
123
+ help="Path to where NSD data is stored / where to download it to",
124
+ )
125
+ parser.add_argument(
126
+ "--subj",type=int, default=1, choices=[1,2,5,7],
127
+ )
128
+ parser.add_argument(
129
+ "--batch_size", type=int, default=32,
130
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
131
+ )
132
+ parser.add_argument(
133
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
134
+ help="whether to log to wandb",
135
+ )
136
+ parser.add_argument(
137
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
138
+ help="if not using wandb and want to resume from a ckpt",
139
+ )
140
+ parser.add_argument(
141
+ "--wandb_project",type=str,default="stability",
142
+ help="wandb project name",
143
+ )
144
+ parser.add_argument(
145
+ "--mixup_pct",type=float,default=.33,
146
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
147
+ )
148
+ parser.add_argument(
149
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
150
+ help="whether to use image augmentation",
151
+ )
152
+ parser.add_argument(
153
+ "--num_epochs",type=int,default=240,
154
+ help="number of epochs of training",
155
+ )
156
+ parser.add_argument(
157
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
158
+ )
159
+ parser.add_argument(
160
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
161
+ )
162
+ parser.add_argument(
163
+ "--ckpt_interval",type=int,default=5,
164
+ help="save backup ckpt and reconstruct every x epochs",
165
+ )
166
+ parser.add_argument(
167
+ "--seed",type=int,default=42,
168
+ )
169
+ parser.add_argument(
170
+ "--max_lr",type=float,default=3e-4,
171
+ )
172
+ parser.add_argument(
173
+ "--n_samples_save",type=int,default=0,choices=[0,1],
174
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
175
+ )
176
+
177
+ if utils.is_interactive():
178
+ args = parser.parse_args(jupyter_args)
179
+ else:
180
+ args = parser.parse_args()
181
+
182
+ # create global variables without the args prefix
183
+ for attribute_name in vars(args).keys():
184
+ globals()[attribute_name] = getattr(args, attribute_name)
185
+
186
+ print("global batch_size", batch_size)
187
+ batch_size = int(batch_size / num_devices)
188
+ print("batch_size", batch_size)
189
+
190
+ # In[5]:
191
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
192
+ if not os.path.exists(outdir):
193
+ os.makedirs(outdir,exist_ok=True)
194
+ if use_image_aug:
195
+ import kornia
196
+ from kornia.augmentation.container import AugmentationSequential
197
+ img_augment = AugmentationSequential(
198
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
199
+ kornia.augmentation.Resize((224, 224)),
200
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
201
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
202
+ kornia.augmentation.RandomGrayscale(p=0.3),
203
+ same_on_batch=False,
204
+ data_keys=["input"],
205
+ )
206
+
207
+
208
+ # # Prep data, models, and dataloaders
209
+
210
+ # ## Dataloader
211
+
212
+ # In[6]:
213
+
214
+ if subj==1:
215
+ num_train = 24958
216
+ num_test = 2770
217
+ test_batch_size = num_test
218
+
219
+ def my_split_by_node(urls): return urls
220
+
221
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
222
+ print(train_url)
223
+
224
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
225
+ .shuffle(750, initial=1500, rng=random.Random(42))\
226
+ .decode("torch")\
227
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
228
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
229
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
230
+
231
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
232
+ print(test_url)
233
+
234
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
235
+ .shuffle(750, initial=1500, rng=random.Random(42))\
236
+ .decode("torch")\
237
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
238
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
239
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
240
+
241
+ # ### check dataloaders are working
242
+
243
+ # In[7]:
244
+
245
+
246
+ # test_indices = []
247
+ # test_images = []
248
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
249
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
250
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
251
+ # test_indices = test_indices.astype(np.int16)
252
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
253
+ # print("---\n")
254
+
255
+ # train_indices = []
256
+ # train_images = []
257
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
258
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
259
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
260
+ # train_indices = train_indices.astype(np.int16)
261
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
262
+
263
+
264
+ # ## Load voxel betas, K-means clustering model, and images
265
+
266
+ # In[8]:
267
+
268
+
269
+ # load betas
270
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
271
+ voxels = f['betas'][:]
272
+ print(f"subj0{subj} betas loaded into memory")
273
+ voxels = torch.Tensor(voxels).to("cpu").half()
274
+ if subj==1:
275
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
276
+ print("voxels", voxels.shape)
277
+ num_voxels = voxels.shape[-1]
278
+
279
+ # load orig images
280
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
281
+ images = f['images'][:]
282
+ images = torch.Tensor(images).to("cpu").half()
283
+ print("images", images.shape)
284
+
285
+
286
+ # In[9]:
287
+
288
+
289
+ from models import Clipper
290
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
291
+
292
+ clip_seq_dim = 257
293
+ clip_emb_dim = 768
294
+ hidden_dim = 4096
295
+
296
+ from diffusers import AutoencoderKL
297
+ autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
298
+ # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
299
+ autoenc.eval()
300
+ autoenc.requires_grad_(False)
301
+ autoenc.to(device)
302
+ utils.count_params(autoenc)
303
+
304
+
305
+ # In[10]:
306
+
307
+
308
+ class MindEyeModule(nn.Module):
309
+ def __init__(self):
310
+ super(MindEyeModule, self).__init__()
311
+ def forward(self, x):
312
+ return x
313
+
314
+ model = MindEyeModule()
315
+ model
316
+
317
+
318
+ # In[11]:
319
+
320
+
321
+ class RidgeRegression(torch.nn.Module):
322
+ # make sure to add weight_decay when initializing optimizer
323
+ def __init__(self, input_size, out_features):
324
+ super(RidgeRegression, self).__init__()
325
+ self.out_features = out_features
326
+ self.linear = torch.nn.Linear(input_size, out_features)
327
+ def forward(self, x):
328
+ return self.linear(x)
329
+
330
+ model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
331
+ utils.count_params(model.ridge)
332
+ utils.count_params(model)
333
+
334
+ b = torch.randn((2,1,voxels.shape[1]))
335
+ print(b.shape, model.ridge(b).shape)
336
+
337
+ # In[12]:
338
+ from functools import partial
339
+ from diffusers.models.vae import Decoder
340
+ class BrainNetwork(nn.Module):
341
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):
342
+ super().__init__()
343
+ self.blurry_dim = blurry_dim
344
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
345
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
346
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
347
+ self.lin0 = nn.Linear(in_dim, h)
348
+ self.mlp = nn.ModuleList([
349
+ nn.Sequential(
350
+ nn.Linear(h, h),
351
+ *[item() for item in act_and_norm],
352
+ nn.Dropout(drop)
353
+ ) for _ in range(n_blocks)
354
+ ])
355
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
356
+ self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
357
+ self.n_blocks = n_blocks
358
+ self.clip_size = clip_size
359
+ self.clip_proj = nn.Sequential(
360
+ nn.LayerNorm(clip_size),
361
+ nn.GELU(),
362
+ nn.Linear(clip_size, 2048),
363
+ nn.LayerNorm(2048),
364
+ nn.GELU(),
365
+ nn.Linear(2048, 2048),
366
+ nn.LayerNorm(2048),
367
+ nn.GELU(),
368
+ nn.Linear(2048, clip_size)
369
+ )
370
+ self.upsampler = Decoder(
371
+ in_channels=64,
372
+ out_channels=4,
373
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
374
+ block_out_channels=[64, 128, 256],
375
+ layers_per_block=1,
376
+ )
377
+
378
+ def forward(self, x):
379
+ x = self.lin0(x)
380
+ residual = x
381
+ for res_block in range(self.n_blocks):
382
+ x = self.mlp[res_block](x)
383
+ x += residual
384
+ residual = x
385
+ x = x.reshape(len(x), -1)
386
+ x = self.lin1(x)
387
+ b = self.blin1(x)
388
+ b = self.upsampler(b.reshape(len(b), -1, 7, 7))
389
+ c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
390
+ return c, b
391
+
392
+ model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
393
+ utils.count_params(model.backbone)
394
+ utils.count_params(model)
395
+
396
+ b = torch.randn((2,hidden_dim))
397
+ print(b.shape)
398
+ clip_, blur_ = model.backbone(b)
399
+ print(clip_.shape, blur_.shape)
400
+
401
+
402
+ # In[13]:
403
+
404
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
405
+ opt_grouped_parameters = [
406
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
407
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
408
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
409
+ ]
410
+
411
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
412
+
413
+ if lr_scheduler_type == 'linear':
414
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
415
+ optimizer,
416
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
417
+ last_epoch=-1
418
+ )
419
+ elif lr_scheduler_type == 'cycle':
420
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
421
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
422
+ optimizer,
423
+ max_lr=max_lr,
424
+ total_steps=total_steps,
425
+ final_div_factor=1000,
426
+ last_epoch=-1, pct_start=2/num_epochs
427
+ )
428
+
429
+ def save_ckpt(tag):
430
+ ckpt_path = outdir+f'/{tag}.pth'
431
+ print(f'saving {ckpt_path}',flush=True)
432
+ unwrapped_model = accelerator.unwrap_model(model)
433
+ try:
434
+ torch.save({
435
+ 'epoch': epoch,
436
+ 'model_state_dict': unwrapped_model.state_dict(),
437
+ 'optimizer_state_dict': optimizer.state_dict(),
438
+ 'lr_scheduler': lr_scheduler.state_dict(),
439
+ 'train_losses': losses,
440
+ 'test_losses': test_losses,
441
+ 'lrs': lrs,
442
+ }, ckpt_path)
443
+ except:
444
+ print("Couldn't save... moving on to prevent crashing.")
445
+ del unwrapped_model
446
+
447
+ print("\nDone with model preparations!")
448
+ utils.count_params(model)
449
+
450
+ # # Weights and Biases
451
+
452
+ # In[14]:
453
+
454
+
455
+ # params for wandb
456
+ # params for wandb
457
+ wandb_log = True
458
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
459
+ import wandb
460
+
461
+ wandb_project = 'stability'
462
+ wandb_run = model_name
463
+ wandb_notes = ''
464
+
465
+ print(f"wandb {wandb_project} run {wandb_run}")
466
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
467
+ wandb_config = {
468
+ "model_name": model_name,
469
+ "batch_size": batch_size,
470
+ "num_epochs": num_epochs,
471
+ "use_image_aug": use_image_aug,
472
+ "max_lr": max_lr,
473
+ "lr_scheduler_type": lr_scheduler_type,
474
+ "mixup_pct": mixup_pct,
475
+ "num_train": num_train,
476
+ "num_test": num_test,
477
+ "seed": seed,
478
+ "distributed": distributed,
479
+ "num_devices": num_devices,
480
+ "world_size": world_size,
481
+ }
482
+ print("wandb_config:\n",wandb_config)
483
+ if False: # wandb_auto_resume
484
+ print("wandb_id:",model_name)
485
+ wandb.init(
486
+ id = model_name,
487
+ project=wandb_project,
488
+ name=wandb_run,
489
+ config=wandb_config,
490
+ notes=wandb_notes,
491
+ resume="allow",
492
+ )
493
+ else:
494
+ wandb.init(
495
+ project=wandb_project,
496
+ name=model_name,
497
+ config=wandb_config,
498
+ notes=wandb_notes,
499
+ )
500
+ else:
501
+ wandb_log = False
502
+
503
+
504
+ # using the same preprocessing as was used in MindEye + BrainDiffuser
505
+ pixcorr_preprocess = transforms.Compose([
506
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
507
+ ])
508
+ def pixcorr(images,brains):
509
+ # Flatten images while keeping the batch dimension
510
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
511
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
512
+ corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
513
+ return corrmean
514
+
515
+ # # Main
516
+
517
+ # In[15]:
518
+
519
+ epoch = 0
520
+ losses, test_losses, lrs = [], [], []
521
+ best_test_loss = 1e9
522
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
523
+
524
+ # Optionally resume from checkpoint #
525
+ resume_from_ckpt = False
526
+ if resume_from_ckpt:
527
+ print("\n---resuming from last.pth ckpt---\n")
528
+ try:
529
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
530
+ except:
531
+ print('last.pth failed... trying last_backup.pth')
532
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
533
+ epoch = checkpoint['epoch']
534
+ print("Epoch",epoch)
535
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
536
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
537
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
538
+ del checkpoint
539
+ elif False:
540
+ if wandb.run.resumed:
541
+ print("\n---resuming from last.pth ckpt---\n")
542
+ try:
543
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
544
+ except:
545
+ print('last.pth failed... trying last_backup.pth')
546
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
547
+ epoch = checkpoint['epoch']
548
+ print("Epoch",epoch)
549
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
550
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
551
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
552
+ del checkpoint
553
+ torch.cuda.empty_cache()
554
+
555
+
556
+ # In[16]:
557
+
558
+
559
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
560
+ model, optimizer, train_dl, test_dl, lr_scheduler
561
+ )
562
+
563
+
564
+ # In[17]:
565
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
566
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
567
+ test_image, test_voxel = None, None
568
+ mse = nn.MSELoss()
569
+ for epoch in progress_bar:
570
+ model.train()
571
+
572
+ fwd_percent_correct = 0.
573
+ bwd_percent_correct = 0.
574
+ test_fwd_percent_correct = 0.
575
+ test_bwd_percent_correct = 0.
576
+
577
+ loss_clip_total = 0.
578
+ loss_blurry_total = 0.
579
+ test_loss_clip_total = 0.
580
+ test_loss_blurry_total = 0.
581
+
582
+ blurry_pixcorr = 0.
583
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
584
+
585
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
586
+ with torch.cuda.amp.autocast():
587
+ optimizer.zero_grad()
588
+
589
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
590
+
591
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
592
+
593
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
594
+
595
+ if use_image_aug: image = img_augment(image)
596
+
597
+ clip_target = clip_model.embed_image(image)
598
+ assert not torch.any(torch.isnan(clip_target))
599
+
600
+ if epoch < int(mixup_pct * num_epochs):
601
+ voxel, perm, betas, select = utils.mixco(voxel)
602
+
603
+ voxel_ridge = model.ridge(voxel)
604
+
605
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
606
+
607
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
608
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
609
+
610
+ if epoch < int(mixup_pct * num_epochs):
611
+ loss_clip = utils.mixco_nce(
612
+ clip_voxels_norm,
613
+ clip_target_norm,
614
+ temp=.006,
615
+ perm=perm, betas=betas, select=select)
616
+ else:
617
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
618
+ loss_clip = utils.soft_clip_loss(
619
+ clip_voxels_norm,
620
+ clip_target_norm,
621
+ temp=epoch_temp)
622
+
623
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
624
+
625
+ loss_clip_total += loss_clip.item()
626
+ loss_blurry_total += loss_blurry.item()
627
+
628
+ loss = loss_blurry + loss_clip
629
+
630
+ utils.check_loss(loss)
631
+
632
+ accelerator.backward(loss)
633
+ optimizer.step()
634
+
635
+ losses.append(loss.item())
636
+ lrs.append(optimizer.param_groups[0]['lr'])
637
+
638
+ # forward and backward top 1 accuracy
639
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
640
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
641
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
642
+
643
+ with torch.no_grad():
644
+ # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
645
+ random_samps = np.random.choice(np.arange(len(voxel)), size=2, replace=False)
646
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
647
+ blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
648
+
649
+ if lr_scheduler_type is not None:
650
+ lr_scheduler.step()
651
+
652
+ model.eval()
653
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
654
+ with torch.cuda.amp.autocast():
655
+ with torch.no_grad():
656
+ # all test samples should be loaded per batch such that test_i should never exceed 0
657
+ if len(behav) != num_test: print("!",len(behav),num_test)
658
+
659
+ ## Average same-image repeats ##
660
+ if test_image is None:
661
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
662
+
663
+ image = behav[:,0,0].cpu().long()
664
+
665
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
666
+ for im in unique_image:
667
+ locs = torch.where(im == image)[0]
668
+ if test_image is None:
669
+ test_image = images[im][None]
670
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
671
+ else:
672
+ test_image = torch.vstack((test_image, images[im][None]))
673
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
674
+
675
+ # sample of batch_size
676
+ random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
677
+ voxel = test_voxel[random_indices].to(device)
678
+ image = test_image[random_indices].to(device)
679
+ assert len(image) == batch_size
680
+
681
+ blurry_image_enc = autoenc.encode(image).latent_dist.mode()
682
+
683
+ clip_target = clip_model.embed_image(image.float())
684
+
685
+ voxel_ridge = model.ridge(voxel)
686
+
687
+ clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
688
+
689
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
690
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
691
+
692
+ loss_clip = utils.soft_clip_loss(
693
+ clip_voxels_norm,
694
+ clip_target_norm,
695
+ temp=.006)
696
+
697
+ loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
698
+
699
+ loss = loss_blurry + loss_clip
700
+
701
+ utils.check_loss(loss)
702
+
703
+ test_losses.append(loss.item())
704
+
705
+ # forward and backward top 1 accuracy
706
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
707
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
708
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
709
+
710
+ # halving the batch size because the decoder is computationally heavy
711
+ blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
712
+ blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
713
+ test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
714
+
715
+ # transform blurry recon latents to images and plot it
716
+ fig, axes = plt.subplots(1, 4, figsize=(8, 4))
717
+ axes[0].imshow(utils.torch_to_Image(image[[0]]))
718
+ axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
719
+ axes[2].imshow(utils.torch_to_Image(image[[1]]))
720
+ axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
721
+ axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
722
+ plt.show()
723
+
724
+ if local_rank==0:
725
+ # if utils.is_interactive(): clear_output(wait=True)
726
+ assert (test_i+1) == 1
727
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
728
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
729
+ "train/lr": lrs[-1],
730
+ "train/num_steps": len(losses),
731
+ "test/num_steps": len(test_losses),
732
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
733
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
734
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
735
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
736
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
737
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
738
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
739
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
740
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
741
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
742
+ }
743
+ progress_bar.set_postfix(**logs)
744
+
745
+ # Save model checkpoint and reconstruct
746
+ if epoch % ckpt_interval == 0:
747
+ if not utils.is_interactive():
748
+ save_ckpt(f'last')
749
+
750
+ if wandb_log: wandb.log(logs)
751
+
752
+ # wait for other GPUs to catch up if needed
753
+ accelerator.wait_for_everyone()
754
+ torch.cuda.empty_cache()
755
+ gc.collect()
756
+
757
+ print("\n===Finished!===\n")
758
+ if ckpt_saving:
759
+ save_ckpt(f'last')
760
+ if not utils.is_interactive():
761
+ sys.exit(0)
src/Train_MLPMixer-Copy1.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train_MLPMixer-Copy1.py ADDED
@@ -0,0 +1,1352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ #from subprocess import call
9
+ #command = "jupyter nbconvert Train_MLPMixer-Copy1.ipynb --to python"
10
+ #call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import string
28
+ import h5py
29
+ from tqdm import tqdm
30
+
31
+ import webdataset as wds
32
+ import gc
33
+
34
+ import matplotlib.pyplot as plt
35
+ import torch
36
+ import torch.nn as nn
37
+ from torchvision import transforms
38
+
39
+ from accelerate import Accelerator, DeepSpeedPlugin
40
+
41
+ # tf32 data type is faster than standard float32
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+
44
+ # custom functions #
45
+ import utils
46
+
47
+
48
+ # In[3]:
49
+
50
+
51
+ ### Multi-GPU config ###
52
+ local_rank = os.getenv('RANK')
53
+ if local_rank is None:
54
+ local_rank = 0
55
+ else:
56
+ local_rank = int(local_rank)
57
+ print("LOCAL RANK ", local_rank)
58
+
59
+ num_devices = torch.cuda.device_count()
60
+ if num_devices==0: num_devices = 1
61
+
62
+ # ## UNCOMMENT BELOW SECTION AND COMMENT OUT DEEPSPEED SECTION TO AVOID USING DEEPSPEED ###
63
+ # accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
64
+ # global_batch_size = batch_size = 128
65
+ # data_type = torch.float16 # change depending on your mixed_precision
66
+
67
+ ### DEEPSPEED INITIALIZATION ###
68
+ if num_devices <= 1 and utils.is_interactive():
69
+ global_batch_size = batch_size = 128
70
+ print(f"Setting batch_size to {batch_size}")
71
+ # can emulate a distributed environment for deepspeed to work in jupyter notebook
72
+ os.environ["MASTER_ADDR"] = "localhost"
73
+ os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
74
+ os.environ["RANK"] = "0"
75
+ os.environ["LOCAL_RANK"] = "0"
76
+ os.environ["WORLD_SIZE"] = "1"
77
+ os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
78
+ else:
79
+ global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
80
+ batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
81
+
82
+ # alter the deepspeed config according to your global and local batch size
83
+ if local_rank == 0:
84
+ with open('deepspeed_config_stage2.json', 'r') as file:
85
+ config = json.load(file)
86
+ config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
87
+ config['train_micro_batch_size_per_gpu'] = batch_size
88
+ config['bf16'] = {'enabled': False}
89
+ config['fp16'] = {'enabled': True}
90
+ with open('deepspeed_config_stage2.json', 'w') as file:
91
+ json.dump(config, file)
92
+ else:
93
+ # give some time for the local_rank=0 gpu to prep new deepspeed config file
94
+ time.sleep(10)
95
+ deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2_cpuoffload.json")
96
+ accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
97
+
98
+
99
+ # In[4]:
100
+
101
+
102
+ print("PID of this process =",os.getpid())
103
+ device = accelerator.device
104
+ print("device:",device)
105
+ num_workers = num_devices
106
+ print(accelerator.state)
107
+ world_size = accelerator.state.num_processes
108
+ distributed = not accelerator.state.distributed_type == 'NO'
109
+
110
+ # set data_type to match your mixed precision (automatically set based on deepspeed config)
111
+ if accelerator.mixed_precision == "bf16":
112
+ data_type = torch.bfloat16
113
+ elif accelerator.mixed_precision == "fp16":
114
+ data_type = torch.float16
115
+ else:
116
+ data_type = torch.float32
117
+
118
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
119
+ print = accelerator.print # only print if local_rank=0
120
+
121
+
122
+ # # Configurations
123
+
124
+ # In[5]:
125
+
126
+
127
+ # if running this interactively, can specify jupyter_args here for argparser to use
128
+ if utils.is_interactive():
129
+ # create random model_name
130
+ model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
131
+ model_name = model_name + "_interactive"
132
+ print("model_name:", model_name)
133
+
134
+ # global_batch_size and batch_size should already be defined in the above cells
135
+ # other variables can be specified in the following string:
136
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
137
+ --model_name={model_name} \
138
+ --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \
139
+ --clip_scale=1. --blur_scale=100. --depth_scale=100. \
140
+ --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving"
141
+
142
+ jupyter_args = jupyter_args.split()
143
+ print(jupyter_args)
144
+
145
+ from IPython.display import clear_output # function to clear print outputs in cell
146
+ get_ipython().run_line_magic('load_ext', 'autoreload')
147
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
148
+ get_ipython().run_line_magic('autoreload', '2')
149
+
150
+
151
+ # In[6]:
152
+
153
+
154
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
155
+ parser.add_argument(
156
+ "--model_name", type=str, default="testing",
157
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
158
+ )
159
+ parser.add_argument(
160
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
161
+ help="Path to where NSD data is stored / where to download it to",
162
+ )
163
+ parser.add_argument(
164
+ "--subj",type=int, default=1, choices=[1,2,5,7],
165
+ )
166
+ parser.add_argument(
167
+ "--batch_size", type=int, default=32,
168
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
169
+ )
170
+ parser.add_argument(
171
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
172
+ help="whether to log to wandb",
173
+ )
174
+ parser.add_argument(
175
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
176
+ help="if not using wandb and want to resume from a ckpt",
177
+ )
178
+ parser.add_argument(
179
+ "--wandb_project",type=str,default="stability",
180
+ help="wandb project name",
181
+ )
182
+ parser.add_argument(
183
+ "--mixup_pct",type=float,default=.33,
184
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
185
+ )
186
+ parser.add_argument(
187
+ "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
188
+ help="whether to output blurry reconstructions",
189
+ )
190
+ parser.add_argument(
191
+ "--depth_recon",action=argparse.BooleanOptionalAction,default=True,
192
+ help="whether to output depth reconstructions",
193
+ )
194
+ parser.add_argument(
195
+ "--blur_scale",type=float,default=100.,
196
+ help="multiply loss from blurry recons by this number",
197
+ )
198
+ parser.add_argument(
199
+ "--depth_scale",type=float,default=100.,
200
+ help="multiply loss from depth recons by this number",
201
+ )
202
+ parser.add_argument(
203
+ "--clip_scale",type=float,default=1.,
204
+ help="multiply contrastive loss by this number",
205
+ )
206
+ parser.add_argument(
207
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
208
+ help="whether to use image augmentation",
209
+ )
210
+ parser.add_argument(
211
+ "--num_epochs",type=int,default=120,
212
+ help="number of epochs of training",
213
+ )
214
+ parser.add_argument(
215
+ "--hidden_dim",type=int,default=4096,
216
+ )
217
+ parser.add_argument(
218
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
219
+ )
220
+ parser.add_argument(
221
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
222
+ )
223
+ parser.add_argument(
224
+ "--ckpt_interval",type=int,default=5,
225
+ help="save backup ckpt and reconstruct every x epochs",
226
+ )
227
+ parser.add_argument(
228
+ "--seed",type=int,default=42,
229
+ )
230
+ parser.add_argument(
231
+ "--max_lr",type=float,default=3e-4,
232
+ )
233
+
234
+ if utils.is_interactive():
235
+ args = parser.parse_args(jupyter_args)
236
+ else:
237
+ args = parser.parse_args()
238
+
239
+ # create global variables without the args prefix
240
+ for attribute_name in vars(args).keys():
241
+ globals()[attribute_name] = getattr(args, attribute_name)
242
+
243
+
244
+ # In[7]:
245
+
246
+
247
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
248
+ if not os.path.exists(outdir) and ckpt_saving:
249
+ os.makedirs(outdir,exist_ok=True)
250
+ if use_image_aug:
251
+ import kornia
252
+ from kornia.augmentation.container import AugmentationSequential
253
+ img_augment = AugmentationSequential(
254
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
255
+ kornia.augmentation.Resize((224, 224)),
256
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
257
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
258
+ kornia.augmentation.RandomGrayscale(p=0.3),
259
+ same_on_batch=False,
260
+ data_keys=["input"],
261
+ )
262
+
263
+
264
+ # # Prep data, models, and dataloaders
265
+
266
+ # ## Dataloader
267
+
268
+ # In[8]:
269
+
270
+
271
+ if subj==1:
272
+ num_train = 24958
273
+ num_test = 2770
274
+ test_batch_size = num_test
275
+
276
+ def my_split_by_node(urls): return urls
277
+
278
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
279
+ # train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..1}.tar"
280
+ print(train_url)
281
+
282
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
283
+ .shuffle(750, initial=1500, rng=random.Random(42))\
284
+ .decode("torch")\
285
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
286
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
287
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
288
+
289
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
290
+ print(test_url)
291
+
292
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
293
+ .shuffle(750, initial=1500, rng=random.Random(42))\
294
+ .decode("torch")\
295
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
296
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
297
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)
298
+
299
+
300
+ # ### check dataloaders are working
301
+
302
+ # In[9]:
303
+
304
+
305
+ test_vox_indices = []
306
+ test_73k_images = []
307
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
308
+ test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())
309
+ test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())
310
+ test_vox_indices = test_vox_indices.astype(np.int16)
311
+ print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))
312
+ print("---\n")
313
+
314
+ train_vox_indices = []
315
+ train_73k_images = []
316
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
317
+ train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())
318
+ train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())
319
+ train_vox_indices = train_vox_indices.astype(np.int16)
320
+ print(train_i, (train_i+1) * batch_size, len(train_vox_indices))
321
+
322
+
323
+ # ## Load data and images
324
+
325
+ # In[10]:
326
+
327
+
328
+ # load betas
329
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
330
+ # f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')
331
+
332
+ voxels = f['betas'][:]
333
+ print(f"subj0{subj} betas loaded into memory")
334
+ voxels = torch.Tensor(voxels).to("cpu").to(data_type)
335
+ print("voxels", voxels.shape)
336
+ num_voxels = voxels.shape[-1]
337
+
338
+ # load orig images
339
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
340
+ images = f['images'][:]
341
+ images = torch.Tensor(images).to("cpu").to(data_type)
342
+ print("images", images.shape)
343
+
344
+
345
+ # ## Load models
346
+
347
+ # ### CLIP image embeddings model
348
+
349
+ # In[11]:
350
+
351
+
352
+ from models import Clipper
353
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
354
+ clip_seq_dim = 257
355
+ clip_emb_dim = 768 #1024
356
+ # hidden_dim = 4096
357
+ seq_len = 1 #2 #32
358
+
359
+
360
+ # ### SD VAE
361
+
362
+ # In[12]:
363
+
364
+
365
+ # if blurry_recon:
366
+ # from diffusers import AutoencoderKL
367
+ # autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
368
+ # # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
369
+ # autoenc.eval()
370
+ # autoenc.requires_grad_(False)
371
+ # autoenc.to(device)
372
+ # utils.count_params(autoenc)
373
+
374
+ if blurry_recon:# or depth_recon:
375
+ from diffusers import VQModel
376
+ autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type)
377
+ autoenc.eval()
378
+ autoenc.requires_grad_(False)
379
+ autoenc.to(device)
380
+ utils.count_params(autoenc)
381
+
382
+
383
+ # #### downsampled images
384
+
385
+ # In[13]:
386
+
387
+
388
+ if blurry_recon:
389
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
390
+
391
+ input_batch = images[[30]].to(device)
392
+ print(input_batch.shape)
393
+
394
+ downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)
395
+ re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')
396
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
397
+ print(re_upsampled_enc.shape)
398
+
399
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
400
+
401
+
402
+ # #### MiDaS depth
403
+
404
+ # In[14]:
405
+
406
+
407
+ if depth_recon:
408
+ from controlnet_aux.midas import MidasDetector
409
+
410
+ midas_depth = MidasDetector.from_pretrained(
411
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device)
412
+ midas_depth.model.eval()
413
+ midas_depth.model.requires_grad_(False)
414
+ midas_depth.model.to(device)
415
+ pass
416
+
417
+
418
+ # In[15]:
419
+
420
+
421
+ if depth_recon:
422
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
423
+
424
+ input_batch = images[[30,31]].float().to(device)
425
+ print(input_batch.shape)
426
+
427
+ midas_emb = midas_depth.model(input_batch).unsqueeze(1)
428
+ print(midas_emb.shape)
429
+
430
+ prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()
431
+ print(prediction.shape)
432
+
433
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
434
+ midas_emb_size = prediction.flatten(1).shape[1]
435
+ print("midas_emb", prediction.shape, prediction.min(), prediction.max())
436
+ print("midas_emb_size", midas_emb_size)
437
+
438
+ if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224)))
439
+
440
+ if blurry_recon:
441
+ prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)
442
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
443
+ prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215
444
+ print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())
445
+
446
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
447
+
448
+
449
+ # ### MindEye modules
450
+
451
+ # In[16]:
452
+
453
+
454
+ class MindEyeModule(nn.Module):
455
+ def __init__(self):
456
+ super(MindEyeModule, self).__init__()
457
+ def forward(self, x):
458
+ return x
459
+
460
+ model = MindEyeModule()
461
+ model
462
+
463
+
464
+ # In[17]:
465
+
466
+
467
+ time_embedding_dim = 512
468
+
469
+ class RidgeRegression(torch.nn.Module):
470
+ # make sure to add weight_decay when initializing optimizer
471
+ def __init__(self, input_size, out_features):
472
+ super(RidgeRegression, self).__init__()
473
+ self.out_features = out_features
474
+ self.linear = torch.nn.Linear(input_size, out_features)
475
+ def forward(self, x):
476
+ return self.linear(x)
477
+
478
+ model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
479
+ utils.count_params(model.ridge)
480
+ utils.count_params(model)
481
+
482
+ b = torch.randn((2,1,voxels.shape[1]))
483
+ time_emb_test = torch.randn((2,1,time_embedding_dim))
484
+ print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
485
+
486
+
487
+ # In[59]:
488
+
489
+
490
+ num_past_voxels = 15
491
+ seq_len = 1 + 1
492
+
493
+
494
+ # In[73]:
495
+
496
+
497
+ class BrainNetwork(nn.Module):
498
+ def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):
499
+ super().__init__()
500
+ self.seq_len = seq_len
501
+ self.h = h
502
+ self.clip_size = clip_size
503
+
504
+ # Initial linear layer to match the input dimensions to hidden dimensions
505
+ # self.lin0 = nn.Linear(in_dim, seq_len * h)
506
+
507
+ # Mixer Blocks
508
+ self.mixer_blocks1 = nn.ModuleList([
509
+ self.mixer_block1(h, drop) for _ in range(n_blocks)
510
+ ])
511
+ self.mixer_blocks2 = nn.ModuleList([
512
+ self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
513
+ ])
514
+
515
+ # Output linear layer
516
+ self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)
517
+
518
+ # low-rank matrices
519
+ # self.rank = 500
520
+ # self.U = nn.Parameter(torch.randn(self.rank, out_dim))
521
+ # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))
522
+ # self.S = nn.Parameter(torch.randn(out_dim))
523
+
524
+ self.clip_proj = nn.Sequential(
525
+ nn.LayerNorm(clip_size),
526
+ nn.GELU(),
527
+ nn.Linear(clip_size, 2048),
528
+ nn.LayerNorm(2048),
529
+ nn.GELU(),
530
+ nn.Linear(2048, 2048),
531
+ nn.LayerNorm(2048),
532
+ nn.GELU(),
533
+ nn.Linear(2048, clip_size)
534
+ )
535
+
536
+ if blurry_recon:
537
+ # self.blin1 = nn.Sequential(
538
+ # nn.Linear(out_dim, 4096, bias=True),
539
+ # nn.LayerNorm(4096),
540
+ # nn.GELU(),
541
+ # nn.Linear(4096, 4096))
542
+ self.blin1 = nn.Linear(h*seq_len, 4096)
543
+ self.bgroupnorm = nn.GroupNorm(1, 256)
544
+ self.bupsampler = Decoder(
545
+ in_channels=256,
546
+ out_channels=128,
547
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
548
+ block_out_channels=[32, 64, 128],
549
+ layers_per_block=1,
550
+ )
551
+
552
+ if depth_recon:
553
+ # self.dlin1 = nn.Sequential(
554
+ # nn.Linear(h, midas_emb_size),
555
+ # nn.Sigmoid(),
556
+ # )
557
+ self.dlin1 = nn.Linear(h*seq_len, 4096)
558
+ self.dgroupnorm = nn.GroupNorm(1, 256)
559
+ self.dupsampler = Decoder(
560
+ in_channels=256,
561
+ out_channels=1,#128,
562
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
563
+ block_out_channels=[32, 64, 128, 256],
564
+ layers_per_block=1,
565
+ )
566
+
567
+ def mixer_block1(self, h, drop):
568
+ return nn.Sequential(
569
+ nn.LayerNorm(h),
570
+ self.mlp(h, h, drop), # Token mixing
571
+ )
572
+
573
+ def mixer_block2(self, seq_len, drop):
574
+ return nn.Sequential(
575
+ nn.LayerNorm(seq_len),
576
+ self.mlp(seq_len, seq_len, drop) # Channel mixing
577
+ )
578
+
579
+ def mlp(self, in_dim, out_dim, drop):
580
+ return nn.Sequential(
581
+ nn.Linear(in_dim, out_dim),
582
+ nn.GELU(),
583
+ nn.Dropout(drop),
584
+ nn.Linear(out_dim, out_dim),
585
+ )
586
+
587
+ def forward(self, x):
588
+ # make empty tensors for blur and depth outputs
589
+ b,d = torch.Tensor([0.]), torch.Tensor([0.])
590
+
591
+ # Initial linear layer
592
+ # x = self.lin0(x)
593
+
594
+ # Reshape to seq_len by dim
595
+ # x = x.reshape(-1, self.seq_len, self.h)
596
+
597
+ # Mixer blocks
598
+ residual1 = x
599
+ residual2 = x.permute(0,2,1)
600
+ for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
601
+ x = block1(x) + residual1
602
+ residual1 = x
603
+ x = x.permute(0,2,1)
604
+
605
+ x = block2(x) + residual2
606
+ residual2 = x
607
+ x = x.permute(0,2,1)
608
+
609
+ # Flatten
610
+ x = x.reshape(x.size(0), -1)
611
+
612
+ c = self.clin1(x)
613
+
614
+ # low rank linear to out dim cuts # params by nearly half compared to full linear mapping
615
+ # c = (x @ (self.V/100) @ (self.U/100)) + self.S
616
+
617
+ c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))
618
+
619
+ if blurry_recon:
620
+ b = self.blin1(x)
621
+ b = b.reshape(len(b), 256, 4, 4)
622
+ b = self.bgroupnorm(b)
623
+ b = self.bupsampler(b)
624
+
625
+ if depth_recon:
626
+ d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)
627
+ d = d.reshape(len(d), 256, 4, 4)
628
+ d = self.dgroupnorm(d)
629
+ d = self.dupsampler(d)
630
+
631
+ return c, b, d
632
+
633
+
634
+ class TimeEmbedding(nn.Module):
635
+ def __init__(self, embedding_time_dim=512, num_past_voxels=15):
636
+ super().__init__()
637
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
638
+ self.num_past_voxels = num_past_voxels
639
+ self.embedding_time_dim = embedding_time_dim
640
+
641
+ def forward(self, time):
642
+ # time is (batch_size,)
643
+ time = time.long()
644
+ time = self.embedding_time(time)
645
+ return time # (batch_size, embedding_time_dim)
646
+
647
+
648
+ #model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
649
+ model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
650
+
651
+ model.backbone = BrainNetwork(h=1024, in_dim=1024, seq_len=4, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim)
652
+ utils.count_params(model.backbone)
653
+ utils.count_params(model)
654
+
655
+
656
+ # test that the model works on some fake data
657
+ b = torch.randn((256,4,1024))
658
+ print("b.shape",b.shape)
659
+ with torch.no_grad():
660
+ clip_, blur_, depth_ = model.backbone(b)
661
+ print(clip_.shape, blur_.shape, depth_.shape)
662
+
663
+
664
+ # In[70]:
665
+
666
+
667
+ voxel_ridge = torch.randn(512,4096)
668
+ voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)
669
+ print("b.shape",voxel_ridge.shape)
670
+ with torch.no_grad():
671
+ clip_, blur_, depth_ = model.backbone(voxel_ridge)
672
+ print(clip_.shape, blur_.shape, depth_.shape)
673
+
674
+
675
+ # In[64]:
676
+
677
+
678
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
679
+ opt_grouped_parameters = [
680
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
681
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
682
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
683
+ ]
684
+
685
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
686
+
687
+ if lr_scheduler_type == 'linear':
688
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
689
+ optimizer,
690
+ total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),
691
+ last_epoch=-1
692
+ )
693
+ elif lr_scheduler_type == 'cycle':
694
+ total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))
695
+ print("total_steps", total_steps)
696
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
697
+ optimizer,
698
+ max_lr=max_lr,
699
+ total_steps=total_steps,
700
+ final_div_factor=1000,
701
+ last_epoch=-1, pct_start=2/num_epochs
702
+ )
703
+
704
+ def save_ckpt(tag):
705
+ ckpt_path = outdir+f'/{tag}.pth'
706
+ print(f'saving {ckpt_path}',flush=True)
707
+ unwrapped_model = accelerator.unwrap_model(model)
708
+ try:
709
+ torch.save({
710
+ 'epoch': epoch,
711
+ 'model_state_dict': unwrapped_model.state_dict(),
712
+ 'optimizer_state_dict': optimizer.state_dict(),
713
+ 'lr_scheduler': lr_scheduler.state_dict(),
714
+ 'train_losses': losses,
715
+ 'test_losses': test_losses,
716
+ 'lrs': lrs,
717
+ }, ckpt_path)
718
+ except:
719
+ print("Couldn't save... moving on to prevent crashing.")
720
+ del unwrapped_model
721
+
722
+ print("\nDone with model preparations!")
723
+ utils.count_params(model)
724
+
725
+
726
+ # In[49]:
727
+
728
+
729
+ seq_len = 4
730
+
731
+
732
+ # In[57]:
733
+
734
+
735
+ voxel_ridge = torch.randn(512,4096)
736
+ voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)
737
+
738
+
739
+ # In[58]:
740
+
741
+
742
+ voxel_ridge.shape
743
+
744
+
745
+ # In[55]:
746
+
747
+
748
+ pp = None
749
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
750
+ with torch.cuda.amp.autocast(dtype=data_type):
751
+ #optimizer.zero_grad()
752
+
753
+ voxel = voxels[behav[:,0,5].cpu().long()]#.to(device)
754
+ image = images[behav[:,0,0].cpu().long()].float()#.to(device).float()
755
+
756
+ past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()]#.to(device) # batch_size, 15, 15279
757
+ past_15_times = torch.Tensor([i for i in range(seq_len)])#.to(device) # 15
758
+
759
+ print(past_15_times)
760
+ #for past in range(1):
761
+ # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)
762
+ """
763
+ #if blurry_recon:
764
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
765
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
766
+
767
+ if depth_recon:
768
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
769
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
770
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
771
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
772
+
773
+ if use_image_aug:
774
+ image = img_augment(image)
775
+
776
+ clip_target = clip_model.embed_image(image)
777
+ assert not torch.any(torch.isnan(clip_target))
778
+
779
+ if epoch < int(mixup_pct * num_epochs):
780
+ voxel, perm, betas, select = utils.mixco(voxel)
781
+ past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)
782
+ """
783
+ for p in range(seq_len-1):
784
+ print(past_behav.shape) #128, 15, 17
785
+ print(past_behav[:,p,-1])
786
+ print(past_15_voxels.shape) # 128, 1, 15724
787
+ mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])
788
+ print(mask) # 128
789
+ past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])
790
+ print(past_15_voxels)
791
+ pp = past_15_voxels
792
+
793
+ break
794
+
795
+
796
+ # In[54]:
797
+
798
+
799
+ pp[20, 0, :]
800
+
801
+
802
+ # # Weights and Biases
803
+
804
+ # In[66]:
805
+
806
+
807
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
808
+ import wandb
809
+ wandb_project = 'mindeyev2'
810
+ wandb_run = model_name
811
+ wandb_notes = ''
812
+
813
+ print(f"wandb {wandb_project} run {wandb_run}")
814
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
815
+ wandb_config = {
816
+ "model_name": model_name,
817
+ "global_batch_size": global_batch_size,
818
+ "batch_size": batch_size,
819
+ "num_epochs": num_epochs,
820
+ "clip_scale": clip_scale,
821
+ "blur_scale": blur_scale,
822
+ "use_image_aug": use_image_aug,
823
+ "max_lr": max_lr,
824
+ "mixup_pct": mixup_pct,
825
+ "num_train": num_train,
826
+ "num_test": num_test,
827
+ "ckpt_interval": ckpt_interval,
828
+ "ckpt_saving": ckpt_saving,
829
+ "seed": seed,
830
+ "distributed": distributed,
831
+ "num_devices": num_devices,
832
+ "world_size": world_size,
833
+ "train_url": train_url,
834
+ "test_url": test_url,
835
+ }
836
+ print("wandb_config:\n",wandb_config)
837
+ if True: # wandb_auto_resume
838
+ print("wandb_id:",model_name)
839
+ wandb.init(
840
+ id = model_name,
841
+ project=wandb_project,
842
+ name=wandb_run,
843
+ config=wandb_config,
844
+ notes=wandb_notes,
845
+ resume="allow",
846
+ )
847
+ else:
848
+ wandb.init(
849
+ project=wandb_project,
850
+ name=wandb_run,
851
+ config=wandb_config,
852
+ notes=wandb_notes,
853
+ )
854
+ else:
855
+ wandb_log = False
856
+
857
+
858
+ # # Main
859
+
860
+ # In[67]:
861
+
862
+
863
+ epoch = 0
864
+ losses, test_losses, lrs = [], [], []
865
+ best_test_loss = 1e9
866
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
867
+
868
+ # Optionally resume from checkpoint #
869
+ if resume_from_ckpt:
870
+ print("\n---resuming from last.pth ckpt---\n")
871
+ try:
872
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
873
+ except:
874
+ print('last.pth failed... trying last_backup.pth')
875
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
876
+ epoch = checkpoint['epoch']
877
+ print("Epoch",epoch)
878
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
879
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
880
+ model.load_state_dict(checkpoint['model_state_dict'])
881
+ del checkpoint
882
+ elif wandb_log:
883
+ if wandb.run.resumed:
884
+ print("\n---resuming from last.pth ckpt---\n")
885
+ try:
886
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
887
+ except:
888
+ print('last.pth failed... trying last_backup.pth')
889
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
890
+ epoch = checkpoint['epoch']
891
+ print("Epoch",epoch)
892
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
893
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
894
+ model.load_state_dict(checkpoint['model_state_dict'])
895
+ del checkpoint
896
+ torch.cuda.empty_cache()
897
+
898
+
899
+ # In[68]:
900
+
901
+
902
+ model, optimizer, train_dl, lr_scheduler = accelerator.prepare(
903
+ model, optimizer, train_dl, lr_scheduler
904
+ )
905
+ # leaving out test_dl since we will only have local_rank 0 device do evals
906
+
907
+
908
+ # In[ ]:
909
+
910
+
911
+ def add_saturation(image, alpha=2):
912
+ gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]
913
+ gray_image = gray_image.unsqueeze(1).expand_as(image)
914
+ saturated_image = alpha * image + (1 - alpha) * gray_image
915
+ return torch.clamp(saturated_image, 0, 1)
916
+
917
+
918
+ # In[65]:
919
+
920
+
921
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
922
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
923
+ test_image, test_voxel = None, None
924
+ mse = nn.MSELoss()
925
+ l1 = nn.L1Loss()
926
+
927
+ for epoch in progress_bar:
928
+ model.train()
929
+
930
+ fwd_percent_correct = 0.
931
+ bwd_percent_correct = 0.
932
+ test_fwd_percent_correct = 0.
933
+ test_bwd_percent_correct = 0.
934
+
935
+ loss_clip_total = 0.
936
+ loss_blurry_total = 0.
937
+ loss_depth_total = 0.
938
+ test_loss_clip_total = 0.
939
+ test_loss_blurry_total = 0.
940
+ test_loss_depth_total = 0.
941
+
942
+ blurry_pixcorr = 0.
943
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
944
+
945
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
946
+ with torch.cuda.amp.autocast(dtype=data_type):
947
+ optimizer.zero_grad()
948
+
949
+ #voxel = voxels[behav[:,0,5].cpu().long()].to(device)
950
+ #image = images[behav[:,0,0].cpu().long()].to(device).float()
951
+
952
+ #past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
953
+ #past_15_times = torch.Tensor([i for i in range(seq_len - 1)]).to(device) # 15
954
+
955
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
956
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
957
+
958
+ past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
959
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
960
+ #for past in range(1):
961
+ # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)
962
+
963
+ if blurry_recon:
964
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
965
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
966
+
967
+ if depth_recon:
968
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
969
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
970
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
971
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
972
+
973
+ if use_image_aug:
974
+ image = img_augment(image)
975
+
976
+ clip_target = clip_model.embed_image(image)
977
+ assert not torch.any(torch.isnan(clip_target))
978
+
979
+ if epoch < int(mixup_pct * num_epochs):
980
+ voxel, perm, betas, select = utils.mixco(voxel)
981
+ past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)
982
+
983
+ for p in range(seq_len-1):
984
+ #print(past_behav.shape) #128, 15, 17
985
+ #print(past_behav[:,p,-1])
986
+ #print(past_15_voxels.shape) # 128, 1, 15724
987
+ mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])
988
+ #print(mask) # 128
989
+ past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])
990
+ #print(past_15_voxels)
991
+
992
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
993
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
994
+ past_15_times = past_15_times.reshape(-1)
995
+ time_embeddings = model.time_embedding(past_15_times)
996
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
997
+
998
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
999
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
1000
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
1001
+ voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)
1002
+ #unsqueeze(1) # bz * 2, 1, 4096
1003
+
1004
+ # past_voxel_ridge = model.ridge(past_voxel)
1005
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)
1006
+ print(voxel_ridge.shape)
1007
+
1008
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)
1009
+
1010
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1011
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1012
+
1013
+ if epoch < int(mixup_pct * num_epochs):
1014
+ loss_clip = utils.mixco_nce(
1015
+ clip_voxels_norm,
1016
+ clip_target_norm,
1017
+ temp=.006,
1018
+ perm=perm, betas=betas, select=select)
1019
+ else:
1020
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
1021
+ loss_clip = utils.soft_clip_loss(
1022
+ clip_voxels_norm,
1023
+ clip_target_norm,
1024
+ temp=epoch_temp)
1025
+
1026
+ loss_clip_total += loss_clip.item()
1027
+ loss_clip *= clip_scale
1028
+ loss = loss_clip
1029
+
1030
+ if blurry_recon:
1031
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
1032
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
1033
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
1034
+
1035
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
1036
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
1037
+ loss_blurry_total += loss_blurry.item()
1038
+ loss_blurry *= blur_scale
1039
+ loss += loss_blurry
1040
+
1041
+ if depth_recon:
1042
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
1043
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
1044
+ loss_depth_total += loss_depth.item()
1045
+ loss_depth *= depth_scale
1046
+ loss += loss_depth
1047
+
1048
+ # forward and backward top 1 accuracy
1049
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
1050
+ fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()
1051
+ bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()
1052
+
1053
+ if blurry_recon:
1054
+ with torch.no_grad():
1055
+ # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()
1056
+ random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)
1057
+ # random_samps = np.arange(batch_size//5)
1058
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)
1059
+ # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean
1060
+ pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)
1061
+ # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
1062
+ # loss += (1 - pixcorr)
1063
+ blurry_pixcorr += pixcorr.item()
1064
+ # utils.check_loss(pixcorr)
1065
+
1066
+ utils.check_loss(loss)
1067
+ accelerator.backward(loss)
1068
+ optimizer.step()
1069
+
1070
+ losses.append(loss.item())
1071
+ lrs.append(optimizer.param_groups[0]['lr'])
1072
+
1073
+ if lr_scheduler_type is not None:
1074
+ lr_scheduler.step()
1075
+
1076
+ model.eval()
1077
+ if local_rank==0:
1078
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):
1079
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
1080
+ # all test samples should be loaded per batch such that test_i should never exceed 0
1081
+ assert len(behav) == num_test
1082
+
1083
+ ## Average same-image repeats ##
1084
+ if test_image is None:
1085
+ voxel = voxels[behav[:,0,5].cpu().long()]
1086
+ image = behav[:,0,0].cpu().long()
1087
+
1088
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
1089
+ for im in unique_image:
1090
+ locs = torch.where(im == image)[0]
1091
+ if test_image is None:
1092
+ test_image = images[im][None]
1093
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
1094
+ else:
1095
+ test_image = torch.vstack((test_image, images[im][None]))
1096
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
1097
+
1098
+ # random sample of 300
1099
+ random_indices = torch.arange(len(test_voxel))[:300]
1100
+ voxel = test_voxel[random_indices].to(device)
1101
+ image = test_image[random_indices].to(device)
1102
+ assert len(image) == 300
1103
+
1104
+ current_past_behav = past_behav[random_indices]
1105
+
1106
+ past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
1107
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
1108
+
1109
+ if blurry_recon:
1110
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
1111
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
1112
+
1113
+ if depth_recon:
1114
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
1115
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
1116
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
1117
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
1118
+
1119
+ clip_target = clip_model.embed_image(image.float())
1120
+
1121
+
1122
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
1123
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
1124
+ past_15_times = past_15_times.reshape(-1)
1125
+ time_embeddings = model.time_embedding(past_15_times)
1126
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
1127
+
1128
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
1129
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
1130
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2)).unsqueeze(1)
1131
+
1132
+ #voxel_ridge = model.ridge(voxel).unsqueeze(1)
1133
+
1134
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)
1135
+
1136
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)
1137
+
1138
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1139
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1140
+
1141
+ loss_clip = utils.soft_clip_loss(
1142
+ clip_voxels_norm,
1143
+ clip_target_norm,
1144
+ temp=.006)
1145
+ test_loss_clip_total += loss_clip.item()
1146
+ loss_clip = loss_clip * clip_scale
1147
+ loss = loss_clip
1148
+
1149
+ if blurry_recon:
1150
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
1151
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
1152
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
1153
+
1154
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
1155
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
1156
+ test_loss_blurry_total += loss_blurry.item()
1157
+ loss_blurry *= blur_scale
1158
+ loss += loss_blurry
1159
+
1160
+ # halving the batch size because the decoder is computationally heavy
1161
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)
1162
+ blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1163
+ pixcorr = utils.pixcorr(image, blurry_recon_images)
1164
+ loss += (1 - pixcorr)
1165
+ test_blurry_pixcorr += pixcorr.item()
1166
+
1167
+ if depth_recon:
1168
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
1169
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
1170
+ test_loss_depth_total += loss_depth.item()
1171
+ loss_depth *= depth_scale
1172
+ loss += loss_depth
1173
+
1174
+ # forward and backward top 1 accuracy
1175
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
1176
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
1177
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()
1178
+
1179
+ utils.check_loss(loss)
1180
+ test_losses.append(loss.item())
1181
+
1182
+ # if utils.is_interactive(): clear_output(wait=True)
1183
+ print("---")
1184
+
1185
+ assert (test_i+1) == 1
1186
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
1187
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
1188
+ "train/lr": lrs[-1],
1189
+ "train/num_steps": len(losses),
1190
+ "test/num_steps": len(test_losses),
1191
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1192
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1193
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1194
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1195
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1196
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1197
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1198
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1199
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1200
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1201
+ "train/loss_depth_total": loss_depth_total / (train_i + 1),
1202
+ "test/loss_depth_total": test_loss_depth_total / (test_i + 1),
1203
+ }
1204
+
1205
+ if blurry_recon:
1206
+ # transform blurry recon latents to images and plot it
1207
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1208
+ jj=-1
1209
+ for j in [0,1,2,3]:
1210
+ jj+=1
1211
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1212
+ axes[jj].axis('off')
1213
+ jj+=1
1214
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1215
+ axes[jj].axis('off')
1216
+
1217
+ if wandb_log:
1218
+ logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1219
+ plt.close()
1220
+ else:
1221
+ plt.show()
1222
+
1223
+ if depth_recon:
1224
+ # transform blurry recon latents to images and plot it
1225
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1226
+ # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1227
+ # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1228
+ jj=-1
1229
+ for j in [0,1,2,3]:
1230
+ jj+=1
1231
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))
1232
+ axes[jj].axis('off')
1233
+ jj+=1
1234
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))
1235
+ axes[jj].axis('off')
1236
+ if wandb_log:
1237
+ logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1238
+ plt.close()
1239
+ else:
1240
+ plt.show()
1241
+
1242
+ progress_bar.set_postfix(**logs)
1243
+
1244
+ # Save model checkpoint and reconstruct
1245
+ if epoch % ckpt_interval == 0:
1246
+ if not utils.is_interactive():
1247
+ save_ckpt(f'last')
1248
+
1249
+ if wandb_log: wandb.log(logs)
1250
+
1251
+ # wait for other GPUs to catch up if needed
1252
+ accelerator.wait_for_everyone()
1253
+ torch.cuda.empty_cache()
1254
+ gc.collect()
1255
+
1256
+ print("\n===Finished!===\n")
1257
+ if ckpt_saving:
1258
+ save_ckpt(f'last')
1259
+ if not utils.is_interactive():
1260
+ sys.exit(0)
1261
+
1262
+
1263
+ # In[ ]:
1264
+
1265
+
1266
+ plt.plot(losses)
1267
+ plt.show()
1268
+ plt.plot(test_losses)
1269
+ plt.show()
1270
+
1271
+
1272
+ # # Retrieve nearest neighbor in the training set using test set data
1273
+
1274
+ # In[ ]:
1275
+
1276
+
1277
+ annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy")
1278
+
1279
+
1280
+ # In[ ]:
1281
+
1282
+
1283
+ ii=2
1284
+ all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))
1285
+ with torch.no_grad(), torch.cuda.amp.autocast():
1286
+ for batch in tqdm(range(0,len(all_indices),512)):
1287
+ if batch==0:
1288
+ clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1289
+ else:
1290
+ target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1291
+ clip_target = torch.vstack((clip_target,target))
1292
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1293
+
1294
+ voxel = test_voxel[[ii]].to(device)
1295
+ image = test_image[[ii]].to(device)
1296
+
1297
+ print("Original Image (test set)")
1298
+ display(utils.torch_to_Image(image))
1299
+
1300
+ clip_target = clip_model.embed_image(image).cpu()
1301
+ # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))
1302
+
1303
+ voxel_ridge = model.ridge(voxel).unsqueeze(1)
1304
+ clip_voxels, _, _ = model.backbone(voxel_ridge)
1305
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1306
+ clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1307
+
1308
+ print("clip_voxels_norm", clip_voxels_norm.shape)
1309
+ print("clip_target_norm", clip_target_norm.shape)
1310
+
1311
+ sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(),
1312
+ clip_target_norm).flatten()).flip(0)
1313
+ picks = all_indices[sortt[:5]]
1314
+
1315
+ print("\nNearest neighbors in training set")
1316
+ for ip,p in enumerate(picks):
1317
+ display(utils.torch_to_Image(images[[p]]))
1318
+ # print(utils.select_annotations([annots[int(p)]]))
1319
+ if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]
1320
+
1321
+ print("\n=====\npredicted_caption:\n", predicted_caption)
1322
+
1323
+
1324
+ # # Feed into Stable Diffusion XL for reconstructions
1325
+
1326
+ # In[ ]:
1327
+
1328
+
1329
+ from diffusers import StableDiffusionXLPipeline
1330
+ pipe = StableDiffusionXLPipeline.from_pretrained(
1331
+ "/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
1332
+ )
1333
+ pipe.to("cuda")
1334
+ pass
1335
+
1336
+
1337
+ # In[ ]:
1338
+
1339
+
1340
+ prompt = predicted_caption
1341
+ recon = pipe(prompt=prompt).images[0]
1342
+
1343
+
1344
+ # In[ ]:
1345
+
1346
+
1347
+ print("Seen image")
1348
+ display(utils.torch_to_Image(image))
1349
+
1350
+ print("Reconstruction")
1351
+ utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))
1352
+
src/Train_MLPMixer-Copy2.py ADDED
@@ -0,0 +1,1275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ #from subprocess import call
9
+ #command = "jupyter nbconvert Train_MLPMixer-Copy1.ipynb --to python"
10
+ #call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import string
28
+ import h5py
29
+ from tqdm import tqdm
30
+
31
+ import webdataset as wds
32
+ import gc
33
+
34
+ import matplotlib.pyplot as plt
35
+ import torch
36
+ import torch.nn as nn
37
+ from torchvision import transforms
38
+
39
+ from accelerate import Accelerator, DeepSpeedPlugin
40
+
41
+ # tf32 data type is faster than standard float32
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+
44
+ # custom functions #
45
+ import utils
46
+
47
+
48
+ # In[3]:
49
+
50
+
51
+ ### Multi-GPU config ###
52
+ local_rank = os.getenv('RANK')
53
+ if local_rank is None:
54
+ local_rank = 0
55
+ else:
56
+ local_rank = int(local_rank)
57
+ print("LOCAL RANK ", local_rank)
58
+
59
+ num_devices = torch.cuda.device_count()
60
+ if num_devices==0: num_devices = 1
61
+
62
+ # ## UNCOMMENT BELOW SECTION AND COMMENT OUT DEEPSPEED SECTION TO AVOID USING DEEPSPEED ###
63
+ # accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
64
+ # global_batch_size = batch_size = 32
65
+ # data_type = torch.float16 # change depending on your mixed_precision
66
+
67
+ ### DEEPSPEED INITIALIZATION ###
68
+ if num_devices <= 1 and utils.is_interactive():
69
+ global_batch_size = batch_size = 32
70
+ print(f"Setting batch_size to {batch_size}")
71
+ # can emulate a distributed environment for deepspeed to work in jupyter notebook
72
+ os.environ["MASTER_ADDR"] = "localhost"
73
+ os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
74
+ os.environ["RANK"] = "0"
75
+ os.environ["LOCAL_RANK"] = "0"
76
+ os.environ["WORLD_SIZE"] = "1"
77
+ os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
78
+ else:
79
+ global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
80
+ batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
81
+
82
+ # alter the deepspeed config according to your global and local batch size
83
+ if local_rank == 0:
84
+ with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json', 'r') as file:
85
+ config = json.load(file)
86
+ config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
87
+ config['train_micro_batch_size_per_gpu'] = batch_size
88
+ config['bf16'] = {'enabled': False}
89
+ config['fp16'] = {'enabled': True}
90
+ with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json', 'w') as file:
91
+ json.dump(config, file)
92
+ else:
93
+ # give some time for the local_rank=0 gpu to prep new deepspeed config file
94
+ time.sleep(10)
95
+ deepspeed_plugin = DeepSpeedPlugin("/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json")
96
+ accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
97
+
98
+
99
+ # In[4]:
100
+
101
+
102
+ print("PID of this process =",os.getpid())
103
+ device = accelerator.device
104
+ print("device:",device)
105
+ num_workers = num_devices
106
+ print(accelerator.state)
107
+ world_size = accelerator.state.num_processes
108
+ distributed = not accelerator.state.distributed_type == 'NO'
109
+
110
+ # set data_type to match your mixed precision (automatically set based on deepspeed config)
111
+ if accelerator.mixed_precision == "bf16":
112
+ data_type = torch.bfloat16
113
+ elif accelerator.mixed_precision == "fp16":
114
+ data_type = torch.float16
115
+ else:
116
+ data_type = torch.float32
117
+
118
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
119
+ print = accelerator.print # only print if local_rank=0
120
+
121
+
122
+ # # Configurations
123
+
124
+ # In[5]:
125
+
126
+
127
+ # if running this interactively, can specify jupyter_args here for argparser to use
128
+ if utils.is_interactive():
129
+ # create random model_name
130
+ model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
131
+ model_name = model_name + "_interactive"
132
+ print("model_name:", model_name)
133
+
134
+ # global_batch_size and batch_size should already be defined in the above cells
135
+ # other variables can be specified in the following string:
136
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
137
+ --model_name={model_name} \
138
+ --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \
139
+ --clip_scale=1. --blur_scale=100. --depth_scale=100. \
140
+ --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving"
141
+
142
+ jupyter_args = jupyter_args.split()
143
+ print(jupyter_args)
144
+
145
+ from IPython.display import clear_output # function to clear print outputs in cell
146
+ get_ipython().run_line_magic('load_ext', 'autoreload')
147
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
148
+ get_ipython().run_line_magic('autoreload', '2')
149
+
150
+
151
+ # In[6]:
152
+
153
+
154
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
155
+ parser.add_argument(
156
+ "--model_name", type=str, default="testing",
157
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
158
+ )
159
+ parser.add_argument(
160
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
161
+ help="Path to where NSD data is stored / where to download it to",
162
+ )
163
+ parser.add_argument(
164
+ "--subj",type=int, default=1, choices=[1,2,5,7],
165
+ )
166
+ parser.add_argument(
167
+ "--batch_size", type=int, default=32,
168
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
169
+ )
170
+ parser.add_argument(
171
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=True,
172
+ help="whether to log to wandb",
173
+ )
174
+ parser.add_argument(
175
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
176
+ help="if not using wandb and want to resume from a ckpt",
177
+ )
178
+ parser.add_argument(
179
+ "--wandb_project",type=str,default="stability",
180
+ help="wandb project name",
181
+ )
182
+ parser.add_argument(
183
+ "--mixup_pct",type=float,default=.33,
184
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
185
+ )
186
+ parser.add_argument(
187
+ "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
188
+ help="whether to output blurry reconstructions",
189
+ )
190
+ parser.add_argument(
191
+ "--depth_recon",action=argparse.BooleanOptionalAction,default=True,
192
+ help="whether to output depth reconstructions",
193
+ )
194
+ parser.add_argument(
195
+ "--blur_scale",type=float,default=100.,
196
+ help="multiply loss from blurry recons by this number",
197
+ )
198
+ parser.add_argument(
199
+ "--depth_scale",type=float,default=100.,
200
+ help="multiply loss from depth recons by this number",
201
+ )
202
+ parser.add_argument(
203
+ "--clip_scale",type=float,default=1.,
204
+ help="multiply contrastive loss by this number",
205
+ )
206
+ parser.add_argument(
207
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
208
+ help="whether to use image augmentation",
209
+ )
210
+ parser.add_argument(
211
+ "--num_epochs",type=int,default=120,
212
+ help="number of epochs of training",
213
+ )
214
+ parser.add_argument(
215
+ "--hidden_dim",type=int,default=4096,
216
+ )
217
+ parser.add_argument(
218
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
219
+ )
220
+ parser.add_argument(
221
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
222
+ )
223
+ parser.add_argument(
224
+ "--ckpt_interval",type=int,default=5,
225
+ help="save backup ckpt and reconstruct every x epochs",
226
+ )
227
+ parser.add_argument(
228
+ "--seed",type=int,default=42,
229
+ )
230
+ parser.add_argument(
231
+ "--max_lr",type=float,default=3e-4,
232
+ )
233
+ parser.add_argument(
234
+ "--seq_len",type=int,default=2,
235
+ )
236
+
237
+ if utils.is_interactive():
238
+ args = parser.parse_args(jupyter_args)
239
+ else:
240
+ args = parser.parse_args()
241
+
242
+ # create global variables without the args prefix
243
+ for attribute_name in vars(args).keys():
244
+ globals()[attribute_name] = getattr(args, attribute_name)
245
+
246
+
247
+ # In[7]:
248
+
249
+
250
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
251
+ if not os.path.exists(outdir) and ckpt_saving:
252
+ os.makedirs(outdir,exist_ok=True)
253
+ if use_image_aug:
254
+ import kornia
255
+ from kornia.augmentation.container import AugmentationSequential
256
+ img_augment = AugmentationSequential(
257
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
258
+ kornia.augmentation.Resize((224, 224)),
259
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
260
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
261
+ kornia.augmentation.RandomGrayscale(p=0.3),
262
+ same_on_batch=False,
263
+ data_keys=["input"],
264
+ )
265
+
266
+
267
+ # # Prep data, models, and dataloaders
268
+
269
+ # ## Dataloader
270
+
271
+ # In[8]:
272
+
273
+
274
+ if subj==1:
275
+ num_train = 24958
276
+ num_test = 2770
277
+ test_batch_size = num_test
278
+
279
+ def my_split_by_node(urls): return urls
280
+
281
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
282
+ # train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..1}.tar"
283
+ print(train_url)
284
+
285
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
286
+ .shuffle(750, initial=1500, rng=random.Random(42))\
287
+ .decode("torch")\
288
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
289
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
290
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
291
+
292
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
293
+ print(test_url)
294
+
295
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
296
+ .shuffle(750, initial=1500, rng=random.Random(42))\
297
+ .decode("torch")\
298
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
299
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
300
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)
301
+
302
+
303
+ # ### check dataloaders are working
304
+
305
+ # In[9]:
306
+
307
+
308
+ test_vox_indices = []
309
+ test_73k_images = []
310
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
311
+ test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())
312
+ test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())
313
+ test_vox_indices = test_vox_indices.astype(np.int16)
314
+ print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))
315
+ print("---\n")
316
+
317
+ train_vox_indices = []
318
+ train_73k_images = []
319
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
320
+ train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())
321
+ train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())
322
+ train_vox_indices = train_vox_indices.astype(np.int16)
323
+ print(train_i, (train_i+1) * batch_size, len(train_vox_indices))
324
+
325
+
326
+ # ## Load data and images
327
+
328
+ # In[10]:
329
+
330
+
331
+ # load betas
332
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
333
+ # f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')
334
+
335
+ voxels = f['betas'][:]
336
+ print(f"subj0{subj} betas loaded into memory")
337
+ voxels = torch.Tensor(voxels).to("cpu").to(data_type)
338
+ print("voxels", voxels.shape)
339
+ num_voxels = voxels.shape[-1]
340
+
341
+ # load orig images
342
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
343
+ images = f['images'][:]
344
+ images = torch.Tensor(images).to("cpu").to(data_type)
345
+ print("images", images.shape)
346
+
347
+
348
+ # ## Load models
349
+
350
+ # ### CLIP image embeddings model
351
+
352
+ # In[11]:
353
+
354
+
355
+ from models import Clipper
356
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
357
+ clip_seq_dim = 257
358
+ clip_emb_dim = 768 #1024
359
+ # hidden_dim = 4096
360
+ #seq_len = 1 #2 #32
361
+
362
+
363
+ # ### SD VAE
364
+
365
+ # In[12]:
366
+
367
+
368
+ # if blurry_recon:
369
+ # from diffusers import AutoencoderKL
370
+ # autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
371
+ # # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
372
+ # autoenc.eval()
373
+ # autoenc.requires_grad_(False)
374
+ # autoenc.to(device)
375
+ # utils.count_params(autoenc)
376
+
377
+ if blurry_recon:# or depth_recon:
378
+ from diffusers import VQModel
379
+ autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type)
380
+ autoenc.eval()
381
+ autoenc.requires_grad_(False)
382
+ autoenc.to(device)
383
+ utils.count_params(autoenc)
384
+
385
+
386
+ # #### downsampled images
387
+
388
+ # In[13]:
389
+
390
+
391
+ if blurry_recon:
392
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
393
+
394
+ input_batch = images[[30]].to(device)
395
+ print(input_batch.shape)
396
+
397
+ downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)
398
+ re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')
399
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
400
+ print(re_upsampled_enc.shape)
401
+
402
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
403
+
404
+
405
+ # #### MiDaS depth
406
+
407
+ # In[14]:
408
+
409
+
410
+ if depth_recon:
411
+ from controlnet_aux.midas import MidasDetector
412
+
413
+ midas_depth = MidasDetector.from_pretrained(
414
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device)
415
+ midas_depth.model.eval()
416
+ midas_depth.model.requires_grad_(False)
417
+ midas_depth.model.to(device)
418
+ pass
419
+
420
+
421
+ # In[15]:
422
+
423
+
424
+ if depth_recon:
425
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
426
+
427
+ input_batch = images[[30,31]].float().to(device)
428
+ print(input_batch.shape)
429
+
430
+ midas_emb = midas_depth.model(input_batch).unsqueeze(1)
431
+ print(midas_emb.shape)
432
+
433
+ prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()
434
+ print(prediction.shape)
435
+
436
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
437
+ midas_emb_size = prediction.flatten(1).shape[1]
438
+ print("midas_emb", prediction.shape, prediction.min(), prediction.max())
439
+ print("midas_emb_size", midas_emb_size)
440
+
441
+ if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224)))
442
+
443
+ if blurry_recon:
444
+ prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)
445
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
446
+ prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215
447
+ print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())
448
+
449
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
450
+
451
+
452
+ # ### MindEye modules
453
+
454
+ # In[17]:
455
+
456
+
457
+ class MindEyeModule(nn.Module):
458
+ def __init__(self):
459
+ super(MindEyeModule, self).__init__()
460
+ def forward(self, x):
461
+ return x
462
+
463
+ model = MindEyeModule()
464
+ model
465
+
466
+
467
+ # In[18]:
468
+
469
+
470
+ time_embedding_dim = 512
471
+
472
+ class RidgeRegression(torch.nn.Module):
473
+ # make sure to add weight_decay when initializing optimizer
474
+ def __init__(self, input_size, out_features):
475
+ super(RidgeRegression, self).__init__()
476
+ self.out_features = out_features
477
+ self.linear = torch.nn.Linear(input_size, out_features)
478
+ def forward(self, x):
479
+ return self.linear(x)
480
+
481
+ model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
482
+ utils.count_params(model.ridge)
483
+ utils.count_params(model)
484
+
485
+ b = torch.randn((2,1,voxels.shape[1]))
486
+ time_emb_test = torch.randn((2,1,time_embedding_dim))
487
+ print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
488
+
489
+
490
+ # In[24]:
491
+
492
+
493
+ num_past_voxels = 15
494
+
495
+
496
+
497
+ # In[25]:
498
+
499
+
500
+ from functools import partial
501
+ from diffusers.models.vae import Decoder
502
+ class BrainNetwork(nn.Module):
503
+ def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):
504
+ super().__init__()
505
+ self.seq_len = seq_len
506
+ self.h = h
507
+ self.clip_size = clip_size
508
+
509
+ # Initial linear layer to match the input dimensions to hidden dimensions
510
+ # self.lin0 = nn.Linear(in_dim, seq_len * h)
511
+
512
+ # Mixer Blocks
513
+ self.mixer_blocks1 = nn.ModuleList([
514
+ self.mixer_block1(h, drop) for _ in range(n_blocks)
515
+ ])
516
+ self.mixer_blocks2 = nn.ModuleList([
517
+ self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
518
+ ])
519
+
520
+ # Output linear layer
521
+ self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)
522
+
523
+ # low-rank matrices
524
+ # self.rank = 500
525
+ # self.U = nn.Parameter(torch.randn(self.rank, out_dim))
526
+ # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))
527
+ # self.S = nn.Parameter(torch.randn(out_dim))
528
+
529
+ self.clip_proj = nn.Sequential(
530
+ nn.LayerNorm(clip_size),
531
+ nn.GELU(),
532
+ nn.Linear(clip_size, 2048),
533
+ nn.LayerNorm(2048),
534
+ nn.GELU(),
535
+ nn.Linear(2048, 2048),
536
+ nn.LayerNorm(2048),
537
+ nn.GELU(),
538
+ nn.Linear(2048, clip_size)
539
+ )
540
+
541
+ if blurry_recon:
542
+ # self.blin1 = nn.Sequential(
543
+ # nn.Linear(out_dim, 4096, bias=True),
544
+ # nn.LayerNorm(4096),
545
+ # nn.GELU(),
546
+ # nn.Linear(4096, 4096))
547
+ self.blin1 = nn.Linear(h*seq_len, 4096)
548
+ self.bgroupnorm = nn.GroupNorm(1, 256)
549
+ self.bupsampler = Decoder(
550
+ in_channels=256,
551
+ out_channels=128,
552
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
553
+ block_out_channels=[32, 64, 128],
554
+ layers_per_block=1,
555
+ )
556
+
557
+ if depth_recon:
558
+ # self.dlin1 = nn.Sequential(
559
+ # nn.Linear(h, midas_emb_size),
560
+ # nn.Sigmoid(),
561
+ # )
562
+ self.dlin1 = nn.Linear(h*seq_len, 4096)
563
+ self.dgroupnorm = nn.GroupNorm(1, 256)
564
+ self.dupsampler = Decoder(
565
+ in_channels=256,
566
+ out_channels=1,#128,
567
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
568
+ block_out_channels=[32, 64, 128, 256],
569
+ layers_per_block=1,
570
+ )
571
+
572
+ def mixer_block1(self, h, drop):
573
+ return nn.Sequential(
574
+ nn.LayerNorm(h),
575
+ self.mlp(h, h, drop), # Token mixing
576
+ )
577
+
578
+ def mixer_block2(self, seq_len, drop):
579
+ return nn.Sequential(
580
+ nn.LayerNorm(seq_len),
581
+ self.mlp(seq_len, seq_len, drop) # Channel mixing
582
+ )
583
+
584
+ def mlp(self, in_dim, out_dim, drop):
585
+ return nn.Sequential(
586
+ nn.Linear(in_dim, out_dim),
587
+ nn.GELU(),
588
+ nn.Dropout(drop),
589
+ nn.Linear(out_dim, out_dim),
590
+ )
591
+
592
+ def forward(self, x, idx = None):
593
+ print(idx)
594
+ # make empty tensors for blur and depth outputs
595
+ b,d = torch.Tensor([0.]), torch.Tensor([0.])
596
+
597
+ # Initial linear layer
598
+ # x = self.lin0(x)
599
+
600
+ # Reshape to seq_len by dim
601
+ # x = x.reshape(-1, self.seq_len, self.h)
602
+
603
+ # Mixer blocks
604
+ #print("x shape ", x.shape)
605
+ residual1 = x
606
+ residual2 = x.permute(0,2,1)
607
+ #print("residual 2", residual2.shape)
608
+ for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
609
+ x = block1(x) + residual1
610
+ #print("xblo", x.shape)
611
+ residual1 = x
612
+ x = x.permute(0,2,1)
613
+
614
+ x = block2(x) + residual2
615
+ #print("xblo2", x.shape)
616
+ residual2 = x
617
+ x = x.permute(0,2,1)
618
+
619
+ # Flatten
620
+ x = x.reshape(x.size(0), -1)
621
+
622
+ c = self.clin1(x)
623
+
624
+ # low rank linear to out dim cuts # params by nearly half compared to full linear mapping
625
+ # c = (x @ (self.V/100) @ (self.U/100)) + self.S
626
+
627
+ c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))
628
+
629
+ if blurry_recon:
630
+ b = self.blin1(x)
631
+ b = b.reshape(len(b), 256, 4, 4)
632
+ b = self.bgroupnorm(b)
633
+ b = self.bupsampler(b)
634
+
635
+ if depth_recon:
636
+ d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)
637
+ d = d.reshape(len(d), 256, 4, 4)
638
+ d = self.dgroupnorm(d)
639
+ d = self.dupsampler(d)
640
+
641
+ return c, b, d
642
+
643
+
644
+ class TimeEmbedding(nn.Module):
645
+ def __init__(self, embedding_time_dim=512, num_past_voxels=15):
646
+ super().__init__()
647
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
648
+ self.num_past_voxels = num_past_voxels
649
+ self.embedding_time_dim = embedding_time_dim
650
+
651
+ def forward(self, time):
652
+ # time is (batch_size,)
653
+ time = time.long()
654
+ time = self.embedding_time(time)
655
+ return time # (batch_size, embedding_time_dim)
656
+
657
+
658
+ #model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
659
+ model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
660
+
661
+ model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim)
662
+ utils.count_params(model.backbone)
663
+ utils.count_params(model)
664
+
665
+ # test that the model works on some fake data
666
+ b = torch.randn((1,seq_len,hidden_dim))
667
+ print("b.shape",b.shape)
668
+ with torch.no_grad():
669
+ clip_, blur_, depth_ = model.backbone(b)
670
+ print(clip_.shape, blur_.shape, depth_.shape)
671
+
672
+
673
+ # In[ ]:
674
+
675
+
676
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
677
+ opt_grouped_parameters = [
678
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
679
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
680
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
681
+ ]
682
+
683
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
684
+
685
+ if lr_scheduler_type == 'linear':
686
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
687
+ optimizer,
688
+ total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),
689
+ last_epoch=-1
690
+ )
691
+ elif lr_scheduler_type == 'cycle':
692
+ total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))
693
+ print("total_steps", total_steps)
694
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
695
+ optimizer,
696
+ max_lr=max_lr,
697
+ total_steps=total_steps,
698
+ final_div_factor=1000,
699
+ last_epoch=-1, pct_start=2/num_epochs
700
+ )
701
+
702
+ def save_ckpt(tag):
703
+ ckpt_path = outdir+f'/{tag}.pth'
704
+ print(f'saving {ckpt_path}',flush=True)
705
+ unwrapped_model = accelerator.unwrap_model(model)
706
+ try:
707
+ torch.save({
708
+ 'epoch': epoch,
709
+ 'model_state_dict': unwrapped_model.state_dict(),
710
+ 'optimizer_state_dict': optimizer.state_dict(),
711
+ 'lr_scheduler': lr_scheduler.state_dict(),
712
+ 'train_losses': losses,
713
+ 'test_losses': test_losses,
714
+ 'lrs': lrs,
715
+ }, ckpt_path)
716
+ except:
717
+ print("Couldn't save... moving on to prevent crashing.")
718
+ del unwrapped_model
719
+
720
+ print("\nDone with model preparations!")
721
+ utils.count_params(model)
722
+
723
+
724
+ # # Weights and Biases
725
+
726
+ # In[ ]:
727
+
728
+
729
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
730
+ import wandb
731
+ wandb_project = 'mindeyev2'
732
+ wandb_run = model_name
733
+ wandb_notes = ''
734
+
735
+ print(f"wandb {wandb_project} run {wandb_run}")
736
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
737
+ wandb_config = {
738
+ "model_name": model_name,
739
+ "global_batch_size": global_batch_size,
740
+ "batch_size": batch_size,
741
+ "num_epochs": num_epochs,
742
+ "clip_scale": clip_scale,
743
+ "blur_scale": blur_scale,
744
+ "use_image_aug": use_image_aug,
745
+ "max_lr": max_lr,
746
+ "mixup_pct": mixup_pct,
747
+ "num_train": num_train,
748
+ "num_test": num_test,
749
+ "ckpt_interval": ckpt_interval,
750
+ "ckpt_saving": ckpt_saving,
751
+ "seed": seed,
752
+ "distributed": distributed,
753
+ "num_devices": num_devices,
754
+ "world_size": world_size,
755
+ "train_url": train_url,
756
+ "test_url": test_url,
757
+ }
758
+ print("wandb_config:\n",wandb_config)
759
+ if False: # wandb_auto_resume
760
+ print("wandb_id:",model_name)
761
+ wandb.init(
762
+ id = model_name,
763
+ project=wandb_project,
764
+ name=wandb_run,
765
+ config=wandb_config,
766
+ notes=wandb_notes,
767
+ resume="allow",
768
+ )
769
+ else:
770
+ wandb.init(
771
+ project=wandb_project,
772
+ name=wandb_run,
773
+ config=wandb_config,
774
+ notes=wandb_notes,
775
+ )
776
+ else:
777
+ wandb_log = False
778
+
779
+
780
+ # # Main
781
+
782
+ # In[ ]:
783
+
784
+
785
+ epoch = 0
786
+ losses, test_losses, lrs = [], [], []
787
+ best_test_loss = 1e9
788
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
789
+
790
+ # Optionally resume from checkpoint #
791
+ if resume_from_ckpt:
792
+ print("\n---resuming from last.pth ckpt---\n")
793
+ try:
794
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
795
+ except:
796
+ print('last.pth failed... trying last_backup.pth')
797
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
798
+ epoch = checkpoint['epoch']
799
+ print("Epoch",epoch)
800
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
801
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
802
+ model.load_state_dict(checkpoint['model_state_dict'])
803
+ del checkpoint
804
+ elif wandb_log:
805
+ if wandb.run.resumed:
806
+ print("\n---resuming from last.pth ckpt---\n")
807
+ try:
808
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
809
+ except:
810
+ print('last.pth failed... trying last_backup.pth')
811
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
812
+ epoch = checkpoint['epoch']
813
+ print("Epoch",epoch)
814
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
815
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
816
+ model.load_state_dict(checkpoint['model_state_dict'])
817
+ del checkpoint
818
+ torch.cuda.empty_cache()
819
+
820
+
821
+ # In[ ]:
822
+
823
+
824
+ model, optimizer, train_dl, lr_scheduler = accelerator.prepare(
825
+ model, optimizer, train_dl, lr_scheduler
826
+ )
827
+ # leaving out test_dl since we will only have local_rank 0 device do evals
828
+
829
+
830
+ # In[ ]:
831
+
832
+
833
+ def add_saturation(image, alpha=2):
834
+ gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]
835
+ gray_image = gray_image.unsqueeze(1).expand_as(image)
836
+ saturated_image = alpha * image + (1 - alpha) * gray_image
837
+ return torch.clamp(saturated_image, 0, 1)
838
+
839
+
840
+ # In[ ]:
841
+
842
+
843
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
844
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
845
+ test_image, test_voxel = None, None
846
+ mse = nn.MSELoss()
847
+ l1 = nn.L1Loss()
848
+
849
+ for epoch in progress_bar:
850
+ model.train()
851
+
852
+ fwd_percent_correct = 0.
853
+ bwd_percent_correct = 0.
854
+ test_fwd_percent_correct = 0.
855
+ test_bwd_percent_correct = 0.
856
+
857
+ loss_clip_total = 0.
858
+ loss_blurry_total = 0.
859
+ loss_depth_total = 0.
860
+ test_loss_clip_total = 0.
861
+ test_loss_blurry_total = 0.
862
+ test_loss_depth_total = 0.
863
+
864
+ blurry_pixcorr = 0.
865
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
866
+
867
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
868
+ with torch.cuda.amp.autocast(dtype=data_type):
869
+ optimizer.zero_grad()
870
+
871
+ #voxel = voxels[behav[:,0,5].cpu().long()].to(device)
872
+ #image = images[behav[:,0,0].cpu().long()].to(device).float()
873
+
874
+ #past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
875
+ #past_15_times = torch.Tensor([i for i in range(seq_len - 1)]).to(device) # 15
876
+
877
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
878
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
879
+
880
+ past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
881
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
882
+ #for past in range(1):
883
+ # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)
884
+
885
+ if blurry_recon:
886
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
887
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
888
+
889
+ if depth_recon:
890
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
891
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
892
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
893
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
894
+
895
+ if use_image_aug:
896
+ image = img_augment(image)
897
+
898
+ clip_target = clip_model.embed_image(image)
899
+ assert not torch.any(torch.isnan(clip_target))
900
+
901
+ if epoch < int(mixup_pct * num_epochs):
902
+ voxel, perm, betas, select = utils.mixco(voxel)
903
+ past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)
904
+
905
+ for p in range(seq_len-1):
906
+ #print(past_behav.shape) #128, 15, 17
907
+ #print(past_behav[:,p,-1])
908
+ #print(past_15_voxels.shape) # 128, 1, 15724
909
+ mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])
910
+ #print(mask) # 128
911
+ past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])
912
+ #print(past_15_voxels)
913
+
914
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
915
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
916
+ past_15_times = past_15_times.reshape(-1)
917
+ time_embeddings = model.time_embedding(past_15_times)
918
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
919
+
920
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
921
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
922
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
923
+ voxel_ridge = voxel_ridge.view( seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2)
924
+ #unsqueeze(1) # bz * 2, 1, 4096
925
+
926
+ # past_voxel_ridge = model.ridge(past_voxel)
927
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)
928
+ #print(voxel_ridge.shape)
929
+
930
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge, idx = train_i)
931
+
932
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
933
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
934
+
935
+ if epoch < int(mixup_pct * num_epochs):
936
+ loss_clip = utils.mixco_nce(
937
+ clip_voxels_norm,
938
+ clip_target_norm,
939
+ temp=.006,
940
+ perm=perm, betas=betas, select=select)
941
+ else:
942
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
943
+ loss_clip = utils.soft_clip_loss(
944
+ clip_voxels_norm,
945
+ clip_target_norm,
946
+ temp=epoch_temp)
947
+
948
+ loss_clip_total += loss_clip.item()
949
+ loss_clip *= clip_scale
950
+ loss = loss_clip
951
+
952
+ if blurry_recon:
953
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
954
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
955
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
956
+
957
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
958
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
959
+ loss_blurry_total += loss_blurry.item()
960
+ loss_blurry *= blur_scale
961
+ loss += loss_blurry
962
+
963
+ if depth_recon:
964
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
965
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
966
+ loss_depth_total += loss_depth.item()
967
+ loss_depth *= depth_scale
968
+ loss += loss_depth
969
+
970
+ # forward and backward top 1 accuracy
971
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
972
+ fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()
973
+ bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()
974
+
975
+ if blurry_recon:
976
+ with torch.no_grad():
977
+ # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()
978
+ random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)
979
+ # random_samps = np.arange(batch_size//5)
980
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)
981
+ # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean
982
+ pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)
983
+ # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
984
+ # loss += (1 - pixcorr)
985
+ blurry_pixcorr += pixcorr.item()
986
+ # utils.check_loss(pixcorr)
987
+
988
+ utils.check_loss(loss)
989
+ accelerator.backward(loss)
990
+ optimizer.step()
991
+
992
+ losses.append(loss.item())
993
+ lrs.append(optimizer.param_groups[0]['lr'])
994
+
995
+ if lr_scheduler_type is not None:
996
+ lr_scheduler.step()
997
+
998
+ model.eval()
999
+ if local_rank==0:
1000
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):
1001
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
1002
+ # all test samples should be loaded per batch such that test_i should never exceed 0
1003
+ assert len(behav) == num_test
1004
+
1005
+ ## Average same-image repeats ##
1006
+ if test_image is None:
1007
+ voxel = voxels[behav[:,0,5].cpu().long()]
1008
+ image = behav[:,0,0].cpu().long()
1009
+
1010
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
1011
+ for im in unique_image:
1012
+ locs = torch.where(im == image)[0]
1013
+ if test_image is None:
1014
+ test_image = images[im][None]
1015
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
1016
+ else:
1017
+ test_image = torch.vstack((test_image, images[im][None]))
1018
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
1019
+
1020
+ # random sample of 300
1021
+ random_indices = torch.arange(len(test_voxel))[:300]
1022
+ voxel = test_voxel[random_indices].to(device)
1023
+ image = test_image[random_indices].to(device)
1024
+ assert len(image) == 300
1025
+
1026
+ current_past_behav = past_behav[random_indices]
1027
+
1028
+ past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
1029
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
1030
+
1031
+ if blurry_recon:
1032
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
1033
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
1034
+
1035
+ if depth_recon:
1036
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
1037
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
1038
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
1039
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
1040
+
1041
+ clip_target = clip_model.embed_image(image.float())
1042
+
1043
+
1044
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
1045
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
1046
+ past_15_times = past_15_times.reshape(-1)
1047
+ time_embeddings = model.time_embedding(past_15_times)
1048
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
1049
+
1050
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
1051
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
1052
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
1053
+ voxel_ridge = voxel_ridge.view(seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2)
1054
+
1055
+ #voxel_ridge = model.ridge(voxel).unsqueeze(1)
1056
+
1057
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)
1058
+
1059
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)
1060
+
1061
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1062
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1063
+
1064
+ loss_clip = utils.soft_clip_loss(
1065
+ clip_voxels_norm,
1066
+ clip_target_norm,
1067
+ temp=.006)
1068
+ test_loss_clip_total += loss_clip.item()
1069
+ loss_clip = loss_clip * clip_scale
1070
+ loss = loss_clip
1071
+
1072
+ if blurry_recon:
1073
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
1074
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
1075
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
1076
+
1077
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
1078
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
1079
+ test_loss_blurry_total += loss_blurry.item()
1080
+ loss_blurry *= blur_scale
1081
+ loss += loss_blurry
1082
+
1083
+ # halving the batch size because the decoder is computationally heavy
1084
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)
1085
+ blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1086
+ pixcorr = utils.pixcorr(image, blurry_recon_images)
1087
+ loss += (1 - pixcorr)
1088
+ test_blurry_pixcorr += pixcorr.item()
1089
+
1090
+ if depth_recon:
1091
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
1092
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
1093
+ test_loss_depth_total += loss_depth.item()
1094
+ loss_depth *= depth_scale
1095
+ loss += loss_depth
1096
+
1097
+ # forward and backward top 1 accuracy
1098
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
1099
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
1100
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()
1101
+
1102
+ utils.check_loss(loss)
1103
+ test_losses.append(loss.item())
1104
+
1105
+ # if utils.is_interactive(): clear_output(wait=True)
1106
+ print("---")
1107
+
1108
+ assert (test_i+1) == 1
1109
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
1110
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
1111
+ "train/lr": lrs[-1],
1112
+ "train/num_steps": len(losses),
1113
+ "test/num_steps": len(test_losses),
1114
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1115
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1116
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1117
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1118
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1119
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1120
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1121
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1122
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1123
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1124
+ "train/loss_depth_total": loss_depth_total / (train_i + 1),
1125
+ "test/loss_depth_total": test_loss_depth_total / (test_i + 1),
1126
+ }
1127
+
1128
+ if blurry_recon:
1129
+ # transform blurry recon latents to images and plot it
1130
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1131
+ jj=-1
1132
+ for j in [0,1,2,3]:
1133
+ jj+=1
1134
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1135
+ axes[jj].axis('off')
1136
+ jj+=1
1137
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1138
+ axes[jj].axis('off')
1139
+
1140
+ if wandb_log:
1141
+ logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1142
+ plt.close()
1143
+ else:
1144
+ plt.show()
1145
+
1146
+ if depth_recon:
1147
+ # transform blurry recon latents to images and plot it
1148
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1149
+ # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1150
+ # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1151
+ jj=-1
1152
+ for j in [0,1,2,3]:
1153
+ jj+=1
1154
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))
1155
+ axes[jj].axis('off')
1156
+ jj+=1
1157
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))
1158
+ axes[jj].axis('off')
1159
+ if wandb_log:
1160
+ logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1161
+ plt.close()
1162
+ else:
1163
+ plt.show()
1164
+
1165
+ progress_bar.set_postfix(**logs)
1166
+
1167
+ # Save model checkpoint and reconstruct
1168
+ if epoch % ckpt_interval == 0:
1169
+ if not utils.is_interactive():
1170
+ save_ckpt(f'last')
1171
+
1172
+ if wandb_log: wandb.log(logs)
1173
+
1174
+ # wait for other GPUs to catch up if needed
1175
+ accelerator.wait_for_everyone()
1176
+ torch.cuda.empty_cache()
1177
+ gc.collect()
1178
+
1179
+ print("\n===Finished!===\n")
1180
+ if ckpt_saving:
1181
+ save_ckpt(f'last')
1182
+ if not utils.is_interactive():
1183
+ sys.exit(0)
1184
+
1185
+
1186
+ # In[ ]:
1187
+
1188
+
1189
+ plt.plot(losses)
1190
+ plt.show()
1191
+ plt.plot(test_losses)
1192
+ plt.show()
1193
+
1194
+
1195
+ # # Retrieve nearest neighbor in the training set using test set data
1196
+
1197
+ # In[ ]:
1198
+
1199
+
1200
+ annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy")
1201
+
1202
+
1203
+ # In[ ]:
1204
+
1205
+
1206
+ ii=2
1207
+ all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))
1208
+ with torch.no_grad(), torch.cuda.amp.autocast():
1209
+ for batch in tqdm(range(0,len(all_indices),512)):
1210
+ if batch==0:
1211
+ clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1212
+ else:
1213
+ target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1214
+ clip_target = torch.vstack((clip_target,target))
1215
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1216
+
1217
+ voxel = test_voxel[[ii]].to(device)
1218
+ image = test_image[[ii]].to(device)
1219
+
1220
+ print("Original Image (test set)")
1221
+ display(utils.torch_to_Image(image))
1222
+
1223
+ clip_target = clip_model.embed_image(image).cpu()
1224
+ # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))
1225
+
1226
+ voxel_ridge = model.ridge(voxel).unsqueeze(1)
1227
+ clip_voxels, _, _ = model.backbone(voxel_ridge)
1228
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1229
+ clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1230
+
1231
+ print("clip_voxels_norm", clip_voxels_norm.shape)
1232
+ print("clip_target_norm", clip_target_norm.shape)
1233
+
1234
+ sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(),
1235
+ clip_target_norm).flatten()).flip(0)
1236
+ picks = all_indices[sortt[:5]]
1237
+
1238
+ print("\nNearest neighbors in training set")
1239
+ for ip,p in enumerate(picks):
1240
+ display(utils.torch_to_Image(images[[p]]))
1241
+ # print(utils.select_annotations([annots[int(p)]]))
1242
+ if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]
1243
+
1244
+ print("\n=====\npredicted_caption:\n", predicted_caption)
1245
+
1246
+
1247
+ # # Feed into Stable Diffusion XL for reconstructions
1248
+
1249
+ # In[ ]:
1250
+
1251
+
1252
+ from diffusers import StableDiffusionXLPipeline
1253
+ pipe = StableDiffusionXLPipeline.from_pretrained(
1254
+ "/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
1255
+ )
1256
+ pipe.to("cuda")
1257
+ pass
1258
+
1259
+
1260
+ # In[ ]:
1261
+
1262
+
1263
+ prompt = predicted_caption
1264
+ recon = pipe(prompt=prompt).images[0]
1265
+
1266
+
1267
+ # In[ ]:
1268
+
1269
+
1270
+ print("Seen image")
1271
+ display(utils.torch_to_Image(image))
1272
+
1273
+ print("Reconstruction")
1274
+ utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))
1275
+
src/Train_MLPMixer-img.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train_MLPMixer-img.py ADDED
@@ -0,0 +1,1444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ #from subprocess import call
9
+ #command = "jupyter nbconvert Train_MLPMixer-Copy1.ipynb --to python"
10
+ #call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import string
28
+ import h5py
29
+ from tqdm import tqdm
30
+
31
+ import webdataset as wds
32
+ import gc
33
+
34
+ import matplotlib.pyplot as plt
35
+ import torch
36
+ import torch.nn as nn
37
+ from torchvision import transforms
38
+
39
+ from accelerate import Accelerator, DeepSpeedPlugin
40
+
41
+ # tf32 data type is faster than standard float32
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+
44
+ # custom functions #
45
+ import utils
46
+
47
+ global_batch_size = 16 #128
48
+
49
+ import os
50
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
51
+
52
+
53
+ # In[3]:
54
+
55
+
56
+ ### Multi-GPU config ###
57
+ local_rank = os.getenv('RANK')
58
+ if local_rank is None:
59
+ local_rank = 0
60
+ else:
61
+ local_rank = int(local_rank)
62
+ print("LOCAL RANK ", local_rank)
63
+
64
+ num_devices = torch.cuda.device_count()
65
+ if num_devices==0: num_devices = 1
66
+
67
+ #accelerator = Accelerator(split_batches=False)
68
+
69
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
70
+
71
+ if num_devices <= 1 and utils.is_interactive():
72
+ # can emulate a distributed environment for deepspeed to work in jupyter notebook
73
+ os.environ["MASTER_ADDR"] = "localhost"
74
+ os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
75
+ os.environ["RANK"] = "0"
76
+ os.environ["LOCAL_RANK"] = "0"
77
+ os.environ["WORLD_SIZE"] = "1"
78
+ os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
79
+ global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
80
+
81
+ # alter the deepspeed config according to your global and local batch size
82
+ if local_rank == 0:
83
+ with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2_cpuoffload.json', 'r') as file:
84
+ config = json.load(file)
85
+ config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
86
+ config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
87
+ with open('deepspeed_config_stage2.json', 'w') as file:
88
+ json.dump(config, file)
89
+ else:
90
+ # give some time for the local_rank=0 gpu to prep new deepspeed config file
91
+ time.sleep(10)
92
+ deepspeed_plugin = DeepSpeedPlugin("/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2_cpuoffload.json")
93
+ accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
94
+
95
+
96
+ # In[4]:
97
+
98
+
99
+ print("PID of this process =",os.getpid())
100
+ device = accelerator.device
101
+ print("device:",device)
102
+ num_workers = num_devices
103
+ print(accelerator.state)
104
+ world_size = accelerator.state.num_processes
105
+ distributed = not accelerator.state.distributed_type == 'NO'
106
+
107
+ # set data_type to match your mixed precision (automatically set based on deepspeed config)
108
+ if accelerator.mixed_precision == "bf16":
109
+ data_type = torch.bfloat16
110
+ elif accelerator.mixed_precision == "fp16":
111
+ data_type = torch.float16
112
+ else:
113
+ data_type = torch.float32
114
+
115
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
116
+ print = accelerator.print # only print if local_rank=0
117
+
118
+
119
+ # In[5]:
120
+
121
+
122
+ accelerator.state.distributed_type
123
+
124
+
125
+ # # Configurations
126
+
127
+ # In[6]:
128
+
129
+
130
+ # if running this interactively, can specify jupyter_args here for argparser to use
131
+ if utils.is_interactive():
132
+ # create random model_name
133
+ model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
134
+ model_name = model_name + "_interactive"
135
+ print("model_name:", model_name)
136
+
137
+ # global_batch_size and batch_size should already be defined in the above cells
138
+ # other variables can be specified in the following string:
139
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
140
+ --model_name={model_name} \
141
+ --subj=1 --batch_size={global_batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=1024 \
142
+ --clip_scale=1. --blur_scale=100. --depth_scale=100. \
143
+ --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving"
144
+
145
+ jupyter_args = jupyter_args.split()
146
+ print(jupyter_args)
147
+
148
+ from IPython.display import clear_output # function to clear print outputs in cell
149
+ get_ipython().run_line_magic('load_ext', 'autoreload')
150
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
151
+ get_ipython().run_line_magic('autoreload', '2')
152
+
153
+
154
+ # In[7]:
155
+
156
+
157
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
158
+ parser.add_argument(
159
+ "--model_name", type=str, default="testing",
160
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
161
+ )
162
+ parser.add_argument(
163
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
164
+ help="Path to where NSD data is stored / where to download it to",
165
+ )
166
+ parser.add_argument(
167
+ "--subj",type=int, default=1, choices=[1,2,5,7],
168
+ )
169
+ parser.add_argument(
170
+ "--batch_size", type=int, default=32,
171
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
172
+ )
173
+ parser.add_argument(
174
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=True,
175
+ help="whether to log to wandb",
176
+ )
177
+ parser.add_argument(
178
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
179
+ help="if not using wandb and want to resume from a ckpt",
180
+ )
181
+ parser.add_argument(
182
+ "--wandb_project",type=str,default="stability",
183
+ help="wandb project name",
184
+ )
185
+ parser.add_argument(
186
+ "--mixup_pct",type=float,default=.33,
187
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
188
+ )
189
+ parser.add_argument(
190
+ "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
191
+ help="whether to output blurry reconstructions",
192
+ )
193
+ parser.add_argument(
194
+ "--depth_recon",action=argparse.BooleanOptionalAction,default=True,
195
+ help="whether to output depth reconstructions",
196
+ )
197
+ parser.add_argument(
198
+ "--blur_scale",type=float,default=100.,
199
+ help="multiply loss from blurry recons by this number",
200
+ )
201
+ parser.add_argument(
202
+ "--depth_scale",type=float,default=100.,
203
+ help="multiply loss from depth recons by this number",
204
+ )
205
+ parser.add_argument(
206
+ "--clip_scale",type=float,default=1.,
207
+ help="multiply contrastive loss by this number",
208
+ )
209
+ parser.add_argument(
210
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
211
+ help="whether to use image augmentation",
212
+ )
213
+ parser.add_argument(
214
+ "--num_epochs",type=int,default=120,
215
+ help="number of epochs of training",
216
+ )
217
+ parser.add_argument(
218
+ "--hidden_dim",type=int,default=4096,
219
+ )
220
+ parser.add_argument(
221
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
222
+ )
223
+ parser.add_argument(
224
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
225
+ )
226
+ parser.add_argument(
227
+ "--ckpt_interval",type=int,default=5,
228
+ help="save backup ckpt and reconstruct every x epochs",
229
+ )
230
+ parser.add_argument(
231
+ "--seed",type=int,default=42,
232
+ )
233
+ parser.add_argument(
234
+ "--max_lr",type=float,default=3e-4,
235
+ )
236
+ parser.add_argument(
237
+ "--seq_len",type=int,default=2,
238
+ )
239
+
240
+ if utils.is_interactive():
241
+ args = parser.parse_args(jupyter_args)
242
+ else:
243
+ args = parser.parse_args()
244
+
245
+ # create global variables without the args prefix
246
+ for attribute_name in vars(args).keys():
247
+ globals()[attribute_name] = getattr(args, attribute_name)
248
+
249
+
250
+ # In[8]:
251
+
252
+
253
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
254
+ if not os.path.exists(outdir) and ckpt_saving:
255
+ os.makedirs(outdir,exist_ok=True)
256
+ if use_image_aug:
257
+ import kornia
258
+ from kornia.augmentation.container import AugmentationSequential
259
+ img_augment = AugmentationSequential(
260
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
261
+ kornia.augmentation.Resize((224, 224)),
262
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
263
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
264
+ kornia.augmentation.RandomGrayscale(p=0.3),
265
+ same_on_batch=False,
266
+ data_keys=["input"],
267
+ )
268
+
269
+
270
+ # # Prep data, models, and dataloaders
271
+
272
+ # ## Dataloader
273
+
274
+ # In[9]:
275
+
276
+
277
+ if subj==1:
278
+ num_train = 24958
279
+ num_test = 2770
280
+ test_batch_size = num_test
281
+
282
+ def my_split_by_node(urls): return urls
283
+
284
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
285
+ # train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..1}.tar"
286
+ print(train_url)
287
+
288
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
289
+ .shuffle(750, initial=1500, rng=random.Random(42))\
290
+ .decode("torch")\
291
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
292
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
293
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
294
+
295
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
296
+ print(test_url)
297
+
298
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
299
+ .shuffle(750, initial=1500, rng=random.Random(42))\
300
+ .decode("torch")\
301
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
302
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
303
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)
304
+
305
+
306
+ # ### check dataloaders are working
307
+
308
+ # In[10]:
309
+
310
+
311
+ test_vox_indices = []
312
+ test_73k_images = []
313
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
314
+ test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())
315
+ test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())
316
+ test_vox_indices = test_vox_indices.astype(np.int16)
317
+ print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))
318
+ print("---\n")
319
+
320
+ train_vox_indices = []
321
+ train_73k_images = []
322
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
323
+ train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())
324
+ train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())
325
+ train_vox_indices = train_vox_indices.astype(np.int16)
326
+ print(train_i, (train_i+1) * batch_size, len(train_vox_indices))
327
+
328
+
329
+ # ## Load data and images
330
+
331
+ # In[11]:
332
+
333
+
334
+ # load betas
335
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
336
+ # f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')
337
+
338
+ voxels = f['betas'][:]
339
+ print(f"subj0{subj} betas loaded into memory")
340
+ voxels = torch.Tensor(voxels).to("cpu").to(data_type)
341
+ print("voxels", voxels.shape)
342
+ num_voxels = voxels.shape[-1]
343
+
344
+ # load orig images
345
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
346
+ images = f['images'][:]
347
+ images = torch.Tensor(images).to("cpu").to(data_type)
348
+ print("images", images.shape)
349
+
350
+
351
+ # ## Load models
352
+
353
+ # ### CLIP image embeddings model
354
+
355
+ # In[12]:
356
+
357
+
358
+ from models import Clipper
359
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
360
+ clip_seq_dim = 257
361
+ clip_emb_dim = 768 #1024
362
+ # hidden_dim = 4096
363
+ #seq_len = 1 #2 #32
364
+
365
+
366
+ # In[13]:
367
+
368
+
369
+ clip_model2 = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=False, norm_embs=True)
370
+
371
+
372
+ # In[14]:
373
+
374
+
375
+ #out2t = clip_model2.embed_image(torch.randn(32,3,224,224))
376
+
377
+
378
+ # In[15]:
379
+
380
+
381
+ #out2t.shape
382
+
383
+
384
+ # ### SD VAE
385
+
386
+ # In[16]:
387
+
388
+
389
+ # if blurry_recon:
390
+ # from diffusers import AutoencoderKL
391
+ # autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
392
+ # # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
393
+ # autoenc.eval()
394
+ # autoenc.requires_grad_(False)
395
+ # autoenc.to(device)
396
+ # utils.count_params(autoenc)
397
+
398
+ if blurry_recon:# or depth_recon:
399
+ from diffusers import VQModel
400
+ autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type)
401
+ autoenc.eval()
402
+ autoenc.requires_grad_(False)
403
+ autoenc.to(device)
404
+ utils.count_params(autoenc)
405
+
406
+
407
+ # #### downsampled images
408
+
409
+ # In[17]:
410
+
411
+
412
+ if blurry_recon:
413
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
414
+
415
+ input_batch = images[[30]].to(device)
416
+ print(input_batch.shape)
417
+
418
+ downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)
419
+ re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')
420
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
421
+ print(re_upsampled_enc.shape)
422
+
423
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
424
+
425
+
426
+ # #### MiDaS depth
427
+
428
+ # In[18]:
429
+
430
+
431
+ if depth_recon:
432
+ from controlnet_aux.midas import MidasDetector
433
+
434
+ midas_depth = MidasDetector.from_pretrained(
435
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device)
436
+ midas_depth.model.eval()
437
+ midas_depth.model.requires_grad_(False)
438
+ midas_depth.model.to(device)
439
+ pass
440
+
441
+
442
+ # In[19]:
443
+
444
+
445
+ if depth_recon:
446
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
447
+
448
+ input_batch = images[[30,31]].float().to(device)
449
+ print(input_batch.shape)
450
+
451
+ midas_emb = midas_depth.model(input_batch).unsqueeze(1)
452
+ print(midas_emb.shape)
453
+
454
+ prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()
455
+ print(prediction.shape)
456
+
457
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
458
+ midas_emb_size = prediction.flatten(1).shape[1]
459
+ print("midas_emb", prediction.shape, prediction.min(), prediction.max())
460
+ print("midas_emb_size", midas_emb_size)
461
+
462
+ if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224)))
463
+
464
+ if blurry_recon:
465
+ prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)
466
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
467
+ prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215
468
+ print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())
469
+
470
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
471
+
472
+
473
+ # ### MindEye modules
474
+
475
+ # In[20]:
476
+
477
+
478
+ class MindEyeModule(nn.Module):
479
+ def __init__(self):
480
+ super(MindEyeModule, self).__init__()
481
+ def forward(self, x):
482
+ return x
483
+
484
+ model = MindEyeModule()
485
+ model
486
+
487
+
488
+ # In[21]:
489
+
490
+
491
+ time_embedding_dim = 512
492
+
493
+ class RidgeRegression(torch.nn.Module):
494
+ # make sure to add weight_decay when initializing optimizer
495
+ def __init__(self, input_size, out_features):
496
+ super(RidgeRegression, self).__init__()
497
+ self.out_features = out_features
498
+ self.linear = torch.nn.Linear(input_size, out_features)
499
+ def forward(self, x):
500
+ return self.linear(x)
501
+
502
+ model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
503
+ utils.count_params(model.ridge)
504
+ utils.count_params(model)
505
+
506
+ b = torch.randn((2,1,voxels.shape[1]))
507
+ time_emb_test = torch.randn((2,1,time_embedding_dim))
508
+ print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
509
+
510
+
511
+ # In[22]:
512
+
513
+
514
+ num_past_voxels = 15
515
+ #seq_len = 1 + 1
516
+
517
+
518
+ # In[23]:
519
+
520
+
521
+ from functools import partial
522
+ from diffusers.models.vae import Decoder
523
+ class BrainNetwork(nn.Module):
524
+ def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):
525
+ super().__init__()
526
+ self.seq_len = seq_len
527
+ self.h = h
528
+ self.clip_size = clip_size
529
+
530
+ # Initial linear layer to match the input dimensions to hidden dimensions
531
+ # self.lin0 = nn.Linear(in_dim, seq_len * h)
532
+
533
+ # Mixer Blocks
534
+ self.mixer_blocks1 = nn.ModuleList([
535
+ self.mixer_block1(h, drop) for _ in range(n_blocks)
536
+ ])
537
+ self.mixer_blocks2 = nn.ModuleList([
538
+ self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
539
+ ])
540
+
541
+ # Output linear layer
542
+ self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)
543
+
544
+ # low-rank matrices
545
+ # self.rank = 500
546
+ # self.U = nn.Parameter(torch.randn(self.rank, out_dim))
547
+ # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))
548
+ # self.S = nn.Parameter(torch.randn(out_dim))
549
+
550
+ self.clip_proj = nn.Sequential(
551
+ nn.LayerNorm(clip_size),
552
+ nn.GELU(),
553
+ nn.Linear(clip_size, 2048),
554
+ nn.LayerNorm(2048),
555
+ nn.GELU(),
556
+ nn.Linear(2048, 2048),
557
+ nn.LayerNorm(2048),
558
+ nn.GELU(),
559
+ nn.Linear(2048, clip_size)
560
+ )
561
+
562
+ if blurry_recon:
563
+ # self.blin1 = nn.Sequential(
564
+ # nn.Linear(out_dim, 4096, bias=True),
565
+ # nn.LayerNorm(4096),
566
+ # nn.GELU(),
567
+ # nn.Linear(4096, 4096))
568
+ self.blin1 = nn.Linear(h*seq_len, 4096)
569
+ self.bgroupnorm = nn.GroupNorm(1, 256)
570
+ self.bupsampler = Decoder(
571
+ in_channels=256,
572
+ out_channels=128,
573
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
574
+ block_out_channels=[32, 64, 128],
575
+ layers_per_block=1,
576
+ )
577
+
578
+ if depth_recon:
579
+ # self.dlin1 = nn.Sequential(
580
+ # nn.Linear(h, midas_emb_size),
581
+ # nn.Sigmoid(),
582
+ # )
583
+ self.dlin1 = nn.Linear(h*seq_len, 4096)
584
+ self.dgroupnorm = nn.GroupNorm(1, 256)
585
+ self.dupsampler = Decoder(
586
+ in_channels=256,
587
+ out_channels=1,#128,
588
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
589
+ block_out_channels=[32, 64, 128, 256],
590
+ layers_per_block=1,
591
+ )
592
+
593
+ def mixer_block1(self, h, drop):
594
+ return nn.Sequential(
595
+ nn.LayerNorm(h),
596
+ self.mlp(h, h, drop), # Token mixing
597
+ )
598
+
599
+ def mixer_block2(self, seq_len, drop):
600
+ return nn.Sequential(
601
+ nn.LayerNorm(seq_len),
602
+ self.mlp(seq_len, seq_len, drop) # Channel mixing
603
+ )
604
+
605
+ def mlp(self, in_dim, out_dim, drop):
606
+ return nn.Sequential(
607
+ nn.Linear(in_dim, out_dim),
608
+ nn.GELU(),
609
+ nn.Dropout(drop),
610
+ nn.Linear(out_dim, out_dim),
611
+ )
612
+
613
+ def forward(self, x, idx = None):
614
+ print(idx)
615
+ # make empty tensors for blur and depth outputs
616
+ b,d = torch.Tensor([0.]), torch.Tensor([0.])
617
+
618
+ # Initial linear layer
619
+ # x = self.lin0(x)
620
+
621
+ # Reshape to seq_len by dim
622
+ # x = x.reshape(-1, self.seq_len, self.h)
623
+
624
+ # Mixer blocks
625
+ #print("x shape ", x.shape)
626
+ residual1 = x
627
+ residual2 = x.permute(0,2,1)
628
+ #print("residual 2", residual2.shape)
629
+ for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
630
+ x = block1(x) + residual1
631
+ #print("xblo", x.shape)
632
+ residual1 = x
633
+ x = x.permute(0,2,1)
634
+
635
+ x = block2(x) + residual2
636
+ #print("xblo2", x.shape)
637
+ residual2 = x
638
+ x = x.permute(0,2,1)
639
+
640
+ # Flatten
641
+ x = x.reshape(x.size(0), -1)
642
+
643
+ c = self.clin1(x)
644
+
645
+ # low rank linear to out dim cuts # params by nearly half compared to full linear mapping
646
+ # c = (x @ (self.V/100) @ (self.U/100)) + self.S
647
+
648
+ c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))
649
+
650
+ if blurry_recon:
651
+ b = self.blin1(x)
652
+ b = b.reshape(len(b), 256, 4, 4)
653
+ b = self.bgroupnorm(b)
654
+ b = self.bupsampler(b)
655
+
656
+ if depth_recon:
657
+ d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)
658
+ d = d.reshape(len(d), 256, 4, 4)
659
+ d = self.dgroupnorm(d)
660
+ d = self.dupsampler(d)
661
+
662
+ return c, b, d
663
+
664
+
665
+ class TimeEmbedding(nn.Module):
666
+ def __init__(self, embedding_time_dim=512, num_past_voxels=15):
667
+ super().__init__()
668
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
669
+ self.num_past_voxels = num_past_voxels
670
+ self.embedding_time_dim = embedding_time_dim
671
+
672
+ def forward(self, time):
673
+ # time is (batch_size,)
674
+ time = time.long()
675
+ time = self.embedding_time(time)
676
+ return time # (batch_size, embedding_time_dim)
677
+
678
+
679
+ #model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
680
+ model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
681
+
682
+ model.backbone = BrainNetwork(h=hidden_dim + clip_emb_dim, in_dim=hidden_dim + clip_emb_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim)
683
+ utils.count_params(model.backbone)
684
+ utils.count_params(model)
685
+
686
+ # test that the model works on some fake data
687
+ b = torch.randn((1,seq_len,hidden_dim + clip_emb_dim))
688
+ print("b.shape",b.shape)
689
+ with torch.no_grad():
690
+ clip_, blur_, depth_ = model.backbone(b)
691
+ print(clip_.shape, blur_.shape, depth_.shape)
692
+
693
+
694
+ # In[24]:
695
+
696
+
697
+ """
698
+ voxel_ridge = torch.randn(512,4096)
699
+ voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)
700
+ print("b.shape",voxel_ridge.shape)
701
+ with torch.no_grad():
702
+ clip_, blur_, depth_ = model.backbone(voxel_ridge)
703
+ print(clip_.shape, blur_.shape, depth_.shape)"""
704
+
705
+
706
+ # In[25]:
707
+
708
+
709
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
710
+ opt_grouped_parameters = [
711
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
712
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
713
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
714
+ ]
715
+
716
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
717
+
718
+ if lr_scheduler_type == 'linear':
719
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
720
+ optimizer,
721
+ total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),
722
+ last_epoch=-1
723
+ )
724
+ elif lr_scheduler_type == 'cycle':
725
+ total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))
726
+ print("total_steps", total_steps)
727
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
728
+ optimizer,
729
+ max_lr=max_lr,
730
+ total_steps=total_steps,
731
+ final_div_factor=1000,
732
+ last_epoch=-1, pct_start=2/num_epochs
733
+ )
734
+
735
+ def save_ckpt(tag):
736
+ ckpt_path = outdir+f'/{tag}.pth'
737
+ print(f'saving {ckpt_path}',flush=True)
738
+ unwrapped_model = accelerator.unwrap_model(model)
739
+ try:
740
+ torch.save({
741
+ 'epoch': epoch,
742
+ 'model_state_dict': unwrapped_model.state_dict(),
743
+ 'optimizer_state_dict': optimizer.state_dict(),
744
+ 'lr_scheduler': lr_scheduler.state_dict(),
745
+ 'train_losses': losses,
746
+ 'test_losses': test_losses,
747
+ 'lrs': lrs,
748
+ }, ckpt_path)
749
+ except:
750
+ print("Couldn't save... moving on to prevent crashing.")
751
+ del unwrapped_model
752
+
753
+ print("\nDone with model preparations!")
754
+ utils.count_params(model)
755
+
756
+
757
+ # In[26]:
758
+
759
+
760
+ #nn++
761
+
762
+
763
+ # In[27]:
764
+
765
+
766
+ """pp = None
767
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
768
+ #with torch.cuda.amp.autocast(dtype=data_type):
769
+ #optimizer.zero_grad()
770
+
771
+ voxel = voxels[behav[:,0,5].cpu().long()]#.to(device)
772
+ image = images[behav[:,0,0].cpu().long()].float()#.to(device).float()
773
+
774
+ past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()]#.to(device) # batch_size, 15, 15279
775
+ past_15_times = torch.Tensor([i for i in range(seq_len)])#.to(device) # 15
776
+ print(past_behav[:,:seq_len-1,0].cpu().long())
777
+ past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()]
778
+
779
+ break
780
+
781
+ print(past_15_times)
782
+ #for past in range(1):
783
+ # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)
784
+
785
+ #if blurry_recon:
786
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
787
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
788
+
789
+ if depth_recon:
790
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
791
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
792
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
793
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
794
+
795
+ if use_image_aug:
796
+ image = img_augment(image)
797
+
798
+ clip_target = clip_model.embed_image(image)
799
+ assert not torch.any(torch.isnan(clip_target))
800
+
801
+ if epoch < int(mixup_pct * num_epochs):
802
+ voxel, perm, betas, select = utils.mixco(voxel)
803
+ past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)
804
+
805
+ for p in range(seq_len-1):
806
+ print(past_behav.shape) #128, 15, 17
807
+ print(past_behav[:,p,-1])
808
+ print(past_15_voxels.shape) # 128, 1, 15724
809
+ mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])
810
+ print(mask) # 128
811
+ past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])
812
+ print(past_15_voxels)
813
+ pp = past_15_voxels
814
+
815
+ break"""
816
+
817
+
818
+ # In[28]:
819
+
820
+
821
+ #pp[20, 0, :]
822
+
823
+
824
+ # # Weights and Biases
825
+
826
+ # In[29]:
827
+
828
+
829
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
830
+ import wandb
831
+ wandb_project = 'mindeyev2'
832
+ wandb_run = model_name
833
+ wandb_notes = ''
834
+
835
+ print(f"wandb {wandb_project} run {wandb_run}")
836
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
837
+ wandb_config = {
838
+ "model_name": model_name,
839
+ "global_batch_size": global_batch_size,
840
+ "batch_size": batch_size,
841
+ "num_epochs": num_epochs,
842
+ "clip_scale": clip_scale,
843
+ "blur_scale": blur_scale,
844
+ "use_image_aug": use_image_aug,
845
+ "max_lr": max_lr,
846
+ "mixup_pct": mixup_pct,
847
+ "num_train": num_train,
848
+ "num_test": num_test,
849
+ "ckpt_interval": ckpt_interval,
850
+ "ckpt_saving": ckpt_saving,
851
+ "seed": seed,
852
+ "distributed": distributed,
853
+ "num_devices": num_devices,
854
+ "world_size": world_size,
855
+ "train_url": train_url,
856
+ "test_url": test_url,
857
+ }
858
+ print("wandb_config:\n",wandb_config)
859
+ if False: # wandb_auto_resume
860
+ print("wandb_id:",model_name)
861
+ wandb.init(
862
+ id = model_name,
863
+ project=wandb_project,
864
+ name=wandb_run,
865
+ config=wandb_config,
866
+ notes=wandb_notes,
867
+ resume="allow",
868
+ )
869
+ else:
870
+ wandb.init(
871
+ project=wandb_project,
872
+ name=wandb_run,
873
+ config=wandb_config,
874
+ notes=wandb_notes,
875
+ )
876
+ else:
877
+ wandb_log = False
878
+
879
+
880
+ # # Main
881
+
882
+ # In[30]:
883
+
884
+
885
+ epoch = 0
886
+ losses, test_losses, lrs = [], [], []
887
+ best_test_loss = 1e9
888
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
889
+
890
+ # Optionally resume from checkpoint #
891
+ if resume_from_ckpt:
892
+ print("\n---resuming from last.pth ckpt---\n")
893
+ try:
894
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
895
+ except:
896
+ print('last.pth failed... trying last_backup.pth')
897
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
898
+ epoch = checkpoint['epoch']
899
+ print("Epoch",epoch)
900
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
901
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
902
+ model.load_state_dict(checkpoint['model_state_dict'])
903
+ del checkpoint
904
+ elif wandb_log:
905
+ if wandb.run.resumed:
906
+ print("\n---resuming from last.pth ckpt---\n")
907
+ try:
908
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
909
+ except:
910
+ print('last.pth failed... trying last_backup.pth')
911
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
912
+ epoch = checkpoint['epoch']
913
+ print("Epoch",epoch)
914
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
915
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
916
+ model.load_state_dict(checkpoint['model_state_dict'])
917
+ del checkpoint
918
+ torch.cuda.empty_cache()
919
+
920
+
921
+ # In[31]:
922
+
923
+
924
+ model, optimizer, train_dl, lr_scheduler = accelerator.prepare(
925
+ model, optimizer, train_dl, lr_scheduler
926
+ )
927
+ # leaving out test_dl since we will only have local_rank 0 device do evals
928
+
929
+
930
+ # In[32]:
931
+
932
+
933
+ def add_saturation(image, alpha=2):
934
+ gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]
935
+ gray_image = gray_image.unsqueeze(1).expand_as(image)
936
+ saturated_image = alpha * image + (1 - alpha) * gray_image
937
+ return torch.clamp(saturated_image, 0, 1)
938
+
939
+
940
+ # In[33]:
941
+
942
+
943
+ #b = torch.randn(1,2)
944
+ #b.to(device)
945
+
946
+
947
+ # In[34]:
948
+
949
+
950
+ #device
951
+
952
+
953
+ # In[35]:
954
+
955
+
956
+ #past_15_times = torch.Tensor([i for i in range(seq_len-1)]).long() # 15
957
+ #past_15_times.to(device)
958
+
959
+
960
+ # In[36]:
961
+
962
+
963
+ #nn++
964
+
965
+
966
+ # In[ ]:
967
+
968
+
969
+ #images.shape
970
+
971
+
972
+ # In[94]:
973
+
974
+
975
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
976
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
977
+ test_image, test_voxel = None, None
978
+ mse = nn.MSELoss()
979
+ l1 = nn.L1Loss()
980
+
981
+ for epoch in progress_bar:
982
+ model.train()
983
+
984
+ fwd_percent_correct = 0.
985
+ bwd_percent_correct = 0.
986
+ test_fwd_percent_correct = 0.
987
+ test_bwd_percent_correct = 0.
988
+
989
+ loss_clip_total = 0.
990
+ loss_blurry_total = 0.
991
+ loss_depth_total = 0.
992
+ test_loss_clip_total = 0.
993
+ test_loss_blurry_total = 0.
994
+ test_loss_depth_total = 0.
995
+
996
+ blurry_pixcorr = 0.
997
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
998
+
999
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
1000
+ with torch.cuda.amp.autocast():
1001
+ optimizer.zero_grad()
1002
+
1003
+ #voxel = voxels[behav[:,0,5].cpu().long()].to(device)
1004
+ #image = images[behav[:,0,0].cpu().long()].to(device).float()
1005
+
1006
+ #past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
1007
+ #past_15_times = torch.Tensor([i for i in range(seq_len - 1)]).to(device) # 15
1008
+
1009
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
1010
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
1011
+
1012
+ past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
1013
+ #print(past_behav[:,:seq_len-1,0].cpu().long(), behav[:,0,0].cpu().long(), past_behav[:,:seq_len-1,0].cpu().long()[0])
1014
+ past_15_images = images[past_behav[:,:seq_len-1,0].cpu().long()].to(device).float()
1015
+ past_array = [i for i in range(seq_len-1)]
1016
+ past_15_times = torch.Tensor(past_array) # 15
1017
+ #print(past_15_times)
1018
+ #print(past_15_voxels.shape, past_behav[:,:seq_len-1,5].cpu().long())
1019
+ past_15_times = past_15_times.to(device)
1020
+ #for past in range(1):
1021
+ # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)
1022
+
1023
+ if blurry_recon:
1024
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
1025
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
1026
+
1027
+ if depth_recon:
1028
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
1029
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
1030
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
1031
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
1032
+
1033
+ if use_image_aug:
1034
+ image = img_augment(image)
1035
+
1036
+ clip_target = clip_model.embed_image(image)
1037
+ assert not torch.any(torch.isnan(clip_target))
1038
+
1039
+ if epoch < int(mixup_pct * num_epochs):
1040
+ voxel, perm, betas, select = utils.mixco(voxel)
1041
+ past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)
1042
+
1043
+ #print(past_15_images.shape)
1044
+
1045
+ for p in range(seq_len-1):
1046
+ #print(past_behav.shape) #128, 15, 17
1047
+ #print(past_behav[:,p,-1])
1048
+ #print(past_15_voxels.shape) # 128, 1, 15724
1049
+ mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])
1050
+ #print(mask) # 128
1051
+ past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])
1052
+ past_15_images[mask, p, :] = torch.zeros_like(past_15_images[0, p, :])
1053
+ #print(past_15_voxels)
1054
+
1055
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
1056
+ past_15_images = past_15_images.reshape(-1, past_15_images.shape[-3], past_15_images.shape[-2], past_15_images.shape[-1])
1057
+ #print(past_15_images.shape)
1058
+ past_15_embeddings = clip_model2.embed_image(past_15_images)
1059
+ #print(past_15_embeddings.shape, 'uteho')
1060
+ past_15_embeddings = torch.cat([torch.zeros(batch_size, past_15_embeddings.shape[-1]).to(past_15_embeddings.device), past_15_embeddings], dim = 0)
1061
+ #print('tuhet', past_15_embeddings.shape)
1062
+ #print('yepe', past_15_embeddings[0,:])
1063
+ #print('yepe', past_15_embeddings[17,:])
1064
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
1065
+ past_15_times = past_15_times.reshape(-1)
1066
+ time_embeddings = model.time_embedding(past_15_times)
1067
+
1068
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
1069
+
1070
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
1071
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
1072
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
1073
+ voxel_ridge = voxel_ridge.view(seq_len,int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2)
1074
+ #past_15_embeddings = torch.split(past_15_embeddings, seq_len)
1075
+ #print(past_15_embeddings, 'ttt')
1076
+ past_15_embeddings = past_15_embeddings.reshape(seq_len, int(past_15_embeddings.shape[0]/seq_len), clip_emb_dim).permute(1,0,2)
1077
+ #unsqueeze(1) # bz * 2, 1, 4096
1078
+ #print(voxel_ridge.shape, past_15_embeddings.shape)
1079
+ #print('yepe', past_15_embeddings[10,0,:])
1080
+ #print('yepe', past_15_embeddings[10,1,:])
1081
+ voxel_ridge = torch.cat((voxel_ridge, past_15_embeddings), dim=-1)
1082
+ #print(voxel_ridge[:,0,-10:-1])
1083
+ #print(voxel_ridge[:,0,10:20])
1084
+ #raise("uehot")
1085
+ # past_voxel_ridge = model.ridge(past_voxel)
1086
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)
1087
+ #print(voxel_ridge.shape)
1088
+
1089
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)
1090
+
1091
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1092
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1093
+
1094
+ if epoch < int(mixup_pct * num_epochs):
1095
+ loss_clip = utils.mixco_nce(
1096
+ clip_voxels_norm,
1097
+ clip_target_norm,
1098
+ temp=.006,
1099
+ perm=perm, betas=betas, select=select)
1100
+ else:
1101
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
1102
+ loss_clip = utils.soft_clip_loss(
1103
+ clip_voxels_norm,
1104
+ clip_target_norm,
1105
+ temp=epoch_temp)
1106
+
1107
+ loss_clip_total += loss_clip.item()
1108
+ loss_clip *= clip_scale
1109
+ loss = loss_clip
1110
+
1111
+ if blurry_recon:
1112
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
1113
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
1114
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
1115
+
1116
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
1117
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
1118
+ loss_blurry_total += loss_blurry.item()
1119
+ loss_blurry *= blur_scale
1120
+ loss += loss_blurry
1121
+
1122
+ if depth_recon:
1123
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
1124
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
1125
+ loss_depth_total += loss_depth.item()
1126
+ loss_depth *= depth_scale
1127
+ loss += loss_depth
1128
+
1129
+ # forward and backward top 1 accuracy
1130
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
1131
+ fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()
1132
+ bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()
1133
+
1134
+ if blurry_recon:
1135
+ with torch.no_grad():
1136
+ # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()
1137
+ random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)
1138
+ # random_samps = np.arange(batch_size//5)
1139
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)
1140
+ # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean
1141
+ pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)
1142
+ # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
1143
+ # loss += (1 - pixcorr)
1144
+ blurry_pixcorr += pixcorr.item()
1145
+ # utils.check_loss(pixcorr)
1146
+
1147
+ utils.check_loss(loss)
1148
+ accelerator.backward(loss)
1149
+ optimizer.step()
1150
+
1151
+ losses.append(loss.item())
1152
+ lrs.append(optimizer.param_groups[0]['lr'])
1153
+
1154
+ if lr_scheduler_type is not None:
1155
+ lr_scheduler.step()
1156
+
1157
+ model.eval()
1158
+ if local_rank==0:
1159
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):
1160
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
1161
+ # all test samples should be loaded per batch such that test_i should never exceed 0
1162
+ assert len(behav) == num_test
1163
+
1164
+ ## Average same-image repeats ##
1165
+ if test_image is None:
1166
+ voxel = voxels[behav[:,0,5].cpu().long()]
1167
+ image = behav[:,0,0].cpu().long()
1168
+
1169
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
1170
+ for im in unique_image:
1171
+ locs = torch.where(im == image)[0]
1172
+ if test_image is None:
1173
+ test_image = images[im][None]
1174
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
1175
+ else:
1176
+ test_image = torch.vstack((test_image, images[im][None]))
1177
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
1178
+
1179
+ # random sample of 300
1180
+ random_indices = torch.arange(len(test_voxel))[:300]
1181
+ voxel = test_voxel[random_indices].to(device)
1182
+ image = test_image[random_indices].to(device)
1183
+ assert len(image) == 300
1184
+
1185
+ current_past_behav = past_behav[random_indices]
1186
+
1187
+ past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
1188
+ past_15_images = images[current_past_behav[:,:seq_len-1,0].cpu().long()].to(device).float()
1189
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
1190
+
1191
+ if blurry_recon:
1192
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
1193
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
1194
+
1195
+ if depth_recon:
1196
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
1197
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
1198
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
1199
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
1200
+
1201
+ clip_target = clip_model.embed_image(image.float())
1202
+
1203
+
1204
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
1205
+ past_15_images = past_15_images.reshape(-1, past_15_images.shape[-3], past_15_images.shape[-2], past_15_images.shape[-1])
1206
+ #print(past_15_images.shape)
1207
+ past_15_embeddings = clip_model2.embed_image(past_15_images)
1208
+ #print(past_15_embeddings.shape)
1209
+ past_15_embeddings = torch.cat([torch.zeros(image.shape[0], past_15_embeddings.shape[-1]).to(past_15_embeddings.device), past_15_embeddings], dim = 0)
1210
+ #print(past_15_embeddings.shape)
1211
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
1212
+ past_15_times = past_15_times.reshape(-1)
1213
+ time_embeddings = model.time_embedding(past_15_times)
1214
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
1215
+
1216
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
1217
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
1218
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
1219
+ voxel_ridge = voxel_ridge.view(seq_len, int(voxel_ridge.shape[0]/seq_len), hidden_dim).permute(1,0,2)
1220
+ past_15_embeddings = past_15_embeddings.view(seq_len, int(past_15_embeddings.shape[0]/seq_len), clip_emb_dim).permute(1,0,2)
1221
+ #print(past_15_embeddings.shape, voxel_ridge.shape)
1222
+ voxel_ridge = torch.cat((voxel_ridge, past_15_embeddings), dim=-1)
1223
+
1224
+ #voxel_ridge = model.ridge(voxel).unsqueeze(1)
1225
+
1226
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)
1227
+
1228
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)
1229
+
1230
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1231
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1232
+
1233
+ loss_clip = utils.soft_clip_loss(
1234
+ clip_voxels_norm,
1235
+ clip_target_norm,
1236
+ temp=.006)
1237
+ test_loss_clip_total += loss_clip.item()
1238
+ loss_clip = loss_clip * clip_scale
1239
+ loss = loss_clip
1240
+
1241
+ if blurry_recon:
1242
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
1243
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
1244
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
1245
+
1246
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
1247
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
1248
+ test_loss_blurry_total += loss_blurry.item()
1249
+ loss_blurry *= blur_scale
1250
+ loss += loss_blurry
1251
+
1252
+ # halving the batch size because the decoder is computationally heavy
1253
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)
1254
+ blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1255
+ pixcorr = utils.pixcorr(image, blurry_recon_images)
1256
+ loss += (1 - pixcorr)
1257
+ test_blurry_pixcorr += pixcorr.item()
1258
+
1259
+ if depth_recon:
1260
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
1261
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
1262
+ test_loss_depth_total += loss_depth.item()
1263
+ loss_depth *= depth_scale
1264
+ loss += loss_depth
1265
+
1266
+ # forward and backward top 1 accuracy
1267
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
1268
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
1269
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()
1270
+
1271
+ utils.check_loss(loss)
1272
+ test_losses.append(loss.item())
1273
+
1274
+ # if utils.is_interactive(): clear_output(wait=True)
1275
+ print("---")
1276
+
1277
+ assert (test_i+1) == 1
1278
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
1279
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
1280
+ "train/lr": lrs[-1],
1281
+ "train/num_steps": len(losses),
1282
+ "test/num_steps": len(test_losses),
1283
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1284
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1285
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1286
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1287
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1288
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1289
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1290
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1291
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1292
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1293
+ "train/loss_depth_total": loss_depth_total / (train_i + 1),
1294
+ "test/loss_depth_total": test_loss_depth_total / (test_i + 1),
1295
+ }
1296
+
1297
+ if blurry_recon:
1298
+ # transform blurry recon latents to images and plot it
1299
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1300
+ jj=-1
1301
+ for j in [0,1,2,3]:
1302
+ jj+=1
1303
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1304
+ axes[jj].axis('off')
1305
+ jj+=1
1306
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1307
+ axes[jj].axis('off')
1308
+
1309
+ if wandb_log:
1310
+ logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1311
+ plt.close()
1312
+ else:
1313
+ plt.show()
1314
+
1315
+ if depth_recon:
1316
+ # transform blurry recon latents to images and plot it
1317
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1318
+ # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1319
+ # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1320
+ jj=-1
1321
+ for j in [0,1,2,3]:
1322
+ jj+=1
1323
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))
1324
+ axes[jj].axis('off')
1325
+ jj+=1
1326
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))
1327
+ axes[jj].axis('off')
1328
+ if wandb_log:
1329
+ logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1330
+ plt.close()
1331
+ else:
1332
+ plt.show()
1333
+
1334
+ progress_bar.set_postfix(**logs)
1335
+
1336
+ # Save model checkpoint and reconstruct
1337
+ if epoch % ckpt_interval == 0:
1338
+ if not utils.is_interactive():
1339
+ save_ckpt(f'last')
1340
+
1341
+ if wandb_log: wandb.log(logs)
1342
+
1343
+ # wait for other GPUs to catch up if needed
1344
+ accelerator.wait_for_everyone()
1345
+ torch.cuda.empty_cache()
1346
+ gc.collect()
1347
+
1348
+ print("\n===Finished!===\n")
1349
+ if ckpt_saving:
1350
+ save_ckpt(f'last')
1351
+ if not utils.is_interactive():
1352
+ sys.exit(0)
1353
+
1354
+
1355
+ # In[ ]:
1356
+
1357
+
1358
+ plt.plot(losses)
1359
+ plt.show()
1360
+ plt.plot(test_losses)
1361
+ plt.show()
1362
+
1363
+
1364
+ # # Retrieve nearest neighbor in the training set using test set data
1365
+
1366
+ # In[ ]:
1367
+
1368
+
1369
+ annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy")
1370
+
1371
+
1372
+ # In[ ]:
1373
+
1374
+
1375
+ ii=2
1376
+ all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))
1377
+ with torch.no_grad(), torch.cuda.amp.autocast():
1378
+ for batch in tqdm(range(0,len(all_indices),512)):
1379
+ if batch==0:
1380
+ clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1381
+ else:
1382
+ target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1383
+ clip_target = torch.vstack((clip_target,target))
1384
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1385
+
1386
+ voxel = test_voxel[[ii]].to(device)
1387
+ image = test_image[[ii]].to(device)
1388
+
1389
+ print("Original Image (test set)")
1390
+ display(utils.torch_to_Image(image))
1391
+
1392
+ clip_target = clip_model.embed_image(image).cpu()
1393
+ # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))
1394
+
1395
+ voxel_ridge = model.ridge(voxel).unsqueeze(1)
1396
+ clip_voxels, _, _ = model.backbone(voxel_ridge)
1397
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1398
+ clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1399
+
1400
+ print("clip_voxels_norm", clip_voxels_norm.shape)
1401
+ print("clip_target_norm", clip_target_norm.shape)
1402
+
1403
+ sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(),
1404
+ clip_target_norm).flatten()).flip(0)
1405
+ picks = all_indices[sortt[:5]]
1406
+
1407
+ print("\nNearest neighbors in training set")
1408
+ for ip,p in enumerate(picks):
1409
+ display(utils.torch_to_Image(images[[p]]))
1410
+ # print(utils.select_annotations([annots[int(p)]]))
1411
+ if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]
1412
+
1413
+ print("\n=====\npredicted_caption:\n", predicted_caption)
1414
+
1415
+
1416
+ # # Feed into Stable Diffusion XL for reconstructions
1417
+
1418
+ # In[ ]:
1419
+
1420
+
1421
+ from diffusers import StableDiffusionXLPipeline
1422
+ pipe = StableDiffusionXLPipeline.from_pretrained(
1423
+ "/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
1424
+ )
1425
+ pipe.to("cuda")
1426
+ pass
1427
+
1428
+
1429
+ # In[ ]:
1430
+
1431
+
1432
+ prompt = predicted_caption
1433
+ recon = pipe(prompt=prompt).images[0]
1434
+
1435
+
1436
+ # In[ ]:
1437
+
1438
+
1439
+ print("Seen image")
1440
+ display(utils.torch_to_Image(image))
1441
+
1442
+ print("Reconstruction")
1443
+ utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))
1444
+
src/Train_MLPMixer.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/Train_MLPMixer.py ADDED
@@ -0,0 +1,1275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ #from subprocess import call
9
+ #command = "jupyter nbconvert Train_MLPMixer-Copy1.ipynb --to python"
10
+ #call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import string
28
+ import h5py
29
+ from tqdm import tqdm
30
+
31
+ import webdataset as wds
32
+ import gc
33
+
34
+ import matplotlib.pyplot as plt
35
+ import torch
36
+ import torch.nn as nn
37
+ from torchvision import transforms
38
+
39
+ from accelerate import Accelerator, DeepSpeedPlugin
40
+
41
+ # tf32 data type is faster than standard float32
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+
44
+ # custom functions #
45
+ import utils
46
+
47
+
48
+ # In[3]:
49
+
50
+
51
+ ### Multi-GPU config ###
52
+ local_rank = os.getenv('RANK')
53
+ if local_rank is None:
54
+ local_rank = 0
55
+ else:
56
+ local_rank = int(local_rank)
57
+ print("LOCAL RANK ", local_rank)
58
+
59
+ num_devices = torch.cuda.device_count()
60
+ if num_devices==0: num_devices = 1
61
+
62
+ # ## UNCOMMENT BELOW SECTION AND COMMENT OUT DEEPSPEED SECTION TO AVOID USING DEEPSPEED ###
63
+ # accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
64
+ # global_batch_size = batch_size = 32
65
+ # data_type = torch.float16 # change depending on your mixed_precision
66
+
67
+ ### DEEPSPEED INITIALIZATION ###
68
+ if num_devices <= 1 and utils.is_interactive():
69
+ global_batch_size = batch_size = 32
70
+ print(f"Setting batch_size to {batch_size}")
71
+ # can emulate a distributed environment for deepspeed to work in jupyter notebook
72
+ os.environ["MASTER_ADDR"] = "localhost"
73
+ os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
74
+ os.environ["RANK"] = "0"
75
+ os.environ["LOCAL_RANK"] = "0"
76
+ os.environ["WORLD_SIZE"] = "1"
77
+ os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
78
+ else:
79
+ global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
80
+ batch_size = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
81
+
82
+ # alter the deepspeed config according to your global and local batch size
83
+ if local_rank == 0:
84
+ with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json', 'r') as file:
85
+ config = json.load(file)
86
+ config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
87
+ config['train_micro_batch_size_per_gpu'] = batch_size
88
+ config['bf16'] = {'enabled': False}
89
+ config['fp16'] = {'enabled': True}
90
+ with open('/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2.json', 'w') as file:
91
+ json.dump(config, file)
92
+ else:
93
+ # give some time for the local_rank=0 gpu to prep new deepspeed config file
94
+ time.sleep(10)
95
+ deepspeed_plugin = DeepSpeedPlugin("/fsx/proj-fmri/ckadirt/MindEyeV2/src/deepspeed_config_stage2_cpuoffload.json")
96
+ accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
97
+
98
+
99
+ # In[4]:
100
+
101
+
102
+ print("PID of this process =",os.getpid())
103
+ device = accelerator.device
104
+ print("device:",device)
105
+ num_workers = num_devices
106
+ print(accelerator.state)
107
+ world_size = accelerator.state.num_processes
108
+ distributed = not accelerator.state.distributed_type == 'NO'
109
+
110
+ # set data_type to match your mixed precision (automatically set based on deepspeed config)
111
+ if accelerator.mixed_precision == "bf16":
112
+ data_type = torch.bfloat16
113
+ elif accelerator.mixed_precision == "fp16":
114
+ data_type = torch.float16
115
+ else:
116
+ data_type = torch.float32
117
+
118
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
119
+ print = accelerator.print # only print if local_rank=0
120
+
121
+
122
+ # # Configurations
123
+
124
+ # In[5]:
125
+
126
+
127
+ # if running this interactively, can specify jupyter_args here for argparser to use
128
+ if utils.is_interactive():
129
+ # create random model_name
130
+ model_name = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
131
+ model_name = model_name + "_interactive"
132
+ print("model_name:", model_name)
133
+
134
+ # global_batch_size and batch_size should already be defined in the above cells
135
+ # other variables can be specified in the following string:
136
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
137
+ --model_name={model_name} \
138
+ --subj=1 --batch_size={batch_size} --no-blurry_recon --no-depth_recon --hidden_dim=4096 \
139
+ --clip_scale=1. --blur_scale=100. --depth_scale=100. \
140
+ --max_lr=3e-4 --mixup_pct=.66 --num_epochs=12 --ckpt_interval=999 --no-use_image_aug --no-ckpt_saving"
141
+
142
+ jupyter_args = jupyter_args.split()
143
+ print(jupyter_args)
144
+
145
+ from IPython.display import clear_output # function to clear print outputs in cell
146
+ get_ipython().run_line_magic('load_ext', 'autoreload')
147
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
148
+ get_ipython().run_line_magic('autoreload', '2')
149
+
150
+
151
+ # In[6]:
152
+
153
+
154
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
155
+ parser.add_argument(
156
+ "--model_name", type=str, default="testing",
157
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
158
+ )
159
+ parser.add_argument(
160
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
161
+ help="Path to where NSD data is stored / where to download it to",
162
+ )
163
+ parser.add_argument(
164
+ "--subj",type=int, default=1, choices=[1,2,5,7],
165
+ )
166
+ parser.add_argument(
167
+ "--batch_size", type=int, default=32,
168
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
169
+ )
170
+ parser.add_argument(
171
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=True,
172
+ help="whether to log to wandb",
173
+ )
174
+ parser.add_argument(
175
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
176
+ help="if not using wandb and want to resume from a ckpt",
177
+ )
178
+ parser.add_argument(
179
+ "--wandb_project",type=str,default="stability",
180
+ help="wandb project name",
181
+ )
182
+ parser.add_argument(
183
+ "--mixup_pct",type=float,default=.33,
184
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
185
+ )
186
+ parser.add_argument(
187
+ "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
188
+ help="whether to output blurry reconstructions",
189
+ )
190
+ parser.add_argument(
191
+ "--depth_recon",action=argparse.BooleanOptionalAction,default=True,
192
+ help="whether to output depth reconstructions",
193
+ )
194
+ parser.add_argument(
195
+ "--blur_scale",type=float,default=100.,
196
+ help="multiply loss from blurry recons by this number",
197
+ )
198
+ parser.add_argument(
199
+ "--depth_scale",type=float,default=100.,
200
+ help="multiply loss from depth recons by this number",
201
+ )
202
+ parser.add_argument(
203
+ "--clip_scale",type=float,default=1.,
204
+ help="multiply contrastive loss by this number",
205
+ )
206
+ parser.add_argument(
207
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
208
+ help="whether to use image augmentation",
209
+ )
210
+ parser.add_argument(
211
+ "--num_epochs",type=int,default=120,
212
+ help="number of epochs of training",
213
+ )
214
+ parser.add_argument(
215
+ "--hidden_dim",type=int,default=4096,
216
+ )
217
+ parser.add_argument(
218
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
219
+ )
220
+ parser.add_argument(
221
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
222
+ )
223
+ parser.add_argument(
224
+ "--ckpt_interval",type=int,default=5,
225
+ help="save backup ckpt and reconstruct every x epochs",
226
+ )
227
+ parser.add_argument(
228
+ "--seed",type=int,default=42,
229
+ )
230
+ parser.add_argument(
231
+ "--max_lr",type=float,default=3e-4,
232
+ )
233
+ parser.add_argument(
234
+ "--seq_len",type=int,default=2,
235
+ )
236
+
237
+ if utils.is_interactive():
238
+ args = parser.parse_args(jupyter_args)
239
+ else:
240
+ args = parser.parse_args()
241
+
242
+ # create global variables without the args prefix
243
+ for attribute_name in vars(args).keys():
244
+ globals()[attribute_name] = getattr(args, attribute_name)
245
+
246
+
247
+ # In[7]:
248
+
249
+
250
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
251
+ if not os.path.exists(outdir) and ckpt_saving:
252
+ os.makedirs(outdir,exist_ok=True)
253
+ if use_image_aug:
254
+ import kornia
255
+ from kornia.augmentation.container import AugmentationSequential
256
+ img_augment = AugmentationSequential(
257
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
258
+ kornia.augmentation.Resize((224, 224)),
259
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
260
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
261
+ kornia.augmentation.RandomGrayscale(p=0.3),
262
+ same_on_batch=False,
263
+ data_keys=["input"],
264
+ )
265
+
266
+
267
+ # # Prep data, models, and dataloaders
268
+
269
+ # ## Dataloader
270
+
271
+ # In[8]:
272
+
273
+
274
+ if subj==1:
275
+ num_train = 24958
276
+ num_test = 2770
277
+ test_batch_size = num_test
278
+
279
+ def my_split_by_node(urls): return urls
280
+
281
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
282
+ # train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..1}.tar"
283
+ print(train_url)
284
+
285
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
286
+ .shuffle(750, initial=1500, rng=random.Random(42))\
287
+ .decode("torch")\
288
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
289
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
290
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)
291
+
292
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
293
+ print(test_url)
294
+
295
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
296
+ .shuffle(750, initial=1500, rng=random.Random(42))\
297
+ .decode("torch")\
298
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
299
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
300
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=True, pin_memory=True)
301
+
302
+
303
+ # ### check dataloaders are working
304
+
305
+ # In[9]:
306
+
307
+
308
+ test_vox_indices = []
309
+ test_73k_images = []
310
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
311
+ test_vox_indices = np.append(test_vox_indices, behav[:,0,5].cpu().numpy())
312
+ test_73k_images = np.append(test_73k_images, behav[:,0,0].cpu().numpy())
313
+ test_vox_indices = test_vox_indices.astype(np.int16)
314
+ print(test_i, (test_i+1) * test_batch_size, len(test_vox_indices))
315
+ print("---\n")
316
+
317
+ train_vox_indices = []
318
+ train_73k_images = []
319
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
320
+ train_vox_indices = np.append(train_vox_indices, behav[:,0,5].long().cpu().numpy())
321
+ train_73k_images = np.append(train_73k_images, behav[:,0,0].cpu().numpy())
322
+ train_vox_indices = train_vox_indices.astype(np.int16)
323
+ print(train_i, (train_i+1) * batch_size, len(train_vox_indices))
324
+
325
+
326
+ # ## Load data and images
327
+
328
+ # In[10]:
329
+
330
+
331
+ # load betas
332
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
333
+ # f = h5py.File(f'{data_path}/betas_subj0{subj}_thresholded_wholebrain.hdf5', 'r')
334
+
335
+ voxels = f['betas'][:]
336
+ print(f"subj0{subj} betas loaded into memory")
337
+ voxels = torch.Tensor(voxels).to("cpu").to(data_type)
338
+ print("voxels", voxels.shape)
339
+ num_voxels = voxels.shape[-1]
340
+
341
+ # load orig images
342
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
343
+ images = f['images'][:]
344
+ images = torch.Tensor(images).to("cpu").to(data_type)
345
+ print("images", images.shape)
346
+
347
+
348
+ # ## Load models
349
+
350
+ # ### CLIP image embeddings model
351
+
352
+ # In[11]:
353
+
354
+
355
+ from models import Clipper
356
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
357
+ clip_seq_dim = 257
358
+ clip_emb_dim = 768 #1024
359
+ # hidden_dim = 4096
360
+ #seq_len = 1 #2 #32
361
+
362
+
363
+ # ### SD VAE
364
+
365
+ # In[12]:
366
+
367
+
368
+ # if blurry_recon:
369
+ # from diffusers import AutoencoderKL
370
+ # autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
371
+ # # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
372
+ # autoenc.eval()
373
+ # autoenc.requires_grad_(False)
374
+ # autoenc.to(device)
375
+ # utils.count_params(autoenc)
376
+
377
+ if blurry_recon:# or depth_recon:
378
+ from diffusers import VQModel
379
+ autoenc = VQModel.from_pretrained("/fsx/proj-fmri/shared/cache/models--microsoft--vq-diffusion-ithq/snapshots/3f796fb49ee559370dc638dea1d8116af131d993/vqvae", torch_dtype=data_type)
380
+ autoenc.eval()
381
+ autoenc.requires_grad_(False)
382
+ autoenc.to(device)
383
+ utils.count_params(autoenc)
384
+
385
+
386
+ # #### downsampled images
387
+
388
+ # In[13]:
389
+
390
+
391
+ if blurry_recon:
392
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
393
+
394
+ input_batch = images[[30]].to(device)
395
+ print(input_batch.shape)
396
+
397
+ downsampled_image = nn.functional.interpolate(input_batch, size=(8, 8), mode='bilinear', align_corners=False)
398
+ re_upsampled_image = nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest')
399
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
400
+ print(re_upsampled_enc.shape)
401
+
402
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(re_upsampled_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
403
+
404
+
405
+ # #### MiDaS depth
406
+
407
+ # In[14]:
408
+
409
+
410
+ if depth_recon:
411
+ from controlnet_aux.midas import MidasDetector
412
+
413
+ midas_depth = MidasDetector.from_pretrained(
414
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large", cache_dir="/fsx/proj-fmri/shared/cache").to(device)
415
+ midas_depth.model.eval()
416
+ midas_depth.model.requires_grad_(False)
417
+ midas_depth.model.to(device)
418
+ pass
419
+
420
+
421
+ # In[15]:
422
+
423
+
424
+ if depth_recon:
425
+ if utils.is_interactive(): display(utils.torch_to_Image(images[[30]]))
426
+
427
+ input_batch = images[[30,31]].float().to(device)
428
+ print(input_batch.shape)
429
+
430
+ midas_emb = midas_depth.model(input_batch).unsqueeze(1)
431
+ print(midas_emb.shape)
432
+
433
+ prediction = utils.resize(midas_emb, 32) #/30).clamp(0,1).half() # 30 is roughly prediction.max()
434
+ print(prediction.shape)
435
+
436
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
437
+ midas_emb_size = prediction.flatten(1).shape[1]
438
+ print("midas_emb", prediction.shape, prediction.min(), prediction.max())
439
+ print("midas_emb_size", midas_emb_size)
440
+
441
+ if utils.is_interactive(): display(utils.torch_to_Image(utils.resize(prediction, 224)))
442
+
443
+ if blurry_recon:
444
+ prediction = utils.resize(midas_emb, 128).half().repeat(1,3,1,1)
445
+ prediction = (prediction / prediction.view(prediction.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(prediction)).half()
446
+ prediction_enc = autoenc.encode(2*prediction-1).latents * 0.18215
447
+ print("vae midas_emb", prediction_enc.shape, prediction_enc.min(), prediction_enc.max())
448
+
449
+ if utils.is_interactive(): display(utils.torch_to_Image((autoenc.decode(prediction_enc/0.18215).sample / 2 + 0.5).clamp(0,1)))
450
+
451
+
452
+ # ### MindEye modules
453
+
454
+ # In[17]:
455
+
456
+
457
+ class MindEyeModule(nn.Module):
458
+ def __init__(self):
459
+ super(MindEyeModule, self).__init__()
460
+ def forward(self, x):
461
+ return x
462
+
463
+ model = MindEyeModule()
464
+ model
465
+
466
+
467
+ # In[18]:
468
+
469
+
470
+ time_embedding_dim = 512
471
+
472
+ class RidgeRegression(torch.nn.Module):
473
+ # make sure to add weight_decay when initializing optimizer
474
+ def __init__(self, input_size, out_features):
475
+ super(RidgeRegression, self).__init__()
476
+ self.out_features = out_features
477
+ self.linear = torch.nn.Linear(input_size, out_features)
478
+ def forward(self, x):
479
+ return self.linear(x)
480
+
481
+ model.ridge = RidgeRegression(voxels.shape[1] + time_embedding_dim, out_features=hidden_dim)
482
+ utils.count_params(model.ridge)
483
+ utils.count_params(model)
484
+
485
+ b = torch.randn((2,1,voxels.shape[1]))
486
+ time_emb_test = torch.randn((2,1,time_embedding_dim))
487
+ print(b.shape, model.ridge(torch.cat((b,time_emb_test),dim=-1)).shape)
488
+
489
+
490
+ # In[24]:
491
+
492
+
493
+ num_past_voxels = 15
494
+
495
+
496
+
497
+ # In[25]:
498
+
499
+
500
+ from functools import partial
501
+ from diffusers.models.vae import Decoder
502
+ class BrainNetwork(nn.Module):
503
+ def __init__(self, out_dim=768, in_dim=15724, seq_len=2, h=4096, n_blocks=4, drop=.15, clip_size=768):
504
+ super().__init__()
505
+ self.seq_len = seq_len
506
+ self.h = h
507
+ self.clip_size = clip_size
508
+
509
+ # Initial linear layer to match the input dimensions to hidden dimensions
510
+ # self.lin0 = nn.Linear(in_dim, seq_len * h)
511
+
512
+ # Mixer Blocks
513
+ self.mixer_blocks1 = nn.ModuleList([
514
+ self.mixer_block1(h, drop) for _ in range(n_blocks)
515
+ ])
516
+ self.mixer_blocks2 = nn.ModuleList([
517
+ self.mixer_block2(seq_len, drop) for _ in range(n_blocks)
518
+ ])
519
+
520
+ # Output linear layer
521
+ self.clin1 = nn.Linear(h * seq_len, out_dim, bias=True)
522
+
523
+ # low-rank matrices
524
+ # self.rank = 500
525
+ # self.U = nn.Parameter(torch.randn(self.rank, out_dim))
526
+ # self.V = nn.Parameter(torch.randn(h * seq_len, self.rank))
527
+ # self.S = nn.Parameter(torch.randn(out_dim))
528
+
529
+ self.clip_proj = nn.Sequential(
530
+ nn.LayerNorm(clip_size),
531
+ nn.GELU(),
532
+ nn.Linear(clip_size, 2048),
533
+ nn.LayerNorm(2048),
534
+ nn.GELU(),
535
+ nn.Linear(2048, 2048),
536
+ nn.LayerNorm(2048),
537
+ nn.GELU(),
538
+ nn.Linear(2048, clip_size)
539
+ )
540
+
541
+ if blurry_recon:
542
+ # self.blin1 = nn.Sequential(
543
+ # nn.Linear(out_dim, 4096, bias=True),
544
+ # nn.LayerNorm(4096),
545
+ # nn.GELU(),
546
+ # nn.Linear(4096, 4096))
547
+ self.blin1 = nn.Linear(h*seq_len, 4096)
548
+ self.bgroupnorm = nn.GroupNorm(1, 256)
549
+ self.bupsampler = Decoder(
550
+ in_channels=256,
551
+ out_channels=128,
552
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
553
+ block_out_channels=[32, 64, 128],
554
+ layers_per_block=1,
555
+ )
556
+
557
+ if depth_recon:
558
+ # self.dlin1 = nn.Sequential(
559
+ # nn.Linear(h, midas_emb_size),
560
+ # nn.Sigmoid(),
561
+ # )
562
+ self.dlin1 = nn.Linear(h*seq_len, 4096)
563
+ self.dgroupnorm = nn.GroupNorm(1, 256)
564
+ self.dupsampler = Decoder(
565
+ in_channels=256,
566
+ out_channels=1,#128,
567
+ up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
568
+ block_out_channels=[32, 64, 128, 256],
569
+ layers_per_block=1,
570
+ )
571
+
572
+ def mixer_block1(self, h, drop):
573
+ return nn.Sequential(
574
+ nn.LayerNorm(h),
575
+ self.mlp(h, h, drop), # Token mixing
576
+ )
577
+
578
+ def mixer_block2(self, seq_len, drop):
579
+ return nn.Sequential(
580
+ nn.LayerNorm(seq_len),
581
+ self.mlp(seq_len, seq_len, drop) # Channel mixing
582
+ )
583
+
584
+ def mlp(self, in_dim, out_dim, drop):
585
+ return nn.Sequential(
586
+ nn.Linear(in_dim, out_dim),
587
+ nn.GELU(),
588
+ nn.Dropout(drop),
589
+ nn.Linear(out_dim, out_dim),
590
+ )
591
+
592
+ def forward(self, x, idx = None):
593
+ print(idx)
594
+ # make empty tensors for blur and depth outputs
595
+ b,d = torch.Tensor([0.]), torch.Tensor([0.])
596
+
597
+ # Initial linear layer
598
+ # x = self.lin0(x)
599
+
600
+ # Reshape to seq_len by dim
601
+ # x = x.reshape(-1, self.seq_len, self.h)
602
+
603
+ # Mixer blocks
604
+ #print("x shape ", x.shape)
605
+ residual1 = x
606
+ residual2 = x.permute(0,2,1)
607
+ #print("residual 2", residual2.shape)
608
+ for block1, block2 in zip(self.mixer_blocks1,self.mixer_blocks2):
609
+ x = block1(x) + residual1
610
+ #print("xblo", x.shape)
611
+ residual1 = x
612
+ x = x.permute(0,2,1)
613
+
614
+ x = block2(x) + residual2
615
+ #print("xblo2", x.shape)
616
+ residual2 = x
617
+ x = x.permute(0,2,1)
618
+
619
+ # Flatten
620
+ x = x.reshape(x.size(0), -1)
621
+
622
+ c = self.clin1(x)
623
+
624
+ # low rank linear to out dim cuts # params by nearly half compared to full linear mapping
625
+ # c = (x @ (self.V/100) @ (self.U/100)) + self.S
626
+
627
+ c = self.clip_proj(c.reshape(len(c), -1, self.clip_size))
628
+
629
+ if blurry_recon:
630
+ b = self.blin1(x)
631
+ b = b.reshape(len(b), 256, 4, 4)
632
+ b = self.bgroupnorm(b)
633
+ b = self.bupsampler(b)
634
+
635
+ if depth_recon:
636
+ d = self.dlin1(x)#.reshape(len(x), 1, 32, 32)
637
+ d = d.reshape(len(d), 256, 4, 4)
638
+ d = self.dgroupnorm(d)
639
+ d = self.dupsampler(d)
640
+
641
+ return c, b, d
642
+
643
+
644
+ class TimeEmbedding(nn.Module):
645
+ def __init__(self, embedding_time_dim=512, num_past_voxels=15):
646
+ super().__init__()
647
+ self.embedding_time = nn.Embedding(num_past_voxels, embedding_time_dim)
648
+ self.num_past_voxels = num_past_voxels
649
+ self.embedding_time_dim = embedding_time_dim
650
+
651
+ def forward(self, time):
652
+ # time is (batch_size,)
653
+ time = time.long()
654
+ time = self.embedding_time(time)
655
+ return time # (batch_size, embedding_time_dim)
656
+
657
+
658
+ #model.memory_encoder = MemoryEncoder(in_dim=voxels.shape[1], out_dim=4096, num_past_voxels=15, embedding_time_dim=512)
659
+ model.time_embedding = TimeEmbedding(embedding_time_dim=512, num_past_voxels=15)
660
+
661
+ model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim)
662
+ utils.count_params(model.backbone)
663
+ utils.count_params(model)
664
+
665
+ # test that the model works on some fake data
666
+ b = torch.randn((1,seq_len,hidden_dim))
667
+ print("b.shape",b.shape)
668
+ with torch.no_grad():
669
+ clip_, blur_, depth_ = model.backbone(b)
670
+ print(clip_.shape, blur_.shape, depth_.shape)
671
+
672
+
673
+ # In[ ]:
674
+
675
+
676
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
677
+ opt_grouped_parameters = [
678
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
679
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
680
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
681
+ ]
682
+
683
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
684
+
685
+ if lr_scheduler_type == 'linear':
686
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
687
+ optimizer,
688
+ total_iters=int(np.floor(num_epochs*(num_train/num_devices/batch_size))),
689
+ last_epoch=-1
690
+ )
691
+ elif lr_scheduler_type == 'cycle':
692
+ total_steps=int(np.floor(num_epochs*(num_train/num_devices/batch_size)))
693
+ print("total_steps", total_steps)
694
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
695
+ optimizer,
696
+ max_lr=max_lr,
697
+ total_steps=total_steps,
698
+ final_div_factor=1000,
699
+ last_epoch=-1, pct_start=2/num_epochs
700
+ )
701
+
702
+ def save_ckpt(tag):
703
+ ckpt_path = outdir+f'/{tag}.pth'
704
+ print(f'saving {ckpt_path}',flush=True)
705
+ unwrapped_model = accelerator.unwrap_model(model)
706
+ try:
707
+ torch.save({
708
+ 'epoch': epoch,
709
+ 'model_state_dict': unwrapped_model.state_dict(),
710
+ 'optimizer_state_dict': optimizer.state_dict(),
711
+ 'lr_scheduler': lr_scheduler.state_dict(),
712
+ 'train_losses': losses,
713
+ 'test_losses': test_losses,
714
+ 'lrs': lrs,
715
+ }, ckpt_path)
716
+ except:
717
+ print("Couldn't save... moving on to prevent crashing.")
718
+ del unwrapped_model
719
+
720
+ print("\nDone with model preparations!")
721
+ utils.count_params(model)
722
+
723
+
724
+ # # Weights and Biases
725
+
726
+ # In[ ]:
727
+
728
+
729
+ if local_rank==0 and wandb_log: # only use main process for wandb logging
730
+ import wandb
731
+ wandb_project = 'mindeyev2'
732
+ wandb_run = model_name
733
+ wandb_notes = ''
734
+
735
+ print(f"wandb {wandb_project} run {wandb_run}")
736
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
737
+ wandb_config = {
738
+ "model_name": model_name,
739
+ "global_batch_size": global_batch_size,
740
+ "batch_size": batch_size,
741
+ "num_epochs": num_epochs,
742
+ "clip_scale": clip_scale,
743
+ "blur_scale": blur_scale,
744
+ "use_image_aug": use_image_aug,
745
+ "max_lr": max_lr,
746
+ "mixup_pct": mixup_pct,
747
+ "num_train": num_train,
748
+ "num_test": num_test,
749
+ "ckpt_interval": ckpt_interval,
750
+ "ckpt_saving": ckpt_saving,
751
+ "seed": seed,
752
+ "distributed": distributed,
753
+ "num_devices": num_devices,
754
+ "world_size": world_size,
755
+ "train_url": train_url,
756
+ "test_url": test_url,
757
+ }
758
+ print("wandb_config:\n",wandb_config)
759
+ if False: # wandb_auto_resume
760
+ print("wandb_id:",model_name)
761
+ wandb.init(
762
+ id = model_name,
763
+ project=wandb_project,
764
+ name=wandb_run,
765
+ config=wandb_config,
766
+ notes=wandb_notes,
767
+ resume="allow",
768
+ )
769
+ else:
770
+ wandb.init(
771
+ project=wandb_project,
772
+ name=wandb_run,
773
+ config=wandb_config,
774
+ notes=wandb_notes,
775
+ )
776
+ else:
777
+ wandb_log = False
778
+
779
+
780
+ # # Main
781
+
782
+ # In[ ]:
783
+
784
+
785
+ epoch = 0
786
+ losses, test_losses, lrs = [], [], []
787
+ best_test_loss = 1e9
788
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
789
+
790
+ # Optionally resume from checkpoint #
791
+ if resume_from_ckpt:
792
+ print("\n---resuming from last.pth ckpt---\n")
793
+ try:
794
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
795
+ except:
796
+ print('last.pth failed... trying last_backup.pth')
797
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
798
+ epoch = checkpoint['epoch']
799
+ print("Epoch",epoch)
800
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
801
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
802
+ model.load_state_dict(checkpoint['model_state_dict'])
803
+ del checkpoint
804
+ elif wandb_log:
805
+ if wandb.run.resumed:
806
+ print("\n---resuming from last.pth ckpt---\n")
807
+ try:
808
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
809
+ except:
810
+ print('last.pth failed... trying last_backup.pth')
811
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
812
+ epoch = checkpoint['epoch']
813
+ print("Epoch",epoch)
814
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
815
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
816
+ model.load_state_dict(checkpoint['model_state_dict'])
817
+ del checkpoint
818
+ torch.cuda.empty_cache()
819
+
820
+
821
+ # In[ ]:
822
+
823
+
824
+ model, optimizer, train_dl, lr_scheduler = accelerator.prepare(
825
+ model, optimizer, train_dl, lr_scheduler
826
+ )
827
+ # leaving out test_dl since we will only have local_rank 0 device do evals
828
+
829
+
830
+ # In[ ]:
831
+
832
+
833
+ def add_saturation(image, alpha=2):
834
+ gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]
835
+ gray_image = gray_image.unsqueeze(1).expand_as(image)
836
+ saturated_image = alpha * image + (1 - alpha) * gray_image
837
+ return torch.clamp(saturated_image, 0, 1)
838
+
839
+
840
+ # In[ ]:
841
+
842
+
843
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
844
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
845
+ test_image, test_voxel = None, None
846
+ mse = nn.MSELoss()
847
+ l1 = nn.L1Loss()
848
+
849
+ for epoch in progress_bar:
850
+ model.train()
851
+
852
+ fwd_percent_correct = 0.
853
+ bwd_percent_correct = 0.
854
+ test_fwd_percent_correct = 0.
855
+ test_bwd_percent_correct = 0.
856
+
857
+ loss_clip_total = 0.
858
+ loss_blurry_total = 0.
859
+ loss_depth_total = 0.
860
+ test_loss_clip_total = 0.
861
+ test_loss_blurry_total = 0.
862
+ test_loss_depth_total = 0.
863
+
864
+ blurry_pixcorr = 0.
865
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
866
+
867
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
868
+ with torch.cuda.amp.autocast(dtype=data_type):
869
+ optimizer.zero_grad()
870
+
871
+ #voxel = voxels[behav[:,0,5].cpu().long()].to(device)
872
+ #image = images[behav[:,0,0].cpu().long()].to(device).float()
873
+
874
+ #past_15_voxels = voxels[past_behav[:,:,5].cpu().long()].to(device) # batch_size, 15, 15279
875
+ #past_15_times = torch.Tensor([i for i in range(seq_len - 1)]).to(device) # 15
876
+
877
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
878
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
879
+
880
+ past_15_voxels = voxels[past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
881
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
882
+ #for past in range(1):
883
+ # past_voxel = voxels[past_behav[:,past,5].cpu().long()].to(device)
884
+
885
+ if blurry_recon:
886
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
887
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
888
+
889
+ if depth_recon:
890
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
891
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
892
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
893
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
894
+
895
+ if use_image_aug:
896
+ image = img_augment(image)
897
+
898
+ clip_target = clip_model.embed_image(image)
899
+ assert not torch.any(torch.isnan(clip_target))
900
+
901
+ if epoch < int(mixup_pct * num_epochs):
902
+ voxel, perm, betas, select = utils.mixco(voxel)
903
+ past_voxel, _, _, _ = utils.mixco(voxel, perm=perm, betas=betas, select=select)
904
+
905
+ for p in range(seq_len-1):
906
+ #print(past_behav.shape) #128, 15, 17
907
+ #print(past_behav[:,p,-1])
908
+ #print(past_15_voxels.shape) # 128, 1, 15724
909
+ mask = past_behav[:,p,-1] == torch.ones_like(past_behav[:,p,-1])
910
+ #print(mask) # 128
911
+ past_15_voxels[mask, p, :] = torch.zeros_like(past_15_voxels[0, p, :])
912
+ #print(past_15_voxels)
913
+
914
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
915
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
916
+ past_15_times = past_15_times.reshape(-1)
917
+ time_embeddings = model.time_embedding(past_15_times)
918
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
919
+
920
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
921
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
922
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
923
+ voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)
924
+ #unsqueeze(1) # bz * 2, 1, 4096
925
+
926
+ # past_voxel_ridge = model.ridge(past_voxel)
927
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1), past_voxel_ridge.unsqueeze(1)), axis=1)
928
+ #print(voxel_ridge.shape)
929
+
930
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge, idx = train_i)
931
+
932
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
933
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
934
+
935
+ if epoch < int(mixup_pct * num_epochs):
936
+ loss_clip = utils.mixco_nce(
937
+ clip_voxels_norm,
938
+ clip_target_norm,
939
+ temp=.006,
940
+ perm=perm, betas=betas, select=select)
941
+ else:
942
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
943
+ loss_clip = utils.soft_clip_loss(
944
+ clip_voxels_norm,
945
+ clip_target_norm,
946
+ temp=epoch_temp)
947
+
948
+ loss_clip_total += loss_clip.item()
949
+ loss_clip *= clip_scale
950
+ loss = loss_clip
951
+
952
+ if blurry_recon:
953
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
954
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
955
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
956
+
957
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
958
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
959
+ loss_blurry_total += loss_blurry.item()
960
+ loss_blurry *= blur_scale
961
+ loss += loss_blurry
962
+
963
+ if depth_recon:
964
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
965
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
966
+ loss_depth_total += loss_depth.item()
967
+ loss_depth *= depth_scale
968
+ loss += loss_depth
969
+
970
+ # forward and backward top 1 accuracy
971
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
972
+ fwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm)), labels, k=1).item()
973
+ bwd_percent_correct += utils.topk(torch.abs(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm)), labels, k=1).item()
974
+
975
+ if blurry_recon:
976
+ with torch.no_grad():
977
+ # only doing pixcorr eval on a subset of the samples per batch because its costly & slow to compute autoenc.decode()
978
+ random_samps = np.random.choice(np.arange(len(voxel)), size=batch_size//5, replace=False)
979
+ # random_samps = np.arange(batch_size//5)
980
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)
981
+ # pixcorr_origsize_nanmean is computationally less intense than utils.pixcorr and uses nanmean instead of mean
982
+ pixcorr = utils.pixcorr_origsize_nanmean(image[random_samps], blurry_recon_images)
983
+ # pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
984
+ # loss += (1 - pixcorr)
985
+ blurry_pixcorr += pixcorr.item()
986
+ # utils.check_loss(pixcorr)
987
+
988
+ utils.check_loss(loss)
989
+ accelerator.backward(loss)
990
+ optimizer.step()
991
+
992
+ losses.append(loss.item())
993
+ lrs.append(optimizer.param_groups[0]['lr'])
994
+
995
+ if lr_scheduler_type is not None:
996
+ lr_scheduler.step()
997
+
998
+ model.eval()
999
+ if local_rank==0:
1000
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=data_type):
1001
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
1002
+ # all test samples should be loaded per batch such that test_i should never exceed 0
1003
+ assert len(behav) == num_test
1004
+
1005
+ ## Average same-image repeats ##
1006
+ if test_image is None:
1007
+ voxel = voxels[behav[:,0,5].cpu().long()]
1008
+ image = behav[:,0,0].cpu().long()
1009
+
1010
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
1011
+ for im in unique_image:
1012
+ locs = torch.where(im == image)[0]
1013
+ if test_image is None:
1014
+ test_image = images[im][None]
1015
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
1016
+ else:
1017
+ test_image = torch.vstack((test_image, images[im][None]))
1018
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
1019
+
1020
+ # random sample of 300
1021
+ random_indices = torch.arange(len(test_voxel))[:300]
1022
+ voxel = test_voxel[random_indices].to(device)
1023
+ image = test_image[random_indices].to(device)
1024
+ assert len(image) == 300
1025
+
1026
+ current_past_behav = past_behav[random_indices]
1027
+
1028
+ past_15_voxels = voxels[current_past_behav[:,:seq_len-1,5].cpu().long()].to(device) # batch_size, 15, 15279
1029
+ past_15_times = torch.Tensor([i for i in range(seq_len-1)]).to(device) # 15
1030
+
1031
+ if blurry_recon:
1032
+ # blurry_image_enc = autoenc.encode(2*utils.resize(image,128)-1).latent_dist.mode() * 0.18215
1033
+ blurry_image_enc = autoenc.encode(2*utils.resize(add_saturation(image),128)-1).latents * 0.18215
1034
+
1035
+ if depth_recon:
1036
+ # depth_images = utils.resize(midas_depth.model(image).unsqueeze(1).repeat(1,3,1,1), 128)
1037
+ depth_images = utils.resize(midas_depth.model(image).unsqueeze(1), 32)
1038
+ depth_images = (depth_images / depth_images.view(depth_images.shape[0], -1).max(dim=1)[0].view(-1, 1, 1, 1).expand_as(depth_images)).half()
1039
+ depth_image_enc = depth_images # autoenc.encode(2*depth_images-1).latents * 0.18215
1040
+
1041
+ clip_target = clip_model.embed_image(image.float())
1042
+
1043
+
1044
+ past_15_voxels = past_15_voxels.reshape(-1, past_15_voxels.shape[-1])
1045
+ past_15_times = past_15_times.repeat(voxel.shape[0], 1)
1046
+ past_15_times = past_15_times.reshape(-1)
1047
+ time_embeddings = model.time_embedding(past_15_times)
1048
+ past_info_full = torch.cat((past_15_voxels, time_embeddings), dim=-1)
1049
+
1050
+ positional_current_voxel = torch.zeros((voxel.shape[0], time_embeddings.shape[-1])).to(voxel.device)
1051
+ voxel = torch.cat((voxel, positional_current_voxel), dim=-1)
1052
+ voxel_ridge = model.ridge(torch.cat((voxel, past_info_full), dim=-2))
1053
+ voxel_ridge = voxel_ridge.view(int(voxel_ridge.shape[0]/seq_len), seq_len, hidden_dim)
1054
+
1055
+ #voxel_ridge = model.ridge(voxel).unsqueeze(1)
1056
+
1057
+ # voxel_ridge = torch.cat((voxel_ridge.unsqueeze(1),voxel_ridge.unsqueeze(1)),axis=1)
1058
+
1059
+ clip_voxels, blurry_image_enc_, depth_image_enc_ = model.backbone(voxel_ridge)
1060
+
1061
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1062
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1063
+
1064
+ loss_clip = utils.soft_clip_loss(
1065
+ clip_voxels_norm,
1066
+ clip_target_norm,
1067
+ temp=.006)
1068
+ test_loss_clip_total += loss_clip.item()
1069
+ loss_clip = loss_clip * clip_scale
1070
+ loss = loss_clip
1071
+
1072
+ if blurry_recon:
1073
+ downsampled_image = nn.functional.interpolate(image, size=(8, 8), mode='bilinear', align_corners=False)
1074
+ re_upsampled_image = add_saturation(nn.functional.interpolate(downsampled_image, size=(128, 128), mode='nearest'))
1075
+ re_upsampled_enc = autoenc.encode(2*re_upsampled_image-1).latents * 0.18215
1076
+
1077
+ loss_blurry = (l1(blurry_image_enc_, blurry_image_enc) + l1(blurry_image_enc_, re_upsampled_enc))
1078
+ loss_blurry += l1(torch.var(blurry_image_enc), torch.var(blurry_image_enc_))
1079
+ test_loss_blurry_total += loss_blurry.item()
1080
+ loss_blurry *= blur_scale
1081
+ loss += loss_blurry
1082
+
1083
+ # halving the batch size because the decoder is computationally heavy
1084
+ blurry_recon_images = (autoenc.decode(blurry_image_enc_[:len(voxel)//2]/0.18215).sample / 2 + 0.5).clamp(0,1)
1085
+ blurry_recon_images = torch.vstack((blurry_recon_images, (autoenc.decode(blurry_image_enc_[len(voxel)//2:]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1086
+ pixcorr = utils.pixcorr(image, blurry_recon_images)
1087
+ loss += (1 - pixcorr)
1088
+ test_blurry_pixcorr += pixcorr.item()
1089
+
1090
+ if depth_recon:
1091
+ loss_depth = l1(depth_image_enc_, depth_image_enc)
1092
+ # loss_depth += l1(torch.var(depth_image_enc_), torch.var(depth_image_enc))
1093
+ test_loss_depth_total += loss_depth.item()
1094
+ loss_depth *= depth_scale
1095
+ loss += loss_depth
1096
+
1097
+ # forward and backward top 1 accuracy
1098
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
1099
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
1100
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()
1101
+
1102
+ utils.check_loss(loss)
1103
+ test_losses.append(loss.item())
1104
+
1105
+ # if utils.is_interactive(): clear_output(wait=True)
1106
+ print("---")
1107
+
1108
+ assert (test_i+1) == 1
1109
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
1110
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
1111
+ "train/lr": lrs[-1],
1112
+ "train/num_steps": len(losses),
1113
+ "test/num_steps": len(test_losses),
1114
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1115
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1116
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1117
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1118
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1119
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1120
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1121
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1122
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1123
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1124
+ "train/loss_depth_total": loss_depth_total / (train_i + 1),
1125
+ "test/loss_depth_total": test_loss_depth_total / (test_i + 1),
1126
+ }
1127
+
1128
+ if blurry_recon:
1129
+ # transform blurry recon latents to images and plot it
1130
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1131
+ jj=-1
1132
+ for j in [0,1,2,3]:
1133
+ jj+=1
1134
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1135
+ axes[jj].axis('off')
1136
+ jj+=1
1137
+ axes[jj].imshow(utils.torch_to_Image((autoenc.decode(blurry_image_enc_[[j]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1138
+ axes[jj].axis('off')
1139
+
1140
+ if wandb_log:
1141
+ logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1142
+ plt.close()
1143
+ else:
1144
+ plt.show()
1145
+
1146
+ if depth_recon:
1147
+ # transform blurry recon latents to images and plot it
1148
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1149
+ # axes[0].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1150
+ # axes[1].imshow(utils.torch_to_Image((autoenc.decode(depth_image_enc_[[0]]/0.18215).sample / 2 + 0.5).clamp(0,1)))
1151
+ jj=-1
1152
+ for j in [0,1,2,3]:
1153
+ jj+=1
1154
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc[[j]].view(1,1,32,32).clamp(0,1), 224)))
1155
+ axes[jj].axis('off')
1156
+ jj+=1
1157
+ axes[jj].imshow(utils.torch_to_Image(utils.resize(depth_image_enc_[[j]].view(1,1,32,32).clamp(0,1), 224)))
1158
+ axes[jj].axis('off')
1159
+ if wandb_log:
1160
+ logs[f"test/depth_recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}")
1161
+ plt.close()
1162
+ else:
1163
+ plt.show()
1164
+
1165
+ progress_bar.set_postfix(**logs)
1166
+
1167
+ # Save model checkpoint and reconstruct
1168
+ if epoch % ckpt_interval == 0:
1169
+ if not utils.is_interactive():
1170
+ save_ckpt(f'last')
1171
+
1172
+ if wandb_log: wandb.log(logs)
1173
+
1174
+ # wait for other GPUs to catch up if needed
1175
+ accelerator.wait_for_everyone()
1176
+ torch.cuda.empty_cache()
1177
+ gc.collect()
1178
+
1179
+ print("\n===Finished!===\n")
1180
+ if ckpt_saving:
1181
+ save_ckpt(f'last')
1182
+ if not utils.is_interactive():
1183
+ sys.exit(0)
1184
+
1185
+
1186
+ # In[ ]:
1187
+
1188
+
1189
+ plt.plot(losses)
1190
+ plt.show()
1191
+ plt.plot(test_losses)
1192
+ plt.show()
1193
+
1194
+
1195
+ # # Retrieve nearest neighbor in the training set using test set data
1196
+
1197
+ # In[ ]:
1198
+
1199
+
1200
+ annots = np.load("/fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_annots_curated.npy")
1201
+
1202
+
1203
+ # In[ ]:
1204
+
1205
+
1206
+ ii=2
1207
+ all_indices = np.unique(train_73k_images) #np.hstack((test_vox_indices[ii],train_vox_indices))
1208
+ with torch.no_grad(), torch.cuda.amp.autocast():
1209
+ for batch in tqdm(range(0,len(all_indices),512)):
1210
+ if batch==0:
1211
+ clip_target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1212
+ else:
1213
+ target = clip_model.embed_image(images[all_indices[batch:batch+512]]).cpu()
1214
+ clip_target = torch.vstack((clip_target,target))
1215
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1216
+
1217
+ voxel = test_voxel[[ii]].to(device)
1218
+ image = test_image[[ii]].to(device)
1219
+
1220
+ print("Original Image (test set)")
1221
+ display(utils.torch_to_Image(image))
1222
+
1223
+ clip_target = clip_model.embed_image(image).cpu()
1224
+ # clip_target_norm = torch.vstack((clip_target_norm, nn.functional.normalize(clip_target.flatten(1), dim=-1)))
1225
+
1226
+ voxel_ridge = model.ridge(voxel).unsqueeze(1)
1227
+ clip_voxels, _, _ = model.backbone(voxel_ridge)
1228
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
1229
+ clip_voxels_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
1230
+
1231
+ print("clip_voxels_norm", clip_voxels_norm.shape)
1232
+ print("clip_target_norm", clip_target_norm.shape)
1233
+
1234
+ sortt = torch.argsort(utils.batchwise_cosine_similarity(clip_voxels_norm.cpu(),
1235
+ clip_target_norm).flatten()).flip(0)
1236
+ picks = all_indices[sortt[:5]]
1237
+
1238
+ print("\nNearest neighbors in training set")
1239
+ for ip,p in enumerate(picks):
1240
+ display(utils.torch_to_Image(images[[p]]))
1241
+ # print(utils.select_annotations([annots[int(p)]]))
1242
+ if ip==0: predicted_caption = utils.select_annotations([annots[int(p)]])[0]
1243
+
1244
+ print("\n=====\npredicted_caption:\n", predicted_caption)
1245
+
1246
+
1247
+ # # Feed into Stable Diffusion XL for reconstructions
1248
+
1249
+ # In[ ]:
1250
+
1251
+
1252
+ from diffusers import StableDiffusionXLPipeline
1253
+ pipe = StableDiffusionXLPipeline.from_pretrained(
1254
+ "/fsx/proj-fmri/shared/cache/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/f898a3e026e802f68796b95e9702464bac78d76f", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
1255
+ )
1256
+ pipe.to("cuda")
1257
+ pass
1258
+
1259
+
1260
+ # In[ ]:
1261
+
1262
+
1263
+ prompt = predicted_caption
1264
+ recon = pipe(prompt=prompt).images[0]
1265
+
1266
+
1267
+ # In[ ]:
1268
+
1269
+
1270
+ print("Seen image")
1271
+ display(utils.torch_to_Image(image))
1272
+
1273
+ print("Reconstruction")
1274
+ utils.torch_to_Image(utils.resize(transforms.ToTensor()(recon),224))
1275
+
src/Train_diffusion.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/accel.slurm ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=topfmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=ms
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=topfmri
12
+
13
+ module load cuda/11.7 # should match torch.cuda.version
14
+
15
+ export NUM_GPUS=4 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=512
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/paulscotti/MindEyeV2/wandb/"
26
+ export WANDB_CACHE_DIR="/fsx/home-paulscotti/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ ###########
34
+
35
+ cd /fsx/proj-fmri/paulscotti/MindEyeV2
36
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=test --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=240 --ckpt_interval=999 --no-use_image_aug
37
+
38
+ # --wandb_log
src/accel2.slurm ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=ms
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:8
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=8 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=512
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
26
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ source /admin/home-ckadirt/.bashrc
34
+
35
+ ###########
36
+
37
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
38
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train-with-memory-cat.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=test_mem_cat_r --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-5 --mixup_pct=.33 --num_epochs=240 --ckpt_interval=999 --no-use_image_aug
39
+
40
+ # --wandb_log
src/accel3.slurm ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=ms
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=4 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=128
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
26
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ source /admin/home-ckadirt/.bashrc
34
+
35
+ ###########
36
+
37
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
38
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train-with-memory.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=test_mem --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=40 --ckpt_interval=999 --no-use_image_aug
39
+
40
+ # --wandb_log
src/accel4.slurm ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=ms
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:8
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=8 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=512
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
26
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ source /admin/home-ckadirt/.bashrc
34
+
35
+ ###########
36
+
37
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
38
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=baseline_r --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-5 --mixup_pct=.33 --num_epochs=240 --ckpt_interval=999 --no-use_image_aug
39
+
40
+ # --wandb_log
src/accel5.slurm ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=memoryrr
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:8
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=8 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=512
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
26
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ source /admin/home-ckadirt/.bashrc
34
+
35
+ ###########
36
+
37
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
38
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train-with-memory-rr.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=testing-rr-uni_r --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=1e-5 --mixup_pct=.66 --num_epochs=120 --ckpt_interval=999 --no-use_image_aug
39
+
40
+ # --wandb_log
src/accel6.slurm ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=memoryrr
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=4 # should = number of gpus
7
+ #SBATCH --gres=gpu:8
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=8 # Set to equal gres=gpu:#
16
+ export GLOBAL_BATCH_SIZE=512
17
+
18
+ # Make sure another job doesnt use same port, here using random number
19
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
20
+
21
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
22
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
23
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
24
+
25
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
26
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
27
+ export WANDB_MODE="online"
28
+
29
+ echo MASTER_ADDR=${MASTER_ADDR}
30
+ echo MASTER_PORT=${MASTER_PORT}
31
+ echo WORLD_SIZE=${COUNT_NODE}
32
+
33
+ source /admin/home-ckadirt/.bashrc
34
+
35
+ ###########
36
+
37
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
38
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train-with-memory-rr-dropout.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=testing-rr-uni_r --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --n_samples_save=0 --max_lr=3e-5 --mixup_pct=.66 --num_epochs=120 --ckpt_interval=999 --no-use_image_aug
39
+
40
+ # --wandb_log
src/accel7.slurm ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=memoryrr
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=8 # should = number of gpus
7
+ #SBATCH --gres=gpu:8
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=8 # Set to equal gres=gpu:#!
16
+ export BATCH_SIZE=32
17
+ export GLOBAL_BATCH_SIZE=$((BATCH_SIZE * NUM_GPUS))
18
+
19
+ # Make sure another job doesnt use same port, here using random number
20
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
21
+
22
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
23
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
24
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
25
+
26
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
27
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
28
+ export WANDB_MODE="online"
29
+
30
+ echo MASTER_ADDR=${MASTER_ADDR}
31
+ echo MASTER_PORT=${MASTER_PORT}
32
+ echo WORLD_SIZE=${COUNT_NODE}
33
+
34
+ source /admin/home-ckadirt/.bashrc
35
+
36
+ ###########
37
+
38
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
39
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train_MLPMixer-Copy2.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=testing-rr-1024-past-5 --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --max_lr=3e-4 --mixup_pct=.66 --num_epochs=120 --ckpt_interval=10 --no-use_image_aug --hidden_dim=1024 --seq_len=5
40
+
41
+ # --wandb_log
src/accel8.slurm ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=memoryrr
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=8 # should = number of gpus
7
+ #SBATCH --gres=gpu:8
8
+ #SBATCH --time=32:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH -e slurms/%j.err
10
+ #SBATCH -o slurms/%j.out
11
+ #SBATCH --comment=fmri
12
+
13
+
14
+
15
+ export NUM_GPUS=8 # Set to equal gres=gpu:#!
16
+ export BATCH_SIZE=32
17
+ export GLOBAL_BATCH_SIZE=$((BATCH_SIZE * NUM_GPUS))
18
+
19
+ # Make sure another job doesnt use same port, here using random number
20
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
21
+
22
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
23
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
24
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
25
+
26
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
27
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
28
+ export WANDB_MODE="online"
29
+
30
+ echo MASTER_ADDR=${MASTER_ADDR}
31
+ echo MASTER_PORT=${MASTER_PORT}
32
+ echo WORLD_SIZE=${COUNT_NODE}
33
+
34
+ source /admin/home-ckadirt/.bashrc
35
+
36
+ ###########
37
+
38
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
39
+ accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT Train_MLPMixer-img.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=testing-rr-1024-img-past-2 --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --max_lr=3e-4 --mixup_pct=.66 --num_epochs=120 --ckpt_interval=10 --no-use_image_aug --hidden_dim=1024 --seq_len=2
40
+
41
+ # --wandb_log
src/accel9.slurm ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=fmri
3
+ #SBATCH --partition=g40x
4
+ #SBATCH --job-name=blip2captions
5
+ #SBATCH --nodes=1
6
+ #SBATCH --ntasks-per-node=1 # should = number of gpus
7
+ #SBATCH --gres=gpu:1
8
+ #SBATCH --time=24:00:00 # total run time limit (HH:MM:SS)
9
+ #SBATCH --comment=medarc
10
+ #SBATCH --requeue
11
+ #SBATCH -e slurms/%j.err
12
+ #SBATCH -o slurms/%j.out
13
+
14
+
15
+
16
+
17
+ export NUM_GPUS=1 # Set to equal gres=gpu:#!
18
+ export BATCH_SIZE=128
19
+ export GLOBAL_BATCH_SIZE=$((BATCH_SIZE * NUM_GPUS))
20
+
21
+ # Make sure another job doesnt use same port, here using random number
22
+ export MASTER_PORT=$((RANDOM % (19000 - 11000 + 1) + 11000))
23
+
24
+ export HOSTNAMES=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
25
+ export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
26
+ export COUNT_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
27
+
28
+ export WANDB_DIR="/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb"
29
+ export WANDB_CACHE_DIR="/admin/home-ckadirt/.cache"
30
+ export WANDB_MODE="online"
31
+
32
+ echo MASTER_ADDR=${MASTER_ADDR}
33
+ echo MASTER_PORT=${MASTER_PORT}
34
+ echo WORLD_SIZE=${COUNT_NODE}
35
+
36
+ source /admin/home-ckadirt/.bashrc
37
+
38
+ ###########
39
+
40
+ cd /fsx/proj-fmri/ckadirt/MindEyeV2/src/
41
+ # accelerate launch --num_processes=$(($NUM_GPUS * $COUNT_NODE)) --num_machines=$COUNT_NODE --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT
42
+ python train2.py --data_path=/fsx/proj-fmri/shared/mindeyev2_dataset --model_name=caption_clip_0.5_bz --subj=1 --batch_size=${GLOBAL_BATCH_SIZE} --max_lr=1e-4 --mixup_pct=.66 --num_epochs=50 --use_image_aug --ckpt_interval=15 --clip_mse_ratio=0.5
43
+
44
+ # --wandb_log
src/blip2_captions.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+ import numpy as np
6
+ import math
7
+ from einops import rearrange
8
+ import time
9
+ import random
10
+ import h5py
11
+ from tqdm import tqdm
12
+
13
+ import webdataset as wds
14
+ import gc
15
+
16
+ import matplotlib.pyplot as plt
17
+ import torch
18
+ import torch.nn as nn
19
+ from torchvision import transforms
20
+ from torchvision.transforms import ToPILImage #CHANGED (added)
21
+
22
+ from accelerate import Accelerator, DeepSpeedPlugin
23
+
24
+ # tf32 data type is faster than standard float32
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+
27
+ # custom functions #
28
+ import utils
29
+
30
+ global_batch_size = 128 #128
31
+
32
+ ### Multi-GPU config ###
33
+ local_rank = os.getenv('RANK')
34
+ if local_rank is None:
35
+ local_rank = 0
36
+ else:
37
+ local_rank = int(local_rank)
38
+ print("LOCAL RANK ", local_rank)
39
+
40
+ num_devices = torch.cuda.device_count()
41
+ if num_devices==0: num_devices = 1
42
+
43
+ accelerator = Accelerator(split_batches=False)
44
+
45
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
46
+
47
+ # if num_devices <= 1 and utils.is_interactive():
48
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
49
+ # os.environ["MASTER_ADDR"] = "localhost"
50
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
51
+ # os.environ["RANK"] = "0"
52
+ # os.environ["LOCAL_RANK"] = "0"
53
+ # os.environ["WORLD_SIZE"] = "1"
54
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
55
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
56
+
57
+ # # alter the deepspeed config according to your global and local batch size
58
+ # if local_rank == 0:
59
+ # with open('deepspeed_config_stage2.json', 'r') as file:
60
+ # config = json.load(file)
61
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
62
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
63
+ # with open('deepspeed_config_stage2.json', 'w') as file:
64
+ # json.dump(config, file)
65
+ # else:
66
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
67
+ # time.sleep(10)
68
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
69
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
70
+
71
+
src/blip_tryal.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/checking_models.ipynb ADDED
@@ -0,0 +1,1526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 25,
6
+ "id": "ef9e1556-7840-4004-b181-a2c97ac2ab17",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import torch\n",
12
+ "import torch.nn as nn\n",
13
+ "import numpy as np\n",
14
+ "import matplotlib.pyplot as plt"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "id": "b6f12dd4-f3aa-4981-b604-b72e67229011",
20
+ "metadata": {},
21
+ "source": [
22
+ "# DinoV2"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 26,
28
+ "id": "2a604617-b602-4503-b288-e9828684505e",
29
+ "metadata": {},
30
+ "outputs": [
31
+ {
32
+ "name": "stderr",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "Using cache found in /fsx/proj-fmri/shared/cache/dinov2/hub/facebookresearch_dinov2_main\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "# need to change TORCH_HOME env variable to specify pretrained model should go in shared folder, not home directory\n",
41
+ "os.environ['TORCH_HOME'] = '/fsx/proj-fmri/shared/cache/dinov2'\n",
42
+ "dinov2_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')\n",
43
+ "# remove initial image patching\n",
44
+ "dinov2_model.patch_embed = nn.Identity()\n",
45
+ "dinov2_model.patch_embed = nn.Identity()"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 27,
51
+ "id": "32da913d-d931-4967-a5e8-bd40c21d1ad9",
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "torch.Size([2, 33, 1024])\n"
59
+ ]
60
+ }
61
+ ],
62
+ "source": [
63
+ "dinov2_model.to(\"cuda\")\n",
64
+ "input = torch.randn((2,33,1024)).to(\"cuda\")\n",
65
+ "\n",
66
+ "for block in dinov2_model.blocks: input = block(input)\n",
67
+ "input = dinov2_model.norm(input)\n",
68
+ "\n",
69
+ "print(input.shape)"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "markdown",
74
+ "id": "febe89c0-06d0-4309-b378-a8d58b99bf4c",
75
+ "metadata": {},
76
+ "source": [
77
+ "# eva"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 28,
83
+ "id": "690204d0-13d7-452b-97af-14d144800e81",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "from urllib.request import urlopen\n",
88
+ "from PIL import Image\n",
89
+ "import timm\n",
90
+ "\n",
91
+ "img = Image.open(urlopen(\n",
92
+ " 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'\n",
93
+ "))\n",
94
+ "\n",
95
+ "model = timm.create_model(\n",
96
+ " \"eva02_enormous_patch14_clip_224.laion2b\",\n",
97
+ " pretrained=True,\n",
98
+ " num_classes=0, # remove classifier nn.Linear\n",
99
+ ")\n",
100
+ "model = model.eval()"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 39,
106
+ "id": "035e3e9d-86c9-4ddf-b760-7b78dded7d2e",
107
+ "metadata": {},
108
+ "outputs": [
109
+ {
110
+ "ename": "ValueError",
111
+ "evalue": "You have to specify pixel_values",
112
+ "output_type": "error",
113
+ "traceback": [
114
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
115
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
116
+ "Cell \u001b[0;32mIn[39], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m data_config \u001b[38;5;241m=\u001b[39m timm\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mresolve_model_data_config(model)\n\u001b[1;32m 3\u001b[0m transforms \u001b[38;5;241m=\u001b[39m timm\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mcreate_transform(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_config, is_training\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[0;32m----> 5\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtransforms\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# output is (batch_size, num_features) shaped tensor\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(output\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# or equivalently (without needing to set num_classes=0)\u001b[39;00m\n",
117
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
118
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/clipseg/modeling_clipseg.py:1433\u001b[0m, in \u001b[0;36mCLIPSegForImageSegmentation.forward\u001b[0;34m(self, input_ids, pixel_values, conditional_pixel_values, conditional_embeddings, attention_mask, position_ids, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1431\u001b[0m \u001b[38;5;66;03m# step 1: forward the query images through the frozen CLIP vision encoder\u001b[39;00m\n\u001b[1;32m 1432\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1433\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclip\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvision_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1434\u001b[0m \u001b[43m \u001b[49m\u001b[43mpixel_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpixel_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1435\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1436\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# we need the intermediate hidden states\u001b[39;49;00m\n\u001b[1;32m 1437\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1438\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1439\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclip\u001b[38;5;241m.\u001b[39mvisual_projection(vision_outputs[\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 1441\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m vision_outputs\u001b[38;5;241m.\u001b[39mhidden_states \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;28;01melse\u001b[39;00m vision_outputs[\u001b[38;5;241m2\u001b[39m]\n",
119
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
120
+ "File \u001b[0;32m~/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/models/clipseg/modeling_clipseg.py:872\u001b[0m, in \u001b[0;36mCLIPSegVisionTransformer.forward\u001b[0;34m(self, pixel_values, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 869\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pixel_values \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 872\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou have to specify pixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 874\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(pixel_values)\n\u001b[1;32m 875\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpre_layrnorm(hidden_states)\n",
121
+ "\u001b[0;31mValueError\u001b[0m: You have to specify pixel_values"
122
+ ]
123
+ }
124
+ ],
125
+ "source": [
126
+ "# get model specific transforms (normalization, resize)\n",
127
+ "data_config = timm.data.resolve_model_data_config(model)\n",
128
+ "transforms = timm.data.create_transform(**data_config, is_training=False)\n",
129
+ "\n",
130
+ "output = model(transforms(img).unsqueeze(0)) # output is (batch_size, num_features) shaped tensor\n",
131
+ "print(output.shape)\n",
132
+ "# or equivalently (without needing to set num_classes=0)\n",
133
+ "\n",
134
+ "output = model.forward_features(transforms(img).unsqueeze(0))\n",
135
+ "# output is unpooled, a (1, 257, 768) shaped tensor\n",
136
+ "print(output.shape)\n",
137
+ "\n",
138
+ "output = model.forward_head(output, pre_logits=True)\n",
139
+ "# output is a (1, num_features) shaped tensor\n",
140
+ "print(output.shape)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "54275c4c-e506-4959-92f1-29e584f5ce51",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "model.forward_features(transforms(img).unsqueeze(0)).shape"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "id": "6546c673-f3ab-4d43-a051-cab20e782bab",
156
+ "metadata": {},
157
+ "source": [
158
+ "# Eva02-clip"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 29,
164
+ "id": "dfbc95de-9af9-4583-98fc-b8061114ef64",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "import timm \n",
169
+ "# couldnt figure out how to load pretrained model from shared folder rather than home directory using timm...\n",
170
+ "eva02_model = timm.create_model(\"eva02_enormous_patch14_clip_224.laion2b\", pretrained=True)\n",
171
+ "# eva02_model.head_drop = nn.Identity()\n",
172
+ "# eva02_model.head = nn.Identity()"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 30,
178
+ "id": "97e3ea29-ae6b-4bd2-b3d7-17839098a6e4",
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "data": {
183
+ "text/plain": [
184
+ "torch.Size([2, 1024])"
185
+ ]
186
+ },
187
+ "execution_count": 30,
188
+ "metadata": {},
189
+ "output_type": "execute_result"
190
+ }
191
+ ],
192
+ "source": [
193
+ "eva02_model(torch.randn((2,3,224,224))).shape"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": 31,
199
+ "id": "069b76f0-029f-42b1-85f5-a492ee1cc5d1",
200
+ "metadata": {},
201
+ "outputs": [
202
+ {
203
+ "name": "stdout",
204
+ "output_type": "stream",
205
+ "text": [
206
+ "torch.Size([2, 256, 1024])\n"
207
+ ]
208
+ }
209
+ ],
210
+ "source": [
211
+ "image = torch.randn((2,3,224,224))\n",
212
+ "\n",
213
+ "input = eva02_model.patch_embed(image)\n",
214
+ "input = eva02_model.pos_drop(input)\n",
215
+ "for block in eva02_model.blocks: input = block(input)\n",
216
+ "input = eva02_model.norm(input)\n",
217
+ "input = eva02_model.fc_norm(input)\n",
218
+ "input = eva02_model.head_drop(input)\n",
219
+ "input = eva02_model.head(input)\n",
220
+ "\n",
221
+ "print(input.shape)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 32,
227
+ "id": "90e4e8e7-3dd1-43b0-a305-066a6ec13c2e",
228
+ "metadata": {},
229
+ "outputs": [
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "Help on Eva in module timm.models.eva object:\n",
235
+ "\n",
236
+ "class Eva(torch.nn.modules.module.Module)\n",
237
+ " | Eva(img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4.0, swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>, init_values: Optional[float] = None, class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, ref_feat_shape: Union[int, Tuple[int, int], NoneType] = None, head_init_scale: float = 0.001)\n",
238
+ " | \n",
239
+ " | Eva Vision Transformer w/ Abs & Rotary Pos Embed\n",
240
+ " | \n",
241
+ " | This class implements the EVA and EVA02 models that were based on the BEiT ViT variant\n",
242
+ " | * EVA - abs pos embed, global avg pool\n",
243
+ " | * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)\n",
244
+ " | \n",
245
+ " | Method resolution order:\n",
246
+ " | Eva\n",
247
+ " | torch.nn.modules.module.Module\n",
248
+ " | builtins.object\n",
249
+ " | \n",
250
+ " | Methods defined here:\n",
251
+ " | \n",
252
+ " | __init__(self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: str = 'avg', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, qkv_bias: bool = True, qkv_fused: bool = True, mlp_ratio: float = 4.0, swiglu_mlp: bool = False, scale_mlp: bool = False, scale_attn_inner: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: Callable = <class 'timm.layers.norm.LayerNorm'>, init_values: Optional[float] = None, class_token: bool = True, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, ref_feat_shape: Union[int, Tuple[int, int], NoneType] = None, head_init_scale: float = 0.001)\n",
253
+ " | Args:\n",
254
+ " | img_size:\n",
255
+ " | patch_size:\n",
256
+ " | in_chans:\n",
257
+ " | num_classes:\n",
258
+ " | global_pool:\n",
259
+ " | embed_dim:\n",
260
+ " | depth:\n",
261
+ " | num_heads:\n",
262
+ " | qkv_bias:\n",
263
+ " | qkv_fused:\n",
264
+ " | mlp_ratio:\n",
265
+ " | swiglu_mlp:\n",
266
+ " | scale_mlp:\n",
267
+ " | scale_attn_inner:\n",
268
+ " | drop_rate:\n",
269
+ " | pos_drop_rate:\n",
270
+ " | proj_drop_rate:\n",
271
+ " | attn_drop_rate:\n",
272
+ " | drop_path_rate:\n",
273
+ " | norm_layer:\n",
274
+ " | init_values:\n",
275
+ " | class_token:\n",
276
+ " | use_abs_pos_emb:\n",
277
+ " | use_rot_pos_emb:\n",
278
+ " | use_post_norm:\n",
279
+ " | ref_feat_shape:\n",
280
+ " | head_init_scale:\n",
281
+ " | \n",
282
+ " | fix_init_weight(self)\n",
283
+ " | \n",
284
+ " | forward(self, x)\n",
285
+ " | Defines the computation performed at every call.\n",
286
+ " | \n",
287
+ " | Should be overridden by all subclasses.\n",
288
+ " | \n",
289
+ " | .. note::\n",
290
+ " | Although the recipe for forward pass needs to be defined within\n",
291
+ " | this function, one should call the :class:`Module` instance afterwards\n",
292
+ " | instead of this since the former takes care of running the\n",
293
+ " | registered hooks while the latter silently ignores them.\n",
294
+ " | \n",
295
+ " | forward_features(self, x)\n",
296
+ " | \n",
297
+ " | forward_head(self, x, pre_logits: bool = False)\n",
298
+ " | \n",
299
+ " | get_classifier(self)\n",
300
+ " | \n",
301
+ " | group_matcher(self, coarse=False)\n",
302
+ " | \n",
303
+ " | no_weight_decay(self)\n",
304
+ " | \n",
305
+ " | reset_classifier(self, num_classes, global_pool=None)\n",
306
+ " | \n",
307
+ " | set_grad_checkpointing(self, enable=True)\n",
308
+ " | \n",
309
+ " | ----------------------------------------------------------------------\n",
310
+ " | Data and other attributes defined here:\n",
311
+ " | \n",
312
+ " | __annotations__ = {}\n",
313
+ " | \n",
314
+ " | ----------------------------------------------------------------------\n",
315
+ " | Methods inherited from torch.nn.modules.module.Module:\n",
316
+ " | \n",
317
+ " | __call__ = _call_impl(self, *args, **kwargs)\n",
318
+ " | \n",
319
+ " | __delattr__(self, name)\n",
320
+ " | Implement delattr(self, name).\n",
321
+ " | \n",
322
+ " | __dir__(self)\n",
323
+ " | Default dir() implementation.\n",
324
+ " | \n",
325
+ " | __getattr__(self, name: str) -> Union[torch.Tensor, ForwardRef('Module')]\n",
326
+ " | \n",
327
+ " | __repr__(self)\n",
328
+ " | Return repr(self).\n",
329
+ " | \n",
330
+ " | __setattr__(self, name: str, value: Union[torch.Tensor, ForwardRef('Module')]) -> None\n",
331
+ " | Implement setattr(self, name, value).\n",
332
+ " | \n",
333
+ " | __setstate__(self, state)\n",
334
+ " | \n",
335
+ " | add_module(self, name: str, module: Optional[ForwardRef('Module')]) -> None\n",
336
+ " | Adds a child module to the current module.\n",
337
+ " | \n",
338
+ " | The module can be accessed as an attribute using the given name.\n",
339
+ " | \n",
340
+ " | Args:\n",
341
+ " | name (str): name of the child module. The child module can be\n",
342
+ " | accessed from this module using the given name\n",
343
+ " | module (Module): child module to be added to the module.\n",
344
+ " | \n",
345
+ " | apply(self: ~T, fn: Callable[[ForwardRef('Module')], NoneType]) -> ~T\n",
346
+ " | Applies ``fn`` recursively to every submodule (as returned by ``.children()``)\n",
347
+ " | as well as self. Typical use includes initializing the parameters of a model\n",
348
+ " | (see also :ref:`nn-init-doc`).\n",
349
+ " | \n",
350
+ " | Args:\n",
351
+ " | fn (:class:`Module` -> None): function to be applied to each submodule\n",
352
+ " | \n",
353
+ " | Returns:\n",
354
+ " | Module: self\n",
355
+ " | \n",
356
+ " | Example::\n",
357
+ " | \n",
358
+ " | >>> @torch.no_grad()\n",
359
+ " | >>> def init_weights(m):\n",
360
+ " | >>> print(m)\n",
361
+ " | >>> if type(m) == nn.Linear:\n",
362
+ " | >>> m.weight.fill_(1.0)\n",
363
+ " | >>> print(m.weight)\n",
364
+ " | >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
365
+ " | >>> net.apply(init_weights)\n",
366
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
367
+ " | Parameter containing:\n",
368
+ " | tensor([[1., 1.],\n",
369
+ " | [1., 1.]], requires_grad=True)\n",
370
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
371
+ " | Parameter containing:\n",
372
+ " | tensor([[1., 1.],\n",
373
+ " | [1., 1.]], requires_grad=True)\n",
374
+ " | Sequential(\n",
375
+ " | (0): Linear(in_features=2, out_features=2, bias=True)\n",
376
+ " | (1): Linear(in_features=2, out_features=2, bias=True)\n",
377
+ " | )\n",
378
+ " | \n",
379
+ " | bfloat16(self: ~T) -> ~T\n",
380
+ " | Casts all floating point parameters and buffers to ``bfloat16`` datatype.\n",
381
+ " | \n",
382
+ " | .. note::\n",
383
+ " | This method modifies the module in-place.\n",
384
+ " | \n",
385
+ " | Returns:\n",
386
+ " | Module: self\n",
387
+ " | \n",
388
+ " | buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]\n",
389
+ " | Returns an iterator over module buffers.\n",
390
+ " | \n",
391
+ " | Args:\n",
392
+ " | recurse (bool): if True, then yields buffers of this module\n",
393
+ " | and all submodules. Otherwise, yields only buffers that\n",
394
+ " | are direct members of this module.\n",
395
+ " | \n",
396
+ " | Yields:\n",
397
+ " | torch.Tensor: module buffer\n",
398
+ " | \n",
399
+ " | Example::\n",
400
+ " | \n",
401
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
402
+ " | >>> for buf in model.buffers():\n",
403
+ " | >>> print(type(buf), buf.size())\n",
404
+ " | <class 'torch.Tensor'> (20L,)\n",
405
+ " | <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n",
406
+ " | \n",
407
+ " | children(self) -> Iterator[ForwardRef('Module')]\n",
408
+ " | Returns an iterator over immediate children modules.\n",
409
+ " | \n",
410
+ " | Yields:\n",
411
+ " | Module: a child module\n",
412
+ " | \n",
413
+ " | cpu(self: ~T) -> ~T\n",
414
+ " | Moves all model parameters and buffers to the CPU.\n",
415
+ " | \n",
416
+ " | .. note::\n",
417
+ " | This method modifies the module in-place.\n",
418
+ " | \n",
419
+ " | Returns:\n",
420
+ " | Module: self\n",
421
+ " | \n",
422
+ " | cuda(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
423
+ " | Moves all model parameters and buffers to the GPU.\n",
424
+ " | \n",
425
+ " | This also makes associated parameters and buffers different objects. So\n",
426
+ " | it should be called before constructing optimizer if the module will\n",
427
+ " | live on GPU while being optimized.\n",
428
+ " | \n",
429
+ " | .. note::\n",
430
+ " | This method modifies the module in-place.\n",
431
+ " | \n",
432
+ " | Args:\n",
433
+ " | device (int, optional): if specified, all parameters will be\n",
434
+ " | copied to that device\n",
435
+ " | \n",
436
+ " | Returns:\n",
437
+ " | Module: self\n",
438
+ " | \n",
439
+ " | double(self: ~T) -> ~T\n",
440
+ " | Casts all floating point parameters and buffers to ``double`` datatype.\n",
441
+ " | \n",
442
+ " | .. note::\n",
443
+ " | This method modifies the module in-place.\n",
444
+ " | \n",
445
+ " | Returns:\n",
446
+ " | Module: self\n",
447
+ " | \n",
448
+ " | eval(self: ~T) -> ~T\n",
449
+ " | Sets the module in evaluation mode.\n",
450
+ " | \n",
451
+ " | This has any effect only on certain modules. See documentations of\n",
452
+ " | particular modules for details of their behaviors in training/evaluation\n",
453
+ " | mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n",
454
+ " | etc.\n",
455
+ " | \n",
456
+ " | This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.\n",
457
+ " | \n",
458
+ " | See :ref:`locally-disable-grad-doc` for a comparison between\n",
459
+ " | `.eval()` and several similar mechanisms that may be confused with it.\n",
460
+ " | \n",
461
+ " | Returns:\n",
462
+ " | Module: self\n",
463
+ " | \n",
464
+ " | extra_repr(self) -> str\n",
465
+ " | Set the extra representation of the module\n",
466
+ " | \n",
467
+ " | To print customized extra information, you should re-implement\n",
468
+ " | this method in your own modules. Both single-line and multi-line\n",
469
+ " | strings are acceptable.\n",
470
+ " | \n",
471
+ " | float(self: ~T) -> ~T\n",
472
+ " | Casts all floating point parameters and buffers to ``float`` datatype.\n",
473
+ " | \n",
474
+ " | .. note::\n",
475
+ " | This method modifies the module in-place.\n",
476
+ " | \n",
477
+ " | Returns:\n",
478
+ " | Module: self\n",
479
+ " | \n",
480
+ " | get_buffer(self, target: str) -> 'Tensor'\n",
481
+ " | Returns the buffer given by ``target`` if it exists,\n",
482
+ " | otherwise throws an error.\n",
483
+ " | \n",
484
+ " | See the docstring for ``get_submodule`` for a more detailed\n",
485
+ " | explanation of this method's functionality as well as how to\n",
486
+ " | correctly specify ``target``.\n",
487
+ " | \n",
488
+ " | Args:\n",
489
+ " | target: The fully-qualified string name of the buffer\n",
490
+ " | to look for. (See ``get_submodule`` for how to specify a\n",
491
+ " | fully-qualified string.)\n",
492
+ " | \n",
493
+ " | Returns:\n",
494
+ " | torch.Tensor: The buffer referenced by ``target``\n",
495
+ " | \n",
496
+ " | Raises:\n",
497
+ " | AttributeError: If the target string references an invalid\n",
498
+ " | path or resolves to something that is not a\n",
499
+ " | buffer\n",
500
+ " | \n",
501
+ " | get_extra_state(self) -> Any\n",
502
+ " | Returns any extra state to include in the module's state_dict.\n",
503
+ " | Implement this and a corresponding :func:`set_extra_state` for your module\n",
504
+ " | if you need to store extra state. This function is called when building the\n",
505
+ " | module's `state_dict()`.\n",
506
+ " | \n",
507
+ " | Note that extra state should be picklable to ensure working serialization\n",
508
+ " | of the state_dict. We only provide provide backwards compatibility guarantees\n",
509
+ " | for serializing Tensors; other objects may break backwards compatibility if\n",
510
+ " | their serialized pickled form changes.\n",
511
+ " | \n",
512
+ " | Returns:\n",
513
+ " | object: Any extra state to store in the module's state_dict\n",
514
+ " | \n",
515
+ " | get_parameter(self, target: str) -> 'Parameter'\n",
516
+ " | Returns the parameter given by ``target`` if it exists,\n",
517
+ " | otherwise throws an error.\n",
518
+ " | \n",
519
+ " | See the docstring for ``get_submodule`` for a more detailed\n",
520
+ " | explanation of this method's functionality as well as how to\n",
521
+ " | correctly specify ``target``.\n",
522
+ " | \n",
523
+ " | Args:\n",
524
+ " | target: The fully-qualified string name of the Parameter\n",
525
+ " | to look for. (See ``get_submodule`` for how to specify a\n",
526
+ " | fully-qualified string.)\n",
527
+ " | \n",
528
+ " | Returns:\n",
529
+ " | torch.nn.Parameter: The Parameter referenced by ``target``\n",
530
+ " | \n",
531
+ " | Raises:\n",
532
+ " | AttributeError: If the target string references an invalid\n",
533
+ " | path or resolves to something that is not an\n",
534
+ " | ``nn.Parameter``\n",
535
+ " | \n",
536
+ " | get_submodule(self, target: str) -> 'Module'\n",
537
+ " | Returns the submodule given by ``target`` if it exists,\n",
538
+ " | otherwise throws an error.\n",
539
+ " | \n",
540
+ " | For example, let's say you have an ``nn.Module`` ``A`` that\n",
541
+ " | looks like this:\n",
542
+ " | \n",
543
+ " | .. code-block:: text\n",
544
+ " | \n",
545
+ " | A(\n",
546
+ " | (net_b): Module(\n",
547
+ " | (net_c): Module(\n",
548
+ " | (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n",
549
+ " | )\n",
550
+ " | (linear): Linear(in_features=100, out_features=200, bias=True)\n",
551
+ " | )\n",
552
+ " | )\n",
553
+ " | \n",
554
+ " | (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested\n",
555
+ " | submodule ``net_b``, which itself has two submodules ``net_c``\n",
556
+ " | and ``linear``. ``net_c`` then has a submodule ``conv``.)\n",
557
+ " | \n",
558
+ " | To check whether or not we have the ``linear`` submodule, we\n",
559
+ " | would call ``get_submodule(\"net_b.linear\")``. To check whether\n",
560
+ " | we have the ``conv`` submodule, we would call\n",
561
+ " | ``get_submodule(\"net_b.net_c.conv\")``.\n",
562
+ " | \n",
563
+ " | The runtime of ``get_submodule`` is bounded by the degree\n",
564
+ " | of module nesting in ``target``. A query against\n",
565
+ " | ``named_modules`` achieves the same result, but it is O(N) in\n",
566
+ " | the number of transitive modules. So, for a simple check to see\n",
567
+ " | if some submodule exists, ``get_submodule`` should always be\n",
568
+ " | used.\n",
569
+ " | \n",
570
+ " | Args:\n",
571
+ " | target: The fully-qualified string name of the submodule\n",
572
+ " | to look for. (See above example for how to specify a\n",
573
+ " | fully-qualified string.)\n",
574
+ " | \n",
575
+ " | Returns:\n",
576
+ " | torch.nn.Module: The submodule referenced by ``target``\n",
577
+ " | \n",
578
+ " | Raises:\n",
579
+ " | AttributeError: If the target string references an invalid\n",
580
+ " | path or resolves to something that is not an\n",
581
+ " | ``nn.Module``\n",
582
+ " | \n",
583
+ " | half(self: ~T) -> ~T\n",
584
+ " | Casts all floating point parameters and buffers to ``half`` datatype.\n",
585
+ " | \n",
586
+ " | .. note::\n",
587
+ " | This method modifies the module in-place.\n",
588
+ " | \n",
589
+ " | Returns:\n",
590
+ " | Module: self\n",
591
+ " | \n",
592
+ " | ipu(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
593
+ " | Moves all model parameters and buffers to the IPU.\n",
594
+ " | \n",
595
+ " | This also makes associated parameters and buffers different objects. So\n",
596
+ " | it should be called before constructing optimizer if the module will\n",
597
+ " | live on IPU while being optimized.\n",
598
+ " | \n",
599
+ " | .. note::\n",
600
+ " | This method modifies the module in-place.\n",
601
+ " | \n",
602
+ " | Arguments:\n",
603
+ " | device (int, optional): if specified, all parameters will be\n",
604
+ " | copied to that device\n",
605
+ " | \n",
606
+ " | Returns:\n",
607
+ " | Module: self\n",
608
+ " | \n",
609
+ " | load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True)\n",
610
+ " | Copies parameters and buffers from :attr:`state_dict` into\n",
611
+ " | this module and its descendants. If :attr:`strict` is ``True``, then\n",
612
+ " | the keys of :attr:`state_dict` must exactly match the keys returned\n",
613
+ " | by this module's :meth:`~torch.nn.Module.state_dict` function.\n",
614
+ " | \n",
615
+ " | Args:\n",
616
+ " | state_dict (dict): a dict containing parameters and\n",
617
+ " | persistent buffers.\n",
618
+ " | strict (bool, optional): whether to strictly enforce that the keys\n",
619
+ " | in :attr:`state_dict` match the keys returned by this module's\n",
620
+ " | :meth:`~torch.nn.Module.state_dict` function. Default: ``True``\n",
621
+ " | \n",
622
+ " | Returns:\n",
623
+ " | ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:\n",
624
+ " | * **missing_keys** is a list of str containing the missing keys\n",
625
+ " | * **unexpected_keys** is a list of str containing the unexpected keys\n",
626
+ " | \n",
627
+ " | Note:\n",
628
+ " | If a parameter or buffer is registered as ``None`` and its corresponding key\n",
629
+ " | exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a\n",
630
+ " | ``RuntimeError``.\n",
631
+ " | \n",
632
+ " | modules(self) -> Iterator[ForwardRef('Module')]\n",
633
+ " | Returns an iterator over all modules in the network.\n",
634
+ " | \n",
635
+ " | Yields:\n",
636
+ " | Module: a module in the network\n",
637
+ " | \n",
638
+ " | Note:\n",
639
+ " | Duplicate modules are returned only once. In the following\n",
640
+ " | example, ``l`` will be returned only once.\n",
641
+ " | \n",
642
+ " | Example::\n",
643
+ " | \n",
644
+ " | >>> l = nn.Linear(2, 2)\n",
645
+ " | >>> net = nn.Sequential(l, l)\n",
646
+ " | >>> for idx, m in enumerate(net.modules()):\n",
647
+ " | ... print(idx, '->', m)\n",
648
+ " | \n",
649
+ " | 0 -> Sequential(\n",
650
+ " | (0): Linear(in_features=2, out_features=2, bias=True)\n",
651
+ " | (1): Linear(in_features=2, out_features=2, bias=True)\n",
652
+ " | )\n",
653
+ " | 1 -> Linear(in_features=2, out_features=2, bias=True)\n",
654
+ " | \n",
655
+ " | named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.Tensor]]\n",
656
+ " | Returns an iterator over module buffers, yielding both the\n",
657
+ " | name of the buffer as well as the buffer itself.\n",
658
+ " | \n",
659
+ " | Args:\n",
660
+ " | prefix (str): prefix to prepend to all buffer names.\n",
661
+ " | recurse (bool, optional): if True, then yields buffers of this module\n",
662
+ " | and all submodules. Otherwise, yields only buffers that\n",
663
+ " | are direct members of this module. Defaults to True.\n",
664
+ " | remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.\n",
665
+ " | \n",
666
+ " | Yields:\n",
667
+ " | (str, torch.Tensor): Tuple containing the name and buffer\n",
668
+ " | \n",
669
+ " | Example::\n",
670
+ " | \n",
671
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
672
+ " | >>> for name, buf in self.named_buffers():\n",
673
+ " | >>> if name in ['running_var']:\n",
674
+ " | >>> print(buf.size())\n",
675
+ " | \n",
676
+ " | named_children(self) -> Iterator[Tuple[str, ForwardRef('Module')]]\n",
677
+ " | Returns an iterator over immediate children modules, yielding both\n",
678
+ " | the name of the module as well as the module itself.\n",
679
+ " | \n",
680
+ " | Yields:\n",
681
+ " | (str, Module): Tuple containing a name and child module\n",
682
+ " | \n",
683
+ " | Example::\n",
684
+ " | \n",
685
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
686
+ " | >>> for name, module in model.named_children():\n",
687
+ " | >>> if name in ['conv4', 'conv5']:\n",
688
+ " | >>> print(module)\n",
689
+ " | \n",
690
+ " | named_modules(self, memo: Optional[Set[ForwardRef('Module')]] = None, prefix: str = '', remove_duplicate: bool = True)\n",
691
+ " | Returns an iterator over all modules in the network, yielding\n",
692
+ " | both the name of the module as well as the module itself.\n",
693
+ " | \n",
694
+ " | Args:\n",
695
+ " | memo: a memo to store the set of modules already added to the result\n",
696
+ " | prefix: a prefix that will be added to the name of the module\n",
697
+ " | remove_duplicate: whether to remove the duplicated module instances in the result\n",
698
+ " | or not\n",
699
+ " | \n",
700
+ " | Yields:\n",
701
+ " | (str, Module): Tuple of name and module\n",
702
+ " | \n",
703
+ " | Note:\n",
704
+ " | Duplicate modules are returned only once. In the following\n",
705
+ " | example, ``l`` will be returned only once.\n",
706
+ " | \n",
707
+ " | Example::\n",
708
+ " | \n",
709
+ " | >>> l = nn.Linear(2, 2)\n",
710
+ " | >>> net = nn.Sequential(l, l)\n",
711
+ " | >>> for idx, m in enumerate(net.named_modules()):\n",
712
+ " | ... print(idx, '->', m)\n",
713
+ " | \n",
714
+ " | 0 -> ('', Sequential(\n",
715
+ " | (0): Linear(in_features=2, out_features=2, bias=True)\n",
716
+ " | (1): Linear(in_features=2, out_features=2, bias=True)\n",
717
+ " | ))\n",
718
+ " | 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))\n",
719
+ " | \n",
720
+ " | named_parameters(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, torch.nn.parameter.Parameter]]\n",
721
+ " | Returns an iterator over module parameters, yielding both the\n",
722
+ " | name of the parameter as well as the parameter itself.\n",
723
+ " | \n",
724
+ " | Args:\n",
725
+ " | prefix (str): prefix to prepend to all parameter names.\n",
726
+ " | recurse (bool): if True, then yields parameters of this module\n",
727
+ " | and all submodules. Otherwise, yields only parameters that\n",
728
+ " | are direct members of this module.\n",
729
+ " | remove_duplicate (bool, optional): whether to remove the duplicated\n",
730
+ " | parameters in the result. Defaults to True.\n",
731
+ " | \n",
732
+ " | Yields:\n",
733
+ " | (str, Parameter): Tuple containing the name and parameter\n",
734
+ " | \n",
735
+ " | Example::\n",
736
+ " | \n",
737
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
738
+ " | >>> for name, param in self.named_parameters():\n",
739
+ " | >>> if name in ['bias']:\n",
740
+ " | >>> print(param.size())\n",
741
+ " | \n",
742
+ " | parameters(self, recurse: bool = True) -> Iterator[torch.nn.parameter.Parameter]\n",
743
+ " | Returns an iterator over module parameters.\n",
744
+ " | \n",
745
+ " | This is typically passed to an optimizer.\n",
746
+ " | \n",
747
+ " | Args:\n",
748
+ " | recurse (bool): if True, then yields parameters of this module\n",
749
+ " | and all submodules. Otherwise, yields only parameters that\n",
750
+ " | are direct members of this module.\n",
751
+ " | \n",
752
+ " | Yields:\n",
753
+ " | Parameter: module parameter\n",
754
+ " | \n",
755
+ " | Example::\n",
756
+ " | \n",
757
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
758
+ " | >>> for param in model.parameters():\n",
759
+ " | >>> print(type(param), param.size())\n",
760
+ " | <class 'torch.Tensor'> (20L,)\n",
761
+ " | <class 'torch.Tensor'> (20L, 1L, 5L, 5L)\n",
762
+ " | \n",
763
+ " | register_backward_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]) -> torch.utils.hooks.RemovableHandle\n",
764
+ " | Registers a backward hook on the module.\n",
765
+ " | \n",
766
+ " | This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and\n",
767
+ " | the behavior of this function will change in future versions.\n",
768
+ " | \n",
769
+ " | Returns:\n",
770
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
771
+ " | a handle that can be used to remove the added hook by calling\n",
772
+ " | ``handle.remove()``\n",
773
+ " | \n",
774
+ " | register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None\n",
775
+ " | Adds a buffer to the module.\n",
776
+ " | \n",
777
+ " | This is typically used to register a buffer that should not to be\n",
778
+ " | considered a model parameter. For example, BatchNorm's ``running_mean``\n",
779
+ " | is not a parameter, but is part of the module's state. Buffers, by\n",
780
+ " | default, are persistent and will be saved alongside parameters. This\n",
781
+ " | behavior can be changed by setting :attr:`persistent` to ``False``. The\n",
782
+ " | only difference between a persistent buffer and a non-persistent buffer\n",
783
+ " | is that the latter will not be a part of this module's\n",
784
+ " | :attr:`state_dict`.\n",
785
+ " | \n",
786
+ " | Buffers can be accessed as attributes using given names.\n",
787
+ " | \n",
788
+ " | Args:\n",
789
+ " | name (str): name of the buffer. The buffer can be accessed\n",
790
+ " | from this module using the given name\n",
791
+ " | tensor (Tensor or None): buffer to be registered. If ``None``, then operations\n",
792
+ " | that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,\n",
793
+ " | the buffer is **not** included in the module's :attr:`state_dict`.\n",
794
+ " | persistent (bool): whether the buffer is part of this module's\n",
795
+ " | :attr:`state_dict`.\n",
796
+ " | \n",
797
+ " | Example::\n",
798
+ " | \n",
799
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
800
+ " | >>> self.register_buffer('running_mean', torch.zeros(num_features))\n",
801
+ " | \n",
802
+ " | register_forward_hook(self, hook: Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle\n",
803
+ " | Registers a forward hook on the module.\n",
804
+ " | \n",
805
+ " | The hook will be called every time after :func:`forward` has computed an output.\n",
806
+ " | \n",
807
+ " | If ``with_kwargs`` is ``False`` or not specified, the input contains only\n",
808
+ " | the positional arguments given to the module. Keyword arguments won't be\n",
809
+ " | passed to the hooks and only to the ``forward``. The hook can modify the\n",
810
+ " | output. It can modify the input inplace but it will not have effect on\n",
811
+ " | forward since this is called after :func:`forward` is called. The hook\n",
812
+ " | should have the following signature::\n",
813
+ " | \n",
814
+ " | hook(module, args, output) -> None or modified output\n",
815
+ " | \n",
816
+ " | If ``with_kwargs`` is ``True``, the forward hook will be passed the\n",
817
+ " | ``kwargs`` given to the forward function and be expected to return the\n",
818
+ " | output possibly modified. The hook should have the following signature::\n",
819
+ " | \n",
820
+ " | hook(module, args, kwargs, output) -> None or modified output\n",
821
+ " | \n",
822
+ " | Args:\n",
823
+ " | hook (Callable): The user defined hook to be registered.\n",
824
+ " | prepend (bool): If ``True``, the provided ``hook`` will be fired\n",
825
+ " | before all existing ``forward`` hooks on this\n",
826
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
827
+ " | ``hook`` will be fired after all existing ``forward`` hooks on\n",
828
+ " | this :class:`torch.nn.modules.Module`. Note that global\n",
829
+ " | ``forward`` hooks registered with\n",
830
+ " | :func:`register_module_forward_hook` will fire before all hooks\n",
831
+ " | registered by this method.\n",
832
+ " | Default: ``False``\n",
833
+ " | with_kwargs (bool): If ``True``, the ``hook`` will be passed the\n",
834
+ " | kwargs given to the forward function.\n",
835
+ " | Default: ``False``\n",
836
+ " | \n",
837
+ " | Returns:\n",
838
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
839
+ " | a handle that can be used to remove the added hook by calling\n",
840
+ " | ``handle.remove()``\n",
841
+ " | \n",
842
+ " | register_forward_pre_hook(self, hook: Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) -> torch.utils.hooks.RemovableHandle\n",
843
+ " | Registers a forward pre-hook on the module.\n",
844
+ " | \n",
845
+ " | The hook will be called every time before :func:`forward` is invoked.\n",
846
+ " | \n",
847
+ " | \n",
848
+ " | If ``with_kwargs`` is false or not specified, the input contains only\n",
849
+ " | the positional arguments given to the module. Keyword arguments won't be\n",
850
+ " | passed to the hooks and only to the ``forward``. The hook can modify the\n",
851
+ " | input. User can either return a tuple or a single modified value in the\n",
852
+ " | hook. We will wrap the value into a tuple if a single value is returned\n",
853
+ " | (unless that value is already a tuple). The hook should have the\n",
854
+ " | following signature::\n",
855
+ " | \n",
856
+ " | hook(module, args) -> None or modified input\n",
857
+ " | \n",
858
+ " | If ``with_kwargs`` is true, the forward pre-hook will be passed the\n",
859
+ " | kwargs given to the forward function. And if the hook modifies the\n",
860
+ " | input, both the args and kwargs should be returned. The hook should have\n",
861
+ " | the following signature::\n",
862
+ " | \n",
863
+ " | hook(module, args, kwargs) -> None or a tuple of modified input and kwargs\n",
864
+ " | \n",
865
+ " | Args:\n",
866
+ " | hook (Callable): The user defined hook to be registered.\n",
867
+ " | prepend (bool): If true, the provided ``hook`` will be fired before\n",
868
+ " | all existing ``forward_pre`` hooks on this\n",
869
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
870
+ " | ``hook`` will be fired after all existing ``forward_pre`` hooks\n",
871
+ " | on this :class:`torch.nn.modules.Module`. Note that global\n",
872
+ " | ``forward_pre`` hooks registered with\n",
873
+ " | :func:`register_module_forward_pre_hook` will fire before all\n",
874
+ " | hooks registered by this method.\n",
875
+ " | Default: ``False``\n",
876
+ " | with_kwargs (bool): If true, the ``hook`` will be passed the kwargs\n",
877
+ " | given to the forward function.\n",
878
+ " | Default: ``False``\n",
879
+ " | \n",
880
+ " | Returns:\n",
881
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
882
+ " | a handle that can be used to remove the added hook by calling\n",
883
+ " | ``handle.remove()``\n",
884
+ " | \n",
885
+ " | register_full_backward_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]], prepend: bool = False) -> torch.utils.hooks.RemovableHandle\n",
886
+ " | Registers a backward hook on the module.\n",
887
+ " | \n",
888
+ " | The hook will be called every time the gradients with respect to a module\n",
889
+ " | are computed, i.e. the hook will execute if and only if the gradients with\n",
890
+ " | respect to module outputs are computed. The hook should have the following\n",
891
+ " | signature::\n",
892
+ " | \n",
893
+ " | hook(module, grad_input, grad_output) -> tuple(Tensor) or None\n",
894
+ " | \n",
895
+ " | The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients\n",
896
+ " | with respect to the inputs and outputs respectively. The hook should\n",
897
+ " | not modify its arguments, but it can optionally return a new gradient with\n",
898
+ " | respect to the input that will be used in place of :attr:`grad_input` in\n",
899
+ " | subsequent computations. :attr:`grad_input` will only correspond to the inputs given\n",
900
+ " | as positional arguments and all kwarg arguments are ignored. Entries\n",
901
+ " | in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor\n",
902
+ " | arguments.\n",
903
+ " | \n",
904
+ " | For technical reasons, when this hook is applied to a Module, its forward function will\n",
905
+ " | receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n",
906
+ " | of each Tensor returned by the Module's forward function.\n",
907
+ " | \n",
908
+ " | .. warning ::\n",
909
+ " | Modifying inputs or outputs inplace is not allowed when using backward hooks and\n",
910
+ " | will raise an error.\n",
911
+ " | \n",
912
+ " | Args:\n",
913
+ " | hook (Callable): The user-defined hook to be registered.\n",
914
+ " | prepend (bool): If true, the provided ``hook`` will be fired before\n",
915
+ " | all existing ``backward`` hooks on this\n",
916
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
917
+ " | ``hook`` will be fired after all existing ``backward`` hooks on\n",
918
+ " | this :class:`torch.nn.modules.Module`. Note that global\n",
919
+ " | ``backward`` hooks registered with\n",
920
+ " | :func:`register_module_full_backward_hook` will fire before\n",
921
+ " | all hooks registered by this method.\n",
922
+ " | \n",
923
+ " | Returns:\n",
924
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
925
+ " | a handle that can be used to remove the added hook by calling\n",
926
+ " | ``handle.remove()``\n",
927
+ " | \n",
928
+ " | register_full_backward_pre_hook(self, hook: Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]], prepend: bool = False) -> torch.utils.hooks.RemovableHandle\n",
929
+ " | Registers a backward pre-hook on the module.\n",
930
+ " | \n",
931
+ " | The hook will be called every time the gradients for the module are computed.\n",
932
+ " | The hook should have the following signature::\n",
933
+ " | \n",
934
+ " | hook(module, grad_output) -> Tensor or None\n",
935
+ " | \n",
936
+ " | The :attr:`grad_output` is a tuple. The hook should\n",
937
+ " | not modify its arguments, but it can optionally return a new gradient with\n",
938
+ " | respect to the output that will be used in place of :attr:`grad_output` in\n",
939
+ " | subsequent computations. Entries in :attr:`grad_output` will be ``None`` for\n",
940
+ " | all non-Tensor arguments.\n",
941
+ " | \n",
942
+ " | For technical reasons, when this hook is applied to a Module, its forward function will\n",
943
+ " | receive a view of each Tensor passed to the Module. Similarly the caller will receive a view\n",
944
+ " | of each Tensor returned by the Module's forward function.\n",
945
+ " | \n",
946
+ " | .. warning ::\n",
947
+ " | Modifying inputs inplace is not allowed when using backward hooks and\n",
948
+ " | will raise an error.\n",
949
+ " | \n",
950
+ " | Args:\n",
951
+ " | hook (Callable): The user-defined hook to be registered.\n",
952
+ " | prepend (bool): If true, the provided ``hook`` will be fired before\n",
953
+ " | all existing ``backward_pre`` hooks on this\n",
954
+ " | :class:`torch.nn.modules.Module`. Otherwise, the provided\n",
955
+ " | ``hook`` will be fired after all existing ``backward_pre`` hooks\n",
956
+ " | on this :class:`torch.nn.modules.Module`. Note that global\n",
957
+ " | ``backward_pre`` hooks registered with\n",
958
+ " | :func:`register_module_full_backward_pre_hook` will fire before\n",
959
+ " | all hooks registered by this method.\n",
960
+ " | \n",
961
+ " | Returns:\n",
962
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
963
+ " | a handle that can be used to remove the added hook by calling\n",
964
+ " | ``handle.remove()``\n",
965
+ " | \n",
966
+ " | register_load_state_dict_post_hook(self, hook)\n",
967
+ " | Registers a post hook to be run after module's ``load_state_dict``\n",
968
+ " | is called.\n",
969
+ " | \n",
970
+ " | It should have the following signature::\n",
971
+ " | hook(module, incompatible_keys) -> None\n",
972
+ " | \n",
973
+ " | The ``module`` argument is the current module that this hook is registered\n",
974
+ " | on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting\n",
975
+ " | of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``\n",
976
+ " | is a ``list`` of ``str`` containing the missing keys and\n",
977
+ " | ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.\n",
978
+ " | \n",
979
+ " | The given incompatible_keys can be modified inplace if needed.\n",
980
+ " | \n",
981
+ " | Note that the checks performed when calling :func:`load_state_dict` with\n",
982
+ " | ``strict=True`` are affected by modifications the hook makes to\n",
983
+ " | ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either\n",
984
+ " | set of keys will result in an error being thrown when ``strict=True``, and\n",
985
+ " | clearing out both missing and unexpected keys will avoid an error.\n",
986
+ " | \n",
987
+ " | Returns:\n",
988
+ " | :class:`torch.utils.hooks.RemovableHandle`:\n",
989
+ " | a handle that can be used to remove the added hook by calling\n",
990
+ " | ``handle.remove()``\n",
991
+ " | \n",
992
+ " | register_module(self, name: str, module: Optional[ForwardRef('Module')]) -> None\n",
993
+ " | Alias for :func:`add_module`.\n",
994
+ " | \n",
995
+ " | register_parameter(self, name: str, param: Optional[torch.nn.parameter.Parameter]) -> None\n",
996
+ " | Adds a parameter to the module.\n",
997
+ " | \n",
998
+ " | The parameter can be accessed as an attribute using given name.\n",
999
+ " | \n",
1000
+ " | Args:\n",
1001
+ " | name (str): name of the parameter. The parameter can be accessed\n",
1002
+ " | from this module using the given name\n",
1003
+ " | param (Parameter or None): parameter to be added to the module. If\n",
1004
+ " | ``None``, then operations that run on parameters, such as :attr:`cuda`,\n",
1005
+ " | are ignored. If ``None``, the parameter is **not** included in the\n",
1006
+ " | module's :attr:`state_dict`.\n",
1007
+ " | \n",
1008
+ " | register_state_dict_pre_hook(self, hook)\n",
1009
+ " | These hooks will be called with arguments: ``self``, ``prefix``,\n",
1010
+ " | and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered\n",
1011
+ " | hooks can be used to perform pre-processing before the ``state_dict``\n",
1012
+ " | call is made.\n",
1013
+ " | \n",
1014
+ " | requires_grad_(self: ~T, requires_grad: bool = True) -> ~T\n",
1015
+ " | Change if autograd should record operations on parameters in this\n",
1016
+ " | module.\n",
1017
+ " | \n",
1018
+ " | This method sets the parameters' :attr:`requires_grad` attributes\n",
1019
+ " | in-place.\n",
1020
+ " | \n",
1021
+ " | This method is helpful for freezing part of the module for finetuning\n",
1022
+ " | or training parts of a model individually (e.g., GAN training).\n",
1023
+ " | \n",
1024
+ " | See :ref:`locally-disable-grad-doc` for a comparison between\n",
1025
+ " | `.requires_grad_()` and several similar mechanisms that may be confused with it.\n",
1026
+ " | \n",
1027
+ " | Args:\n",
1028
+ " | requires_grad (bool): whether autograd should record operations on\n",
1029
+ " | parameters in this module. Default: ``True``.\n",
1030
+ " | \n",
1031
+ " | Returns:\n",
1032
+ " | Module: self\n",
1033
+ " | \n",
1034
+ " | set_extra_state(self, state: Any)\n",
1035
+ " | This function is called from :func:`load_state_dict` to handle any extra state\n",
1036
+ " | found within the `state_dict`. Implement this function and a corresponding\n",
1037
+ " | :func:`get_extra_state` for your module if you need to store extra state within its\n",
1038
+ " | `state_dict`.\n",
1039
+ " | \n",
1040
+ " | Args:\n",
1041
+ " | state (dict): Extra state from the `state_dict`\n",
1042
+ " | \n",
1043
+ " | share_memory(self: ~T) -> ~T\n",
1044
+ " | See :meth:`torch.Tensor.share_memory_`\n",
1045
+ " | \n",
1046
+ " | state_dict(self, *args, destination=None, prefix='', keep_vars=False)\n",
1047
+ " | Returns a dictionary containing references to the whole state of the module.\n",
1048
+ " | \n",
1049
+ " | Both parameters and persistent buffers (e.g. running averages) are\n",
1050
+ " | included. Keys are corresponding parameter and buffer names.\n",
1051
+ " | Parameters and buffers set to ``None`` are not included.\n",
1052
+ " | \n",
1053
+ " | .. note::\n",
1054
+ " | The returned object is a shallow copy. It contains references\n",
1055
+ " | to the module's parameters and buffers.\n",
1056
+ " | \n",
1057
+ " | .. warning::\n",
1058
+ " | Currently ``state_dict()`` also accepts positional arguments for\n",
1059
+ " | ``destination``, ``prefix`` and ``keep_vars`` in order. However,\n",
1060
+ " | this is being deprecated and keyword arguments will be enforced in\n",
1061
+ " | future releases.\n",
1062
+ " | \n",
1063
+ " | .. warning::\n",
1064
+ " | Please avoid the use of argument ``destination`` as it is not\n",
1065
+ " | designed for end-users.\n",
1066
+ " | \n",
1067
+ " | Args:\n",
1068
+ " | destination (dict, optional): If provided, the state of module will\n",
1069
+ " | be updated into the dict and the same object is returned.\n",
1070
+ " | Otherwise, an ``OrderedDict`` will be created and returned.\n",
1071
+ " | Default: ``None``.\n",
1072
+ " | prefix (str, optional): a prefix added to parameter and buffer\n",
1073
+ " | names to compose the keys in state_dict. Default: ``''``.\n",
1074
+ " | keep_vars (bool, optional): by default the :class:`~torch.Tensor` s\n",
1075
+ " | returned in the state dict are detached from autograd. If it's\n",
1076
+ " | set to ``True``, detaching will not be performed.\n",
1077
+ " | Default: ``False``.\n",
1078
+ " | \n",
1079
+ " | Returns:\n",
1080
+ " | dict:\n",
1081
+ " | a dictionary containing a whole state of the module\n",
1082
+ " | \n",
1083
+ " | Example::\n",
1084
+ " | \n",
1085
+ " | >>> # xdoctest: +SKIP(\"undefined vars\")\n",
1086
+ " | >>> module.state_dict().keys()\n",
1087
+ " | ['bias', 'weight']\n",
1088
+ " | \n",
1089
+ " | to(self, *args, **kwargs)\n",
1090
+ " | Moves and/or casts the parameters and buffers.\n",
1091
+ " | \n",
1092
+ " | This can be called as\n",
1093
+ " | \n",
1094
+ " | .. function:: to(device=None, dtype=None, non_blocking=False)\n",
1095
+ " | :noindex:\n",
1096
+ " | \n",
1097
+ " | .. function:: to(dtype, non_blocking=False)\n",
1098
+ " | :noindex:\n",
1099
+ " | \n",
1100
+ " | .. function:: to(tensor, non_blocking=False)\n",
1101
+ " | :noindex:\n",
1102
+ " | \n",
1103
+ " | .. function:: to(memory_format=torch.channels_last)\n",
1104
+ " | :noindex:\n",
1105
+ " | \n",
1106
+ " | Its signature is similar to :meth:`torch.Tensor.to`, but only accepts\n",
1107
+ " | floating point or complex :attr:`dtype`\\ s. In addition, this method will\n",
1108
+ " | only cast the floating point or complex parameters and buffers to :attr:`dtype`\n",
1109
+ " | (if given). The integral parameters and buffers will be moved\n",
1110
+ " | :attr:`device`, if that is given, but with dtypes unchanged. When\n",
1111
+ " | :attr:`non_blocking` is set, it tries to convert/move asynchronously\n",
1112
+ " | with respect to the host if possible, e.g., moving CPU Tensors with\n",
1113
+ " | pinned memory to CUDA devices.\n",
1114
+ " | \n",
1115
+ " | See below for examples.\n",
1116
+ " | \n",
1117
+ " | .. note::\n",
1118
+ " | This method modifies the module in-place.\n",
1119
+ " | \n",
1120
+ " | Args:\n",
1121
+ " | device (:class:`torch.device`): the desired device of the parameters\n",
1122
+ " | and buffers in this module\n",
1123
+ " | dtype (:class:`torch.dtype`): the desired floating point or complex dtype of\n",
1124
+ " | the parameters and buffers in this module\n",
1125
+ " | tensor (torch.Tensor): Tensor whose dtype and device are the desired\n",
1126
+ " | dtype and device for all parameters and buffers in this module\n",
1127
+ " | memory_format (:class:`torch.memory_format`): the desired memory\n",
1128
+ " | format for 4D parameters and buffers in this module (keyword\n",
1129
+ " | only argument)\n",
1130
+ " | \n",
1131
+ " | Returns:\n",
1132
+ " | Module: self\n",
1133
+ " | \n",
1134
+ " | Examples::\n",
1135
+ " | \n",
1136
+ " | >>> # xdoctest: +IGNORE_WANT(\"non-deterministic\")\n",
1137
+ " | >>> linear = nn.Linear(2, 2)\n",
1138
+ " | >>> linear.weight\n",
1139
+ " | Parameter containing:\n",
1140
+ " | tensor([[ 0.1913, -0.3420],\n",
1141
+ " | [-0.5113, -0.2325]])\n",
1142
+ " | >>> linear.to(torch.double)\n",
1143
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
1144
+ " | >>> linear.weight\n",
1145
+ " | Parameter containing:\n",
1146
+ " | tensor([[ 0.1913, -0.3420],\n",
1147
+ " | [-0.5113, -0.2325]], dtype=torch.float64)\n",
1148
+ " | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)\n",
1149
+ " | >>> gpu1 = torch.device(\"cuda:1\")\n",
1150
+ " | >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)\n",
1151
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
1152
+ " | >>> linear.weight\n",
1153
+ " | Parameter containing:\n",
1154
+ " | tensor([[ 0.1914, -0.3420],\n",
1155
+ " | [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')\n",
1156
+ " | >>> cpu = torch.device(\"cpu\")\n",
1157
+ " | >>> linear.to(cpu)\n",
1158
+ " | Linear(in_features=2, out_features=2, bias=True)\n",
1159
+ " | >>> linear.weight\n",
1160
+ " | Parameter containing:\n",
1161
+ " | tensor([[ 0.1914, -0.3420],\n",
1162
+ " | [-0.5112, -0.2324]], dtype=torch.float16)\n",
1163
+ " | \n",
1164
+ " | >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)\n",
1165
+ " | >>> linear.weight\n",
1166
+ " | Parameter containing:\n",
1167
+ " | tensor([[ 0.3741+0.j, 0.2382+0.j],\n",
1168
+ " | [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)\n",
1169
+ " | >>> linear(torch.ones(3, 2, dtype=torch.cdouble))\n",
1170
+ " | tensor([[0.6122+0.j, 0.1150+0.j],\n",
1171
+ " | [0.6122+0.j, 0.1150+0.j],\n",
1172
+ " | [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)\n",
1173
+ " | \n",
1174
+ " | to_empty(self: ~T, *, device: Union[str, torch.device]) -> ~T\n",
1175
+ " | Moves the parameters and buffers to the specified device without copying storage.\n",
1176
+ " | \n",
1177
+ " | Args:\n",
1178
+ " | device (:class:`torch.device`): The desired device of the parameters\n",
1179
+ " | and buffers in this module.\n",
1180
+ " | \n",
1181
+ " | Returns:\n",
1182
+ " | Module: self\n",
1183
+ " | \n",
1184
+ " | train(self: ~T, mode: bool = True) -> ~T\n",
1185
+ " | Sets the module in training mode.\n",
1186
+ " | \n",
1187
+ " | This has any effect only on certain modules. See documentations of\n",
1188
+ " | particular modules for details of their behaviors in training/evaluation\n",
1189
+ " | mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,\n",
1190
+ " | etc.\n",
1191
+ " | \n",
1192
+ " | Args:\n",
1193
+ " | mode (bool): whether to set training mode (``True``) or evaluation\n",
1194
+ " | mode (``False``). Default: ``True``.\n",
1195
+ " | \n",
1196
+ " | Returns:\n",
1197
+ " | Module: self\n",
1198
+ " | \n",
1199
+ " | type(self: ~T, dst_type: Union[torch.dtype, str]) -> ~T\n",
1200
+ " | Casts all parameters and buffers to :attr:`dst_type`.\n",
1201
+ " | \n",
1202
+ " | .. note::\n",
1203
+ " | This method modifies the module in-place.\n",
1204
+ " | \n",
1205
+ " | Args:\n",
1206
+ " | dst_type (type or string): the desired type\n",
1207
+ " | \n",
1208
+ " | Returns:\n",
1209
+ " | Module: self\n",
1210
+ " | \n",
1211
+ " | xpu(self: ~T, device: Union[int, torch.device, NoneType] = None) -> ~T\n",
1212
+ " | Moves all model parameters and buffers to the XPU.\n",
1213
+ " | \n",
1214
+ " | This also makes associated parameters and buffers different objects. So\n",
1215
+ " | it should be called before constructing optimizer if the module will\n",
1216
+ " | live on XPU while being optimized.\n",
1217
+ " | \n",
1218
+ " | .. note::\n",
1219
+ " | This method modifies the module in-place.\n",
1220
+ " | \n",
1221
+ " | Arguments:\n",
1222
+ " | device (int, optional): if specified, all parameters will be\n",
1223
+ " | copied to that device\n",
1224
+ " | \n",
1225
+ " | Returns:\n",
1226
+ " | Module: self\n",
1227
+ " | \n",
1228
+ " | zero_grad(self, set_to_none: bool = True) -> None\n",
1229
+ " | Sets gradients of all model parameters to zero. See similar function\n",
1230
+ " | under :class:`torch.optim.Optimizer` for more context.\n",
1231
+ " | \n",
1232
+ " | Args:\n",
1233
+ " | set_to_none (bool): instead of setting to zero, set the grads to None.\n",
1234
+ " | See :meth:`torch.optim.Optimizer.zero_grad` for details.\n",
1235
+ " | \n",
1236
+ " | ----------------------------------------------------------------------\n",
1237
+ " | Data descriptors inherited from torch.nn.modules.module.Module:\n",
1238
+ " | \n",
1239
+ " | __dict__\n",
1240
+ " | dictionary for instance variables (if defined)\n",
1241
+ " | \n",
1242
+ " | __weakref__\n",
1243
+ " | list of weak references to the object (if defined)\n",
1244
+ " | \n",
1245
+ " | ----------------------------------------------------------------------\n",
1246
+ " | Data and other attributes inherited from torch.nn.modules.module.Module:\n",
1247
+ " | \n",
1248
+ " | T_destination = ~T_destination\n",
1249
+ " | \n",
1250
+ " | call_super_init = False\n",
1251
+ " | \n",
1252
+ " | dump_patches = False\n",
1253
+ "\n"
1254
+ ]
1255
+ }
1256
+ ],
1257
+ "source": [
1258
+ "help(eva02_model)"
1259
+ ]
1260
+ },
1261
+ {
1262
+ "cell_type": "markdown",
1263
+ "id": "2f5ac1a7-6f1b-4417-8a67-1b2e32d385dd",
1264
+ "metadata": {},
1265
+ "source": [
1266
+ "# DETR"
1267
+ ]
1268
+ },
1269
+ {
1270
+ "cell_type": "code",
1271
+ "execution_count": 33,
1272
+ "id": "5c3ade1b-18ea-4368-abd9-53be1fdfb610",
1273
+ "metadata": {},
1274
+ "outputs": [
1275
+ {
1276
+ "name": "stdout",
1277
+ "output_type": "stream",
1278
+ "text": [
1279
+ "[2023-08-28 01:51:14,033] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
1280
+ ]
1281
+ },
1282
+ {
1283
+ "name": "stderr",
1284
+ "output_type": "stream",
1285
+ "text": [
1286
+ "The `max_size` parameter is deprecated and will be removed in v4.26. Please specify in `size['longest_edge'] instead`.\n"
1287
+ ]
1288
+ }
1289
+ ],
1290
+ "source": [
1291
+ "from transformers import DetrImageProcessor, DetrForObjectDetection\n",
1292
+ "import torch\n",
1293
+ "from PIL import Image\n",
1294
+ "import requests\n",
1295
+ "\n",
1296
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
1297
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
1298
+ "\n",
1299
+ "processor = DetrImageProcessor.from_pretrained(\"facebook/detr-resnet-50\", cache_dir='/fsx/proj-fmri/shared/cache')\n",
1300
+ "model = DetrForObjectDetection.from_pretrained(\"facebook/detr-resnet-50\", cache_dir='/fsx/proj-fmri/shared/cache')"
1301
+ ]
1302
+ },
1303
+ {
1304
+ "cell_type": "code",
1305
+ "execution_count": 34,
1306
+ "id": "1d5aa2d7-4868-4751-8d90-7c52be028cd9",
1307
+ "metadata": {},
1308
+ "outputs": [],
1309
+ "source": [
1310
+ "inputs = processor(images=image, return_tensors=\"pt\")\n",
1311
+ "outputs = model(**inputs)"
1312
+ ]
1313
+ },
1314
+ {
1315
+ "cell_type": "code",
1316
+ "execution_count": 35,
1317
+ "id": "ae6bafc6-cee4-4e59-b7ba-12efc2a65b74",
1318
+ "metadata": {},
1319
+ "outputs": [
1320
+ {
1321
+ "name": "stdout",
1322
+ "output_type": "stream",
1323
+ "text": [
1324
+ "Detected remote with confidence 0.998 at location [40.16, 70.81, 175.55, 117.98]\n",
1325
+ "Detected remote with confidence 0.996 at location [333.24, 72.55, 368.33, 187.66]\n",
1326
+ "Detected couch with confidence 0.995 at location [-0.02, 1.15, 639.73, 473.76]\n",
1327
+ "Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93]\n",
1328
+ "Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72]\n"
1329
+ ]
1330
+ }
1331
+ ],
1332
+ "source": [
1333
+ "# convert outputs (bounding boxes and class logits) to COCO API\n",
1334
+ "# let's only keep detections with score > 0.9\n",
1335
+ "target_sizes = torch.tensor([image.size[::-1]])\n",
1336
+ "results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]\n",
1337
+ "\n",
1338
+ "for score, label, box in zip(results[\"scores\"], results[\"labels\"], results[\"boxes\"]):\n",
1339
+ " box = [round(i, 2) for i in box.tolist()]\n",
1340
+ " print(\n",
1341
+ " f\"Detected {model.config.id2label[label.item()]} with confidence \"\n",
1342
+ " f\"{round(score.item(), 3)} at location {box}\"\n",
1343
+ " )"
1344
+ ]
1345
+ },
1346
+ {
1347
+ "cell_type": "code",
1348
+ "execution_count": 36,
1349
+ "id": "6dcc5934-79d4-4062-8b32-e42b3ebcdc0f",
1350
+ "metadata": {},
1351
+ "outputs": [
1352
+ {
1353
+ "data": {
1354
+ "text/plain": [
1355
+ "DetrImageProcessor {\n",
1356
+ " \"do_normalize\": true,\n",
1357
+ " \"do_pad\": true,\n",
1358
+ " \"do_rescale\": true,\n",
1359
+ " \"do_resize\": true,\n",
1360
+ " \"feature_extractor_type\": \"DetrFeatureExtractor\",\n",
1361
+ " \"format\": \"coco_detection\",\n",
1362
+ " \"image_mean\": [\n",
1363
+ " 0.485,\n",
1364
+ " 0.456,\n",
1365
+ " 0.406\n",
1366
+ " ],\n",
1367
+ " \"image_processor_type\": \"DetrImageProcessor\",\n",
1368
+ " \"image_std\": [\n",
1369
+ " 0.229,\n",
1370
+ " 0.224,\n",
1371
+ " 0.225\n",
1372
+ " ],\n",
1373
+ " \"resample\": 2,\n",
1374
+ " \"rescale_factor\": 0.00392156862745098,\n",
1375
+ " \"size\": {\n",
1376
+ " \"longest_edge\": 1333,\n",
1377
+ " \"shortest_edge\": 800\n",
1378
+ " }\n",
1379
+ "}"
1380
+ ]
1381
+ },
1382
+ "execution_count": 36,
1383
+ "metadata": {},
1384
+ "output_type": "execute_result"
1385
+ }
1386
+ ],
1387
+ "source": [
1388
+ "processor"
1389
+ ]
1390
+ },
1391
+ {
1392
+ "cell_type": "markdown",
1393
+ "id": "db1d89cc-b432-473e-af69-d81c435ac731",
1394
+ "metadata": {},
1395
+ "source": [
1396
+ "# CLIPSeg"
1397
+ ]
1398
+ },
1399
+ {
1400
+ "cell_type": "code",
1401
+ "execution_count": 37,
1402
+ "id": "15db14d1-ee4d-4429-9286-054c4498293b",
1403
+ "metadata": {},
1404
+ "outputs": [],
1405
+ "source": [
1406
+ "from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation\n",
1407
+ "\n",
1408
+ "processor = CLIPSegProcessor.from_pretrained(\"CIDAS/clipseg-rd16\",cache_dir='/fsx/proj-fmri/shared/cache')\n",
1409
+ "model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd16\",cache_dir='/fsx/proj-fmri/shared/cache')"
1410
+ ]
1411
+ },
1412
+ {
1413
+ "cell_type": "code",
1414
+ "execution_count": 38,
1415
+ "id": "4aa225d4-5a3b-4dbb-ae57-dea2872ff492",
1416
+ "metadata": {},
1417
+ "outputs": [
1418
+ {
1419
+ "ename": "AttributeError",
1420
+ "evalue": "'JpegImageFile' object has no attribute 'shape'",
1421
+ "output_type": "error",
1422
+ "traceback": [
1423
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1424
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1425
+ "Cell \u001b[0;32mIn[38], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mimage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\n",
1426
+ "\u001b[0;31mAttributeError\u001b[0m: 'JpegImageFile' object has no attribute 'shape'"
1427
+ ]
1428
+ }
1429
+ ],
1430
+ "source": [
1431
+ "image.shape"
1432
+ ]
1433
+ },
1434
+ {
1435
+ "cell_type": "code",
1436
+ "execution_count": null,
1437
+ "id": "ad7e2daf-0c7c-4fec-b29e-9ba47a037c6b",
1438
+ "metadata": {},
1439
+ "outputs": [],
1440
+ "source": [
1441
+ "from PIL import Image\n",
1442
+ "import requests\n",
1443
+ "import h5py\n",
1444
+ "\n",
1445
+ "# url = \"https://unsplash.com/photos/8Nc_oQsc2qQ/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjcxMjAwNzI0&force=true&w=640\"\n",
1446
+ "# image = Image.open(requests.get(url, stream=True).raw)\n",
1447
+ "\n",
1448
+ "image_path = \"/fsx/proj-fmri/shared/mindeyev2_dataset/coco_images_224_float16.hdf5\"\n",
1449
+ "with h5py.File(image_path, 'r') as file:\n",
1450
+ " image = file['images'][0]\n",
1451
+ "image = np.moveaxis(image, 0, -1).astype(np.float32)\n",
1452
+ "plt.imshow(image)\n",
1453
+ "\n",
1454
+ "prompts = [\"person\",\"animal\",\"object\",\"background\"]\n",
1455
+ "import torch\n",
1456
+ "\n",
1457
+ "# Rescale to [0, 255]\n",
1458
+ "array = (image * 255).astype(np.uint8)\n",
1459
+ "\n",
1460
+ "# Convert to PIL image\n",
1461
+ "image = Image.fromarray(array)\n",
1462
+ "\n",
1463
+ "inputs = processor(text=prompts, images=[image] * len(prompts), padding=\"max_length\", return_tensors=\"pt\")\n",
1464
+ "# predict\n",
1465
+ "with torch.no_grad():\n",
1466
+ " outputs = model(**inputs)\n",
1467
+ "preds = outputs.logits.unsqueeze(1)\n",
1468
+ "print(preds.shape)"
1469
+ ]
1470
+ },
1471
+ {
1472
+ "cell_type": "code",
1473
+ "execution_count": null,
1474
+ "id": "131eb5b7-2f16-4a79-8402-edc1a1d8c348",
1475
+ "metadata": {},
1476
+ "outputs": [],
1477
+ "source": [
1478
+ "preds = ((preds[0] + preds[1] + preds[2] + preds[-1].max() - preds[-1]) / 4)[None]\n",
1479
+ "preds.shape"
1480
+ ]
1481
+ },
1482
+ {
1483
+ "cell_type": "code",
1484
+ "execution_count": null,
1485
+ "id": "e2bf99e7-064d-4c22-997f-aa1a35dbab82",
1486
+ "metadata": {},
1487
+ "outputs": [],
1488
+ "source": [
1489
+ "_, ax = plt.subplots(1, len(prompts) + 1, figsize=(3*(len(prompts) + 1), 4))\n",
1490
+ "[a.axis('off') for a in ax.flatten()]\n",
1491
+ "ax[0].imshow(image)\n",
1492
+ "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(1)];\n",
1493
+ "# [ax[i+1].text(0, -15, prompt) for i, prompt in enumerate(prompts)];"
1494
+ ]
1495
+ },
1496
+ {
1497
+ "cell_type": "code",
1498
+ "execution_count": null,
1499
+ "id": "b58b926f-a2b2-423b-b367-18808cf6b4f7",
1500
+ "metadata": {},
1501
+ "outputs": [],
1502
+ "source": []
1503
+ }
1504
+ ],
1505
+ "metadata": {
1506
+ "kernelspec": {
1507
+ "display_name": "Python 3 (ipykernel)",
1508
+ "language": "python",
1509
+ "name": "python3"
1510
+ },
1511
+ "language_info": {
1512
+ "codemirror_mode": {
1513
+ "name": "ipython",
1514
+ "version": 3
1515
+ },
1516
+ "file_extension": ".py",
1517
+ "mimetype": "text/x-python",
1518
+ "name": "python",
1519
+ "nbconvert_exporter": "python",
1520
+ "pygments_lexer": "ipython3",
1521
+ "version": "3.10.8"
1522
+ }
1523
+ },
1524
+ "nbformat": 4,
1525
+ "nbformat_minor": 5
1526
+ }
src/deepspeed_config_stage1.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bf16": {"enabled": false}, "fp16": {"enabled": true}, "zero_optimization": {"stage": 1, "contiguous_gradients": true, "stage3_gather_16bit_weights_on_model_save": true, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_prefetch_bucket_size": 10000000.0, "stage3_param_persistence_threshold": 100000.0, "reduce_bucket_size": 10000000.0, "sub_group_size": 1000000000.0, "offload_optimizer": {"device": "none", "nvme_path": "/scratch", "pin_memory": true}, "offload_param": {"device": "none", "nvme_path": "/scratch", "buffer_size": 4000000000.0, "pin_memory": true}}, "aio": {"block_size": 26214400, "queue_depth": 32, "thread_count": 1, "single_submit": false, "overlap_events": true}, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 20000, "train_batch_size": 8, "train_micro_batch_size_per_gpu": 8, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true}
src/deepspeed_config_stage2.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bf16": {"enabled": false}, "fp16": {"enabled": true}, "zero_optimization": {"stage": 2, "contiguous_gradients": true, "stage3_gather_16bit_weights_on_model_save": true, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_prefetch_bucket_size": 10000000.0, "stage3_param_persistence_threshold": 100000.0, "reduce_bucket_size": 10000000.0, "sub_group_size": 1000000000.0, "offload_optimizer": {"device": "cpu", "nvme_path": "/scratch", "pin_memory": true}, "offload_param": {"device": "none", "nvme_path": "/scratch", "buffer_size": 4000000000.0, "pin_memory": true}}, "aio": {"block_size": 26214400, "queue_depth": 32, "thread_count": 1, "single_submit": false, "overlap_events": true}, "gradient_accumulation_steps": 1, "gradient_clipping": 1.0, "steps_per_print": 20000, "train_batch_size": 256, "train_micro_batch_size_per_gpu": 32, "wall_clock_breakdown": false, "zero_allow_untested_optimizer": true}
src/deepspeed_config_stage2_cpuoffload.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": false
4
+ },
5
+ "fp16": {
6
+ "enabled": true
7
+ },
8
+ "zero_optimization": {
9
+ "stage": 2,
10
+ "contiguous_gradients": true,
11
+ "stage3_gather_16bit_weights_on_model_save": true,
12
+ "stage3_max_live_parameters": 1000000000.0,
13
+ "stage3_max_reuse_distance": 1000000000.0,
14
+ "stage3_prefetch_bucket_size": 10000000.0,
15
+ "stage3_param_persistence_threshold": 100000.0,
16
+ "reduce_bucket_size": 10000000.0,
17
+ "sub_group_size": 1000000000.0,
18
+ "offload_optimizer": {
19
+ "device": "cpu",
20
+ "nvme_path": "/scratch",
21
+ "pin_memory": true
22
+ },
23
+ "offload_param": {
24
+ "device": "none",
25
+ "nvme_path": "/scratch",
26
+ "buffer_size": 4000000000.0,
27
+ "pin_memory": true
28
+ }
29
+ },
30
+ "aio": {
31
+ "block_size": 26214400,
32
+ "queue_depth": 32,
33
+ "thread_count": 1,
34
+ "single_submit": false,
35
+ "overlap_events": true
36
+ },
37
+ "gradient_accumulation_steps": 1,
38
+ "gradient_clipping": 1.0,
39
+ "steps_per_print": 20000,
40
+ "train_batch_size": 256,
41
+ "train_micro_batch_size_per_gpu": 32,
42
+ "wall_clock_breakdown": false,
43
+ "zero_allow_untested_optimizer": true
44
+ }
src/deepspeed_config_stage3.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"fp16": {"enabled": true}, "optimizer": {"type": "AdamW", "params": {"lr": "auto", "weight_decay": "auto"}}, "scheduler": {"type": "WarmupDecayLR", "params": {"warmup_min_lr": "auto", "warmup_max_lr": "auto", "warmup_num_steps": "auto", "total_num_steps": "auto"}}, "zero_optimization": {"stage": 3, "offload_optimizer": {"device": "cpu", "pin_memory": true}, "offload_param": {"device": "cpu", "pin_memory": true}, "overlap_comm": true, "contiguous_gradients": true, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", "sub_group_size": 1000000000.0, "stage3_max_live_parameters": 1000000000.0, "stage3_max_reuse_distance": 1000000000.0, "stage3_gather_16bit_weights_on_model_save": "auto"}, "gradient_accumulation_steps": 1, "gradient_clipping": "auto", "steps_per_print": 2000, "train_batch_size": 256, "train_micro_batch_size_per_gpu": 32, "wall_clock_breakdown": false, "bf16": {"enabled": false}}
src/huggingface_to_s3.ipynb ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "cf698d59-1cc2-4859-9c43-9a5d4d924ee1",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Transfer huggingface mindeyev2 dataset to Stability aws s3"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "94c7404c-7a0f-4508-a630-954bc9af11fa",
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/shared1000.npy -O /fsx/proj-fmri/shared/mindeyev2_dataset/shared1000.npy\n",
22
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj01.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj01.hdf5\n",
23
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj02.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj02.hdf5\n",
24
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj03.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj03.hdf5\n",
25
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj04.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj04.hdf5\n",
26
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj05.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj05.hdf5\n",
27
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj06.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj06.hdf5\n",
28
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj07.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj07.hdf5\n",
29
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/betas_all_subj08.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/betas_all_subj08.hdf5\n",
30
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/coco_images_224_float16.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/coco_images_224_float16.hdf5\n",
31
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/COCO_73k_subj_indices.hdf5 -O /fsx/proj-fmri/shared/mindeyev2_dataset/COCO_73k_subj_indices.hdf5\n",
32
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/0.tar\n",
33
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/1.tar\n",
34
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/2.tar\n",
35
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/3.tar\n",
36
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/4.tar\n",
37
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/5.tar\n",
38
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/6.tar\n",
39
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/7.tar\n",
40
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/8.tar\n",
41
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/9.tar\n",
42
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/10.tar\n",
43
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/11.tar\n",
44
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/12.tar\n",
45
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/13.tar\n",
46
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/14.tar\n",
47
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/15.tar\n",
48
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/16.tar\n",
49
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/17.tar\n",
50
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/18.tar\n",
51
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/19.tar\n",
52
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/20.tar\n",
53
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/21.tar\n",
54
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/22.tar\n",
55
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/23.tar\n",
56
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/24.tar\n",
57
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/25.tar\n",
58
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/26.tar\n",
59
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/27.tar\n",
60
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/28.tar\n",
61
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/29.tar\n",
62
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/30.tar\n",
63
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/31.tar\n",
64
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/32.tar\n",
65
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/33.tar\n",
66
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/34.tar\n",
67
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/35.tar\n",
68
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/36.tar\n",
69
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj01/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n",
70
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/0.tar\n",
71
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/1.tar\n",
72
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/2.tar\n",
73
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/3.tar\n",
74
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/4.tar\n",
75
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/5.tar\n",
76
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/6.tar\n",
77
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/7.tar\n",
78
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/8.tar\n",
79
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/9.tar\n",
80
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/10.tar\n",
81
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/11.tar\n",
82
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/12.tar\n",
83
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/13.tar\n",
84
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/14.tar\n",
85
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/15.tar\n",
86
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/16.tar\n",
87
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/17.tar\n",
88
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/18.tar\n",
89
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/19.tar\n",
90
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/20.tar\n",
91
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/21.tar\n",
92
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/22.tar\n",
93
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/23.tar\n",
94
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/24.tar\n",
95
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/25.tar\n",
96
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/26.tar\n",
97
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/27.tar\n",
98
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/28.tar\n",
99
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/29.tar\n",
100
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/30.tar\n",
101
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/31.tar\n",
102
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/32.tar\n",
103
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/33.tar\n",
104
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/34.tar\n",
105
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/35.tar\n",
106
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/train/36.tar\n",
107
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj02/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj02/test/0.tar\n",
108
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/0.tar\n",
109
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/1.tar\n",
110
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/2.tar\n",
111
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/3.tar\n",
112
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/4.tar\n",
113
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/5.tar\n",
114
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/6.tar\n",
115
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/7.tar\n",
116
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/8.tar\n",
117
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/9.tar\n",
118
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/10.tar\n",
119
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/11.tar\n",
120
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/12.tar\n",
121
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/13.tar\n",
122
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/14.tar\n",
123
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/15.tar\n",
124
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/16.tar\n",
125
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/17.tar\n",
126
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/18.tar\n",
127
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/19.tar\n",
128
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/20.tar\n",
129
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/21.tar\n",
130
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/22.tar\n",
131
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/23.tar\n",
132
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/24.tar\n",
133
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/25.tar\n",
134
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/26.tar\n",
135
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/27.tar\n",
136
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/28.tar\n",
137
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/29.tar\n",
138
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/30.tar\n",
139
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/31.tar\n",
140
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/32.tar\n",
141
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/33.tar\n",
142
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/34.tar\n",
143
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/35.tar\n",
144
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/train/36.tar\n",
145
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj03/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj03/test/0.tar\n",
146
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/0.tar\n",
147
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/1.tar\n",
148
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/2.tar\n",
149
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/3.tar\n",
150
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/4.tar\n",
151
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/5.tar\n",
152
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/6.tar\n",
153
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/7.tar\n",
154
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/8.tar\n",
155
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/9.tar\n",
156
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/10.tar\n",
157
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/11.tar\n",
158
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/12.tar\n",
159
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/13.tar\n",
160
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/14.tar\n",
161
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/15.tar\n",
162
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/16.tar\n",
163
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/17.tar\n",
164
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/18.tar\n",
165
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/19.tar\n",
166
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/20.tar\n",
167
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/21.tar\n",
168
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/22.tar\n",
169
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/23.tar\n",
170
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/24.tar\n",
171
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/25.tar\n",
172
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/26.tar\n",
173
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/27.tar\n",
174
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/28.tar\n",
175
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/29.tar\n",
176
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/30.tar\n",
177
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/31.tar\n",
178
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/32.tar\n",
179
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/33.tar\n",
180
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/34.tar\n",
181
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/35.tar\n",
182
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/train/36.tar\n",
183
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj04/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj04/test/0.tar\n",
184
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/0.tar\n",
185
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/1.tar\n",
186
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/2.tar\n",
187
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/3.tar\n",
188
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/4.tar\n",
189
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/5.tar\n",
190
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/6.tar\n",
191
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/7.tar\n",
192
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/8.tar\n",
193
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/9.tar\n",
194
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/10.tar\n",
195
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/11.tar\n",
196
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/12.tar\n",
197
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/13.tar\n",
198
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/14.tar\n",
199
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/15.tar\n",
200
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/16.tar\n",
201
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/17.tar\n",
202
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/18.tar\n",
203
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/19.tar\n",
204
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/20.tar\n",
205
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/21.tar\n",
206
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/22.tar\n",
207
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/23.tar\n",
208
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/24.tar\n",
209
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/25.tar\n",
210
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/26.tar\n",
211
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/27.tar\n",
212
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/28.tar\n",
213
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/29.tar\n",
214
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/30.tar\n",
215
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/31.tar\n",
216
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/32.tar\n",
217
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/33.tar\n",
218
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/34.tar\n",
219
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/35.tar\n",
220
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/train/36.tar\n",
221
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj05/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj05/test/0.tar\n",
222
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/0.tar\n",
223
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/1.tar\n",
224
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/2.tar\n",
225
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/3.tar\n",
226
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/4.tar\n",
227
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/5.tar\n",
228
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/6.tar\n",
229
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/7.tar\n",
230
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/8.tar\n",
231
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/9.tar\n",
232
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/10.tar\n",
233
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/11.tar\n",
234
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/12.tar\n",
235
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/13.tar\n",
236
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/14.tar\n",
237
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/15.tar\n",
238
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/16.tar\n",
239
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/17.tar\n",
240
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/18.tar\n",
241
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/19.tar\n",
242
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/20.tar\n",
243
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/21.tar\n",
244
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/22.tar\n",
245
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/23.tar\n",
246
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/24.tar\n",
247
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/25.tar\n",
248
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/26.tar\n",
249
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/27.tar\n",
250
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/28.tar\n",
251
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/29.tar\n",
252
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/30.tar\n",
253
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/31.tar\n",
254
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/32.tar\n",
255
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/33.tar\n",
256
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/34.tar\n",
257
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/35.tar\n",
258
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/train/36.tar\n",
259
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj06/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj06/test/0.tar\n",
260
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/0.tar\n",
261
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/1.tar\n",
262
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/2.tar\n",
263
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/3.tar\n",
264
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/4.tar\n",
265
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/5.tar\n",
266
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/6.tar\n",
267
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/7.tar\n",
268
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/8.tar\n",
269
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/9.tar\n",
270
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/10.tar\n",
271
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/11.tar\n",
272
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/12.tar\n",
273
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/13.tar\n",
274
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/14.tar\n",
275
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/15.tar\n",
276
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/16.tar\n",
277
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/17.tar\n",
278
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/18.tar\n",
279
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/19.tar\n",
280
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/20.tar\n",
281
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/21.tar\n",
282
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/22.tar\n",
283
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/23.tar\n",
284
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/24.tar\n",
285
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/25.tar\n",
286
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/26.tar\n",
287
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/27.tar\n",
288
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/28.tar\n",
289
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/29.tar\n",
290
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/30.tar\n",
291
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/31.tar\n",
292
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/32.tar\n",
293
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/33.tar\n",
294
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/34.tar\n",
295
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/35.tar\n",
296
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/train/36.tar\n",
297
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj07/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj07/test/0.tar\n",
298
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/0.tar\n",
299
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/1.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/1.tar\n",
300
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/2.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/2.tar\n",
301
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/3.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/3.tar\n",
302
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/4.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/4.tar\n",
303
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/5.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/5.tar\n",
304
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/6.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/6.tar\n",
305
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/7.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/7.tar\n",
306
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/8.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/8.tar\n",
307
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/9.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/9.tar\n",
308
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/10.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/10.tar\n",
309
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/11.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/11.tar\n",
310
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/12.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/12.tar\n",
311
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/13.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/13.tar\n",
312
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/14.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/14.tar\n",
313
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/15.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/15.tar\n",
314
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/16.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/16.tar\n",
315
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/17.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/17.tar\n",
316
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/18.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/18.tar\n",
317
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/19.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/19.tar\n",
318
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/20.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/20.tar\n",
319
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/21.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/21.tar\n",
320
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/22.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/22.tar\n",
321
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/23.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/23.tar\n",
322
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/24.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/24.tar\n",
323
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/25.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/25.tar\n",
324
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/26.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/26.tar\n",
325
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/27.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/27.tar\n",
326
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/28.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/28.tar\n",
327
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/29.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/29.tar\n",
328
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/30.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/30.tar\n",
329
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/31.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/31.tar\n",
330
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/32.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/32.tar\n",
331
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/33.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/33.tar\n",
332
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/34.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/34.tar\n",
333
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/35.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/35.tar\n",
334
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/train/36.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/train/36.tar\n",
335
+ "wget --show-progress https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/wds/subj08/test/0.tar -O /fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj08/test/0.tar\n",
336
+ "aws s3 sync /scratch/mindeyev2_dataset s3://proj-fmri/mindeyev2_dataset --region us-west-2\n"
337
+ ]
338
+ }
339
+ ],
340
+ "source": [
341
+ "import os\n",
342
+ "# from subprocess import call\n",
343
+ "# PS Note: it's faster to print the wget statements and then manually copy paste all them into terminal than to use subprocess call()\n",
344
+ "tmp = '/fsx/proj-fmri/shared/mindeyev2_dataset/' #'/scratch/mindeyev2_dataset/'\n",
345
+ "\n",
346
+ "hf_base_link = 'https://huggingface.co/datasets/pscotti/mindeyev2/resolve/main/'\n",
347
+ "\n",
348
+ "os.makedirs(tmp,exist_ok=True)\n",
349
+ "\n",
350
+ "files = [\n",
351
+ " \"shared1000.npy\",\n",
352
+ " \"betas_all_subj01.hdf5\",\n",
353
+ " \"betas_all_subj02.hdf5\",\n",
354
+ " \"betas_all_subj03.hdf5\",\n",
355
+ " \"betas_all_subj04.hdf5\",\n",
356
+ " \"betas_all_subj05.hdf5\",\n",
357
+ " \"betas_all_subj06.hdf5\",\n",
358
+ " \"betas_all_subj07.hdf5\",\n",
359
+ " \"betas_all_subj08.hdf5\",\n",
360
+ " \"coco_images_224_float16.hdf5\",\n",
361
+ " \"COCO_73k_subj_indices.hdf5\",\n",
362
+ "]\n",
363
+ "\n",
364
+ "for f in files: \n",
365
+ " command = f\"wget --show-progress {hf_base_link}{f} -O {tmp}{f}\"\n",
366
+ " print(command)\n",
367
+ " # call(command,shell=True)\n",
368
+ "\n",
369
+ "for sub in range(1,9):\n",
370
+ " subject = f'subj0{sub}'\n",
371
+ "\n",
372
+ " tmp_fol = f'{tmp}wds/{subject}/'\n",
373
+ " os.makedirs(tmp_fol,exist_ok=True)\n",
374
+ " os.makedirs(tmp_fol+'train',exist_ok=True)\n",
375
+ " os.makedirs(tmp_fol+'test',exist_ok=True)\n",
376
+ "\n",
377
+ " for i in range(37):\n",
378
+ " link = f'train/{i}.tar'\n",
379
+ " command = f\"wget --show-progress {hf_base_link}wds/{subject}/{link} -O {tmp}wds/{subject}/{link}\"\n",
380
+ " print(command)\n",
381
+ " # call(command,shell=True)\n",
382
+ "\n",
383
+ " link = f'test/0.tar'\n",
384
+ " command = f\"wget --show-progress {hf_base_link}wds/{subject}/{link} -O {tmp}wds/{subject}/{link}\"\n",
385
+ " print(command)\n",
386
+ " # call(command,shell=True)\n",
387
+ "\n",
388
+ "command = \"aws s3 sync /scratch/mindeyev2_dataset s3://proj-fmri/mindeyev2_dataset --region us-west-2\"\n",
389
+ "print(command)"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "id": "30966082-59c2-411c-9b2e-4f4e3f9eb0f3",
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": []
399
+ }
400
+ ],
401
+ "metadata": {
402
+ "kernelspec": {
403
+ "display_name": "Python 3 (ipykernel)",
404
+ "language": "python",
405
+ "name": "python3"
406
+ },
407
+ "language_info": {
408
+ "codemirror_mode": {
409
+ "name": "ipython",
410
+ "version": 3
411
+ },
412
+ "file_extension": ".py",
413
+ "mimetype": "text/x-python",
414
+ "name": "python",
415
+ "nbconvert_exporter": "python",
416
+ "pygments_lexer": "ipython3",
417
+ "version": "3.10.8"
418
+ }
419
+ },
420
+ "nbformat": 4,
421
+ "nbformat_minor": 5
422
+ }
src/models.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from torchvision import transforms
4
+ import torch
5
+ import torch.nn as nn
6
+ import PIL
7
+ import clip
8
+ import open_clip
9
+ from functools import partial
10
+ import random
11
+ import json
12
+
13
+ # class BrainMLP(nn.Module):
14
+ # def __init__(self, out_dim=257*768, in_dim=15724, clip_size=768, h=4096):
15
+ # super().__init__()
16
+ # self.lin0 = nn.Sequential(
17
+ # nn.Linear(in_dim, h, bias=False),
18
+ # nn.LayerNorm(h),
19
+ # nn.GELU(inplace=True),
20
+ # nn.Dropout(0.5))
21
+ # self.mlp = nn.ModuleList([
22
+ # nn.Sequential(
23
+ # nn.Linear(h, h),
24
+ # nn.LayerNorm(h),
25
+ # nn.GELU(inplace=True),
26
+ # nn.Dropout(0.15)
27
+ # ) for _ in range(4)])
28
+ # self.lin1 = nn.Linear(h, out_dim, bias=True)
29
+ # self.proj = nn.Sequential(
30
+ # nn.LayerNorm(clip_size),
31
+ # nn.GELU(),
32
+ # nn.Linear(clip_size, 2048),
33
+ # nn.LayerNorm(2048),
34
+ # nn.GELU(),
35
+ # nn.Linear(2048, 2048),
36
+ # nn.LayerNorm(2048),
37
+ # nn.GELU(),
38
+ # nn.Linear(2048, clip_size))
39
+ # def forward(self, x):
40
+ # x = self.lin0(x)
41
+ # residual = x
42
+ # for res_block in range(self.n_blocks):
43
+ # x = self.mlp[res_block](x)
44
+ # x += residual
45
+ # residual = x
46
+ # diffusion_prior_input = self.lin1(x.reshape(len(x), -1))
47
+ # disjointed_clip_fmri = self.proj(diffusion_prior_input.reshape(
48
+ # len(x),-1, self.clip_size))
49
+ # return diffusion_prior_input, disjointed_clip_fmri
50
+
51
+
52
+
53
+ class Clipper(torch.nn.Module):
54
+ def __init__(self, clip_variant, clamp_embs=False, norm_embs=False,
55
+ hidden_state=False, device=torch.device('cpu')):
56
+ super().__init__()
57
+ assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \
58
+ "clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64"
59
+ print(clip_variant, device)
60
+
61
+ if clip_variant=="ViT-L/14" and hidden_state:
62
+ # from transformers import CLIPVisionModelWithProjection
63
+ # image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14",cache_dir="/fsx/proj-medarc/fmri/cache")
64
+ from transformers import CLIPVisionModelWithProjection
65
+ sd_cache_dir = '/fsx/proj-fmri/shared/cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7'
66
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_cache_dir, subfolder='image_encoder').eval()
67
+ image_encoder = image_encoder.to(device)
68
+ for param in image_encoder.parameters():
69
+ param.requires_grad = False # dont need to calculate gradients
70
+ self.image_encoder = image_encoder
71
+ elif hidden_state:
72
+ raise Exception("hidden_state embeddings only works with ViT-L/14 right now")
73
+
74
+ clip_model, preprocess = clip.load(clip_variant, device=device)
75
+ clip_model.eval() # dont want to train model
76
+ for param in clip_model.parameters():
77
+ param.requires_grad = False # dont need to calculate gradients
78
+
79
+ self.clip = clip_model
80
+ self.clip_variant = clip_variant
81
+ if clip_variant == "RN50x64":
82
+ self.clip_size = (448,448)
83
+ else:
84
+ self.clip_size = (224,224)
85
+
86
+ preproc = transforms.Compose([
87
+ transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC),
88
+ transforms.CenterCrop(size=self.clip_size),
89
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
90
+ ])
91
+ self.preprocess = preproc
92
+ self.hidden_state = hidden_state
93
+ self.mean = np.array([0.48145466, 0.4578275, 0.40821073])
94
+ self.std = np.array([0.26862954, 0.26130258, 0.27577711])
95
+ self.normalize = transforms.Normalize(self.mean, self.std)
96
+ self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist())
97
+ self.clamp_embs = clamp_embs
98
+ self.norm_embs = norm_embs
99
+ self.device= device
100
+
101
+ def versatile_normalize_embeddings(encoder_output):
102
+ embeds = encoder_output.last_hidden_state
103
+ embeds = image_encoder.vision_model.post_layernorm(embeds)
104
+ embeds = image_encoder.visual_projection(embeds)
105
+ return embeds
106
+ self.versatile_normalize_embeddings = versatile_normalize_embeddings
107
+
108
+ def resize_image(self, image):
109
+ # note: antialias should be False if planning to use Pinkney's Image Variation SD model
110
+ return transforms.Resize(self.clip_size)(image.to(self.device))
111
+
112
+ def embed_image(self, image):
113
+ """Expects images in -1 to 1 range"""
114
+ if self.hidden_state:
115
+ # clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation
116
+ clip_emb = self.preprocess((image).to(self.device))
117
+ clip_emb = self.image_encoder(clip_emb)
118
+ clip_emb = self.versatile_normalize_embeddings(clip_emb)
119
+ else:
120
+ clip_emb = self.preprocess(image.to(self.device))
121
+ clip_emb = self.clip.encode_image(clip_emb)
122
+ # input is now in CLIP space, but mind-reader preprint further processes embeddings:
123
+ if self.clamp_embs:
124
+ clip_emb = torch.clamp(clip_emb, -1.5, 1.5)
125
+ if self.norm_embs:
126
+ if self.hidden_state:
127
+ # normalize all tokens by cls token's norm
128
+ clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1)
129
+ else:
130
+ clip_emb = nn.functional.normalize(clip_emb, dim=-1)
131
+ return clip_emb
132
+
133
+ def embed_text(self, text_samples):
134
+ clip_text = clip.tokenize(text_samples).to(self.device)
135
+ clip_text = self.clip.encode_text(clip_text)
136
+ if self.clamp_embs:
137
+ clip_text = torch.clamp(clip_text, -1.5, 1.5)
138
+ if self.norm_embs:
139
+ clip_text = nn.functional.normalize(clip_text, dim=-1)
140
+ return clip_text
141
+
142
+ def embed_curated_annotations(self, annots):
143
+ for i,b in enumerate(annots):
144
+ t = ''
145
+ while t == '':
146
+ rand = torch.randint(5,(1,1))[0][0]
147
+ t = b[0,rand]
148
+ if i==0:
149
+ txt = np.array(t)
150
+ else:
151
+ txt = np.vstack((txt,t))
152
+ txt = txt.flatten()
153
+ return self.embed_text(txt)
154
+
155
+ class OpenClipper(torch.nn.Module):
156
+ def __init__(self, clip_variant, norm_embs=False, device=torch.device('cpu')):
157
+ super().__init__()
158
+ print(clip_variant, device)
159
+ assert clip_variant == 'ViT-H-14' # not setup for other models yet
160
+
161
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14',
162
+ pretrained='laion2b_s32b_b79k', device=device)
163
+ clip_model.eval() # dont want to train model
164
+ for param in clip_model.parameters():
165
+ param.requires_grad = False # dont need to calculate gradients
166
+
167
+ # overwrite preprocess to accept torch inputs instead of PIL Image
168
+ preprocess = transforms.Compose([
169
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC, antialias=None),
170
+ transforms.CenterCrop(224),
171
+ transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
172
+ ])
173
+
174
+ tokenizer = open_clip.get_tokenizer('ViT-H-14')
175
+
176
+ self.clip = clip_model
177
+ self.norm_embs = norm_embs
178
+ self.preprocess = preprocess
179
+ self.tokenizer = tokenizer
180
+ self.device = device
181
+
182
+ def embed_image(self, image):
183
+ """Expects images in -1 to 1 range"""
184
+ image = self.preprocess(image).to(self.device)
185
+ with torch.no_grad(), torch.cuda.amp.autocast():
186
+ image_features = self.clip.encode_image(image)
187
+ if self.norm_embs:
188
+ image_features = nn.functional.normalize(image_features, dim=-1)
189
+ return image_features
190
+
191
+ def embed_text(self, text_samples):
192
+ text = self.tokenizer(text_samples).to(self.device)
193
+ with torch.no_grad(), torch.cuda.amp.autocast():
194
+ text_features = self.clip.encode_text(text)
195
+ if self.norm_embs:
196
+ text_features = nn.functional.normalize(text_features, dim=-1)
197
+ return text_features
198
+
199
+ def embed_curated_annotations(self, annots):
200
+ for i,b in enumerate(annots):
201
+ t = ''
202
+ while t == '':
203
+ rand = torch.randint(5,(1,1))[0][0]
204
+ t = b[0,rand]
205
+ if i==0:
206
+ txt = np.array(t)
207
+ else:
208
+ txt = np.vstack((txt,t))
209
+ txt = txt.flatten()
210
+ return self.embed_text(txt)
src/setup.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Commands to setup a new conda environment and install all the necessary packages
3
+ # See the environment.yaml file for "conda env export > environment.yaml" after running this.
4
+
5
+ set -e
6
+
7
+ conda create -n fmri python=3.10.8 -y
8
+ conda activate fmri
9
+
10
+ conda install numpy matplotlib tqdm scikit-image jupyterlab -y
11
+
12
+ pip install accelerate webdataset clip pandas matplotlib ftfy regex kornia umap-learn h5py
13
+ pip install torchvision==0.15.2 torch==2.0.1
14
+ pip install diffusers
15
+ pip install deepspeed
src/train2-tryal.ipynb ADDED
@@ -0,0 +1,2409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n",
13
+ "# from subprocess import call\n",
14
+ "# command = \"jupyter nbconvert Train.ipynb --to python\"\n",
15
+ "# call(command,shell=True)"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "id": "b0f0f4f3",
21
+ "metadata": {},
22
+ "source": [
23
+ "# Import packages & functions"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "id": "5bad764b-45c1-45ce-a716-8d055e09821a",
30
+ "metadata": {
31
+ "tags": []
32
+ },
33
+ "outputs": [
34
+ {
35
+ "name": "stderr",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
39
+ " from .autonotebook import tqdm as notebook_tqdm\n"
40
+ ]
41
+ },
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "[2023-11-19 16:32:39,711] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "import os\n",
52
+ "import sys\n",
53
+ "import json\n",
54
+ "import argparse\n",
55
+ "import numpy as np\n",
56
+ "import math\n",
57
+ "from einops import rearrange\n",
58
+ "import time\n",
59
+ "import random\n",
60
+ "import h5py\n",
61
+ "from tqdm import tqdm\n",
62
+ "\n",
63
+ "import webdataset as wds\n",
64
+ "import gc\n",
65
+ "\n",
66
+ "import matplotlib.pyplot as plt\n",
67
+ "import torch\n",
68
+ "import torch.nn as nn\n",
69
+ "from torchvision import transforms\n",
70
+ "from torchvision.transforms import ToPILImage #CHANGED (added)\n",
71
+ "\n",
72
+ "from accelerate import Accelerator, DeepSpeedPlugin\n",
73
+ "\n",
74
+ "# tf32 data type is faster than standard float32\n",
75
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
76
+ "\n",
77
+ "# custom functions #\n",
78
+ "import utils\n",
79
+ "\n",
80
+ "global_batch_size = 128 #128"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 3,
86
+ "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb",
87
+ "metadata": {
88
+ "tags": []
89
+ },
90
+ "outputs": [
91
+ {
92
+ "name": "stdout",
93
+ "output_type": "stream",
94
+ "text": [
95
+ "LOCAL RANK 0\n"
96
+ ]
97
+ }
98
+ ],
99
+ "source": [
100
+ "### Multi-GPU config ###\n",
101
+ "local_rank = os.getenv('RANK')\n",
102
+ "if local_rank is None: \n",
103
+ " local_rank = 0\n",
104
+ "else:\n",
105
+ " local_rank = int(local_rank)\n",
106
+ "print(\"LOCAL RANK \", local_rank) \n",
107
+ "\n",
108
+ "num_devices = torch.cuda.device_count()\n",
109
+ "if num_devices==0: num_devices = 1\n",
110
+ "\n",
111
+ "accelerator = Accelerator(split_batches=False)\n",
112
+ "\n",
113
+ "### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above \"accelerator = \" line) ###\n",
114
+ "\n",
115
+ "# if num_devices <= 1 and utils.is_interactive():\n",
116
+ "# # can emulate a distributed environment for deepspeed to work in jupyter notebook\n",
117
+ "# os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
118
+ "# os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n",
119
+ "# os.environ[\"RANK\"] = \"0\"\n",
120
+ "# os.environ[\"LOCAL_RANK\"] = \"0\"\n",
121
+ "# os.environ[\"WORLD_SIZE\"] = \"1\"\n",
122
+ "# os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n",
123
+ "# global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n",
124
+ "\n",
125
+ "# # alter the deepspeed config according to your global and local batch size\n",
126
+ "# if local_rank == 0:\n",
127
+ "# with open('deepspeed_config_stage2.json', 'r') as file:\n",
128
+ "# config = json.load(file)\n",
129
+ "# config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n",
130
+ "# config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n",
131
+ "# with open('deepspeed_config_stage2.json', 'w') as file:\n",
132
+ "# json.dump(config, file)\n",
133
+ "# else:\n",
134
+ "# # give some time for the local_rank=0 gpu to prep new deepspeed config file\n",
135
+ "# time.sleep(10)\n",
136
+ "# deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n",
137
+ "# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 4,
143
+ "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c",
144
+ "metadata": {
145
+ "tags": []
146
+ },
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "PID of this process = 2370606\n",
153
+ "device: cuda\n",
154
+ "Distributed environment: NO\n",
155
+ "Num processes: 1\n",
156
+ "Process index: 0\n",
157
+ "Local process index: 0\n",
158
+ "Device: cuda\n",
159
+ "\n",
160
+ "Mixed precision type: no\n",
161
+ "\n",
162
+ "distributed = False num_devices = 1 local rank = 0 world size = 1\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "print(\"PID of this process =\",os.getpid())\n",
168
+ "device = accelerator.device\n",
169
+ "print(\"device:\",device)\n",
170
+ "num_workers = num_devices\n",
171
+ "print(accelerator.state)\n",
172
+ "world_size = accelerator.state.num_processes\n",
173
+ "distributed = not accelerator.state.distributed_type == 'NO'\n",
174
+ "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n",
175
+ "print = accelerator.print # only print if local_rank=0"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "id": "9018b82b-c054-4463-9527-4b0c2a75bda6",
181
+ "metadata": {
182
+ "tags": []
183
+ },
184
+ "source": [
185
+ "# Configurations"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 5,
191
+ "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3",
192
+ "metadata": {
193
+ "tags": []
194
+ },
195
+ "outputs": [
196
+ {
197
+ "name": "stdout",
198
+ "output_type": "stream",
199
+ "text": [
200
+ "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=captions', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=3e-1', '--mixup_pct=.66', '--num_epochs=30', '--ckpt_interval=999', '--no-use_image_aug']\n"
201
+ ]
202
+ }
203
+ ],
204
+ "source": [
205
+ "# if running this interactively, can specify jupyter_args here for argparser to use\n",
206
+ "if utils.is_interactive():\n",
207
+ " # Example use\n",
208
+ " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n",
209
+ " --model_name=captions \\\n",
210
+ " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n",
211
+ " --max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug\"\n",
212
+ " #max_lr=3e-5 originally\n",
213
+ " jupyter_args = jupyter_args.split()\n",
214
+ " print(jupyter_args)\n",
215
+ " \n",
216
+ " from IPython.display import clear_output # function to clear print outputs in cell\n",
217
+ " %load_ext autoreload \n",
218
+ " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n",
219
+ " %autoreload 2 "
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 6,
225
+ "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c",
226
+ "metadata": {
227
+ "tags": []
228
+ },
229
+ "outputs": [
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "global batch_size 128\n",
235
+ "batch_size 128\n"
236
+ ]
237
+ }
238
+ ],
239
+ "source": [
240
+ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n",
241
+ "parser.add_argument(\n",
242
+ " \"--model_name\", type=str, default=\"testing\",\n",
243
+ " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n",
244
+ ")\n",
245
+ "parser.add_argument(\n",
246
+ " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n",
247
+ " help=\"Path to where NSD data is stored / where to download it to\",\n",
248
+ ")\n",
249
+ "parser.add_argument(\n",
250
+ " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n",
251
+ ")\n",
252
+ "parser.add_argument(\n",
253
+ " \"--batch_size\", type=int, default=32,\n",
254
+ " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n",
255
+ ")\n",
256
+ "parser.add_argument(\n",
257
+ " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n",
258
+ " help=\"whether to log to wandb\",\n",
259
+ ")\n",
260
+ "parser.add_argument(\n",
261
+ " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n",
262
+ " help=\"if not using wandb and want to resume from a ckpt\",\n",
263
+ ")\n",
264
+ "parser.add_argument(\n",
265
+ " \"--wandb_project\",type=str,default=\"stability\",\n",
266
+ " help=\"wandb project name\",\n",
267
+ ")\n",
268
+ "parser.add_argument(\n",
269
+ " \"--mixup_pct\",type=float,default=.33,\n",
270
+ " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n",
271
+ ")\n",
272
+ "parser.add_argument(\n",
273
+ " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n",
274
+ " help=\"whether to use image augmentation\",\n",
275
+ ")\n",
276
+ "parser.add_argument(\n",
277
+ " \"--num_epochs\",type=int,default=240,\n",
278
+ " help=\"number of epochs of training\",\n",
279
+ ")\n",
280
+ "parser.add_argument(\n",
281
+ " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n",
282
+ ")\n",
283
+ "parser.add_argument(\n",
284
+ " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n",
285
+ ")\n",
286
+ "parser.add_argument(\n",
287
+ " \"--ckpt_interval\",type=int,default=5,\n",
288
+ " help=\"save backup ckpt and reconstruct every x epochs\",\n",
289
+ ")\n",
290
+ "parser.add_argument(\n",
291
+ " \"--seed\",type=int,default=42,\n",
292
+ ")\n",
293
+ "parser.add_argument(\n",
294
+ " \"--max_lr\",type=float,default=3e-4,\n",
295
+ ")\n",
296
+ "parser.add_argument(\n",
297
+ " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n",
298
+ " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n",
299
+ ")\n",
300
+ "\n",
301
+ "if utils.is_interactive():\n",
302
+ " args = parser.parse_args(jupyter_args)\n",
303
+ "else:\n",
304
+ " args = parser.parse_args()\n",
305
+ "\n",
306
+ "# create global variables without the args prefix\n",
307
+ "for attribute_name in vars(args).keys():\n",
308
+ " globals()[attribute_name] = getattr(args, attribute_name)\n",
309
+ "\n",
310
+ "print(\"global batch_size\", batch_size)\n",
311
+ "batch_size = int(batch_size / num_devices)\n",
312
+ "print(\"batch_size\", batch_size)"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 7,
318
+ "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d",
319
+ "metadata": {
320
+ "tags": []
321
+ },
322
+ "outputs": [],
323
+ "source": [
324
+ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n",
325
+ "if not os.path.exists(outdir):\n",
326
+ " os.makedirs(outdir,exist_ok=True)\n",
327
+ "if use_image_aug:\n",
328
+ " import kornia\n",
329
+ " from kornia.augmentation.container import AugmentationSequential\n",
330
+ " img_augment = AugmentationSequential(\n",
331
+ " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n",
332
+ " kornia.augmentation.Resize((224, 224)),\n",
333
+ " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n",
334
+ " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n",
335
+ " kornia.augmentation.RandomGrayscale(p=0.3),\n",
336
+ " same_on_batch=False,\n",
337
+ " data_keys=[\"input\"],\n",
338
+ " )"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 8,
344
+ "id": "e7807ba9-02b6-4bc0-873c-69869abe4091",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "wandb_log = False"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "markdown",
353
+ "id": "42d13c25-1369-4c49-81d4-83d713586096",
354
+ "metadata": {
355
+ "tags": []
356
+ },
357
+ "source": [
358
+ "# Prep data, models, and dataloaders"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "id": "1c023f24-5233-4a15-a2f5-78487b3a8546",
364
+ "metadata": {},
365
+ "source": [
366
+ "## Dataloader"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": 9,
372
+ "id": "81084834-035f-4465-ad59-59e6b806a2f5",
373
+ "metadata": {
374
+ "tags": []
375
+ },
376
+ "outputs": [
377
+ {
378
+ "name": "stdout",
379
+ "output_type": "stream",
380
+ "text": [
381
+ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n",
382
+ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n"
383
+ ]
384
+ }
385
+ ],
386
+ "source": [
387
+ "if subj==1:\n",
388
+ " num_train = 24958\n",
389
+ " num_test = 2770\n",
390
+ "test_batch_size = num_test\n",
391
+ "\n",
392
+ "def my_split_by_node(urls): return urls\n",
393
+ " \n",
394
+ "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n",
395
+ "print(train_url)\n",
396
+ "\n",
397
+ "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n",
398
+ " .shuffle(750, initial=1500, rng=random.Random(42))\\\n",
399
+ " .decode(\"torch\")\\\n",
400
+ " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n",
401
+ " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n",
402
+ "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n",
403
+ "\n",
404
+ "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n",
405
+ "print(test_url)\n",
406
+ "\n",
407
+ "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n",
408
+ " .shuffle(750, initial=1500, rng=random.Random(42))\\\n",
409
+ " .decode(\"torch\")\\\n",
410
+ " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n",
411
+ " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n",
412
+ "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "markdown",
417
+ "id": "203b060a-2dd2-4c35-929b-c576be82eb52",
418
+ "metadata": {},
419
+ "source": [
420
+ "### check dataloaders are working"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 10,
426
+ "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37",
427
+ "metadata": {
428
+ "tags": []
429
+ },
430
+ "outputs": [],
431
+ "source": [
432
+ "# test_indices = []\n",
433
+ "# test_images = []\n",
434
+ "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n",
435
+ "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n",
436
+ "# test_images = np.append(test_images, behav[:,0,0].numpy())\n",
437
+ "# test_indices = test_indices.astype(np.int16)\n",
438
+ "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n",
439
+ "# print(\"---\\n\")\n",
440
+ "\n",
441
+ "# train_indices = []\n",
442
+ "# train_images = []\n",
443
+ "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n",
444
+ "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n",
445
+ "# train_images = np.append(train_images, behav[:,0,0].numpy())\n",
446
+ "# train_indices = train_indices.astype(np.int16)\n",
447
+ "# print(train_i, (train_i+1) * batch_size, len(train_indices))\n",
448
+ "\n",
449
+ "# # train_images = np.hstack((train_images, test_images))\n",
450
+ "# # print(\"WARNING: ADDED TEST IMAGES TO TRAIN IMAGES\")"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634",
456
+ "metadata": {},
457
+ "source": [
458
+ "## Load data and images"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 11,
464
+ "id": "039dd330-7339-4f88-8f00-45f95e47baa0",
465
+ "metadata": {
466
+ "tags": []
467
+ },
468
+ "outputs": [
469
+ {
470
+ "name": "stdout",
471
+ "output_type": "stream",
472
+ "text": [
473
+ "subj01 betas loaded into memory\n",
474
+ "voxels torch.Size([27750, 15729])\n",
475
+ "images torch.Size([73000, 3, 224, 224])\n"
476
+ ]
477
+ }
478
+ ],
479
+ "source": [
480
+ "# load betas\n",
481
+ "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n",
482
+ "voxels = f['betas'][:]\n",
483
+ "print(f\"subj0{subj} betas loaded into memory\")\n",
484
+ "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n",
485
+ "if subj==1:\n",
486
+ " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n",
487
+ "print(\"voxels\", voxels.shape)\n",
488
+ "num_voxels = voxels.shape[-1]\n",
489
+ "\n",
490
+ "# load orig images\n",
491
+ "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n",
492
+ "images = f['images'][:]\n",
493
+ "images = torch.Tensor(images).to(\"cpu\").half()\n",
494
+ "print(\"images\", images.shape)"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "markdown",
499
+ "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15",
500
+ "metadata": {},
501
+ "source": [
502
+ "## Load models"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5",
508
+ "metadata": {},
509
+ "source": [
510
+ "### CLIP image embeddings model"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 12,
516
+ "id": "795e2885-bd07-4e27-bed7-181473c06df9",
517
+ "metadata": {
518
+ "tags": []
519
+ },
520
+ "outputs": [],
521
+ "source": [
522
+ "import transformers\n",
523
+ "from transformers import Blip2Processor, Blip2ForConditionalGeneration\n",
524
+ "\n",
525
+ "from PIL import Image"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 13,
531
+ "id": "b0420dc0-199e-4c1a-857d-b1747058b467",
532
+ "metadata": {
533
+ "tags": []
534
+ },
535
+ "outputs": [
536
+ {
537
+ "name": "stdout",
538
+ "output_type": "stream",
539
+ "text": [
540
+ "ViT-L/14 cuda:0\n"
541
+ ]
542
+ }
543
+ ],
544
+ "source": [
545
+ "from models import Clipper\n",
546
+ "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": 14,
552
+ "id": "23428fb7-2955-4295-bea1-447cebf9f72e",
553
+ "metadata": {
554
+ "tags": []
555
+ },
556
+ "outputs": [
557
+ {
558
+ "name": "stderr",
559
+ "output_type": "stream",
560
+ "text": [
561
+ "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:08<00:00, 34.47s/it]\n"
562
+ ]
563
+ },
564
+ {
565
+ "data": {
566
+ "text/plain": [
567
+ "'from lavis.models import load_model_and_preprocess\\nfrom lavis.models import model_zoo\\nblip2_model, vis_processors, _ = load_model_and_preprocess(\\n name=\"blip2_t5\", model_type=\"pretrain_flant5xl_vitL\", is_eval=True, device=device)\\n\\nclip_seq_dim = 257\\nclip_emb_dim = 1024\\nhidden_dim = 4096'"
568
+ ]
569
+ },
570
+ "execution_count": 14,
571
+ "metadata": {},
572
+ "output_type": "execute_result"
573
+ }
574
+ ],
575
+ "source": [
576
+ "cache_blip2 = \"/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b\"\n",
577
+ "\n",
578
+ "b2_processor = Blip2Processor.from_pretrained(cache_blip2)\n",
579
+ "b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map=\"auto\")\n",
580
+ "\n",
581
+ "#Load in blip2 as well\n",
582
+ "\"\"\"from lavis.models import load_model_and_preprocess\n",
583
+ "from lavis.models import model_zoo\n",
584
+ "blip2_model, vis_processors, _ = load_model_and_preprocess(\n",
585
+ " name=\"blip2_t5\", model_type=\"pretrain_flant5xl_vitL\", is_eval=True, device=device)\n",
586
+ "\n",
587
+ "clip_seq_dim = 257\n",
588
+ "clip_emb_dim = 1024\n",
589
+ "hidden_dim = 4096\"\"\""
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": 74,
595
+ "id": "b06f3de2-a8da-4ba0-94f0-99096f738d55",
596
+ "metadata": {
597
+ "tags": []
598
+ },
599
+ "outputs": [],
600
+ "source": [
601
+ "def embed_images_b2(images):\n",
602
+ " images = (images * 255).type(torch.uint8)\n",
603
+ " with torch.no_grad():\n",
604
+ " inputs_processed = b2_processor(images, return_tensors=\"pt\").to(\"cuda\", torch.float16)\n",
605
+ " enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values'])\n",
606
+ " return enc_imgs.last_hidden_state.detach(), inputs_processed\n",
607
+ "\n",
608
+ "def embeds_to_captions_b2(embeds, sample = False, temp = 0.9):\n",
609
+ " with torch.no_grad():\n",
610
+ " input_ids = None #inputs['input_ids']\n",
611
+ " attention_mask = None\n",
612
+ " batch_size = embeds.shape[0]\n",
613
+ " image_embeds = embeds\n",
614
+ " image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)\n",
615
+ "\n",
616
+ " query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)\n",
617
+ " query_outputs = b2_model.qformer(\n",
618
+ " query_embeds=query_tokens,\n",
619
+ " encoder_hidden_states=image_embeds,\n",
620
+ " encoder_attention_mask=image_attention_mask,\n",
621
+ " return_dict=True,\n",
622
+ " )\n",
623
+ " query_output = query_outputs.last_hidden_state\n",
624
+ "\n",
625
+ " language_model_inputs = b2_model.language_projection(query_output)\n",
626
+ " language_attention_mask = torch.ones(\n",
627
+ " language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device\n",
628
+ " )\n",
629
+ " if input_ids is None:\n",
630
+ " input_ids = (\n",
631
+ " torch.LongTensor([[b2_model.config.text_config.bos_token_id]])\n",
632
+ " .repeat(batch_size, 1)\n",
633
+ " .to(image_embeds.device)\n",
634
+ " )\n",
635
+ " if attention_mask is None:\n",
636
+ " attention_mask = torch.ones_like(input_ids)\n",
637
+ " attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)\n",
638
+ "\n",
639
+ " # concatenate query embeddings with prompt embeddings\n",
640
+ " inputs_embeds = b2_model.get_input_embeddings()(input_ids)\n",
641
+ " inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)\n",
642
+ "\n",
643
+ " outputs = b2_model.language_model.generate(\n",
644
+ " inputs_embeds=inputs_embeds,\n",
645
+ " attention_mask=attention_mask,\n",
646
+ " temperature=temp,\n",
647
+ " do_sample = sample\n",
648
+ " )\n",
649
+ " text = b2_processor.batch_decode(outputs, skip_special_tokens=True)\n",
650
+ " \n",
651
+ " return outputs, text\n"
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": 73,
657
+ "id": "51b29638-2c81-4e9f-b06d-525fdbac44b1",
658
+ "metadata": {
659
+ "tags": []
660
+ },
661
+ "outputs": [
662
+ {
663
+ "data": {
664
+ "text/plain": [
665
+ "tensor([[ 2, 6209, 14, 10, 205, 425, 13, 10, 7297, 1280,\n",
666
+ " 9, 418, 116, 1437, 38, 10728, 33, 117, 1114, 99]],\n",
667
+ " device='cuda:0')"
668
+ ]
669
+ },
670
+ "execution_count": 73,
671
+ "metadata": {},
672
+ "output_type": "execute_result"
673
+ }
674
+ ],
675
+ "source": [
676
+ "b2_model.language_model.generate(do_sample = True, temperature=1)"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": 16,
682
+ "id": "ec0a34d3-76e0-4a47-a9ab-6131ab2ccecd",
683
+ "metadata": {
684
+ "tags": []
685
+ },
686
+ "outputs": [],
687
+ "source": [
688
+ "image_test = images[1:20].permute(0,2,3,1)\n",
689
+ "#raw_image = Image.open('/fsx/proj-fmri/shared/controlNetData/target/img_t1.jpg').convert('RGB')\n",
690
+ "# Convert the image to a NumPy array\n",
691
+ "#image_test = np.array(raw_image)\n"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": 17,
697
+ "id": "e04876a4-45c7-4015-8255-8574c8f50f14",
698
+ "metadata": {
699
+ "tags": []
700
+ },
701
+ "outputs": [
702
+ {
703
+ "data": {
704
+ "text/plain": [
705
+ "\"import matplotlib.pyplot as plt\\n# Plotting one of the images (taking the first image as an example)\\nimg_to_plot = inputs_rec['pixel_values'][-1]\\n\\n# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\\nimg_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\\nprint(img_to_plot.shape)\\n\\nplt.imshow(img_to_plot)\\nplt.show()\""
706
+ ]
707
+ },
708
+ "execution_count": 17,
709
+ "metadata": {},
710
+ "output_type": "execute_result"
711
+ }
712
+ ],
713
+ "source": [
714
+ "\"\"\"import matplotlib.pyplot as plt\n",
715
+ "# Plotting one of the images (taking the first image as an example)\n",
716
+ "img_to_plot = inputs_rec['pixel_values'][-1]\n",
717
+ "\n",
718
+ "# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\n",
719
+ "img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\n",
720
+ "print(img_to_plot.shape)\n",
721
+ "\n",
722
+ "plt.imshow(img_to_plot)\n",
723
+ "plt.show()\"\"\""
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": 18,
729
+ "id": "328a17d0-593b-4d1e-812a-10a3b6efea6a",
730
+ "metadata": {
731
+ "tags": []
732
+ },
733
+ "outputs": [],
734
+ "source": [
735
+ "embeds_test, inputs_rec = embed_images_b2(image_test)"
736
+ ]
737
+ },
738
+ {
739
+ "cell_type": "code",
740
+ "execution_count": 19,
741
+ "id": "abe5f8a8-fca9-4083-8596-a913bdb57de7",
742
+ "metadata": {
743
+ "tags": []
744
+ },
745
+ "outputs": [],
746
+ "source": [
747
+ "#inputs_rec['pixel_values'].shape"
748
+ ]
749
+ },
750
+ {
751
+ "cell_type": "code",
752
+ "execution_count": 20,
753
+ "id": "c5f3ca7e-b880-421e-b354-7b6c3df565e9",
754
+ "metadata": {
755
+ "tags": []
756
+ },
757
+ "outputs": [],
758
+ "source": [
759
+ "#out = b2_model.generate(**inputs_rec)\n",
760
+ "#print(b2_processor.decode(out[0], skip_special_tokens=True).strip())"
761
+ ]
762
+ },
763
+ {
764
+ "cell_type": "code",
765
+ "execution_count": 21,
766
+ "id": "fb462016-78d7-46ea-8058-0d608f17ea65",
767
+ "metadata": {
768
+ "tags": []
769
+ },
770
+ "outputs": [
771
+ {
772
+ "name": "stderr",
773
+ "output_type": "stream",
774
+ "text": [
775
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation/utils.py:1260: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
776
+ " warnings.warn(\n"
777
+ ]
778
+ }
779
+ ],
780
+ "source": [
781
+ "outputs_test, text_test = embeds_to_captions_b2(embeds_test)"
782
+ ]
783
+ },
784
+ {
785
+ "cell_type": "code",
786
+ "execution_count": 22,
787
+ "id": "6a95fcdf-db87-4c02-9728-09f85605fb1c",
788
+ "metadata": {
789
+ "tags": []
790
+ },
791
+ "outputs": [
792
+ {
793
+ "data": {
794
+ "text/plain": [
795
+ "['a cat sitting on a toilet seat\\n',\n",
796
+ " 'a person cutting a pizza on a cutting board\\n',\n",
797
+ " 'a sandwich and a drink on a table\\n',\n",
798
+ " 'a man crossing the street in front of a truck\\n',\n",
799
+ " 'a giraffe standing in front of trees\\n',\n",
800
+ " 'three men standing together\\n',\n",
801
+ " 'a bird standing on a rock next to a body of water\\n',\n",
802
+ " 'two men sitting on a street corner in asia\\n',\n",
803
+ " 'a woman and two children playing tennis on a court\\n',\n",
804
+ " 'a tall brick building with a clock on the side\\n',\n",
805
+ " 'a train is on the tracks\\n',\n",
806
+ " 'a man and woman in the water with a surfboard\\n',\n",
807
+ " 'a living room with a desk and a chair\\n',\n",
808
+ " 'a group of men on a basketball court\\n',\n",
809
+ " 'a man holding an umbrella\\n',\n",
810
+ " 'a man in a red shirt\\n',\n",
811
+ " 'a group of people holding cell phones and wine glasses\\n',\n",
812
+ " 'a laptop computer sitting on a table in front of a television\\n',\n",
813
+ " 'a baseball player is swinging a bat on a field\\n']"
814
+ ]
815
+ },
816
+ "execution_count": 22,
817
+ "metadata": {},
818
+ "output_type": "execute_result"
819
+ }
820
+ ],
821
+ "source": [
822
+ "text_test"
823
+ ]
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": 23,
828
+ "id": "9ac69fbd-55db-435b-bed6-5ae9186450e3",
829
+ "metadata": {
830
+ "tags": []
831
+ },
832
+ "outputs": [],
833
+ "source": [
834
+ "#inputss['pixel_values'].shape"
835
+ ]
836
+ },
837
+ {
838
+ "cell_type": "code",
839
+ "execution_count": 24,
840
+ "id": "0524f498-c8da-4e8a-8970-d75d2d0f6b8b",
841
+ "metadata": {
842
+ "tags": []
843
+ },
844
+ "outputs": [],
845
+ "source": [
846
+ "#image_test.shape"
847
+ ]
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": 25,
852
+ "id": "5417541b-49eb-4e43-a3e2-d937d9653e04",
853
+ "metadata": {
854
+ "tags": []
855
+ },
856
+ "outputs": [],
857
+ "source": [
858
+ "max_lr = 1e-4"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": 26,
864
+ "id": "da0ce190-1b3e-4c12-9e9f-91cbc076d044",
865
+ "metadata": {
866
+ "tags": []
867
+ },
868
+ "outputs": [],
869
+ "source": [
870
+ "clip_seq_dim = 257 #blip2 image encoder shapes\n",
871
+ "clip_emb_dim = 1408 #blip2 image encoder shapes\n",
872
+ "hidden_dim = 2048"
873
+ ]
874
+ },
875
+ {
876
+ "cell_type": "markdown",
877
+ "id": "5b79bd38-6990-4504-8d45-4a68d57d8885",
878
+ "metadata": {},
879
+ "source": [
880
+ "### SD VAE (blurry images)"
881
+ ]
882
+ },
883
+ {
884
+ "cell_type": "code",
885
+ "execution_count": 40,
886
+ "id": "01baff79-8114-482b-b115-6f05aa8ad691",
887
+ "metadata": {
888
+ "tags": []
889
+ },
890
+ "outputs": [
891
+ {
892
+ "name": "stdout",
893
+ "output_type": "stream",
894
+ "text": [
895
+ "param counts:\n",
896
+ "83,653,863 total\n",
897
+ "0 trainable\n"
898
+ ]
899
+ }
900
+ ],
901
+ "source": [
902
+ "from diffusers import AutoencoderKL\n",
903
+ "autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n",
904
+ "# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n",
905
+ "autoenc.eval()\n",
906
+ "autoenc.requires_grad_(False)\n",
907
+ "autoenc.to(device)\n",
908
+ "utils.count_params(autoenc)"
909
+ ]
910
+ },
911
+ {
912
+ "cell_type": "markdown",
913
+ "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0",
914
+ "metadata": {},
915
+ "source": [
916
+ "### MindEye modules"
917
+ ]
918
+ },
919
+ {
920
+ "cell_type": "code",
921
+ "execution_count": 41,
922
+ "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5",
923
+ "metadata": {
924
+ "tags": []
925
+ },
926
+ "outputs": [
927
+ {
928
+ "data": {
929
+ "text/plain": [
930
+ "MindEyeModule()"
931
+ ]
932
+ },
933
+ "execution_count": 41,
934
+ "metadata": {},
935
+ "output_type": "execute_result"
936
+ }
937
+ ],
938
+ "source": [
939
+ "class MindEyeModule(nn.Module):\n",
940
+ " def __init__(self):\n",
941
+ " super(MindEyeModule, self).__init__()\n",
942
+ " def forward(self, x):\n",
943
+ " return x\n",
944
+ " \n",
945
+ "model = MindEyeModule()\n",
946
+ "model"
947
+ ]
948
+ },
949
+ {
950
+ "cell_type": "code",
951
+ "execution_count": 42,
952
+ "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0",
953
+ "metadata": {
954
+ "tags": []
955
+ },
956
+ "outputs": [
957
+ {
958
+ "name": "stdout",
959
+ "output_type": "stream",
960
+ "text": [
961
+ "param counts:\n",
962
+ "32,215,040 total\n",
963
+ "32,215,040 trainable\n",
964
+ "param counts:\n",
965
+ "32,215,040 total\n",
966
+ "32,215,040 trainable\n",
967
+ "torch.Size([2, 1, 15729]) torch.Size([2, 1, 2048])\n"
968
+ ]
969
+ }
970
+ ],
971
+ "source": [
972
+ "class RidgeRegression(torch.nn.Module):\n",
973
+ " # make sure to add weight_decay when initializing optimizer\n",
974
+ " def __init__(self, input_size, out_features): \n",
975
+ " super(RidgeRegression, self).__init__()\n",
976
+ " self.out_features = out_features\n",
977
+ " self.linear = torch.nn.Linear(input_size, out_features)\n",
978
+ " def forward(self, x):\n",
979
+ " return self.linear(x)\n",
980
+ " \n",
981
+ "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n",
982
+ "utils.count_params(model.ridge)\n",
983
+ "utils.count_params(model)\n",
984
+ "\n",
985
+ "b = torch.randn((2,1,voxels.shape[1]))\n",
986
+ "print(b.shape, model.ridge(b).shape)"
987
+ ]
988
+ },
989
+ {
990
+ "cell_type": "code",
991
+ "execution_count": 43,
992
+ "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd",
993
+ "metadata": {
994
+ "tags": []
995
+ },
996
+ "outputs": [
997
+ {
998
+ "name": "stdout",
999
+ "output_type": "stream",
1000
+ "text": [
1001
+ "param counts:\n",
1002
+ "772,419,072 total\n",
1003
+ "772,419,072 trainable\n",
1004
+ "param counts:\n",
1005
+ "804,634,112 total\n",
1006
+ "804,634,112 trainable\n",
1007
+ "torch.Size([4, 2048])\n",
1008
+ "torch.Size([4, 257, 1408])\n"
1009
+ ]
1010
+ }
1011
+ ],
1012
+ "source": [
1013
+ "from functools import partial\n",
1014
+ "from diffusers.models.vae import Decoder\n",
1015
+ "class BrainNetwork(nn.Module):\n",
1016
+ " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):\n",
1017
+ " super().__init__()\n",
1018
+ " self.blurry_dim = blurry_dim\n",
1019
+ " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n",
1020
+ " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n",
1021
+ " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n",
1022
+ " self.lin0 = nn.Linear(in_dim, h)\n",
1023
+ " self.mlp = nn.ModuleList([\n",
1024
+ " nn.Sequential(\n",
1025
+ " nn.Linear(h, h),\n",
1026
+ " *[item() for item in act_and_norm],\n",
1027
+ " nn.Dropout(drop)\n",
1028
+ " ) for _ in range(n_blocks)\n",
1029
+ " ])\n",
1030
+ " self.lin1 = nn.Linear(h, out_dim, bias=True)\n",
1031
+ " # self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)\n",
1032
+ " self.n_blocks = n_blocks\n",
1033
+ " self.clip_size = clip_size\n",
1034
+ " self.clip_proj = nn.Sequential(\n",
1035
+ " nn.LayerNorm(clip_size),\n",
1036
+ " nn.GELU(),\n",
1037
+ " nn.Linear(clip_size, 2048),\n",
1038
+ " nn.LayerNorm(2048),\n",
1039
+ " nn.GELU(),\n",
1040
+ " nn.Linear(2048, 2048),\n",
1041
+ " nn.LayerNorm(2048),\n",
1042
+ " nn.GELU(),\n",
1043
+ " nn.Linear(2048, clip_size)\n",
1044
+ " )\n",
1045
+ " # self.upsampler = Decoder(\n",
1046
+ " # in_channels=64,\n",
1047
+ " # out_channels=4,\n",
1048
+ " # up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n",
1049
+ " # block_out_channels=[64, 128, 256],\n",
1050
+ " # layers_per_block=1,\n",
1051
+ " # )\n",
1052
+ " \n",
1053
+ " def forward(self, x):\n",
1054
+ " x = self.lin0(x)\n",
1055
+ " residual = x\n",
1056
+ " for res_block in range(self.n_blocks):\n",
1057
+ " x = self.mlp[res_block](x)\n",
1058
+ " x += residual\n",
1059
+ " residual = x\n",
1060
+ " x = x.reshape(len(x), -1)\n",
1061
+ " x = self.lin1(x)\n",
1062
+ " # b = self.blin1(x)\n",
1063
+ " # b = self.upsampler(b.reshape(len(b), -1, 7, 7))\n",
1064
+ " c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))\n",
1065
+ " # return c, b\n",
1066
+ " return c\n",
1067
+ "\n",
1068
+ "model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) \n",
1069
+ "utils.count_params(model.backbone)\n",
1070
+ "utils.count_params(model)\n",
1071
+ "\n",
1072
+ "b = torch.randn((4,hidden_dim))\n",
1073
+ "print(b.shape)\n",
1074
+ "clip_ = model.backbone(b)\n",
1075
+ "print(clip_.shape)"
1076
+ ]
1077
+ },
1078
+ {
1079
+ "cell_type": "code",
1080
+ "execution_count": 44,
1081
+ "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1",
1082
+ "metadata": {
1083
+ "tags": []
1084
+ },
1085
+ "outputs": [
1086
+ {
1087
+ "name": "stdout",
1088
+ "output_type": "stream",
1089
+ "text": [
1090
+ "\n",
1091
+ "Done with model preparations!\n",
1092
+ "param counts:\n",
1093
+ "804,634,112 total\n",
1094
+ "804,634,112 trainable\n"
1095
+ ]
1096
+ }
1097
+ ],
1098
+ "source": [
1099
+ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
1100
+ "opt_grouped_parameters = [\n",
1101
+ " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n",
1102
+ " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
1103
+ " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n",
1104
+ "]\n",
1105
+ "\n",
1106
+ "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))\n",
1107
+ "\n",
1108
+ "if lr_scheduler_type == 'linear':\n",
1109
+ " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
1110
+ " optimizer,\n",
1111
+ " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n",
1112
+ " last_epoch=-1\n",
1113
+ " )\n",
1114
+ "elif lr_scheduler_type == 'cycle':\n",
1115
+ " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n",
1116
+ " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
1117
+ " optimizer, \n",
1118
+ " max_lr=max_lr,\n",
1119
+ " total_steps=total_steps,\n",
1120
+ " final_div_factor=1000,\n",
1121
+ " last_epoch=-1, pct_start=2/num_epochs\n",
1122
+ " )\n",
1123
+ " \n",
1124
+ "def save_ckpt(tag): \n",
1125
+ " ckpt_path = outdir+f'/{tag}.pth'\n",
1126
+ " print(f'saving {ckpt_path}',flush=True)\n",
1127
+ " unwrapped_model = accelerator.unwrap_model(model)\n",
1128
+ " try:\n",
1129
+ " torch.save({\n",
1130
+ " 'epoch': epoch,\n",
1131
+ " 'model_state_dict': unwrapped_model.state_dict(),\n",
1132
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
1133
+ " 'lr_scheduler': lr_scheduler.state_dict(),\n",
1134
+ " 'train_losses': losses,\n",
1135
+ " 'test_losses': test_losses,\n",
1136
+ " 'lrs': lrs,\n",
1137
+ " }, ckpt_path)\n",
1138
+ " except:\n",
1139
+ " print(\"Couldn't save... moving on to prevent crashing.\")\n",
1140
+ " del unwrapped_model\n",
1141
+ " \n",
1142
+ "print(\"\\nDone with model preparations!\")\n",
1143
+ "utils.count_params(model)"
1144
+ ]
1145
+ },
1146
+ {
1147
+ "cell_type": "markdown",
1148
+ "id": "983f458b-35b8-49f2-b6db-80296cece730",
1149
+ "metadata": {},
1150
+ "source": [
1151
+ "# Weights and Biases"
1152
+ ]
1153
+ },
1154
+ {
1155
+ "cell_type": "code",
1156
+ "execution_count": 32,
1157
+ "id": "0a25a662-daa8-4de9-9233-8364800fcb6b",
1158
+ "metadata": {
1159
+ "tags": []
1160
+ },
1161
+ "outputs": [
1162
+ {
1163
+ "name": "stdout",
1164
+ "output_type": "stream",
1165
+ "text": [
1166
+ "wandb mindeyev2 run captions\n"
1167
+ ]
1168
+ },
1169
+ {
1170
+ "name": "stderr",
1171
+ "output_type": "stream",
1172
+ "text": [
1173
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
1174
+ ]
1175
+ },
1176
+ {
1177
+ "name": "stdout",
1178
+ "output_type": "stream",
1179
+ "text": [
1180
+ "wandb_config:\n",
1181
+ " {'model_name': 'captions', 'batch_size': 128, 'num_epochs': 30, 'use_image_aug': False, 'max_lr': 0.0001, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1}\n"
1182
+ ]
1183
+ },
1184
+ {
1185
+ "data": {
1186
+ "text/html": [
1187
+ "wandb version 0.16.0 is available! To upgrade, please run:\n",
1188
+ " $ pip install wandb --upgrade"
1189
+ ],
1190
+ "text/plain": [
1191
+ "<IPython.core.display.HTML object>"
1192
+ ]
1193
+ },
1194
+ "metadata": {},
1195
+ "output_type": "display_data"
1196
+ },
1197
+ {
1198
+ "data": {
1199
+ "text/html": [
1200
+ "Tracking run with wandb version 0.15.5"
1201
+ ],
1202
+ "text/plain": [
1203
+ "<IPython.core.display.HTML object>"
1204
+ ]
1205
+ },
1206
+ "metadata": {},
1207
+ "output_type": "display_data"
1208
+ },
1209
+ {
1210
+ "data": {
1211
+ "text/html": [
1212
+ "Run data is saved locally in <code>/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231119_163615-o1xwsqre</code>"
1213
+ ],
1214
+ "text/plain": [
1215
+ "<IPython.core.display.HTML object>"
1216
+ ]
1217
+ },
1218
+ "metadata": {},
1219
+ "output_type": "display_data"
1220
+ },
1221
+ {
1222
+ "data": {
1223
+ "text/html": [
1224
+ "Syncing run <strong><a href='https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre' target=\"_blank\">captions</a></strong> to <a href='https://stability.wandb.io/ckadirt/mindeyev2' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
1225
+ ],
1226
+ "text/plain": [
1227
+ "<IPython.core.display.HTML object>"
1228
+ ]
1229
+ },
1230
+ "metadata": {},
1231
+ "output_type": "display_data"
1232
+ },
1233
+ {
1234
+ "data": {
1235
+ "text/html": [
1236
+ " View project at <a href='https://stability.wandb.io/ckadirt/mindeyev2' target=\"_blank\">https://stability.wandb.io/ckadirt/mindeyev2</a>"
1237
+ ],
1238
+ "text/plain": [
1239
+ "<IPython.core.display.HTML object>"
1240
+ ]
1241
+ },
1242
+ "metadata": {},
1243
+ "output_type": "display_data"
1244
+ },
1245
+ {
1246
+ "data": {
1247
+ "text/html": [
1248
+ " View run at <a href='https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre' target=\"_blank\">https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre</a>"
1249
+ ],
1250
+ "text/plain": [
1251
+ "<IPython.core.display.HTML object>"
1252
+ ]
1253
+ },
1254
+ "metadata": {},
1255
+ "output_type": "display_data"
1256
+ }
1257
+ ],
1258
+ "source": [
1259
+ "# params for wandb\n",
1260
+ "if local_rank==0 and True: # only use main process for wandb logging\n",
1261
+ " import wandb\n",
1262
+ " \n",
1263
+ " wandb_project = 'mindeyev2'\n",
1264
+ " wandb_run = model_name\n",
1265
+ " wandb_notes = ''\n",
1266
+ " \n",
1267
+ " print(f\"wandb {wandb_project} run {wandb_run}\")\n",
1268
+ " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n",
1269
+ " wandb_config = {\n",
1270
+ " \"model_name\": model_name,\n",
1271
+ " \"batch_size\": batch_size,\n",
1272
+ " \"num_epochs\": num_epochs,\n",
1273
+ " \"use_image_aug\": use_image_aug,\n",
1274
+ " \"max_lr\": max_lr,\n",
1275
+ " \"lr_scheduler_type\": lr_scheduler_type,\n",
1276
+ " \"mixup_pct\": mixup_pct,\n",
1277
+ " \"num_train\": num_train,\n",
1278
+ " \"num_test\": num_test,\n",
1279
+ " \"seed\": seed,\n",
1280
+ " \"distributed\": distributed,\n",
1281
+ " \"num_devices\": num_devices,\n",
1282
+ " \"world_size\": world_size,\n",
1283
+ " }\n",
1284
+ " print(\"wandb_config:\\n\",wandb_config)\n",
1285
+ " if False: # wandb_auto_resume\n",
1286
+ " print(\"wandb_id:\",model_name)\n",
1287
+ " wandb.init(\n",
1288
+ " id = model_name,\n",
1289
+ " project=wandb_project,\n",
1290
+ " name=wandb_run,\n",
1291
+ " config=wandb_config,\n",
1292
+ " notes=wandb_notes,\n",
1293
+ " resume=\"allow\",\n",
1294
+ " )\n",
1295
+ " else:\n",
1296
+ " wandb.init(\n",
1297
+ " project=wandb_project,\n",
1298
+ " name=wandb_run,\n",
1299
+ " config=wandb_config,\n",
1300
+ " notes=wandb_notes,\n",
1301
+ " )\n",
1302
+ "else:\n",
1303
+ " wandb_log = False"
1304
+ ]
1305
+ },
1306
+ {
1307
+ "cell_type": "code",
1308
+ "execution_count": 33,
1309
+ "id": "4e5de216-5318-4b45-ac02-113f03105adc",
1310
+ "metadata": {},
1311
+ "outputs": [
1312
+ {
1313
+ "data": {
1314
+ "text/html": [
1315
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ff0000; text-decoration-color: #ff0000\">╭──────────────────────────────────────────────────────────────────────────────────────────────────╮</span>\n",
1316
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span> n++ <span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span>\n",
1317
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span> <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">▲</span> <span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span>\n",
1318
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000\">╰───────────────────────────────────────────────────────────────────────────────���──────────────────╯</span>\n",
1319
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">SyntaxError: </span>invalid syntax\n",
1320
+ "</pre>\n"
1321
+ ],
1322
+ "text/plain": [
1323
+ "\u001b[91m╭──────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n",
1324
+ "\u001b[91m│\u001b[0m n++ \u001b[91m│\u001b[0m\n",
1325
+ "\u001b[91m│\u001b[0m \u001b[1;91m▲\u001b[0m \u001b[91m│\u001b[0m\n",
1326
+ "\u001b[91m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
1327
+ "\u001b[1;91mSyntaxError: \u001b[0minvalid syntax\n"
1328
+ ]
1329
+ },
1330
+ "metadata": {},
1331
+ "output_type": "display_data"
1332
+ }
1333
+ ],
1334
+ "source": []
1335
+ },
1336
+ {
1337
+ "cell_type": "markdown",
1338
+ "id": "5b0ae095-3203-4eb8-8606-acc2db6ccf20",
1339
+ "metadata": {},
1340
+ "source": [
1341
+ "# More custom functions"
1342
+ ]
1343
+ },
1344
+ {
1345
+ "cell_type": "code",
1346
+ "execution_count": 34,
1347
+ "id": "827ead88-7eb3-47cc-82da-31565063b927",
1348
+ "metadata": {
1349
+ "tags": []
1350
+ },
1351
+ "outputs": [],
1352
+ "source": [
1353
+ "# using the same preprocessing as was used in MindEye + BrainDiffuser\n",
1354
+ "pixcorr_preprocess = transforms.Compose([\n",
1355
+ " transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
1356
+ "])\n",
1357
+ "def pixcorr(images,brains):\n",
1358
+ " # Flatten images while keeping the batch dimension\n",
1359
+ " all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)\n",
1360
+ " all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)\n",
1361
+ " corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()\n",
1362
+ " return corrmean"
1363
+ ]
1364
+ },
1365
+ {
1366
+ "cell_type": "markdown",
1367
+ "id": "d5690151-2131-4918-b750-e869cbd1a8a8",
1368
+ "metadata": {},
1369
+ "source": [
1370
+ "# Main"
1371
+ ]
1372
+ },
1373
+ {
1374
+ "cell_type": "code",
1375
+ "execution_count": 51,
1376
+ "id": "12de6387-6e18-4e4b-b5ce-a847d625330a",
1377
+ "metadata": {
1378
+ "tags": []
1379
+ },
1380
+ "outputs": [],
1381
+ "source": [
1382
+ "epoch = 0\n",
1383
+ "losses, test_losses, lrs = [], [], []\n",
1384
+ "best_test_loss = 1e9\n",
1385
+ "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n",
1386
+ "\n",
1387
+ "# Optionally resume from checkpoint #\n",
1388
+ "if resume_from_ckpt:\n",
1389
+ " print(\"\\n---resuming from last.pth ckpt---\\n\")\n",
1390
+ " try:\n",
1391
+ " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n",
1392
+ " except:\n",
1393
+ " print('last.pth failed... trying last_backup.pth')\n",
1394
+ " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n",
1395
+ " epoch = checkpoint['epoch']\n",
1396
+ " print(\"Epoch\",epoch)\n",
1397
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
1398
+ " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n",
1399
+ " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n",
1400
+ " del checkpoint\n",
1401
+ "elif wandb_log:\n",
1402
+ " if wandb.run.resumed:\n",
1403
+ " print(\"\\n---resuming from last.pth ckpt---\\n\")\n",
1404
+ " try:\n",
1405
+ " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n",
1406
+ " except:\n",
1407
+ " print('last.pth failed... trying last_backup.pth')\n",
1408
+ " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n",
1409
+ " epoch = checkpoint['epoch']\n",
1410
+ " print(\"Epoch\",epoch)\n",
1411
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
1412
+ " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n",
1413
+ " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n",
1414
+ " del checkpoint\n",
1415
+ "torch.cuda.empty_cache()"
1416
+ ]
1417
+ },
1418
+ {
1419
+ "cell_type": "code",
1420
+ "execution_count": 36,
1421
+ "id": "b4755749-2d99-4e98-ad98-3df661746058",
1422
+ "metadata": {
1423
+ "tags": []
1424
+ },
1425
+ "outputs": [],
1426
+ "source": [
1427
+ "checkpoint = torch.load('/fsx/proj-fmri/ckadirt/MindEyeV2/train_logs/caption_clip_0.5_bz/last.pth', map_location='cpu')"
1428
+ ]
1429
+ },
1430
+ {
1431
+ "cell_type": "code",
1432
+ "execution_count": 45,
1433
+ "id": "cd3dc793-5a20-4b48-959c-bc64430c8c02",
1434
+ "metadata": {
1435
+ "tags": []
1436
+ },
1437
+ "outputs": [
1438
+ {
1439
+ "data": {
1440
+ "text/plain": [
1441
+ "<All keys matched successfully>"
1442
+ ]
1443
+ },
1444
+ "execution_count": 45,
1445
+ "metadata": {},
1446
+ "output_type": "execute_result"
1447
+ }
1448
+ ],
1449
+ "source": [
1450
+ "model.load_state_dict(checkpoint['model_state_dict'])"
1451
+ ]
1452
+ },
1453
+ {
1454
+ "cell_type": "code",
1455
+ "execution_count": 46,
1456
+ "id": "0faa2c6a-00da-4b66-b5e5-8c4864768805",
1457
+ "metadata": {
1458
+ "tags": []
1459
+ },
1460
+ "outputs": [
1461
+ {
1462
+ "data": {
1463
+ "text/plain": [
1464
+ "MindEyeModule(\n",
1465
+ " (ridge): RidgeRegression(\n",
1466
+ " (linear): Linear(in_features=15729, out_features=2048, bias=True)\n",
1467
+ " )\n",
1468
+ " (backbone): BrainNetwork(\n",
1469
+ " (lin0): Linear(in_features=2048, out_features=2048, bias=True)\n",
1470
+ " (mlp): ModuleList(\n",
1471
+ " (0-3): 4 x Sequential(\n",
1472
+ " (0): Linear(in_features=2048, out_features=2048, bias=True)\n",
1473
+ " (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
1474
+ " (2): GELU(approximate='none')\n",
1475
+ " (3): Dropout(p=0.15, inplace=False)\n",
1476
+ " )\n",
1477
+ " )\n",
1478
+ " (lin1): Linear(in_features=2048, out_features=361856, bias=True)\n",
1479
+ " (clip_proj): Sequential(\n",
1480
+ " (0): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)\n",
1481
+ " (1): GELU(approximate='none')\n",
1482
+ " (2): Linear(in_features=1408, out_features=2048, bias=True)\n",
1483
+ " (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
1484
+ " (4): GELU(approximate='none')\n",
1485
+ " (5): Linear(in_features=2048, out_features=2048, bias=True)\n",
1486
+ " (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
1487
+ " (7): GELU(approximate='none')\n",
1488
+ " (8): Linear(in_features=2048, out_features=1408, bias=True)\n",
1489
+ " )\n",
1490
+ " )\n",
1491
+ ")"
1492
+ ]
1493
+ },
1494
+ "execution_count": 46,
1495
+ "metadata": {},
1496
+ "output_type": "execute_result"
1497
+ }
1498
+ ],
1499
+ "source": [
1500
+ "model"
1501
+ ]
1502
+ },
1503
+ {
1504
+ "cell_type": "code",
1505
+ "execution_count": 47,
1506
+ "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4",
1507
+ "metadata": {
1508
+ "tags": []
1509
+ },
1510
+ "outputs": [],
1511
+ "source": [
1512
+ "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n",
1513
+ "model, optimizer, train_dl, test_dl, lr_scheduler\n",
1514
+ ")"
1515
+ ]
1516
+ },
1517
+ {
1518
+ "cell_type": "code",
1519
+ "execution_count": null,
1520
+ "id": "bfeeda32-82ca-4364-bce1-eaa41b4f3e25",
1521
+ "metadata": {
1522
+ "tags": []
1523
+ },
1524
+ "outputs": [],
1525
+ "source": [
1526
+ "\"\"\"transform = transforms.Compose(\n",
1527
+ " [\n",
1528
+ " transforms.Resize(\n",
1529
+ " (224, 224),\n",
1530
+ " ),\n",
1531
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n",
1532
+ " ]\n",
1533
+ " )\n",
1534
+ "\n",
1535
+ "def tensor_2_embed(image): \n",
1536
+ " image_for_blip2 = transform(image)\n",
1537
+ " \n",
1538
+ " #Generate embeddings\n",
1539
+ " with blip2_model.maybe_autocast():\n",
1540
+ " blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))\n",
1541
+ " \n",
1542
+ " return blip2_target\n",
1543
+ "\n",
1544
+ "def embed_2_caption(image_embeds, model):\n",
1545
+ " image_embeds = image_embeds.float()\n",
1546
+ " image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n",
1547
+ " image.device)\n",
1548
+ "\n",
1549
+ " query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)\n",
1550
+ " query_output = model.Qformer.bert(\n",
1551
+ " query_embeds=query_tokens,\n",
1552
+ " encoder_hidden_states=image_embeds,\n",
1553
+ " encoder_attention_mask=image_atts,\n",
1554
+ " return_dict=True)\n",
1555
+ "\n",
1556
+ " inputs_t5 = model.t5_proj(query_output.last_hidden_state)\n",
1557
+ " atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n",
1558
+ " prompt = model.prompt\n",
1559
+ " input_tokens = model.t5_tokenizer(\n",
1560
+ " prompt, padding=\"longest\", return_tensors=\"pt\"\n",
1561
+ " ).to(image.device)\n",
1562
+ " encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n",
1563
+ " \n",
1564
+ " with model.maybe_autocast(dtype=torch.bfloat16):\n",
1565
+ " inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n",
1566
+ " inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n",
1567
+ "\n",
1568
+ " outputs = model.t5_model.generate(\n",
1569
+ " inputs_embeds=inputs_embeds,\n",
1570
+ " attention_mask=encoder_atts)\n",
1571
+ " output_text = model.t5_tokenizer.batch_decode(\n",
1572
+ " outputs, skip_special_tokens=True)\n",
1573
+ " \n",
1574
+ " return output_text\"\"\""
1575
+ ]
1576
+ },
1577
+ {
1578
+ "cell_type": "code",
1579
+ "execution_count": 48,
1580
+ "id": "636b4684-df9a-4e29-8683-86fb035ba690",
1581
+ "metadata": {
1582
+ "tags": []
1583
+ },
1584
+ "outputs": [],
1585
+ "source": [
1586
+ "wandb_log = False"
1587
+ ]
1588
+ },
1589
+ {
1590
+ "cell_type": "code",
1591
+ "execution_count": 49,
1592
+ "id": "0847b380-2edb-4a56-9b33-fdc4c0c3f8d3",
1593
+ "metadata": {
1594
+ "tags": []
1595
+ },
1596
+ "outputs": [],
1597
+ "source": [
1598
+ "predicted_embeddings = None"
1599
+ ]
1600
+ },
1601
+ {
1602
+ "cell_type": "code",
1603
+ "execution_count": 52,
1604
+ "id": "60be0d5f-3e94-4612-9373-61b53d836393",
1605
+ "metadata": {
1606
+ "tags": []
1607
+ },
1608
+ "outputs": [
1609
+ {
1610
+ "name": "stdout",
1611
+ "output_type": "stream",
1612
+ "text": [
1613
+ "captions starting with epoch 0 / 30\n"
1614
+ ]
1615
+ },
1616
+ {
1617
+ "name": "stderr",
1618
+ "output_type": "stream",
1619
+ "text": [
1620
+ " 0%| | 0/30 [00:17<?, ?it/s]"
1621
+ ]
1622
+ },
1623
+ {
1624
+ "name": "stdout",
1625
+ "output_type": "stream",
1626
+ "text": [
1627
+ "\n",
1628
+ "===Finished!===\n",
1629
+ "\n",
1630
+ "saving /fsx/proj-fmri/ckadirt/MindEyeV2/train_logs/captions/last.pth\n"
1631
+ ]
1632
+ },
1633
+ {
1634
+ "name": "stderr",
1635
+ "output_type": "stream",
1636
+ "text": [
1637
+ "\n"
1638
+ ]
1639
+ }
1640
+ ],
1641
+ "source": [
1642
+ "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n",
1643
+ "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n",
1644
+ "test_image, test_voxel = None, None\n",
1645
+ "mse = nn.MSELoss()\n",
1646
+ "for epoch in progress_bar:\n",
1647
+ " model.train()\n",
1648
+ " \n",
1649
+ " fwd_percent_correct = 0.\n",
1650
+ " bwd_percent_correct = 0.\n",
1651
+ " test_fwd_percent_correct = 0.\n",
1652
+ " test_bwd_percent_correct = 0.\n",
1653
+ "\n",
1654
+ " loss_clip_total = 0.\n",
1655
+ " loss_blurry_total = 0.\n",
1656
+ " test_loss_clip_total = 0.\n",
1657
+ " test_loss_blurry_total = 0.\n",
1658
+ "\n",
1659
+ " blurry_pixcorr = 0.\n",
1660
+ " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n",
1661
+ " \n",
1662
+ " \"\"\"for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n",
1663
+ " if epoch == 0:\n",
1664
+ " lrs.append(0)\n",
1665
+ " break\n",
1666
+ " with torch.cuda.amp.autocast():\n",
1667
+ " optimizer.zero_grad()\n",
1668
+ "\n",
1669
+ " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n",
1670
+ " \n",
1671
+ " image = images[behav[:,0,0].cpu().long()].to(device).float()\n",
1672
+ "\n",
1673
+ " # blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n",
1674
+ " \n",
1675
+ " if use_image_aug: image = img_augment(image)\n",
1676
+ " # clip_target = clip_model.embed_image(image)\n",
1677
+ " clip_target = embed_images_b2(image)[0].to(device) #####CHANGED\n",
1678
+ " assert not torch.any(torch.isnan(clip_target))\n",
1679
+ " \n",
1680
+ " if epoch < int(mixup_pct * num_epochs):\n",
1681
+ " voxel, perm, betas, select = utils.mixco(voxel)\n",
1682
+ "\n",
1683
+ " voxel_ridge = model.ridge(voxel)\n",
1684
+ " \n",
1685
+ " # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n",
1686
+ " clip_voxels = model.backbone(voxel_ridge)\n",
1687
+ " \n",
1688
+ " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n",
1689
+ " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
1690
+ "\n",
1691
+ " if epoch < int(mixup_pct * num_epochs): \n",
1692
+ " loss_clip = utils.mixco_nce(\n",
1693
+ " clip_voxels_norm,\n",
1694
+ " clip_target_norm,\n",
1695
+ " temp=.006, \n",
1696
+ " perm=perm, betas=betas, select=select)\n",
1697
+ " else:\n",
1698
+ " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n",
1699
+ " loss_clip = utils.soft_clip_loss(\n",
1700
+ " clip_voxels_norm,\n",
1701
+ " clip_target_norm,\n",
1702
+ " temp=epoch_temp)\n",
1703
+ " \n",
1704
+ " loss_mse= mse(clip_voxels, clip_target)\n",
1705
+ "\n",
1706
+ " # loss_blurry = mse(blurry_image_enc_, blurry_image_enc) \n",
1707
+ "\n",
1708
+ " loss_clip_total += loss_clip.item()\n",
1709
+ " # loss_blurry_total += loss_blurry.item()\n",
1710
+ "\n",
1711
+ " # loss = loss_blurry + loss_clip\n",
1712
+ " loss = 0.7 * loss_clip + 0.3 * loss_mse\n",
1713
+ " if (train_i % 10 == 0):\n",
1714
+ " print(train_i, loss)\n",
1715
+ " # print(batch_size)\n",
1716
+ " utils.check_loss(loss)\n",
1717
+ " accelerator.backward(loss)\n",
1718
+ " optimizer.step()\n",
1719
+ " \n",
1720
+ " losses.append(loss.item())\n",
1721
+ " lrs.append(optimizer.param_groups[0]['lr'])\n",
1722
+ " \n",
1723
+ " # forward and backward top 1 accuracy \n",
1724
+ " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n",
1725
+ " fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n",
1726
+ " bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
1727
+ "\n",
1728
+ " # with torch.no_grad():\n",
1729
+ " # # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()\n",
1730
+ " # random_samps = np.random.choice(np.arange(len(voxel)), size=8, replace=False)\n",
1731
+ " # blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)\n",
1732
+ " # blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)\n",
1733
+ "\n",
1734
+ " if lr_scheduler_type is not None:\n",
1735
+ " lr_scheduler.step()\"\"\"\n",
1736
+ " \n",
1737
+ " model.eval()\n",
1738
+ " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n",
1739
+ " with torch.cuda.amp.autocast():\n",
1740
+ " with torch.no_grad(): \n",
1741
+ " # all test samples should be loaded per batch such that test_i should never exceed 0\n",
1742
+ " if len(behav) != num_test: print(\"!\",len(behav),num_test)\n",
1743
+ " \n",
1744
+ " ## Average same-image repeats ##\n",
1745
+ " if test_image is None:\n",
1746
+ " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n",
1747
+ " \n",
1748
+ " image = behav[:,0,0].cpu().long()\n",
1749
+ " \n",
1750
+ " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n",
1751
+ " for im in unique_image:\n",
1752
+ " locs = torch.where(im == image)[0]\n",
1753
+ " if test_image is None:\n",
1754
+ " test_image = images[im][None]\n",
1755
+ " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n",
1756
+ " else:\n",
1757
+ " test_image = torch.vstack((test_image, images[im][None]))\n",
1758
+ " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n",
1759
+ " \n",
1760
+ " # sample of batch_size\n",
1761
+ " random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]\n",
1762
+ " voxel = test_voxel[random_indices].to(device)\n",
1763
+ " image = test_image[random_indices].to(device)\n",
1764
+ " assert len(image) == batch_size\n",
1765
+ " \n",
1766
+ " # blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n",
1767
+ " \n",
1768
+ " # clip_target = clip_model.embed_image(image.float())\n",
1769
+ " clip_target = embed_images_b2(image)[0].to(device) #####CHANGED\n",
1770
+ " \n",
1771
+ " voxel_ridge = model.ridge(voxel)\n",
1772
+ " \n",
1773
+ " # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n",
1774
+ " clip_voxels = model.backbone(voxel_ridge)\n",
1775
+ " \n",
1776
+ " predicted_embeddings = clip_voxels\n",
1777
+ " break\n",
1778
+ " \n",
1779
+ " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n",
1780
+ " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
1781
+ " \n",
1782
+ " # loss_clip = utils.soft_clip_loss(\n",
1783
+ " # clip_voxels_norm,\n",
1784
+ " # clip_target_norm,\n",
1785
+ " # temp=.006)\n",
1786
+ " \n",
1787
+ " loss_clip = mse(clip_voxels, clip_target)\n",
1788
+ "\n",
1789
+ " # loss_blurry = mse(blurry_image_enc_, blurry_image_enc)\n",
1790
+ " \n",
1791
+ " # loss = loss_blurry + loss_clip\n",
1792
+ " loss = loss_clip\n",
1793
+ " \n",
1794
+ " utils.check_loss(loss)\n",
1795
+ " \n",
1796
+ " test_losses.append(loss.item())\n",
1797
+ " \n",
1798
+ " # forward and backward top 1 accuracy \n",
1799
+ " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n",
1800
+ " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n",
1801
+ " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
1802
+ "\n",
1803
+ " # # halving the batch size because the decoder is computationally heavy\n",
1804
+ " # blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)\n",
1805
+ " # blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))\n",
1806
+ " # test_blurry_pixcorr += pixcorr(image, blurry_recon_images)\n",
1807
+ "\n",
1808
+ " #Find captions and print next to images\n",
1809
+ " #caption1 = embed_2_caption(clip_voxels[[0]], blip2_model)\n",
1810
+ " #caption2 = embed_2_caption(clip_voxels[[1]], blip2_model)\n",
1811
+ "\n",
1812
+ " #true_embed1 = tensor_2_embed(image[[0]])\n",
1813
+ " #true_embed2 = tensor_2_embed(image[[1]])\n",
1814
+ "\n",
1815
+ " # print(clip_voxels[[0]].shape)\n",
1816
+ " # print(true_embed1.shape)\n",
1817
+ " \n",
1818
+ " #true_caption1 = embed_2_caption(true_embed1, blip2_model)\n",
1819
+ " #true_caption2 = embed_2_caption(true_embed2, blip2_model)\n",
1820
+ " \n",
1821
+ " # transform blurry recon latents to images and plot it\n",
1822
+ " #fig, axes = plt.subplots(2, 2, figsize=(8, 4))\n",
1823
+ " #axes[0,0].imshow(utils.torch_to_Image(image[[0]]))\n",
1824
+ " #axes[0,1].imshow(utils.torch_to_Image(image[[1]]))\n",
1825
+ " #axes[0,0].axis('off'); axes[0,1].axis('off'); axes[1,0].axis('off'); axes[1,1].axis('off')\n",
1826
+ " #axes[0,0].set_title(caption1)\n",
1827
+ " #axes[0,1].set_title(caption2)\n",
1828
+ " #axes[1,0].set_title(true_caption1)\n",
1829
+ " #axes[1,1].set_title(true_caption2)\n",
1830
+ "\n",
1831
+ " #plt.show()\n",
1832
+ " \n",
1833
+ " # # transform blurry recon latents to images and plot it\n",
1834
+ " # fig, axes = plt.subplots(1, 4, figsize=(8, 4))\n",
1835
+ " # axes[0].imshow(utils.torch_to_Image(image[[0]]))\n",
1836
+ " # axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))\n",
1837
+ " # axes[2].imshow(utils.torch_to_Image(image[[1]]))\n",
1838
+ " # axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))\n",
1839
+ " # axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')\n",
1840
+ " # axes[0].set_title(caption1)\n",
1841
+ " # axes[3].set_title(caption2)\n",
1842
+ " # plt.show()\n",
1843
+ " \n",
1844
+ " break\n",
1845
+ " if local_rank==0: \n",
1846
+ " # if utils.is_interactive(): clear_output(wait=True)\n",
1847
+ " assert (test_i+1) == 1\n",
1848
+ " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n",
1849
+ " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n",
1850
+ " \"train/lr\": lrs[-1],\n",
1851
+ " \"train/num_steps\": len(losses),\n",
1852
+ " \"test/num_steps\": len(test_losses),\n",
1853
+ " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n",
1854
+ " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n",
1855
+ " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n",
1856
+ " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n",
1857
+ " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n",
1858
+ " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n",
1859
+ " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n",
1860
+ " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n",
1861
+ " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n",
1862
+ " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n",
1863
+ " }\n",
1864
+ " progress_bar.set_postfix(**logs)\n",
1865
+ " \n",
1866
+ " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n",
1867
+ " jj=-1\n",
1868
+ " for j in [0,1,2,3,4,5,6,7]:\n",
1869
+ " jj+=1\n",
1870
+ " axes[jj].imshow(utils.torch_to_Image(image[j]))\n",
1871
+ " axes[jj].axis('off')\n",
1872
+ "\n",
1873
+ " if wandb_log:\n",
1874
+ " generated_captions = embeds_to_captions_b2(clip_voxels[0:8])\n",
1875
+ " print(generated_captions[1])\n",
1876
+ " logs[f\"test/recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\" + \"\\n\".join(generated_captions[1]))\n",
1877
+ " plt.close()\n",
1878
+ " # Save model checkpoint and reconstruct\n",
1879
+ " if epoch % ckpt_interval == 0:\n",
1880
+ " if not utils.is_interactive():\n",
1881
+ " save_ckpt(f'last')\n",
1882
+ " \n",
1883
+ " if wandb_log: wandb.log(logs)\n",
1884
+ "\n",
1885
+ " # wait for other GPUs to catch up if needed\n",
1886
+ " accelerator.wait_for_everyone()\n",
1887
+ " torch.cuda.empty_cache()\n",
1888
+ " gc.collect()\n",
1889
+ "\n",
1890
+ "print(\"\\n===Finished!===\\n\")\n",
1891
+ "if ckpt_saving:\n",
1892
+ " save_ckpt(f'last')\n",
1893
+ "if not utils.is_interactive():\n",
1894
+ " sys.exit(0)"
1895
+ ]
1896
+ },
1897
+ {
1898
+ "cell_type": "code",
1899
+ "execution_count": 54,
1900
+ "id": "f5b47c76-a97a-48ee-b4b3-051c17aebac4",
1901
+ "metadata": {
1902
+ "tags": []
1903
+ },
1904
+ "outputs": [
1905
+ {
1906
+ "data": {
1907
+ "text/plain": [
1908
+ "torch.Size([128, 257, 1408])"
1909
+ ]
1910
+ },
1911
+ "execution_count": 54,
1912
+ "metadata": {},
1913
+ "output_type": "execute_result"
1914
+ }
1915
+ ],
1916
+ "source": [
1917
+ "predicted_embeddings.shape"
1918
+ ]
1919
+ },
1920
+ {
1921
+ "cell_type": "code",
1922
+ "execution_count": 55,
1923
+ "id": "92d0029f-079f-4710-bf43-bc9e3fd08d5e",
1924
+ "metadata": {
1925
+ "tags": []
1926
+ },
1927
+ "outputs": [
1928
+ {
1929
+ "name": "stderr",
1930
+ "output_type": "stream",
1931
+ "text": [
1932
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation/utils.py:1260: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
1933
+ " warnings.warn(\n"
1934
+ ]
1935
+ },
1936
+ {
1937
+ "name": "stdout",
1938
+ "output_type": "stream",
1939
+ "text": [
1940
+ "['a group of people are sitting around a table\\n', 'a man is holding a glass of water in front of a television\\n', 'a man is riding a skateboard on a hill\\n', 'a group of people standing around a bike\\n', 'a building with a sign that says \"the house\"\\n', 'a plate of food with vegetables and meat\\n', 'a white cup with a small bottle of wine\\n', 'a group of people playing baseball and one is holding a ball\\n']\n"
1941
+ ]
1942
+ }
1943
+ ],
1944
+ "source": [
1945
+ "generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8])\n",
1946
+ "print(generated_captions[1])"
1947
+ ]
1948
+ },
1949
+ {
1950
+ "cell_type": "code",
1951
+ "execution_count": 75,
1952
+ "id": "88750a6d-0b61-4943-a7e5-1d675bbb4f8f",
1953
+ "metadata": {
1954
+ "tags": []
1955
+ },
1956
+ "outputs": [
1957
+ {
1958
+ "name": "stdout",
1959
+ "output_type": "stream",
1960
+ "text": [
1961
+ "['a group of people are sitting at a table with food and drinks\\n', 'a man in a kitchen with a large screen\\n', 'a man on a surfboard with his legs in the air\\n', 'a group of people are standing on the beach in front of a boat\\n', 'a building with a sign that says \"home of the person\"\\n', 'a vegetable salad with a variety of vegetables and other ingredients\\n', 'a white cup with a small amount of coffee and a bottle of wine\\n', 'a group of people playing baseball and soccer\\n']\n"
1962
+ ]
1963
+ }
1964
+ ],
1965
+ "source": [
1966
+ "generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8], sample = True, temp = 0.3)\n",
1967
+ "print(generated_captions[1])"
1968
+ ]
1969
+ },
1970
+ {
1971
+ "cell_type": "code",
1972
+ "execution_count": 95,
1973
+ "id": "d99e7583-0f26-41c1-8035-a1aa3b1c2d55",
1974
+ "metadata": {
1975
+ "tags": []
1976
+ },
1977
+ "outputs": [],
1978
+ "source": [
1979
+ "def concatenate_lists_any_depth(list1, list2):\n",
1980
+ " \"\"\"\n",
1981
+ " Concatenates two lists of potentially varying depths, forming a new list of lists.\n",
1982
+ "\n",
1983
+ " Args:\n",
1984
+ " list1 (list): The first list to concatenate. Elements can be of any type.\n",
1985
+ " list2 (list): The second list to concatenate. Elements can be of any type.\n",
1986
+ "\n",
1987
+ " Returns:\n",
1988
+ " list: A new list containing lists of elements from the original lists.\n",
1989
+ " \"\"\"\n",
1990
+ " # Ensure that both lists have the same length\n",
1991
+ " if len(list1) != len(list2):\n",
1992
+ " raise ValueError(\"Lists must be of the same length\")\n",
1993
+ "\n",
1994
+ " concatenated_list = []\n",
1995
+ "\n",
1996
+ " for a, b in zip(list1, list2):\n",
1997
+ " # If the elements are not lists, convert them to lists\n",
1998
+ " if not isinstance(a, list):\n",
1999
+ " a = [a]\n",
2000
+ " if not isinstance(b, list):\n",
2001
+ " b = [b]\n",
2002
+ "\n",
2003
+ " # Concatenate the lists\n",
2004
+ " concatenated_list.append(a + b)\n",
2005
+ "\n",
2006
+ " return concatenated_list"
2007
+ ]
2008
+ },
2009
+ {
2010
+ "cell_type": "code",
2011
+ "execution_count": 96,
2012
+ "id": "ed8167ea-a3ab-438a-aa85-f1309047199c",
2013
+ "metadata": {
2014
+ "tags": []
2015
+ },
2016
+ "outputs": [],
2017
+ "source": [
2018
+ "def sample_several(embeddings, num=10, temp=0.3):\n",
2019
+ " # embeddings shape = batch, 257, 1408\n",
2020
+ " results = None # Initialize results as None\n",
2021
+ "\n",
2022
+ " for i in range(num): # Iterate from 0 to num-1\n",
2023
+ " if results is None:\n",
2024
+ " # For the first iteration, assign the results directly\n",
2025
+ " results = embeds_to_captions_b2(embeddings, sample=True, temp=temp)[1]\n",
2026
+ " else:\n",
2027
+ " # For subsequent iterations, combine the new results with the existing ones\n",
2028
+ " new_results = embeds_to_captions_b2(embeddings, sample=True, temp=temp)[1]\n",
2029
+ " results = concatenate_lists_any_depth(results, new_results)\n",
2030
+ "\n",
2031
+ " return results # Return the combined results\n"
2032
+ ]
2033
+ },
2034
+ {
2035
+ "cell_type": "code",
2036
+ "execution_count": 77,
2037
+ "id": "6700e130-8ae4-4475-a5b4-972fd8b9717a",
2038
+ "metadata": {
2039
+ "tags": []
2040
+ },
2041
+ "outputs": [
2042
+ {
2043
+ "name": "stdout",
2044
+ "output_type": "stream",
2045
+ "text": [
2046
+ "['a group of people sitting on a bench in front of a building\\n', 'a woman is using a computer to make a video\\n', 'a man in a black shirt is sitting on a surfboard\\n', 'a group of people on the beach with a bike and some other things\\n', 'a large building with a sign that says \"the old farmhouse\"\\n', 'a plate with many different types of vegetables\\n', 'a white cup with a bottle of wine and a small bottle of wine\\n', 'a group of people are playing baseball in a field\\n']\n"
2047
+ ]
2048
+ }
2049
+ ],
2050
+ "source": [
2051
+ "generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8], sample = True, temp = 0.3)\n",
2052
+ "print(generated_captions[1])"
2053
+ ]
2054
+ },
2055
+ {
2056
+ "cell_type": "code",
2057
+ "execution_count": 99,
2058
+ "id": "f0e111e3-6134-4a63-a6d7-17b3441be8c8",
2059
+ "metadata": {
2060
+ "tags": []
2061
+ },
2062
+ "outputs": [
2063
+ {
2064
+ "data": {
2065
+ "text/plain": [
2066
+ "[['people are sitting at a table with a bunch of chairs\\n',\n",
2067
+ " 'several people in the yard with some food\\n',\n",
2068
+ " 'people sitting on a bench near a water fountain\\n',\n",
2069
+ " 'a group of people are sitting around a table\\n',\n",
2070
+ " 'a group of people in a room with several people in the foreground\\n',\n",
2071
+ " 'a group of people sitting around a table with food\\n',\n",
2072
+ " 'the people in the background are sitting on the edge of a table\\n',\n",
2073
+ " 'beverages and food are served at a family picnic\\n',\n",
2074
+ " 'a group of people eating in a restaurant\\n',\n",
2075
+ " 'a group of people sitting around a table\\n',\n",
2076
+ " 'people are sitting at a table next to a tree\\n',\n",
2077
+ " 'people are sitting around a table with a lot of food\\n'],\n",
2078
+ " ['a person is holding a newspaper in a restaurant\\n',\n",
2079
+ " 'the man is holding a cup of coffee in front of a television\\n',\n",
2080
+ " 'a woman is preparing to cook in a kitchen\\n',\n",
2081
+ " 'a man working in an office setting with a computer and a man in a chair\\n',\n",
2082
+ " 'a person is using a smartphone in a restaurant\\n',\n",
2083
+ " 'a man is holding a glass of water in front of a television\\n',\n",
2084
+ " 'a man in a kitchen with a knife and a cup of coffee\\n',\n",
2085
+ " 'the kitchen at the new york times\\n',\n",
2086
+ " 'a man is holding a knife and cutting a piece of pizza\\n',\n",
2087
+ " 'a man is reading a book while another is working on a computer\\n',\n",
2088
+ " 'a person is using a computer to make a presentation\\n',\n",
2089
+ " 'a man is holding up a box of food\\n'],\n",
2090
+ " ['a man in a suit and a woman wearing a helmet on a surfboard\\n',\n",
2091
+ " 'a person is on the ground while holding onto a skateboard\\n',\n",
2092
+ " 'a man in a beach chair riding a skateboard\\n',\n",
2093
+ " 'a woman is standing on a surfboard in the ocean while holding a skateboard\\n',\n",
2094
+ " 'a man is riding on a surfboard\\n',\n",
2095
+ " 'a man on his knees in a surfboard with his leg up\\n',\n",
2096
+ " 'a man is doing a trick on a skateboard\\n',\n",
2097
+ " 'a man is riding a skateboard on a wave\\n',\n",
2098
+ " 'a person is sitting on a surfboard while another person is riding on it\\n',\n",
2099
+ " 'a man in a jumpsuit is holding onto a surfboard\\n',\n",
2100
+ " 'a man is jumping on a surfboard while another is sitting on it\\n',\n",
2101
+ " 'a man is sitting on a surfboard while he is riding\\n'],\n",
2102
+ " ['a picture of a man riding a bike next to a bike\\n',\n",
2103
+ " 'a group of people standing on a street with a bike\\n',\n",
2104
+ " 'people are sitting around a picnic table and a bike is being ridden\\n',\n",
2105
+ " 'a group of people are on a beach with a bike and a car\\n',\n",
2106
+ " \"the world's largest boat race is underway in the bay of britain\\n\",\n",
2107
+ " 'a man and his bike standing on the side of a road\\n',\n",
2108
+ " 'a motorcycle is sitting on top of a hill with a boat and a bicycle\\n',\n",
2109
+ " 'a bunch of people are standing around a large park\\n',\n",
2110
+ " 'a group of people standing around a table with bicycles\\n',\n",
2111
+ " 'a man with his bike and helmet in the air\\n',\n",
2112
+ " 'the sun is shining brightly and there are people walking around\\n',\n",
2113
+ " 'a group of people standing on a beach next to a boat\\n'],\n",
2114
+ " ['the home has a large yellow sign\\n',\n",
2115
+ " 'the building has two small windows and a sign\\n',\n",
2116
+ " 'a view of a home with a building in the background\\n',\n",
2117
+ " 'the house has been built in the style of a traditional english cottage\\n',\n",
2118
+ " 'the house is on the corner of a street\\n',\n",
2119
+ " 'the home is an old style building with a white door\\n',\n",
2120
+ " 'the house is in a residential area with many buildings\\n',\n",
2121
+ " 'the building is white and has a red roof\\n',\n",
2122
+ " 'the old building is now a park and recreation center\\n',\n",
2123
+ " 'a large building with a lot of windows and a lot of people\\n',\n",
2124
+ " 'a large house with a white door and a blue sign\\n',\n",
2125
+ " 'the house is in a residential area with a front and back door\\n'],\n",
2126
+ " ['the vegetables are arranged in a square shape on the table\\n',\n",
2127
+ " 'a plate full of vegetables and fruit with a knife and fork\\n',\n",
2128
+ " 'a plate of various vegetables with a knife\\n',\n",
2129
+ " 'a plate with several different types of food\\n',\n",
2130
+ " 'a plate of food with various vegetables and meat\\n',\n",
2131
+ " 'a picture of some vegetables and a plate of food\\n',\n",
2132
+ " 'a close up of several types of food on a table\\n',\n",
2133
+ " 'a plate of food with a variety of vegetables\\n',\n",
2134
+ " 'a large plate with many different types of food\\n',\n",
2135
+ " 'a plate of vegetables and meat on a table\\n',\n",
2136
+ " 'a plate with lots of different types of vegetables\\n',\n",
2137
+ " 'a close up of some food with a knife\\n'],\n",
2138
+ " ['a white cup with a green tea bag and a small bottle of alcohol\\n',\n",
2139
+ " 'a bottle of wine with two glasses and a spoon\\n',\n",
2140
+ " 'the chocolate bar is sitting next to the bottle of wine\\n',\n",
2141
+ " 'a white and black cup and a bottle of wine\\n',\n",
2142
+ " 'a white cup sitting next to some drinks\\n',\n",
2143
+ " 'a bottle of wine and a bottle of champagne on a table\\n',\n",
2144
+ " 'a white cup with two pills and a small bottle of wine\\n',\n",
2145
+ " 'a bottle of wine and a cup of coffee next to a bottle of wine\\n',\n",
2146
+ " 'a bottle of wine and a bottle of beer in a glass\\n',\n",
2147
+ " 'a bottle of wine, a bottle of beer and a wine bottle\\n',\n",
2148
+ " 'a bottle of wine and a cup with some food\\n',\n",
2149
+ " 'a glass of wine and a pair of glasses on a table\\n'],\n",
2150
+ " ['a group of people in white and blue uniforms playing baseball\\n',\n",
2151
+ " 'a group of people playing baseball in a field\\n',\n",
2152
+ " 'a group of people playing a game of baseball\\n',\n",
2153
+ " 'a group of people standing on a field and one is holding a tennis ball\\n',\n",
2154
+ " 'a group of people in uniform playing baseball\\n',\n",
2155
+ " 'two men and a woman in the middle of a game\\n',\n",
2156
+ " 'a group of people playing baseball in the grass\\n',\n",
2157
+ " 'a group of men and women are playing baseball\\n',\n",
2158
+ " 'july 15th, 2011 - june 20th, 2012 - june 17th,',\n",
2159
+ " 'the team is playing soccer and one is holding a ball\\n',\n",
2160
+ " 'people are playing baseball with each other and one is holding a ball\\n',\n",
2161
+ " 'the women are laughing and the man is running\\n']]"
2162
+ ]
2163
+ },
2164
+ "execution_count": 99,
2165
+ "metadata": {},
2166
+ "output_type": "execute_result"
2167
+ }
2168
+ ],
2169
+ "source": [
2170
+ "several = sample_several(predicted_embeddings[0:8], num = 12, temp = 0.5)\n",
2171
+ "several"
2172
+ ]
2173
+ },
2174
+ {
2175
+ "cell_type": "code",
2176
+ "execution_count": 100,
2177
+ "id": "7ced031a-f259-4797-afd7-876fa62cdcfd",
2178
+ "metadata": {
2179
+ "tags": []
2180
+ },
2181
+ "outputs": [
2182
+ {
2183
+ "data": {
2184
+ "text/plain": [
2185
+ "[['a group of people are sitting around a table\\n',\n",
2186
+ " 'a group of people are sitting around a table with food\\n',\n",
2187
+ " 'a group of people sitting at a table with food\\n',\n",
2188
+ " 'a group of people are sitting on the ground in front of a table\\n',\n",
2189
+ " 'a group of people sitting around a table with a person and a dog\\n',\n",
2190
+ " 'a group of people are sitting on the ground and eating\\n',\n",
2191
+ " 'the group is sitting around a table with food\\n',\n",
2192
+ " 'people are sitting around a table with food\\n',\n",
2193
+ " 'a group of people sitting around a table with food\\n',\n",
2194
+ " 'the people are eating in front of a table\\n',\n",
2195
+ " 'a group of people are sitting on a bench in a field\\n',\n",
2196
+ " 'a group of people are sitting on a bench\\n'],\n",
2197
+ " ['a man is using a computer and a phone\\n',\n",
2198
+ " 'a person in a kitchen with a large screen\\n',\n",
2199
+ " 'a man is preparing food in a kitchen\\n',\n",
2200
+ " 'a man is standing in front of a computer and a woman is sitting behind him\\n',\n",
2201
+ " 'a man is using a computer to play a game\\n',\n",
2202
+ " 'a man is using a computer to play a game\\n',\n",
2203
+ " 'a man in a kitchen with a large television\\n',\n",
2204
+ " 'a man is holding a glass of water in front of a television\\n',\n",
2205
+ " 'the man is holding a bottle of water and a glass\\n',\n",
2206
+ " 'a man is using a computer to make a video\\n',\n",
2207
+ " 'a man is serving food at a restaurant\\n',\n",
2208
+ " 'a man is holding a drink in his hand\\n'],\n",
2209
+ " ['a man with a skateboard is riding on a wave\\n',\n",
2210
+ " 'a man is riding a skateboard on a hill\\n',\n",
2211
+ " 'a man is riding a skateboard on a hill\\n',\n",
2212
+ " 'a person is sitting on a surfboard while another person is riding on it\\n',\n",
2213
+ " 'a man is riding a surfboard on a wave\\n',\n",
2214
+ " 'a man with a skateboard is on top of a hill\\n',\n",
2215
+ " 'a person in a surfboard is riding a wave\\n',\n",
2216
+ " 'a man on a surfboard is riding on a wave\\n',\n",
2217
+ " 'a man in a suit and a woman in a bikini are playing on a surf board\\n',\n",
2218
+ " 'a man is riding a skateboard while wearing a helmet\\n',\n",
2219
+ " 'a man on the surf board with his legs in the air\\n',\n",
2220
+ " 'a man in a suit is playing a game with a skateboard\\n'],\n",
2221
+ " ['a group of people standing on a beach with a bike\\n',\n",
2222
+ " 'a group of people standing on a beach with a bike\\n',\n",
2223
+ " 'a group of people standing on a road with a bike and a car\\n',\n",
2224
+ " 'a group of people in the water with two bikes\\n',\n",
2225
+ " 'the bike is in the middle of the road and there are two people on the side of the',\n",
2226
+ " 'a group of people standing around a car with a bike\\n',\n",
2227
+ " 'a man is standing on a bike with a skateboard\\n',\n",
2228
+ " 'a group of people riding bicycles on a road\\n',\n",
2229
+ " 'a bicycle is in the middle of a field with a person on it\\n',\n",
2230
+ " 'a man is standing on a bicycle with a helmet and a skateboard\\n',\n",
2231
+ " 'a photo of a bicycle with a man on it\\n',\n",
2232
+ " 'a group of people riding bicycles on a road\\n'],\n",
2233
+ " ['a building with a sign that says \"the old man\"\\n',\n",
2234
+ " 'a house with a sign that says \"the house that james bond built\"\\n',\n",
2235
+ " 'a building with a sign that says \"the house\"\\n',\n",
2236
+ " 'a house with a sign that says \"museum\"\\n',\n",
2237
+ " 'a building with a sign that says \"the home of the person\"\\n',\n",
2238
+ " 'a building with a sign that says \"the museum of american history\"\\n',\n",
2239
+ " 'a white building with a sign on the side\\n',\n",
2240
+ " 'a brown house with a white roof and a green sign\\n',\n",
2241
+ " 'a house with a large sign on the side\\n',\n",
2242
+ " 'a building with a sign that says \"the building\"\\n',\n",
2243
+ " 'the building is in the middle of the street\\n',\n",
2244
+ " 'the front of an old building with a sign\\n'],\n",
2245
+ " ['a plate of different types of vegetables and meat\\n',\n",
2246
+ " 'a close up of some vegetables and meat\\n',\n",
2247
+ " 'a plate with a variety of different foods on it\\n',\n",
2248
+ " 'a plate of vegetables and meat with a green border\\n',\n",
2249
+ " 'a plate of vegetables with a variety of toppings\\n',\n",
2250
+ " 'a plate of food with different types of vegetables\\n',\n",
2251
+ " 'a plate of food with various vegetables and meat\\n',\n",
2252
+ " 'a plate of vegetables with some green leaves on it\\n',\n",
2253
+ " 'a bunch of vegetables and mushrooms on a plate\\n',\n",
2254
+ " 'a bunch of vegetables and fruit on a table\\n',\n",
2255
+ " 'a plate of vegetables and other items on a table\\n',\n",
2256
+ " 'a close up of some vegetables and meat\\n'],\n",
2257
+ " ['a white cup with a spoon and a spoon\\n',\n",
2258
+ " 'a bottle of wine and a bottle of champagne\\n',\n",
2259
+ " 'a white cup with a small bottle and a small bottle of wine\\n',\n",
2260
+ " 'a white cup with a small bottle of wine and a small bottle of water\\n',\n",
2261
+ " 'the bottle is open and the bottle is next to a cup\\n',\n",
2262
+ " 'the white cup with a small bottle of wine and a small bottle of wine\\n',\n",
2263
+ " 'a white cup with a black handle and a pair of scissors\\n',\n",
2264
+ " 'a bottle of wine and a bottle of wine glasses\\n',\n",
2265
+ " 'a bottle of wine and a bottle of champagne\\n',\n",
2266
+ " 'a white and black cup with a small spoon next to it\\n',\n",
2267
+ " 'a white cup with a small bottle of wine\\n',\n",
2268
+ " 'a white cup with a spoon and a bottle of wine\\n'],\n",
2269
+ " ['a group of people playing baseball and soccer\\n',\n",
2270
+ " 'a group of people are playing baseball in the grass\\n',\n",
2271
+ " 'a group of people playing baseball and running\\n',\n",
2272
+ " 'a group of people playing baseball and soccer\\n',\n",
2273
+ " 'a group of people playing soccer on a field\\n',\n",
2274
+ " 'a group of people are playing baseball in the grass\\n',\n",
2275
+ " 'a group of people playing baseball with a man in the background\\n',\n",
2276
+ " 'a group of people playing baseball and one is holding a ball\\n',\n",
2277
+ " 'a group of people playing baseball in front of a field\\n',\n",
2278
+ " 'a group of people playing baseball on a field\\n',\n",
2279
+ " 'a group of people playing baseball with one person in the background\\n',\n",
2280
+ " 'a group of people are playing baseball and one is holding a ball\\n']]"
2281
+ ]
2282
+ },
2283
+ "execution_count": 100,
2284
+ "metadata": {},
2285
+ "output_type": "execute_result"
2286
+ }
2287
+ ],
2288
+ "source": [
2289
+ "several = sample_several(predicted_embeddings[0:8], num = 12, temp = 0.3)\n",
2290
+ "several"
2291
+ ]
2292
+ },
2293
+ {
2294
+ "cell_type": "code",
2295
+ "execution_count": null,
2296
+ "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a",
2297
+ "metadata": {
2298
+ "tags": []
2299
+ },
2300
+ "outputs": [],
2301
+ "source": [
2302
+ "plt.plot(losses)\n",
2303
+ "plt.show()\n",
2304
+ "plt.plot(test_losses)\n",
2305
+ "plt.show()"
2306
+ ]
2307
+ },
2308
+ {
2309
+ "cell_type": "code",
2310
+ "execution_count": null,
2311
+ "id": "ccfccd4f-764d-4624-842c-f931676eb43b",
2312
+ "metadata": {},
2313
+ "outputs": [],
2314
+ "source": [
2315
+ "print('test')"
2316
+ ]
2317
+ },
2318
+ {
2319
+ "cell_type": "code",
2320
+ "execution_count": null,
2321
+ "id": "f1a60e19-c440-4c9c-a634-30186209012f",
2322
+ "metadata": {},
2323
+ "outputs": [],
2324
+ "source": [
2325
+ "def tensor_2_embed_old(tensor):\n",
2326
+ " embed_array = torch.zeros((tensor.shape[0],257, 1024)) \n",
2327
+ " to_pil = ToPILImage()\n",
2328
+ " for sample in range(tensor.shape[0]):\n",
2329
+ " PIL_image = to_pil(tensor[sample])\n",
2330
+ " image_for_blip2 = vis_processors[\"eval\"](PIL_image).unsqueeze(0).to(device)\n",
2331
+ " #Generate embeddings\n",
2332
+ " with blip2_model.maybe_autocast():\n",
2333
+ " blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))\n",
2334
+ " embed_array[sample] = blip2_target\n",
2335
+ " \n",
2336
+ " return embed_array"
2337
+ ]
2338
+ },
2339
+ {
2340
+ "cell_type": "code",
2341
+ "execution_count": null,
2342
+ "id": "d39ddada-47f7-4111-92fa-0dd98e8a83d6",
2343
+ "metadata": {},
2344
+ "outputs": [],
2345
+ "source": []
2346
+ },
2347
+ {
2348
+ "cell_type": "code",
2349
+ "execution_count": null,
2350
+ "id": "ec8ed96a-61fa-4c20-8da2-fcd9d0a2ed38",
2351
+ "metadata": {},
2352
+ "outputs": [],
2353
+ "source": []
2354
+ },
2355
+ {
2356
+ "cell_type": "code",
2357
+ "execution_count": null,
2358
+ "id": "6228eb1a-e8e7-4500-b7bc-d0c57bcac4c6",
2359
+ "metadata": {},
2360
+ "outputs": [],
2361
+ "source": []
2362
+ }
2363
+ ],
2364
+ "metadata": {
2365
+ "kernelspec": {
2366
+ "display_name": "Python 3 (ipykernel)",
2367
+ "language": "python",
2368
+ "name": "python3"
2369
+ },
2370
+ "language_info": {
2371
+ "codemirror_mode": {
2372
+ "name": "ipython",
2373
+ "version": 3
2374
+ },
2375
+ "file_extension": ".py",
2376
+ "mimetype": "text/x-python",
2377
+ "name": "python",
2378
+ "nbconvert_exporter": "python",
2379
+ "pygments_lexer": "ipython3",
2380
+ "version": "3.10.8"
2381
+ },
2382
+ "toc": {
2383
+ "base_numbering": 1,
2384
+ "nav_menu": {},
2385
+ "number_sections": true,
2386
+ "sideBar": true,
2387
+ "skip_h1_title": false,
2388
+ "title_cell": "Table of Contents",
2389
+ "title_sidebar": "Contents",
2390
+ "toc_cell": false,
2391
+ "toc_position": {
2392
+ "height": "calc(100% - 180px)",
2393
+ "left": "10px",
2394
+ "top": "150px",
2395
+ "width": "165px"
2396
+ },
2397
+ "toc_section_display": true,
2398
+ "toc_window_display": true
2399
+ },
2400
+ "toc-autonumbering": true,
2401
+ "vscode": {
2402
+ "interpreter": {
2403
+ "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376"
2404
+ }
2405
+ }
2406
+ },
2407
+ "nbformat": 4,
2408
+ "nbformat_minor": 5
2409
+ }
src/train2.ipynb ADDED
@@ -0,0 +1,1856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n",
13
+ "# from subprocess import call\n",
14
+ "# command = \"jupyter nbconvert Train.ipynb --to python\"\n",
15
+ "# call(command,shell=True)"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "markdown",
20
+ "id": "b0f0f4f3",
21
+ "metadata": {},
22
+ "source": [
23
+ "# Import packages & functions"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 2,
29
+ "id": "5bad764b-45c1-45ce-a716-8d055e09821a",
30
+ "metadata": {
31
+ "tags": []
32
+ },
33
+ "outputs": [
34
+ {
35
+ "name": "stderr",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
39
+ " from .autonotebook import tqdm as notebook_tqdm\n"
40
+ ]
41
+ },
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "[2023-11-19 16:32:39,711] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "import os\n",
52
+ "import sys\n",
53
+ "import json\n",
54
+ "import argparse\n",
55
+ "import numpy as np\n",
56
+ "import math\n",
57
+ "from einops import rearrange\n",
58
+ "import time\n",
59
+ "import random\n",
60
+ "import h5py\n",
61
+ "from tqdm import tqdm\n",
62
+ "\n",
63
+ "import webdataset as wds\n",
64
+ "import gc\n",
65
+ "\n",
66
+ "import matplotlib.pyplot as plt\n",
67
+ "import torch\n",
68
+ "import torch.nn as nn\n",
69
+ "from torchvision import transforms\n",
70
+ "from torchvision.transforms import ToPILImage #CHANGED (added)\n",
71
+ "\n",
72
+ "from accelerate import Accelerator, DeepSpeedPlugin\n",
73
+ "\n",
74
+ "# tf32 data type is faster than standard float32\n",
75
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
76
+ "\n",
77
+ "# custom functions #\n",
78
+ "import utils\n",
79
+ "\n",
80
+ "global_batch_size = 128 #128"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 3,
86
+ "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb",
87
+ "metadata": {
88
+ "tags": []
89
+ },
90
+ "outputs": [
91
+ {
92
+ "name": "stdout",
93
+ "output_type": "stream",
94
+ "text": [
95
+ "LOCAL RANK 0\n"
96
+ ]
97
+ }
98
+ ],
99
+ "source": [
100
+ "### Multi-GPU config ###\n",
101
+ "local_rank = os.getenv('RANK')\n",
102
+ "if local_rank is None: \n",
103
+ " local_rank = 0\n",
104
+ "else:\n",
105
+ " local_rank = int(local_rank)\n",
106
+ "print(\"LOCAL RANK \", local_rank) \n",
107
+ "\n",
108
+ "num_devices = torch.cuda.device_count()\n",
109
+ "if num_devices==0: num_devices = 1\n",
110
+ "\n",
111
+ "accelerator = Accelerator(split_batches=False)\n",
112
+ "\n",
113
+ "### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above \"accelerator = \" line) ###\n",
114
+ "\n",
115
+ "# if num_devices <= 1 and utils.is_interactive():\n",
116
+ "# # can emulate a distributed environment for deepspeed to work in jupyter notebook\n",
117
+ "# os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
118
+ "# os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n",
119
+ "# os.environ[\"RANK\"] = \"0\"\n",
120
+ "# os.environ[\"LOCAL_RANK\"] = \"0\"\n",
121
+ "# os.environ[\"WORLD_SIZE\"] = \"1\"\n",
122
+ "# os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n",
123
+ "# global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n",
124
+ "\n",
125
+ "# # alter the deepspeed config according to your global and local batch size\n",
126
+ "# if local_rank == 0:\n",
127
+ "# with open('deepspeed_config_stage2.json', 'r') as file:\n",
128
+ "# config = json.load(file)\n",
129
+ "# config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n",
130
+ "# config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n",
131
+ "# with open('deepspeed_config_stage2.json', 'w') as file:\n",
132
+ "# json.dump(config, file)\n",
133
+ "# else:\n",
134
+ "# # give some time for the local_rank=0 gpu to prep new deepspeed config file\n",
135
+ "# time.sleep(10)\n",
136
+ "# deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n",
137
+ "# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 4,
143
+ "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c",
144
+ "metadata": {
145
+ "tags": []
146
+ },
147
+ "outputs": [
148
+ {
149
+ "name": "stdout",
150
+ "output_type": "stream",
151
+ "text": [
152
+ "PID of this process = 2370606\n",
153
+ "device: cuda\n",
154
+ "Distributed environment: NO\n",
155
+ "Num processes: 1\n",
156
+ "Process index: 0\n",
157
+ "Local process index: 0\n",
158
+ "Device: cuda\n",
159
+ "\n",
160
+ "Mixed precision type: no\n",
161
+ "\n",
162
+ "distributed = False num_devices = 1 local rank = 0 world size = 1\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "print(\"PID of this process =\",os.getpid())\n",
168
+ "device = accelerator.device\n",
169
+ "print(\"device:\",device)\n",
170
+ "num_workers = num_devices\n",
171
+ "print(accelerator.state)\n",
172
+ "world_size = accelerator.state.num_processes\n",
173
+ "distributed = not accelerator.state.distributed_type == 'NO'\n",
174
+ "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n",
175
+ "print = accelerator.print # only print if local_rank=0"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "id": "9018b82b-c054-4463-9527-4b0c2a75bda6",
181
+ "metadata": {
182
+ "tags": []
183
+ },
184
+ "source": [
185
+ "# Configurations"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 5,
191
+ "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3",
192
+ "metadata": {
193
+ "tags": []
194
+ },
195
+ "outputs": [
196
+ {
197
+ "name": "stdout",
198
+ "output_type": "stream",
199
+ "text": [
200
+ "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=captions', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=3e-1', '--mixup_pct=.66', '--num_epochs=30', '--ckpt_interval=999', '--no-use_image_aug']\n"
201
+ ]
202
+ }
203
+ ],
204
+ "source": [
205
+ "# if running this interactively, can specify jupyter_args here for argparser to use\n",
206
+ "if utils.is_interactive():\n",
207
+ " # Example use\n",
208
+ " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n",
209
+ " --model_name=captions \\\n",
210
+ " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n",
211
+ " --max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug\"\n",
212
+ " #max_lr=3e-5 originally\n",
213
+ " jupyter_args = jupyter_args.split()\n",
214
+ " print(jupyter_args)\n",
215
+ " \n",
216
+ " from IPython.display import clear_output # function to clear print outputs in cell\n",
217
+ " %load_ext autoreload \n",
218
+ " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n",
219
+ " %autoreload 2 "
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 6,
225
+ "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c",
226
+ "metadata": {
227
+ "tags": []
228
+ },
229
+ "outputs": [
230
+ {
231
+ "name": "stdout",
232
+ "output_type": "stream",
233
+ "text": [
234
+ "global batch_size 128\n",
235
+ "batch_size 128\n"
236
+ ]
237
+ }
238
+ ],
239
+ "source": [
240
+ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n",
241
+ "parser.add_argument(\n",
242
+ " \"--model_name\", type=str, default=\"testing\",\n",
243
+ " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n",
244
+ ")\n",
245
+ "parser.add_argument(\n",
246
+ " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n",
247
+ " help=\"Path to where NSD data is stored / where to download it to\",\n",
248
+ ")\n",
249
+ "parser.add_argument(\n",
250
+ " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n",
251
+ ")\n",
252
+ "parser.add_argument(\n",
253
+ " \"--batch_size\", type=int, default=32,\n",
254
+ " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n",
255
+ ")\n",
256
+ "parser.add_argument(\n",
257
+ " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n",
258
+ " help=\"whether to log to wandb\",\n",
259
+ ")\n",
260
+ "parser.add_argument(\n",
261
+ " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n",
262
+ " help=\"if not using wandb and want to resume from a ckpt\",\n",
263
+ ")\n",
264
+ "parser.add_argument(\n",
265
+ " \"--wandb_project\",type=str,default=\"stability\",\n",
266
+ " help=\"wandb project name\",\n",
267
+ ")\n",
268
+ "parser.add_argument(\n",
269
+ " \"--mixup_pct\",type=float,default=.33,\n",
270
+ " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n",
271
+ ")\n",
272
+ "parser.add_argument(\n",
273
+ " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n",
274
+ " help=\"whether to use image augmentation\",\n",
275
+ ")\n",
276
+ "parser.add_argument(\n",
277
+ " \"--num_epochs\",type=int,default=240,\n",
278
+ " help=\"number of epochs of training\",\n",
279
+ ")\n",
280
+ "parser.add_argument(\n",
281
+ " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n",
282
+ ")\n",
283
+ "parser.add_argument(\n",
284
+ " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n",
285
+ ")\n",
286
+ "parser.add_argument(\n",
287
+ " \"--ckpt_interval\",type=int,default=5,\n",
288
+ " help=\"save backup ckpt and reconstruct every x epochs\",\n",
289
+ ")\n",
290
+ "parser.add_argument(\n",
291
+ " \"--seed\",type=int,default=42,\n",
292
+ ")\n",
293
+ "parser.add_argument(\n",
294
+ " \"--max_lr\",type=float,default=3e-4,\n",
295
+ ")\n",
296
+ "parser.add_argument(\n",
297
+ " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n",
298
+ " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n",
299
+ ")\n",
300
+ "\n",
301
+ "if utils.is_interactive():\n",
302
+ " args = parser.parse_args(jupyter_args)\n",
303
+ "else:\n",
304
+ " args = parser.parse_args()\n",
305
+ "\n",
306
+ "# create global variables without the args prefix\n",
307
+ "for attribute_name in vars(args).keys():\n",
308
+ " globals()[attribute_name] = getattr(args, attribute_name)\n",
309
+ "\n",
310
+ "print(\"global batch_size\", batch_size)\n",
311
+ "batch_size = int(batch_size / num_devices)\n",
312
+ "print(\"batch_size\", batch_size)"
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": 7,
318
+ "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d",
319
+ "metadata": {
320
+ "tags": []
321
+ },
322
+ "outputs": [],
323
+ "source": [
324
+ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n",
325
+ "if not os.path.exists(outdir):\n",
326
+ " os.makedirs(outdir,exist_ok=True)\n",
327
+ "if use_image_aug:\n",
328
+ " import kornia\n",
329
+ " from kornia.augmentation.container import AugmentationSequential\n",
330
+ " img_augment = AugmentationSequential(\n",
331
+ " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n",
332
+ " kornia.augmentation.Resize((224, 224)),\n",
333
+ " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n",
334
+ " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n",
335
+ " kornia.augmentation.RandomGrayscale(p=0.3),\n",
336
+ " same_on_batch=False,\n",
337
+ " data_keys=[\"input\"],\n",
338
+ " )"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 8,
344
+ "id": "e7807ba9-02b6-4bc0-873c-69869abe4091",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "wandb_log = False"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "markdown",
353
+ "id": "42d13c25-1369-4c49-81d4-83d713586096",
354
+ "metadata": {
355
+ "tags": []
356
+ },
357
+ "source": [
358
+ "# Prep data, models, and dataloaders"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "id": "1c023f24-5233-4a15-a2f5-78487b3a8546",
364
+ "metadata": {},
365
+ "source": [
366
+ "## Dataloader"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": 9,
372
+ "id": "81084834-035f-4465-ad59-59e6b806a2f5",
373
+ "metadata": {
374
+ "tags": []
375
+ },
376
+ "outputs": [
377
+ {
378
+ "name": "stdout",
379
+ "output_type": "stream",
380
+ "text": [
381
+ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n",
382
+ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n"
383
+ ]
384
+ }
385
+ ],
386
+ "source": [
387
+ "if subj==1:\n",
388
+ " num_train = 24958\n",
389
+ " num_test = 2770\n",
390
+ "test_batch_size = num_test\n",
391
+ "\n",
392
+ "def my_split_by_node(urls): return urls\n",
393
+ " \n",
394
+ "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n",
395
+ "print(train_url)\n",
396
+ "\n",
397
+ "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n",
398
+ " .shuffle(750, initial=1500, rng=random.Random(42))\\\n",
399
+ " .decode(\"torch\")\\\n",
400
+ " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n",
401
+ " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n",
402
+ "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n",
403
+ "\n",
404
+ "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n",
405
+ "print(test_url)\n",
406
+ "\n",
407
+ "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n",
408
+ " .shuffle(750, initial=1500, rng=random.Random(42))\\\n",
409
+ " .decode(\"torch\")\\\n",
410
+ " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n",
411
+ " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n",
412
+ "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "markdown",
417
+ "id": "203b060a-2dd2-4c35-929b-c576be82eb52",
418
+ "metadata": {},
419
+ "source": [
420
+ "### check dataloaders are working"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": 10,
426
+ "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37",
427
+ "metadata": {
428
+ "tags": []
429
+ },
430
+ "outputs": [],
431
+ "source": [
432
+ "# test_indices = []\n",
433
+ "# test_images = []\n",
434
+ "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n",
435
+ "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n",
436
+ "# test_images = np.append(test_images, behav[:,0,0].numpy())\n",
437
+ "# test_indices = test_indices.astype(np.int16)\n",
438
+ "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n",
439
+ "# print(\"---\\n\")\n",
440
+ "\n",
441
+ "# train_indices = []\n",
442
+ "# train_images = []\n",
443
+ "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n",
444
+ "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n",
445
+ "# train_images = np.append(train_images, behav[:,0,0].numpy())\n",
446
+ "# train_indices = train_indices.astype(np.int16)\n",
447
+ "# print(train_i, (train_i+1) * batch_size, len(train_indices))\n",
448
+ "\n",
449
+ "# # train_images = np.hstack((train_images, test_images))\n",
450
+ "# # print(\"WARNING: ADDED TEST IMAGES TO TRAIN IMAGES\")"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634",
456
+ "metadata": {},
457
+ "source": [
458
+ "## Load data and images"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 11,
464
+ "id": "039dd330-7339-4f88-8f00-45f95e47baa0",
465
+ "metadata": {
466
+ "tags": []
467
+ },
468
+ "outputs": [
469
+ {
470
+ "name": "stdout",
471
+ "output_type": "stream",
472
+ "text": [
473
+ "subj01 betas loaded into memory\n",
474
+ "voxels torch.Size([27750, 15729])\n",
475
+ "images torch.Size([73000, 3, 224, 224])\n"
476
+ ]
477
+ }
478
+ ],
479
+ "source": [
480
+ "# load betas\n",
481
+ "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n",
482
+ "voxels = f['betas'][:]\n",
483
+ "print(f\"subj0{subj} betas loaded into memory\")\n",
484
+ "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n",
485
+ "if subj==1:\n",
486
+ " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n",
487
+ "print(\"voxels\", voxels.shape)\n",
488
+ "num_voxels = voxels.shape[-1]\n",
489
+ "\n",
490
+ "# load orig images\n",
491
+ "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n",
492
+ "images = f['images'][:]\n",
493
+ "images = torch.Tensor(images).to(\"cpu\").half()\n",
494
+ "print(\"images\", images.shape)"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "markdown",
499
+ "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15",
500
+ "metadata": {},
501
+ "source": [
502
+ "## Load models"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5",
508
+ "metadata": {},
509
+ "source": [
510
+ "### CLIP image embeddings model"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": 12,
516
+ "id": "795e2885-bd07-4e27-bed7-181473c06df9",
517
+ "metadata": {
518
+ "tags": []
519
+ },
520
+ "outputs": [],
521
+ "source": [
522
+ "import transformers\n",
523
+ "from transformers import Blip2Processor, Blip2ForConditionalGeneration\n",
524
+ "\n",
525
+ "from PIL import Image"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 13,
531
+ "id": "b0420dc0-199e-4c1a-857d-b1747058b467",
532
+ "metadata": {
533
+ "tags": []
534
+ },
535
+ "outputs": [
536
+ {
537
+ "name": "stdout",
538
+ "output_type": "stream",
539
+ "text": [
540
+ "ViT-L/14 cuda:0\n"
541
+ ]
542
+ }
543
+ ],
544
+ "source": [
545
+ "from models import Clipper\n",
546
+ "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": 14,
552
+ "id": "23428fb7-2955-4295-bea1-447cebf9f72e",
553
+ "metadata": {
554
+ "tags": []
555
+ },
556
+ "outputs": [
557
+ {
558
+ "name": "stderr",
559
+ "output_type": "stream",
560
+ "text": [
561
+ "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:08<00:00, 34.47s/it]\n"
562
+ ]
563
+ },
564
+ {
565
+ "data": {
566
+ "text/plain": [
567
+ "'from lavis.models import load_model_and_preprocess\\nfrom lavis.models import model_zoo\\nblip2_model, vis_processors, _ = load_model_and_preprocess(\\n name=\"blip2_t5\", model_type=\"pretrain_flant5xl_vitL\", is_eval=True, device=device)\\n\\nclip_seq_dim = 257\\nclip_emb_dim = 1024\\nhidden_dim = 4096'"
568
+ ]
569
+ },
570
+ "execution_count": 14,
571
+ "metadata": {},
572
+ "output_type": "execute_result"
573
+ }
574
+ ],
575
+ "source": [
576
+ "cache_blip2 = \"/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b\"\n",
577
+ "\n",
578
+ "b2_processor = Blip2Processor.from_pretrained(cache_blip2)\n",
579
+ "b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map=\"auto\")\n",
580
+ "\n",
581
+ "#Load in blip2 as well\n",
582
+ "\"\"\"from lavis.models import load_model_and_preprocess\n",
583
+ "from lavis.models import model_zoo\n",
584
+ "blip2_model, vis_processors, _ = load_model_and_preprocess(\n",
585
+ " name=\"blip2_t5\", model_type=\"pretrain_flant5xl_vitL\", is_eval=True, device=device)\n",
586
+ "\n",
587
+ "clip_seq_dim = 257\n",
588
+ "clip_emb_dim = 1024\n",
589
+ "hidden_dim = 4096\"\"\""
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": 15,
595
+ "id": "b06f3de2-a8da-4ba0-94f0-99096f738d55",
596
+ "metadata": {
597
+ "tags": []
598
+ },
599
+ "outputs": [],
600
+ "source": [
601
+ "def embed_images_b2(images):\n",
602
+ " images = (images * 255).type(torch.uint8)\n",
603
+ " with torch.no_grad():\n",
604
+ " inputs_processed = b2_processor(images, return_tensors=\"pt\").to(\"cuda\", torch.float16)\n",
605
+ " enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values'])\n",
606
+ " return enc_imgs.last_hidden_state.detach(), inputs_processed\n",
607
+ "\n",
608
+ "def embeds_to_captions_b2(embeds):\n",
609
+ " with torch.no_grad():\n",
610
+ " input_ids = None #inputs['input_ids']\n",
611
+ " attention_mask = None\n",
612
+ " batch_size = embeds.shape[0]\n",
613
+ " image_embeds = embeds\n",
614
+ " image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)\n",
615
+ "\n",
616
+ " query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)\n",
617
+ " query_outputs = b2_model.qformer(\n",
618
+ " query_embeds=query_tokens,\n",
619
+ " encoder_hidden_states=image_embeds,\n",
620
+ " encoder_attention_mask=image_attention_mask,\n",
621
+ " return_dict=True,\n",
622
+ " )\n",
623
+ " query_output = query_outputs.last_hidden_state\n",
624
+ "\n",
625
+ " language_model_inputs = b2_model.language_projection(query_output)\n",
626
+ " language_attention_mask = torch.ones(\n",
627
+ " language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device\n",
628
+ " )\n",
629
+ " if input_ids is None:\n",
630
+ " input_ids = (\n",
631
+ " torch.LongTensor([[b2_model.config.text_config.bos_token_id]])\n",
632
+ " .repeat(batch_size, 1)\n",
633
+ " .to(image_embeds.device)\n",
634
+ " )\n",
635
+ " if attention_mask is None:\n",
636
+ " attention_mask = torch.ones_like(input_ids)\n",
637
+ " attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)\n",
638
+ "\n",
639
+ " # concatenate query embeddings with prompt embeddings\n",
640
+ " inputs_embeds = b2_model.get_input_embeddings()(input_ids)\n",
641
+ " inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)\n",
642
+ "\n",
643
+ " outputs = b2_model.language_model.generate(\n",
644
+ " inputs_embeds=inputs_embeds,\n",
645
+ " attention_mask=attention_mask,\n",
646
+ " )\n",
647
+ " text = b2_processor.batch_decode(outputs, skip_special_tokens=True)\n",
648
+ " \n",
649
+ " return outputs, text\n"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "code",
654
+ "execution_count": 16,
655
+ "id": "ec0a34d3-76e0-4a47-a9ab-6131ab2ccecd",
656
+ "metadata": {
657
+ "tags": []
658
+ },
659
+ "outputs": [],
660
+ "source": [
661
+ "image_test = images[1:20].permute(0,2,3,1)\n",
662
+ "#raw_image = Image.open('/fsx/proj-fmri/shared/controlNetData/target/img_t1.jpg').convert('RGB')\n",
663
+ "# Convert the image to a NumPy array\n",
664
+ "#image_test = np.array(raw_image)\n"
665
+ ]
666
+ },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": 17,
670
+ "id": "e04876a4-45c7-4015-8255-8574c8f50f14",
671
+ "metadata": {
672
+ "tags": []
673
+ },
674
+ "outputs": [
675
+ {
676
+ "data": {
677
+ "text/plain": [
678
+ "\"import matplotlib.pyplot as plt\\n# Plotting one of the images (taking the first image as an example)\\nimg_to_plot = inputs_rec['pixel_values'][-1]\\n\\n# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\\nimg_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\\nprint(img_to_plot.shape)\\n\\nplt.imshow(img_to_plot)\\nplt.show()\""
679
+ ]
680
+ },
681
+ "execution_count": 17,
682
+ "metadata": {},
683
+ "output_type": "execute_result"
684
+ }
685
+ ],
686
+ "source": [
687
+ "\"\"\"import matplotlib.pyplot as plt\n",
688
+ "# Plotting one of the images (taking the first image as an example)\n",
689
+ "img_to_plot = inputs_rec['pixel_values'][-1]\n",
690
+ "\n",
691
+ "# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\n",
692
+ "img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\n",
693
+ "print(img_to_plot.shape)\n",
694
+ "\n",
695
+ "plt.imshow(img_to_plot)\n",
696
+ "plt.show()\"\"\""
697
+ ]
698
+ },
699
+ {
700
+ "cell_type": "code",
701
+ "execution_count": 18,
702
+ "id": "328a17d0-593b-4d1e-812a-10a3b6efea6a",
703
+ "metadata": {
704
+ "tags": []
705
+ },
706
+ "outputs": [],
707
+ "source": [
708
+ "embeds_test, inputs_rec = embed_images_b2(image_test)"
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "code",
713
+ "execution_count": 19,
714
+ "id": "abe5f8a8-fca9-4083-8596-a913bdb57de7",
715
+ "metadata": {
716
+ "tags": []
717
+ },
718
+ "outputs": [],
719
+ "source": [
720
+ "#inputs_rec['pixel_values'].shape"
721
+ ]
722
+ },
723
+ {
724
+ "cell_type": "code",
725
+ "execution_count": 20,
726
+ "id": "c5f3ca7e-b880-421e-b354-7b6c3df565e9",
727
+ "metadata": {
728
+ "tags": []
729
+ },
730
+ "outputs": [],
731
+ "source": [
732
+ "#out = b2_model.generate(**inputs_rec)\n",
733
+ "#print(b2_processor.decode(out[0], skip_special_tokens=True).strip())"
734
+ ]
735
+ },
736
+ {
737
+ "cell_type": "code",
738
+ "execution_count": 21,
739
+ "id": "fb462016-78d7-46ea-8058-0d608f17ea65",
740
+ "metadata": {
741
+ "tags": []
742
+ },
743
+ "outputs": [
744
+ {
745
+ "name": "stderr",
746
+ "output_type": "stream",
747
+ "text": [
748
+ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation/utils.py:1260: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
749
+ " warnings.warn(\n"
750
+ ]
751
+ }
752
+ ],
753
+ "source": [
754
+ "outputs_test, text_test = embeds_to_captions_b2(embeds_test)"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": 22,
760
+ "id": "6a95fcdf-db87-4c02-9728-09f85605fb1c",
761
+ "metadata": {
762
+ "tags": []
763
+ },
764
+ "outputs": [
765
+ {
766
+ "data": {
767
+ "text/plain": [
768
+ "['a cat sitting on a toilet seat\\n',\n",
769
+ " 'a person cutting a pizza on a cutting board\\n',\n",
770
+ " 'a sandwich and a drink on a table\\n',\n",
771
+ " 'a man crossing the street in front of a truck\\n',\n",
772
+ " 'a giraffe standing in front of trees\\n',\n",
773
+ " 'three men standing together\\n',\n",
774
+ " 'a bird standing on a rock next to a body of water\\n',\n",
775
+ " 'two men sitting on a street corner in asia\\n',\n",
776
+ " 'a woman and two children playing tennis on a court\\n',\n",
777
+ " 'a tall brick building with a clock on the side\\n',\n",
778
+ " 'a train is on the tracks\\n',\n",
779
+ " 'a man and woman in the water with a surfboard\\n',\n",
780
+ " 'a living room with a desk and a chair\\n',\n",
781
+ " 'a group of men on a basketball court\\n',\n",
782
+ " 'a man holding an umbrella\\n',\n",
783
+ " 'a man in a red shirt\\n',\n",
784
+ " 'a group of people holding cell phones and wine glasses\\n',\n",
785
+ " 'a laptop computer sitting on a table in front of a television\\n',\n",
786
+ " 'a baseball player is swinging a bat on a field\\n']"
787
+ ]
788
+ },
789
+ "execution_count": 22,
790
+ "metadata": {},
791
+ "output_type": "execute_result"
792
+ }
793
+ ],
794
+ "source": [
795
+ "text_test"
796
+ ]
797
+ },
798
+ {
799
+ "cell_type": "code",
800
+ "execution_count": 23,
801
+ "id": "9ac69fbd-55db-435b-bed6-5ae9186450e3",
802
+ "metadata": {
803
+ "tags": []
804
+ },
805
+ "outputs": [],
806
+ "source": [
807
+ "#inputss['pixel_values'].shape"
808
+ ]
809
+ },
810
+ {
811
+ "cell_type": "code",
812
+ "execution_count": 24,
813
+ "id": "0524f498-c8da-4e8a-8970-d75d2d0f6b8b",
814
+ "metadata": {
815
+ "tags": []
816
+ },
817
+ "outputs": [],
818
+ "source": [
819
+ "#image_test.shape"
820
+ ]
821
+ },
822
+ {
823
+ "cell_type": "code",
824
+ "execution_count": 25,
825
+ "id": "5417541b-49eb-4e43-a3e2-d937d9653e04",
826
+ "metadata": {
827
+ "tags": []
828
+ },
829
+ "outputs": [],
830
+ "source": [
831
+ "max_lr = 1e-4"
832
+ ]
833
+ },
834
+ {
835
+ "cell_type": "code",
836
+ "execution_count": 26,
837
+ "id": "da0ce190-1b3e-4c12-9e9f-91cbc076d044",
838
+ "metadata": {
839
+ "tags": []
840
+ },
841
+ "outputs": [],
842
+ "source": [
843
+ "clip_seq_dim = 257 #blip2 image encoder shapes\n",
844
+ "clip_emb_dim = 1408 #blip2 image encoder shapes\n",
845
+ "hidden_dim = 2048"
846
+ ]
847
+ },
848
+ {
849
+ "cell_type": "markdown",
850
+ "id": "5b79bd38-6990-4504-8d45-4a68d57d8885",
851
+ "metadata": {},
852
+ "source": [
853
+ "### SD VAE (blurry images)"
854
+ ]
855
+ },
856
+ {
857
+ "cell_type": "code",
858
+ "execution_count": 27,
859
+ "id": "01baff79-8114-482b-b115-6f05aa8ad691",
860
+ "metadata": {
861
+ "tags": []
862
+ },
863
+ "outputs": [
864
+ {
865
+ "name": "stdout",
866
+ "output_type": "stream",
867
+ "text": [
868
+ "param counts:\n",
869
+ "83,653,863 total\n",
870
+ "0 trainable\n"
871
+ ]
872
+ }
873
+ ],
874
+ "source": [
875
+ "from diffusers import AutoencoderKL\n",
876
+ "autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n",
877
+ "# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n",
878
+ "autoenc.eval()\n",
879
+ "autoenc.requires_grad_(False)\n",
880
+ "autoenc.to(device)\n",
881
+ "utils.count_params(autoenc)"
882
+ ]
883
+ },
884
+ {
885
+ "cell_type": "markdown",
886
+ "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0",
887
+ "metadata": {},
888
+ "source": [
889
+ "### MindEye modules"
890
+ ]
891
+ },
892
+ {
893
+ "cell_type": "code",
894
+ "execution_count": 28,
895
+ "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5",
896
+ "metadata": {
897
+ "tags": []
898
+ },
899
+ "outputs": [
900
+ {
901
+ "data": {
902
+ "text/plain": [
903
+ "MindEyeModule()"
904
+ ]
905
+ },
906
+ "execution_count": 28,
907
+ "metadata": {},
908
+ "output_type": "execute_result"
909
+ }
910
+ ],
911
+ "source": [
912
+ "class MindEyeModule(nn.Module):\n",
913
+ " def __init__(self):\n",
914
+ " super(MindEyeModule, self).__init__()\n",
915
+ " def forward(self, x):\n",
916
+ " return x\n",
917
+ " \n",
918
+ "model = MindEyeModule()\n",
919
+ "model"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "code",
924
+ "execution_count": 29,
925
+ "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0",
926
+ "metadata": {
927
+ "tags": []
928
+ },
929
+ "outputs": [
930
+ {
931
+ "name": "stdout",
932
+ "output_type": "stream",
933
+ "text": [
934
+ "param counts:\n",
935
+ "32,215,040 total\n",
936
+ "32,215,040 trainable\n",
937
+ "param counts:\n",
938
+ "32,215,040 total\n",
939
+ "32,215,040 trainable\n",
940
+ "torch.Size([2, 1, 15729]) torch.Size([2, 1, 2048])\n"
941
+ ]
942
+ }
943
+ ],
944
+ "source": [
945
+ "class RidgeRegression(torch.nn.Module):\n",
946
+ " # make sure to add weight_decay when initializing optimizer\n",
947
+ " def __init__(self, input_size, out_features): \n",
948
+ " super(RidgeRegression, self).__init__()\n",
949
+ " self.out_features = out_features\n",
950
+ " self.linear = torch.nn.Linear(input_size, out_features)\n",
951
+ " def forward(self, x):\n",
952
+ " return self.linear(x)\n",
953
+ " \n",
954
+ "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n",
955
+ "utils.count_params(model.ridge)\n",
956
+ "utils.count_params(model)\n",
957
+ "\n",
958
+ "b = torch.randn((2,1,voxels.shape[1]))\n",
959
+ "print(b.shape, model.ridge(b).shape)"
960
+ ]
961
+ },
962
+ {
963
+ "cell_type": "code",
964
+ "execution_count": 30,
965
+ "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd",
966
+ "metadata": {
967
+ "tags": []
968
+ },
969
+ "outputs": [
970
+ {
971
+ "name": "stdout",
972
+ "output_type": "stream",
973
+ "text": [
974
+ "param counts:\n",
975
+ "772,419,072 total\n",
976
+ "772,419,072 trainable\n",
977
+ "param counts:\n",
978
+ "804,634,112 total\n",
979
+ "804,634,112 trainable\n",
980
+ "torch.Size([4, 2048])\n",
981
+ "torch.Size([4, 257, 1408])\n"
982
+ ]
983
+ }
984
+ ],
985
+ "source": [
986
+ "from functools import partial\n",
987
+ "from diffusers.models.vae import Decoder\n",
988
+ "class BrainNetwork(nn.Module):\n",
989
+ " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):\n",
990
+ " super().__init__()\n",
991
+ " self.blurry_dim = blurry_dim\n",
992
+ " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n",
993
+ " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n",
994
+ " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n",
995
+ " self.lin0 = nn.Linear(in_dim, h)\n",
996
+ " self.mlp = nn.ModuleList([\n",
997
+ " nn.Sequential(\n",
998
+ " nn.Linear(h, h),\n",
999
+ " *[item() for item in act_and_norm],\n",
1000
+ " nn.Dropout(drop)\n",
1001
+ " ) for _ in range(n_blocks)\n",
1002
+ " ])\n",
1003
+ " self.lin1 = nn.Linear(h, out_dim, bias=True)\n",
1004
+ " # self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)\n",
1005
+ " self.n_blocks = n_blocks\n",
1006
+ " self.clip_size = clip_size\n",
1007
+ " self.clip_proj = nn.Sequential(\n",
1008
+ " nn.LayerNorm(clip_size),\n",
1009
+ " nn.GELU(),\n",
1010
+ " nn.Linear(clip_size, 2048),\n",
1011
+ " nn.LayerNorm(2048),\n",
1012
+ " nn.GELU(),\n",
1013
+ " nn.Linear(2048, 2048),\n",
1014
+ " nn.LayerNorm(2048),\n",
1015
+ " nn.GELU(),\n",
1016
+ " nn.Linear(2048, clip_size)\n",
1017
+ " )\n",
1018
+ " # self.upsampler = Decoder(\n",
1019
+ " # in_channels=64,\n",
1020
+ " # out_channels=4,\n",
1021
+ " # up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n",
1022
+ " # block_out_channels=[64, 128, 256],\n",
1023
+ " # layers_per_block=1,\n",
1024
+ " # )\n",
1025
+ " \n",
1026
+ " def forward(self, x):\n",
1027
+ " x = self.lin0(x)\n",
1028
+ " residual = x\n",
1029
+ " for res_block in range(self.n_blocks):\n",
1030
+ " x = self.mlp[res_block](x)\n",
1031
+ " x += residual\n",
1032
+ " residual = x\n",
1033
+ " x = x.reshape(len(x), -1)\n",
1034
+ " x = self.lin1(x)\n",
1035
+ " # b = self.blin1(x)\n",
1036
+ " # b = self.upsampler(b.reshape(len(b), -1, 7, 7))\n",
1037
+ " c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))\n",
1038
+ " # return c, b\n",
1039
+ " return c\n",
1040
+ "\n",
1041
+ "model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) \n",
1042
+ "utils.count_params(model.backbone)\n",
1043
+ "utils.count_params(model)\n",
1044
+ "\n",
1045
+ "b = torch.randn((4,hidden_dim))\n",
1046
+ "print(b.shape)\n",
1047
+ "clip_ = model.backbone(b)\n",
1048
+ "print(clip_.shape)"
1049
+ ]
1050
+ },
1051
+ {
1052
+ "cell_type": "code",
1053
+ "execution_count": 31,
1054
+ "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1",
1055
+ "metadata": {
1056
+ "tags": []
1057
+ },
1058
+ "outputs": [
1059
+ {
1060
+ "name": "stdout",
1061
+ "output_type": "stream",
1062
+ "text": [
1063
+ "\n",
1064
+ "Done with model preparations!\n",
1065
+ "param counts:\n",
1066
+ "804,634,112 total\n",
1067
+ "804,634,112 trainable\n"
1068
+ ]
1069
+ }
1070
+ ],
1071
+ "source": [
1072
+ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
1073
+ "opt_grouped_parameters = [\n",
1074
+ " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n",
1075
+ " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n",
1076
+ " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n",
1077
+ "]\n",
1078
+ "\n",
1079
+ "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))\n",
1080
+ "\n",
1081
+ "if lr_scheduler_type == 'linear':\n",
1082
+ " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n",
1083
+ " optimizer,\n",
1084
+ " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n",
1085
+ " last_epoch=-1\n",
1086
+ " )\n",
1087
+ "elif lr_scheduler_type == 'cycle':\n",
1088
+ " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n",
1089
+ " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n",
1090
+ " optimizer, \n",
1091
+ " max_lr=max_lr,\n",
1092
+ " total_steps=total_steps,\n",
1093
+ " final_div_factor=1000,\n",
1094
+ " last_epoch=-1, pct_start=2/num_epochs\n",
1095
+ " )\n",
1096
+ " \n",
1097
+ "def save_ckpt(tag): \n",
1098
+ " ckpt_path = outdir+f'/{tag}.pth'\n",
1099
+ " print(f'saving {ckpt_path}',flush=True)\n",
1100
+ " unwrapped_model = accelerator.unwrap_model(model)\n",
1101
+ " try:\n",
1102
+ " torch.save({\n",
1103
+ " 'epoch': epoch,\n",
1104
+ " 'model_state_dict': unwrapped_model.state_dict(),\n",
1105
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
1106
+ " 'lr_scheduler': lr_scheduler.state_dict(),\n",
1107
+ " 'train_losses': losses,\n",
1108
+ " 'test_losses': test_losses,\n",
1109
+ " 'lrs': lrs,\n",
1110
+ " }, ckpt_path)\n",
1111
+ " except:\n",
1112
+ " print(\"Couldn't save... moving on to prevent crashing.\")\n",
1113
+ " del unwrapped_model\n",
1114
+ " \n",
1115
+ "print(\"\\nDone with model preparations!\")\n",
1116
+ "utils.count_params(model)"
1117
+ ]
1118
+ },
1119
+ {
1120
+ "cell_type": "markdown",
1121
+ "id": "983f458b-35b8-49f2-b6db-80296cece730",
1122
+ "metadata": {},
1123
+ "source": [
1124
+ "# Weights and Biases"
1125
+ ]
1126
+ },
1127
+ {
1128
+ "cell_type": "code",
1129
+ "execution_count": 32,
1130
+ "id": "0a25a662-daa8-4de9-9233-8364800fcb6b",
1131
+ "metadata": {
1132
+ "tags": []
1133
+ },
1134
+ "outputs": [
1135
+ {
1136
+ "name": "stdout",
1137
+ "output_type": "stream",
1138
+ "text": [
1139
+ "wandb mindeyev2 run captions\n"
1140
+ ]
1141
+ },
1142
+ {
1143
+ "name": "stderr",
1144
+ "output_type": "stream",
1145
+ "text": [
1146
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
1147
+ ]
1148
+ },
1149
+ {
1150
+ "name": "stdout",
1151
+ "output_type": "stream",
1152
+ "text": [
1153
+ "wandb_config:\n",
1154
+ " {'model_name': 'captions', 'batch_size': 128, 'num_epochs': 30, 'use_image_aug': False, 'max_lr': 0.0001, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1}\n"
1155
+ ]
1156
+ },
1157
+ {
1158
+ "data": {
1159
+ "text/html": [
1160
+ "wandb version 0.16.0 is available! To upgrade, please run:\n",
1161
+ " $ pip install wandb --upgrade"
1162
+ ],
1163
+ "text/plain": [
1164
+ "<IPython.core.display.HTML object>"
1165
+ ]
1166
+ },
1167
+ "metadata": {},
1168
+ "output_type": "display_data"
1169
+ },
1170
+ {
1171
+ "data": {
1172
+ "text/html": [
1173
+ "Tracking run with wandb version 0.15.5"
1174
+ ],
1175
+ "text/plain": [
1176
+ "<IPython.core.display.HTML object>"
1177
+ ]
1178
+ },
1179
+ "metadata": {},
1180
+ "output_type": "display_data"
1181
+ },
1182
+ {
1183
+ "data": {
1184
+ "text/html": [
1185
+ "Run data is saved locally in <code>/fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231119_163615-o1xwsqre</code>"
1186
+ ],
1187
+ "text/plain": [
1188
+ "<IPython.core.display.HTML object>"
1189
+ ]
1190
+ },
1191
+ "metadata": {},
1192
+ "output_type": "display_data"
1193
+ },
1194
+ {
1195
+ "data": {
1196
+ "text/html": [
1197
+ "Syncing run <strong><a href='https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre' target=\"_blank\">captions</a></strong> to <a href='https://stability.wandb.io/ckadirt/mindeyev2' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
1198
+ ],
1199
+ "text/plain": [
1200
+ "<IPython.core.display.HTML object>"
1201
+ ]
1202
+ },
1203
+ "metadata": {},
1204
+ "output_type": "display_data"
1205
+ },
1206
+ {
1207
+ "data": {
1208
+ "text/html": [
1209
+ " View project at <a href='https://stability.wandb.io/ckadirt/mindeyev2' target=\"_blank\">https://stability.wandb.io/ckadirt/mindeyev2</a>"
1210
+ ],
1211
+ "text/plain": [
1212
+ "<IPython.core.display.HTML object>"
1213
+ ]
1214
+ },
1215
+ "metadata": {},
1216
+ "output_type": "display_data"
1217
+ },
1218
+ {
1219
+ "data": {
1220
+ "text/html": [
1221
+ " View run at <a href='https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre' target=\"_blank\">https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre</a>"
1222
+ ],
1223
+ "text/plain": [
1224
+ "<IPython.core.display.HTML object>"
1225
+ ]
1226
+ },
1227
+ "metadata": {},
1228
+ "output_type": "display_data"
1229
+ }
1230
+ ],
1231
+ "source": [
1232
+ "# params for wandb\n",
1233
+ "if local_rank==0 and True: # only use main process for wandb logging\n",
1234
+ " import wandb\n",
1235
+ " \n",
1236
+ " wandb_project = 'mindeyev2'\n",
1237
+ " wandb_run = model_name\n",
1238
+ " wandb_notes = ''\n",
1239
+ " \n",
1240
+ " print(f\"wandb {wandb_project} run {wandb_run}\")\n",
1241
+ " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n",
1242
+ " wandb_config = {\n",
1243
+ " \"model_name\": model_name,\n",
1244
+ " \"batch_size\": batch_size,\n",
1245
+ " \"num_epochs\": num_epochs,\n",
1246
+ " \"use_image_aug\": use_image_aug,\n",
1247
+ " \"max_lr\": max_lr,\n",
1248
+ " \"lr_scheduler_type\": lr_scheduler_type,\n",
1249
+ " \"mixup_pct\": mixup_pct,\n",
1250
+ " \"num_train\": num_train,\n",
1251
+ " \"num_test\": num_test,\n",
1252
+ " \"seed\": seed,\n",
1253
+ " \"distributed\": distributed,\n",
1254
+ " \"num_devices\": num_devices,\n",
1255
+ " \"world_size\": world_size,\n",
1256
+ " }\n",
1257
+ " print(\"wandb_config:\\n\",wandb_config)\n",
1258
+ " if False: # wandb_auto_resume\n",
1259
+ " print(\"wandb_id:\",model_name)\n",
1260
+ " wandb.init(\n",
1261
+ " id = model_name,\n",
1262
+ " project=wandb_project,\n",
1263
+ " name=wandb_run,\n",
1264
+ " config=wandb_config,\n",
1265
+ " notes=wandb_notes,\n",
1266
+ " resume=\"allow\",\n",
1267
+ " )\n",
1268
+ " else:\n",
1269
+ " wandb.init(\n",
1270
+ " project=wandb_project,\n",
1271
+ " name=wandb_run,\n",
1272
+ " config=wandb_config,\n",
1273
+ " notes=wandb_notes,\n",
1274
+ " )\n",
1275
+ "else:\n",
1276
+ " wandb_log = False"
1277
+ ]
1278
+ },
1279
+ {
1280
+ "cell_type": "code",
1281
+ "execution_count": 33,
1282
+ "id": "4e5de216-5318-4b45-ac02-113f03105adc",
1283
+ "metadata": {},
1284
+ "outputs": [
1285
+ {
1286
+ "data": {
1287
+ "text/html": [
1288
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #ff0000; text-decoration-color: #ff0000\">╭──────────────────────────────────────────────────────────────────────────────────────────────────╮</span>\n",
1289
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span> n++ <span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span>\n",
1290
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span> <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">▲</span> <span style=\"color: #ff0000; text-decoration-color: #ff0000\">│</span>\n",
1291
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
1292
+ "<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">SyntaxError: </span>invalid syntax\n",
1293
+ "</pre>\n"
1294
+ ],
1295
+ "text/plain": [
1296
+ "\u001b[91m╭──────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n",
1297
+ "\u001b[91m│\u001b[0m n++ \u001b[91m│\u001b[0m\n",
1298
+ "\u001b[91m���\u001b[0m \u001b[1;91m▲\u001b[0m \u001b[91m│\u001b[0m\n",
1299
+ "\u001b[91m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
1300
+ "\u001b[1;91mSyntaxError: \u001b[0minvalid syntax\n"
1301
+ ]
1302
+ },
1303
+ "metadata": {},
1304
+ "output_type": "display_data"
1305
+ }
1306
+ ],
1307
+ "source": [
1308
+ "n++"
1309
+ ]
1310
+ },
1311
+ {
1312
+ "cell_type": "markdown",
1313
+ "id": "5b0ae095-3203-4eb8-8606-acc2db6ccf20",
1314
+ "metadata": {},
1315
+ "source": [
1316
+ "# More custom functions"
1317
+ ]
1318
+ },
1319
+ {
1320
+ "cell_type": "code",
1321
+ "execution_count": null,
1322
+ "id": "827ead88-7eb3-47cc-82da-31565063b927",
1323
+ "metadata": {
1324
+ "tags": []
1325
+ },
1326
+ "outputs": [],
1327
+ "source": [
1328
+ "# using the same preprocessing as was used in MindEye + BrainDiffuser\n",
1329
+ "pixcorr_preprocess = transforms.Compose([\n",
1330
+ " transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n",
1331
+ "])\n",
1332
+ "def pixcorr(images,brains):\n",
1333
+ " # Flatten images while keeping the batch dimension\n",
1334
+ " all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)\n",
1335
+ " all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)\n",
1336
+ " corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()\n",
1337
+ " return corrmean"
1338
+ ]
1339
+ },
1340
+ {
1341
+ "cell_type": "markdown",
1342
+ "id": "d5690151-2131-4918-b750-e869cbd1a8a8",
1343
+ "metadata": {},
1344
+ "source": [
1345
+ "# Main"
1346
+ ]
1347
+ },
1348
+ {
1349
+ "cell_type": "code",
1350
+ "execution_count": null,
1351
+ "id": "12de6387-6e18-4e4b-b5ce-a847d625330a",
1352
+ "metadata": {
1353
+ "tags": []
1354
+ },
1355
+ "outputs": [],
1356
+ "source": [
1357
+ "epoch = 0\n",
1358
+ "losses, test_losses, lrs = [], [], []\n",
1359
+ "best_test_loss = 1e9\n",
1360
+ "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n",
1361
+ "\n",
1362
+ "# Optionally resume from checkpoint #\n",
1363
+ "if resume_from_ckpt:\n",
1364
+ " print(\"\\n---resuming from last.pth ckpt---\\n\")\n",
1365
+ " try:\n",
1366
+ " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n",
1367
+ " except:\n",
1368
+ " print('last.pth failed... trying last_backup.pth')\n",
1369
+ " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n",
1370
+ " epoch = checkpoint['epoch']\n",
1371
+ " print(\"Epoch\",epoch)\n",
1372
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
1373
+ " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n",
1374
+ " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n",
1375
+ " del checkpoint\n",
1376
+ "elif wandb_log:\n",
1377
+ " if wandb.run.resumed:\n",
1378
+ " print(\"\\n---resuming from last.pth ckpt---\\n\")\n",
1379
+ " try:\n",
1380
+ " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n",
1381
+ " except:\n",
1382
+ " print('last.pth failed... trying last_backup.pth')\n",
1383
+ " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n",
1384
+ " epoch = checkpoint['epoch']\n",
1385
+ " print(\"Epoch\",epoch)\n",
1386
+ " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
1387
+ " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n",
1388
+ " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n",
1389
+ " del checkpoint\n",
1390
+ "torch.cuda.empty_cache()"
1391
+ ]
1392
+ },
1393
+ {
1394
+ "cell_type": "code",
1395
+ "execution_count": null,
1396
+ "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4",
1397
+ "metadata": {
1398
+ "tags": []
1399
+ },
1400
+ "outputs": [],
1401
+ "source": [
1402
+ "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n",
1403
+ "model, optimizer, train_dl, test_dl, lr_scheduler\n",
1404
+ ")"
1405
+ ]
1406
+ },
1407
+ {
1408
+ "cell_type": "code",
1409
+ "execution_count": null,
1410
+ "id": "bfeeda32-82ca-4364-bce1-eaa41b4f3e25",
1411
+ "metadata": {
1412
+ "tags": []
1413
+ },
1414
+ "outputs": [],
1415
+ "source": [
1416
+ "\"\"\"transform = transforms.Compose(\n",
1417
+ " [\n",
1418
+ " transforms.Resize(\n",
1419
+ " (224, 224),\n",
1420
+ " ),\n",
1421
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n",
1422
+ " ]\n",
1423
+ " )\n",
1424
+ "\n",
1425
+ "def tensor_2_embed(image): \n",
1426
+ " image_for_blip2 = transform(image)\n",
1427
+ " \n",
1428
+ " #Generate embeddings\n",
1429
+ " with blip2_model.maybe_autocast():\n",
1430
+ " blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))\n",
1431
+ " \n",
1432
+ " return blip2_target\n",
1433
+ "\n",
1434
+ "def embed_2_caption(image_embeds, model):\n",
1435
+ " image_embeds = image_embeds.float()\n",
1436
+ " image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n",
1437
+ " image.device)\n",
1438
+ "\n",
1439
+ " query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)\n",
1440
+ " query_output = model.Qformer.bert(\n",
1441
+ " query_embeds=query_tokens,\n",
1442
+ " encoder_hidden_states=image_embeds,\n",
1443
+ " encoder_attention_mask=image_atts,\n",
1444
+ " return_dict=True)\n",
1445
+ "\n",
1446
+ " inputs_t5 = model.t5_proj(query_output.last_hidden_state)\n",
1447
+ " atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n",
1448
+ " prompt = model.prompt\n",
1449
+ " input_tokens = model.t5_tokenizer(\n",
1450
+ " prompt, padding=\"longest\", return_tensors=\"pt\"\n",
1451
+ " ).to(image.device)\n",
1452
+ " encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n",
1453
+ " \n",
1454
+ " with model.maybe_autocast(dtype=torch.bfloat16):\n",
1455
+ " inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n",
1456
+ " inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n",
1457
+ "\n",
1458
+ " outputs = model.t5_model.generate(\n",
1459
+ " inputs_embeds=inputs_embeds,\n",
1460
+ " attention_mask=encoder_atts)\n",
1461
+ " output_text = model.t5_tokenizer.batch_decode(\n",
1462
+ " outputs, skip_special_tokens=True)\n",
1463
+ " \n",
1464
+ " return output_text\"\"\""
1465
+ ]
1466
+ },
1467
+ {
1468
+ "cell_type": "code",
1469
+ "execution_count": null,
1470
+ "id": "636b4684-df9a-4e29-8683-86fb035ba690",
1471
+ "metadata": {
1472
+ "tags": []
1473
+ },
1474
+ "outputs": [],
1475
+ "source": [
1476
+ "wandb_log = True"
1477
+ ]
1478
+ },
1479
+ {
1480
+ "cell_type": "code",
1481
+ "execution_count": null,
1482
+ "id": "60be0d5f-3e94-4612-9373-61b53d836393",
1483
+ "metadata": {
1484
+ "tags": []
1485
+ },
1486
+ "outputs": [],
1487
+ "source": [
1488
+ "print(f\"{model_name} starting with epoch {epoch} / {num_epochs}\")\n",
1489
+ "progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))\n",
1490
+ "test_image, test_voxel = None, None\n",
1491
+ "mse = nn.MSELoss()\n",
1492
+ "for epoch in progress_bar:\n",
1493
+ " model.train()\n",
1494
+ " \n",
1495
+ " fwd_percent_correct = 0.\n",
1496
+ " bwd_percent_correct = 0.\n",
1497
+ " test_fwd_percent_correct = 0.\n",
1498
+ " test_bwd_percent_correct = 0.\n",
1499
+ "\n",
1500
+ " loss_clip_total = 0.\n",
1501
+ " loss_blurry_total = 0.\n",
1502
+ " test_loss_clip_total = 0.\n",
1503
+ " test_loss_blurry_total = 0.\n",
1504
+ "\n",
1505
+ " blurry_pixcorr = 0.\n",
1506
+ " test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1\n",
1507
+ " \n",
1508
+ " for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n",
1509
+ " if epoch == 0:\n",
1510
+ " lrs.append(0)\n",
1511
+ " break\n",
1512
+ " with torch.cuda.amp.autocast():\n",
1513
+ " optimizer.zero_grad()\n",
1514
+ "\n",
1515
+ " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n",
1516
+ " \n",
1517
+ " image = images[behav[:,0,0].cpu().long()].to(device).float()\n",
1518
+ "\n",
1519
+ " # blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n",
1520
+ " \n",
1521
+ " if use_image_aug: image = img_augment(image)\n",
1522
+ " # clip_target = clip_model.embed_image(image)\n",
1523
+ " clip_target = embed_images_b2(image)[0].to(device) #####CHANGED\n",
1524
+ " assert not torch.any(torch.isnan(clip_target))\n",
1525
+ " \n",
1526
+ " if epoch < int(mixup_pct * num_epochs):\n",
1527
+ " voxel, perm, betas, select = utils.mixco(voxel)\n",
1528
+ "\n",
1529
+ " voxel_ridge = model.ridge(voxel)\n",
1530
+ " \n",
1531
+ " # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n",
1532
+ " clip_voxels = model.backbone(voxel_ridge)\n",
1533
+ " \n",
1534
+ " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n",
1535
+ " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
1536
+ "\n",
1537
+ " if epoch < int(mixup_pct * num_epochs): \n",
1538
+ " loss_clip = utils.mixco_nce(\n",
1539
+ " clip_voxels_norm,\n",
1540
+ " clip_target_norm,\n",
1541
+ " temp=.006, \n",
1542
+ " perm=perm, betas=betas, select=select)\n",
1543
+ " else:\n",
1544
+ " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n",
1545
+ " loss_clip = utils.soft_clip_loss(\n",
1546
+ " clip_voxels_norm,\n",
1547
+ " clip_target_norm,\n",
1548
+ " temp=epoch_temp)\n",
1549
+ " \n",
1550
+ " loss_mse= mse(clip_voxels, clip_target)\n",
1551
+ "\n",
1552
+ " # loss_blurry = mse(blurry_image_enc_, blurry_image_enc) \n",
1553
+ "\n",
1554
+ " loss_clip_total += loss_clip.item()\n",
1555
+ " # loss_blurry_total += loss_blurry.item()\n",
1556
+ "\n",
1557
+ " # loss = loss_blurry + loss_clip\n",
1558
+ " loss = 0.7 * loss_clip + 0.3 * loss_mse\n",
1559
+ " if (train_i % 10 == 0):\n",
1560
+ " print(train_i, loss)\n",
1561
+ " # print(batch_size)\n",
1562
+ " utils.check_loss(loss)\n",
1563
+ " accelerator.backward(loss)\n",
1564
+ " optimizer.step()\n",
1565
+ " \n",
1566
+ " losses.append(loss.item())\n",
1567
+ " lrs.append(optimizer.param_groups[0]['lr'])\n",
1568
+ " \n",
1569
+ " # forward and backward top 1 accuracy \n",
1570
+ " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n",
1571
+ " fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n",
1572
+ " bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
1573
+ "\n",
1574
+ " # with torch.no_grad():\n",
1575
+ " # # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()\n",
1576
+ " # random_samps = np.random.choice(np.arange(len(voxel)), size=8, replace=False)\n",
1577
+ " # blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)\n",
1578
+ " # blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)\n",
1579
+ "\n",
1580
+ " if lr_scheduler_type is not None:\n",
1581
+ " lr_scheduler.step()\n",
1582
+ " \n",
1583
+ " model.eval()\n",
1584
+ " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n",
1585
+ " with torch.cuda.amp.autocast():\n",
1586
+ " with torch.no_grad(): \n",
1587
+ " # all test samples should be loaded per batch such that test_i should never exceed 0\n",
1588
+ " if len(behav) != num_test: print(\"!\",len(behav),num_test)\n",
1589
+ " \n",
1590
+ " ## Average same-image repeats ##\n",
1591
+ " if test_image is None:\n",
1592
+ " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n",
1593
+ " \n",
1594
+ " image = behav[:,0,0].cpu().long()\n",
1595
+ " \n",
1596
+ " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n",
1597
+ " for im in unique_image:\n",
1598
+ " locs = torch.where(im == image)[0]\n",
1599
+ " if test_image is None:\n",
1600
+ " test_image = images[im][None]\n",
1601
+ " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n",
1602
+ " else:\n",
1603
+ " test_image = torch.vstack((test_image, images[im][None]))\n",
1604
+ " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n",
1605
+ " \n",
1606
+ " # sample of batch_size\n",
1607
+ " random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]\n",
1608
+ " voxel = test_voxel[random_indices].to(device)\n",
1609
+ " image = test_image[random_indices].to(device)\n",
1610
+ " assert len(image) == batch_size\n",
1611
+ " \n",
1612
+ " # blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n",
1613
+ " \n",
1614
+ " # clip_target = clip_model.embed_image(image.float())\n",
1615
+ " clip_target = embed_images_b2(image)[0].to(device) #####CHANGED\n",
1616
+ " \n",
1617
+ " voxel_ridge = model.ridge(voxel)\n",
1618
+ " \n",
1619
+ " # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n",
1620
+ " clip_voxels = model.backbone(voxel_ridge)\n",
1621
+ " \n",
1622
+ " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n",
1623
+ " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n",
1624
+ " \n",
1625
+ " # loss_clip = utils.soft_clip_loss(\n",
1626
+ " # clip_voxels_norm,\n",
1627
+ " # clip_target_norm,\n",
1628
+ " # temp=.006)\n",
1629
+ " \n",
1630
+ " loss_clip = mse(clip_voxels, clip_target)\n",
1631
+ "\n",
1632
+ " # loss_blurry = mse(blurry_image_enc_, blurry_image_enc)\n",
1633
+ " \n",
1634
+ " # loss = loss_blurry + loss_clip\n",
1635
+ " loss = loss_clip\n",
1636
+ " \n",
1637
+ " utils.check_loss(loss)\n",
1638
+ " \n",
1639
+ " test_losses.append(loss.item())\n",
1640
+ " \n",
1641
+ " # forward and backward top 1 accuracy \n",
1642
+ " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n",
1643
+ " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n",
1644
+ " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n",
1645
+ "\n",
1646
+ " # # halving the batch size because the decoder is computationally heavy\n",
1647
+ " # blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)\n",
1648
+ " # blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))\n",
1649
+ " # test_blurry_pixcorr += pixcorr(image, blurry_recon_images)\n",
1650
+ "\n",
1651
+ " #Find captions and print next to images\n",
1652
+ " #caption1 = embed_2_caption(clip_voxels[[0]], blip2_model)\n",
1653
+ " #caption2 = embed_2_caption(clip_voxels[[1]], blip2_model)\n",
1654
+ "\n",
1655
+ " #true_embed1 = tensor_2_embed(image[[0]])\n",
1656
+ " #true_embed2 = tensor_2_embed(image[[1]])\n",
1657
+ "\n",
1658
+ " # print(clip_voxels[[0]].shape)\n",
1659
+ " # print(true_embed1.shape)\n",
1660
+ " \n",
1661
+ " #true_caption1 = embed_2_caption(true_embed1, blip2_model)\n",
1662
+ " #true_caption2 = embed_2_caption(true_embed2, blip2_model)\n",
1663
+ " \n",
1664
+ " # transform blurry recon latents to images and plot it\n",
1665
+ " #fig, axes = plt.subplots(2, 2, figsize=(8, 4))\n",
1666
+ " #axes[0,0].imshow(utils.torch_to_Image(image[[0]]))\n",
1667
+ " #axes[0,1].imshow(utils.torch_to_Image(image[[1]]))\n",
1668
+ " #axes[0,0].axis('off'); axes[0,1].axis('off'); axes[1,0].axis('off'); axes[1,1].axis('off')\n",
1669
+ " #axes[0,0].set_title(caption1)\n",
1670
+ " #axes[0,1].set_title(caption2)\n",
1671
+ " #axes[1,0].set_title(true_caption1)\n",
1672
+ " #axes[1,1].set_title(true_caption2)\n",
1673
+ "\n",
1674
+ " #plt.show()\n",
1675
+ " \n",
1676
+ " # # transform blurry recon latents to images and plot it\n",
1677
+ " # fig, axes = plt.subplots(1, 4, figsize=(8, 4))\n",
1678
+ " # axes[0].imshow(utils.torch_to_Image(image[[0]]))\n",
1679
+ " # axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))\n",
1680
+ " # axes[2].imshow(utils.torch_to_Image(image[[1]]))\n",
1681
+ " # axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))\n",
1682
+ " # axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')\n",
1683
+ " # axes[0].set_title(caption1)\n",
1684
+ " # axes[3].set_title(caption2)\n",
1685
+ " # plt.show()\n",
1686
+ " \n",
1687
+ "\n",
1688
+ " if local_rank==0: \n",
1689
+ " # if utils.is_interactive(): clear_output(wait=True)\n",
1690
+ " assert (test_i+1) == 1\n",
1691
+ " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n",
1692
+ " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n",
1693
+ " \"train/lr\": lrs[-1],\n",
1694
+ " \"train/num_steps\": len(losses),\n",
1695
+ " \"test/num_steps\": len(test_losses),\n",
1696
+ " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n",
1697
+ " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n",
1698
+ " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n",
1699
+ " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n",
1700
+ " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n",
1701
+ " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n",
1702
+ " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n",
1703
+ " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n",
1704
+ " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n",
1705
+ " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n",
1706
+ " }\n",
1707
+ " progress_bar.set_postfix(**logs)\n",
1708
+ " \n",
1709
+ " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n",
1710
+ " jj=-1\n",
1711
+ " for j in [0,1,2,3,4,5,6,7]:\n",
1712
+ " jj+=1\n",
1713
+ " axes[jj].imshow(utils.torch_to_Image(image[j]))\n",
1714
+ " axes[jj].axis('off')\n",
1715
+ "\n",
1716
+ " if wandb_log:\n",
1717
+ " generated_captions = embeds_to_captions_b2(clip_voxels[0:8])\n",
1718
+ " print(generated_captions[1])\n",
1719
+ " logs[f\"test/recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\" + \"\\n\".join(generated_captions[1]))\n",
1720
+ " plt.close()\n",
1721
+ " # Save model checkpoint and reconstruct\n",
1722
+ " if epoch % ckpt_interval == 0:\n",
1723
+ " if not utils.is_interactive():\n",
1724
+ " save_ckpt(f'last')\n",
1725
+ " \n",
1726
+ " if wandb_log: wandb.log(logs)\n",
1727
+ "\n",
1728
+ " # wait for other GPUs to catch up if needed\n",
1729
+ " accelerator.wait_for_everyone()\n",
1730
+ " torch.cuda.empty_cache()\n",
1731
+ " gc.collect()\n",
1732
+ "\n",
1733
+ "print(\"\\n===Finished!===\\n\")\n",
1734
+ "if ckpt_saving:\n",
1735
+ " save_ckpt(f'last')\n",
1736
+ "if not utils.is_interactive():\n",
1737
+ " sys.exit(0)"
1738
+ ]
1739
+ },
1740
+ {
1741
+ "cell_type": "code",
1742
+ "execution_count": null,
1743
+ "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a",
1744
+ "metadata": {
1745
+ "tags": []
1746
+ },
1747
+ "outputs": [],
1748
+ "source": [
1749
+ "plt.plot(losses)\n",
1750
+ "plt.show()\n",
1751
+ "plt.plot(test_losses)\n",
1752
+ "plt.show()"
1753
+ ]
1754
+ },
1755
+ {
1756
+ "cell_type": "code",
1757
+ "execution_count": null,
1758
+ "id": "ccfccd4f-764d-4624-842c-f931676eb43b",
1759
+ "metadata": {},
1760
+ "outputs": [],
1761
+ "source": [
1762
+ "print('test')"
1763
+ ]
1764
+ },
1765
+ {
1766
+ "cell_type": "code",
1767
+ "execution_count": null,
1768
+ "id": "f1a60e19-c440-4c9c-a634-30186209012f",
1769
+ "metadata": {},
1770
+ "outputs": [],
1771
+ "source": [
1772
+ "def tensor_2_embed_old(tensor):\n",
1773
+ " embed_array = torch.zeros((tensor.shape[0],257, 1024)) \n",
1774
+ " to_pil = ToPILImage()\n",
1775
+ " for sample in range(tensor.shape[0]):\n",
1776
+ " PIL_image = to_pil(tensor[sample])\n",
1777
+ " image_for_blip2 = vis_processors[\"eval\"](PIL_image).unsqueeze(0).to(device)\n",
1778
+ " #Generate embeddings\n",
1779
+ " with blip2_model.maybe_autocast():\n",
1780
+ " blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))\n",
1781
+ " embed_array[sample] = blip2_target\n",
1782
+ " \n",
1783
+ " return embed_array"
1784
+ ]
1785
+ },
1786
+ {
1787
+ "cell_type": "code",
1788
+ "execution_count": null,
1789
+ "id": "d39ddada-47f7-4111-92fa-0dd98e8a83d6",
1790
+ "metadata": {},
1791
+ "outputs": [],
1792
+ "source": []
1793
+ },
1794
+ {
1795
+ "cell_type": "code",
1796
+ "execution_count": null,
1797
+ "id": "ec8ed96a-61fa-4c20-8da2-fcd9d0a2ed38",
1798
+ "metadata": {},
1799
+ "outputs": [],
1800
+ "source": []
1801
+ },
1802
+ {
1803
+ "cell_type": "code",
1804
+ "execution_count": null,
1805
+ "id": "6228eb1a-e8e7-4500-b7bc-d0c57bcac4c6",
1806
+ "metadata": {},
1807
+ "outputs": [],
1808
+ "source": []
1809
+ }
1810
+ ],
1811
+ "metadata": {
1812
+ "kernelspec": {
1813
+ "display_name": "Python 3 (ipykernel)",
1814
+ "language": "python",
1815
+ "name": "python3"
1816
+ },
1817
+ "language_info": {
1818
+ "codemirror_mode": {
1819
+ "name": "ipython",
1820
+ "version": 3
1821
+ },
1822
+ "file_extension": ".py",
1823
+ "mimetype": "text/x-python",
1824
+ "name": "python",
1825
+ "nbconvert_exporter": "python",
1826
+ "pygments_lexer": "ipython3",
1827
+ "version": "3.10.8"
1828
+ },
1829
+ "toc": {
1830
+ "base_numbering": 1,
1831
+ "nav_menu": {},
1832
+ "number_sections": true,
1833
+ "sideBar": true,
1834
+ "skip_h1_title": false,
1835
+ "title_cell": "Table of Contents",
1836
+ "title_sidebar": "Contents",
1837
+ "toc_cell": false,
1838
+ "toc_position": {
1839
+ "height": "calc(100% - 180px)",
1840
+ "left": "10px",
1841
+ "top": "150px",
1842
+ "width": "165px"
1843
+ },
1844
+ "toc_section_display": true,
1845
+ "toc_window_display": true
1846
+ },
1847
+ "toc-autonumbering": true,
1848
+ "vscode": {
1849
+ "interpreter": {
1850
+ "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376"
1851
+ }
1852
+ }
1853
+ },
1854
+ "nbformat": 4,
1855
+ "nbformat_minor": 5
1856
+ }
src/train2.py ADDED
@@ -0,0 +1,1141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ # # Code to convert this notebook to .py if you want to run it via command line or with Slurm
8
+ # from subprocess import call
9
+ # command = "jupyter nbconvert Train.ipynb --to python"
10
+ # call(command,shell=True)
11
+
12
+
13
+ # # Import packages & functions
14
+
15
+ # In[2]:
16
+
17
+
18
+ import os
19
+ import sys
20
+ import json
21
+ import argparse
22
+ import numpy as np
23
+ import math
24
+ from einops import rearrange
25
+ import time
26
+ import random
27
+ import h5py
28
+ from tqdm import tqdm
29
+
30
+ import webdataset as wds
31
+ import gc
32
+
33
+ import matplotlib.pyplot as plt
34
+ import torch
35
+ import torch.nn as nn
36
+ from torchvision import transforms
37
+ from torchvision.transforms import ToPILImage #CHANGED (added)
38
+
39
+ from accelerate import Accelerator, DeepSpeedPlugin
40
+
41
+ # tf32 data type is faster than standard float32
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+
44
+ # custom functions #
45
+ import utils
46
+
47
+ global_batch_size = 128 #128
48
+
49
+
50
+ # In[3]:
51
+
52
+
53
+ ### Multi-GPU config ###
54
+ local_rank = os.getenv('RANK')
55
+ if local_rank is None:
56
+ local_rank = 0
57
+ else:
58
+ local_rank = int(local_rank)
59
+ print("LOCAL RANK ", local_rank)
60
+
61
+ num_devices = torch.cuda.device_count()
62
+ if num_devices==0: num_devices = 1
63
+
64
+ accelerator = Accelerator(split_batches=False)
65
+
66
+ ### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above "accelerator = " line) ###
67
+
68
+ # if num_devices <= 1 and utils.is_interactive():
69
+ # # can emulate a distributed environment for deepspeed to work in jupyter notebook
70
+ # os.environ["MASTER_ADDR"] = "localhost"
71
+ # os.environ["MASTER_PORT"] = str(np.random.randint(10000)+9000)
72
+ # os.environ["RANK"] = "0"
73
+ # os.environ["LOCAL_RANK"] = "0"
74
+ # os.environ["WORLD_SIZE"] = "1"
75
+ # os.environ["GLOBAL_BATCH_SIZE"] = str(global_batch_size) # set this to your batch size!
76
+ # global_batch_size = os.environ["GLOBAL_BATCH_SIZE"]
77
+
78
+ # # alter the deepspeed config according to your global and local batch size
79
+ # if local_rank == 0:
80
+ # with open('deepspeed_config_stage2.json', 'r') as file:
81
+ # config = json.load(file)
82
+ # config['train_batch_size'] = int(os.environ["GLOBAL_BATCH_SIZE"])
83
+ # config['train_micro_batch_size_per_gpu'] = int(os.environ["GLOBAL_BATCH_SIZE"]) // num_devices
84
+ # with open('deepspeed_config_stage2.json', 'w') as file:
85
+ # json.dump(config, file)
86
+ # else:
87
+ # # give some time for the local_rank=0 gpu to prep new deepspeed config file
88
+ # time.sleep(10)
89
+ # deepspeed_plugin = DeepSpeedPlugin("deepspeed_config_stage2.json")
90
+ # accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)
91
+
92
+
93
+ # In[4]:
94
+
95
+
96
+ print("PID of this process =",os.getpid())
97
+ device = accelerator.device
98
+ print("device:",device)
99
+ num_workers = num_devices
100
+ print(accelerator.state)
101
+ world_size = accelerator.state.num_processes
102
+ distributed = not accelerator.state.distributed_type == 'NO'
103
+ print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
104
+ print = accelerator.print # only print if local_rank=0
105
+
106
+
107
+ # # Configurations
108
+
109
+ # In[5]:
110
+
111
+
112
+ # if running this interactively, can specify jupyter_args here for argparser to use
113
+ if utils.is_interactive():
114
+ # Example use
115
+ jupyter_args = f"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \
116
+ --model_name=captions \
117
+ --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \
118
+ --max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug"
119
+ #max_lr=3e-5 originally
120
+ jupyter_args = jupyter_args.split()
121
+ print(jupyter_args)
122
+
123
+ from IPython.display import clear_output # function to clear print outputs in cell
124
+ get_ipython().run_line_magic('load_ext', 'autoreload')
125
+ # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
126
+ get_ipython().run_line_magic('autoreload', '2')
127
+
128
+
129
+ # In[6]:
130
+
131
+
132
+ parser = argparse.ArgumentParser(description="Model Training Configuration")
133
+ parser.add_argument(
134
+ "--model_name", type=str, default="testing",
135
+ help="name of model, used for ckpt saving and wandb logging (if enabled)",
136
+ )
137
+ parser.add_argument(
138
+ "--data_path", type=str, default="/fsx/proj-fmri/shared/natural-scenes-dataset",
139
+ help="Path to where NSD data is stored / where to download it to",
140
+ )
141
+ parser.add_argument(
142
+ "--subj",type=int, default=1, choices=[1,2,5,7],
143
+ )
144
+ parser.add_argument(
145
+ "--batch_size", type=int, default=32,
146
+ help="Batch size can be increased by 10x if only training v2c and not diffusion diffuser",
147
+ )
148
+ parser.add_argument(
149
+ "--wandb_log",action=argparse.BooleanOptionalAction,default=False,
150
+ help="whether to log to wandb",
151
+ )
152
+ parser.add_argument(
153
+ "--resume_from_ckpt",action=argparse.BooleanOptionalAction,default=False,
154
+ help="if not using wandb and want to resume from a ckpt",
155
+ )
156
+ parser.add_argument(
157
+ "--wandb_project",type=str,default="stability",
158
+ help="wandb project name",
159
+ )
160
+ parser.add_argument(
161
+ "--mixup_pct",type=float,default=.33,
162
+ help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
163
+ )
164
+ parser.add_argument(
165
+ "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
166
+ help="whether to use image augmentation",
167
+ )
168
+ parser.add_argument(
169
+ "--num_epochs",type=int,default=100,
170
+ help="number of epochs of training",
171
+ )
172
+ parser.add_argument(
173
+ "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
174
+ )
175
+ parser.add_argument(
176
+ "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
177
+ )
178
+ parser.add_argument(
179
+ "--ckpt_interval",type=int,default=5,
180
+ help="save backup ckpt and reconstruct every x epochs",
181
+ )
182
+ parser.add_argument(
183
+ "--seed",type=int,default=42,
184
+ )
185
+ parser.add_argument(
186
+ "--max_lr",type=float,default=3e-4,
187
+ )
188
+ parser.add_argument(
189
+ "--n_samples_save",type=int,default=0,choices=[0,1],
190
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
191
+ )
192
+ parser.add_argument(
193
+ "--clip_mse_ratio",type=float,default=0.7,
194
+ help="Number of reconstructions for monitoring progress, 0 will speed up training",
195
+ )
196
+
197
+ if utils.is_interactive():
198
+ args = parser.parse_args(jupyter_args)
199
+ else:
200
+ args = parser.parse_args()
201
+
202
+ # create global variables without the args prefix
203
+ for attribute_name in vars(args).keys():
204
+ globals()[attribute_name] = getattr(args, attribute_name)
205
+
206
+ print("global batch_size", batch_size)
207
+ batch_size = int(batch_size / num_devices)
208
+ print("batch_size", batch_size)
209
+
210
+
211
+ # In[7]:
212
+
213
+
214
+ outdir = os.path.abspath(f'../train_logs/{model_name}')
215
+ if not os.path.exists(outdir):
216
+ os.makedirs(outdir,exist_ok=True)
217
+ if use_image_aug:
218
+ import kornia
219
+ from kornia.augmentation.container import AugmentationSequential
220
+ img_augment = AugmentationSequential(
221
+ kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),
222
+ kornia.augmentation.Resize((224, 224)),
223
+ kornia.augmentation.RandomHorizontalFlip(p=0.3),
224
+ kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),
225
+ kornia.augmentation.RandomGrayscale(p=0.3),
226
+ same_on_batch=False,
227
+ data_keys=["input"],
228
+ )
229
+
230
+
231
+ # In[8]:
232
+
233
+
234
+ wandb_log = True
235
+
236
+
237
+ # # Prep data, models, and dataloaders
238
+
239
+ # ## Dataloader
240
+
241
+ # In[9]:
242
+
243
+
244
+ if subj==1:
245
+ num_train = 24958
246
+ num_test = 2770
247
+ test_batch_size = num_test
248
+
249
+ def my_split_by_node(urls): return urls
250
+
251
+ train_url = f"{data_path}/wds/subj0{subj}/train/" + "{0..36}.tar"
252
+ print(train_url)
253
+
254
+ train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\
255
+ .shuffle(750, initial=1500, rng=random.Random(42))\
256
+ .decode("torch")\
257
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
258
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
259
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)
260
+
261
+ test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
262
+ print(test_url)
263
+
264
+ test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
265
+ .shuffle(750, initial=1500, rng=random.Random(42))\
266
+ .decode("torch")\
267
+ .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
268
+ .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
269
+ test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
270
+
271
+
272
+ # ### check dataloaders are working
273
+
274
+ # In[10]:
275
+
276
+
277
+ # test_indices = []
278
+ # test_images = []
279
+ # for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
280
+ # test_indices = np.append(test_indices, behav[:,0,5].numpy())
281
+ # test_images = np.append(test_images, behav[:,0,0].numpy())
282
+ # test_indices = test_indices.astype(np.int16)
283
+ # print(test_i, (test_i+1) * test_batch_size, len(test_indices))
284
+ # print("---\n")
285
+
286
+ # train_indices = []
287
+ # train_images = []
288
+ # for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
289
+ # train_indices = np.append(train_indices, behav[:,0,5].long().numpy())
290
+ # train_images = np.append(train_images, behav[:,0,0].numpy())
291
+ # train_indices = train_indices.astype(np.int16)
292
+ # print(train_i, (train_i+1) * batch_size, len(train_indices))
293
+
294
+ # # train_images = np.hstack((train_images, test_images))
295
+ # # print("WARNING: ADDED TEST IMAGES TO TRAIN IMAGES")
296
+
297
+
298
+ # ## Load data and images
299
+
300
+ # In[11]:
301
+
302
+
303
+ # load betas
304
+ f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')
305
+ voxels = f['betas'][:]
306
+ print(f"subj0{subj} betas loaded into memory")
307
+ voxels = torch.Tensor(voxels).to("cpu").half()
308
+ if subj==1:
309
+ voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))
310
+ print("voxels", voxels.shape)
311
+ num_voxels = voxels.shape[-1]
312
+
313
+ # load orig images
314
+ f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
315
+ images = f['images'][:]
316
+ images = torch.Tensor(images).to("cpu").half()
317
+ print("images", images.shape)
318
+
319
+
320
+ # ## Load models
321
+
322
+ # ### CLIP image embeddings model
323
+
324
+ # In[12]:
325
+
326
+
327
+ import transformers
328
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
329
+
330
+ from PIL import Image
331
+
332
+
333
+ # In[13]:
334
+
335
+
336
+ from models import Clipper
337
+ clip_model = Clipper("ViT-L/14", device=torch.device(f"cuda:{local_rank}"), hidden_state=True, norm_embs=True)
338
+
339
+
340
+ # In[14]:
341
+
342
+
343
+ cache_blip2 = "/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b"
344
+
345
+ b2_processor = Blip2Processor.from_pretrained(cache_blip2)
346
+ b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map="auto")
347
+
348
+ #Load in blip2 as well
349
+ """from lavis.models import load_model_and_preprocess
350
+ from lavis.models import model_zoo
351
+ blip2_model, vis_processors, _ = load_model_and_preprocess(
352
+ name="blip2_t5", model_type="pretrain_flant5xl_vitL", is_eval=True, device=device)
353
+
354
+ clip_seq_dim = 257
355
+ clip_emb_dim = 1024
356
+ hidden_dim = 4096"""
357
+
358
+
359
+ # In[15]:
360
+
361
+
362
+ def embed_images_b2(images):
363
+ images = (images * 255).type(torch.uint8)
364
+ with torch.no_grad():
365
+ inputs_processed = b2_processor(images, return_tensors="pt").to("cuda", torch.float16)
366
+ enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values'])
367
+ return enc_imgs.last_hidden_state.detach(), inputs_processed
368
+
369
+ def embeds_to_captions_b2(embeds):
370
+ with torch.no_grad():
371
+ input_ids = None #inputs['input_ids']
372
+ attention_mask = None
373
+ batch_size = embeds.shape[0]
374
+ image_embeds = embeds
375
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
376
+
377
+ query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)
378
+ query_outputs = b2_model.qformer(
379
+ query_embeds=query_tokens,
380
+ encoder_hidden_states=image_embeds,
381
+ encoder_attention_mask=image_attention_mask,
382
+ return_dict=True,
383
+ )
384
+ query_output = query_outputs.last_hidden_state
385
+
386
+ language_model_inputs = b2_model.language_projection(query_output)
387
+ language_attention_mask = torch.ones(
388
+ language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
389
+ )
390
+ if input_ids is None:
391
+ input_ids = (
392
+ torch.LongTensor([[b2_model.config.text_config.bos_token_id]])
393
+ .repeat(batch_size, 1)
394
+ .to(image_embeds.device)
395
+ )
396
+ if attention_mask is None:
397
+ attention_mask = torch.ones_like(input_ids)
398
+ attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
399
+
400
+ # concatenate query embeddings with prompt embeddings
401
+ inputs_embeds = b2_model.get_input_embeddings()(input_ids)
402
+ inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
403
+
404
+ outputs = b2_model.language_model.generate(
405
+ inputs_embeds=inputs_embeds,
406
+ attention_mask=attention_mask,
407
+ )
408
+ text = b2_processor.batch_decode(outputs, skip_special_tokens=True)
409
+
410
+ return outputs, text
411
+
412
+
413
+ # In[16]:
414
+
415
+
416
+ image_test = images[1:20].permute(0,2,3,1)
417
+ #raw_image = Image.open('/fsx/proj-fmri/shared/controlNetData/target/img_t1.jpg').convert('RGB')
418
+ # Convert the image to a NumPy array
419
+ #image_test = np.array(raw_image)
420
+
421
+
422
+ # In[17]:
423
+
424
+
425
+ """import matplotlib.pyplot as plt
426
+ # Plotting one of the images (taking the first image as an example)
427
+ img_to_plot = inputs_rec['pixel_values'][-1]
428
+
429
+ # Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])
430
+ img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')
431
+ print(img_to_plot.shape)
432
+
433
+ plt.imshow(img_to_plot)
434
+ plt.show()"""
435
+
436
+
437
+ # In[18]:
438
+
439
+
440
+ embeds_test, inputs_rec = embed_images_b2(image_test)
441
+
442
+
443
+ # In[19]:
444
+
445
+
446
+ #inputs_rec['pixel_values'].shape
447
+
448
+
449
+ # In[20]:
450
+
451
+
452
+ #out = b2_model.generate(**inputs_rec)
453
+ #print(b2_processor.decode(out[0], skip_special_tokens=True).strip())
454
+
455
+
456
+ # In[21]:
457
+
458
+
459
+ outputs_test, text_test = embeds_to_captions_b2(embeds_test)
460
+
461
+
462
+ # In[22]:
463
+
464
+
465
+ text_test
466
+
467
+
468
+ # In[23]:
469
+
470
+
471
+ #inputss['pixel_values'].shape
472
+
473
+
474
+ # In[24]:
475
+
476
+
477
+ #image_test.shape
478
+
479
+
480
+ # In[25]:
481
+
482
+
483
+
484
+
485
+ # In[26]:
486
+
487
+
488
+ clip_seq_dim = 257 #blip2 image encoder shapes
489
+ clip_emb_dim = 1408 #blip2 image encoder shapes
490
+ hidden_dim = 2048
491
+
492
+
493
+ # ### SD VAE (blurry images)
494
+
495
+ # In[27]:
496
+
497
+
498
+ from diffusers import AutoencoderKL
499
+ autoenc = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, cache_dir="/fsx/proj-fmri/shared/cache")
500
+ # autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')["model_state_dict"])
501
+ autoenc.eval()
502
+ autoenc.requires_grad_(False)
503
+ autoenc.to(device)
504
+ utils.count_params(autoenc)
505
+
506
+
507
+ # ### MindEye modules
508
+
509
+ # In[28]:
510
+
511
+
512
+ class MindEyeModule(nn.Module):
513
+ def __init__(self):
514
+ super(MindEyeModule, self).__init__()
515
+ def forward(self, x):
516
+ return x
517
+
518
+ model = MindEyeModule()
519
+ model
520
+
521
+
522
+ # In[29]:
523
+
524
+
525
+ class RidgeRegression(torch.nn.Module):
526
+ # make sure to add weight_decay when initializing optimizer
527
+ def __init__(self, input_size, out_features):
528
+ super(RidgeRegression, self).__init__()
529
+ self.out_features = out_features
530
+ self.linear = torch.nn.Linear(input_size, out_features)
531
+ def forward(self, x):
532
+ return self.linear(x)
533
+
534
+ model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)
535
+ utils.count_params(model.ridge)
536
+ utils.count_params(model)
537
+
538
+ b = torch.randn((2,1,voxels.shape[1]))
539
+ print(b.shape, model.ridge(b).shape)
540
+
541
+
542
+ # In[30]:
543
+
544
+
545
+ from functools import partial
546
+ from diffusers.models.vae import Decoder
547
+ class BrainNetwork(nn.Module):
548
+ def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.35, blurry_dim=16):
549
+ super().__init__()
550
+ self.blurry_dim = blurry_dim
551
+ norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)
552
+ act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU
553
+ act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)
554
+ self.lin0 = nn.Linear(in_dim, h)
555
+ self.mlp = nn.ModuleList([
556
+ nn.Sequential(
557
+ nn.Linear(h, h),
558
+ *[item() for item in act_and_norm],
559
+ nn.Dropout(drop)
560
+ ) for _ in range(n_blocks)
561
+ ])
562
+ self.lin1 = nn.Linear(h, out_dim, bias=True)
563
+ # self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)
564
+ self.n_blocks = n_blocks
565
+ self.clip_size = clip_size
566
+ self.clip_proj = nn.Sequential(
567
+ nn.LayerNorm(clip_size),
568
+ nn.GELU(),
569
+ nn.Linear(clip_size, 2048),
570
+ nn.LayerNorm(2048),
571
+ nn.GELU(),
572
+ nn.Linear(2048, 2048),
573
+ nn.LayerNorm(2048),
574
+ nn.GELU(),
575
+ nn.Linear(2048, clip_size)
576
+ )
577
+ # self.upsampler = Decoder(
578
+ # in_channels=64,
579
+ # out_channels=4,
580
+ # up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
581
+ # block_out_channels=[64, 128, 256],
582
+ # layers_per_block=1,
583
+ # )
584
+
585
+ def forward(self, x):
586
+ x = self.lin0(x)
587
+ residual = x
588
+ for res_block in range(self.n_blocks):
589
+ x = self.mlp[res_block](x)
590
+ x += residual
591
+ residual = x
592
+ x = x.reshape(len(x), -1)
593
+ x = self.lin1(x)
594
+ # b = self.blin1(x)
595
+ # b = self.upsampler(b.reshape(len(b), -1, 7, 7))
596
+ c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))
597
+ # return c, b
598
+ return c
599
+
600
+ model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7)
601
+ utils.count_params(model.backbone)
602
+ utils.count_params(model)
603
+
604
+ b = torch.randn((4,hidden_dim))
605
+ print(b.shape)
606
+ clip_ = model.backbone(b)
607
+ print(clip_.shape)
608
+
609
+
610
+ # In[31]:
611
+
612
+
613
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
614
+ opt_grouped_parameters = [
615
+ {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
616
+ {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
617
+ {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
618
+ ]
619
+
620
+ optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))
621
+
622
+ if lr_scheduler_type == 'linear':
623
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
624
+ optimizer,
625
+ total_iters=int(num_epochs*(num_train*num_devices//batch_size)),
626
+ last_epoch=-1
627
+ )
628
+ elif lr_scheduler_type == 'cycle':
629
+ total_steps=int(num_epochs*(num_train*num_devices//batch_size))
630
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
631
+ optimizer,
632
+ max_lr=max_lr,
633
+ total_steps=total_steps,
634
+ final_div_factor=1000,
635
+ last_epoch=-1, pct_start=2/num_epochs
636
+ )
637
+
638
+ def save_ckpt(tag):
639
+ ckpt_path = outdir+f'/{tag}.pth'
640
+ print(f'saving {ckpt_path}',flush=True)
641
+ unwrapped_model = accelerator.unwrap_model(model)
642
+ try:
643
+ torch.save({
644
+ 'epoch': epoch,
645
+ 'model_state_dict': unwrapped_model.state_dict(),
646
+ 'optimizer_state_dict': optimizer.state_dict(),
647
+ 'lr_scheduler': lr_scheduler.state_dict(),
648
+ 'train_losses': losses,
649
+ 'test_losses': test_losses,
650
+ 'lrs': lrs,
651
+ }, ckpt_path)
652
+ except:
653
+ print("Couldn't save... moving on to prevent crashing.")
654
+ del unwrapped_model
655
+
656
+ print("\nDone with model preparations!")
657
+ utils.count_params(model)
658
+
659
+
660
+ # # Weights and Biases
661
+
662
+ # In[32]:
663
+
664
+
665
+ # params for wandb
666
+ if local_rank==0 and True: # only use main process for wandb logging
667
+ import wandb
668
+
669
+ wandb_project = 'mindeyev2'
670
+ wandb_run = model_name
671
+ wandb_notes = ''
672
+
673
+ print(f"wandb {wandb_project} run {wandb_run}")
674
+ wandb.login(host='https://stability.wandb.io')#, relogin=True)
675
+ wandb_config = {
676
+ "model_name": model_name,
677
+ "batch_size": batch_size,
678
+ "num_epochs": num_epochs,
679
+ "use_image_aug": use_image_aug,
680
+ "max_lr": max_lr,
681
+ "lr_scheduler_type": lr_scheduler_type,
682
+ "mixup_pct": mixup_pct,
683
+ "num_train": num_train,
684
+ "num_test": num_test,
685
+ "seed": seed,
686
+ "distributed": distributed,
687
+ "num_devices": num_devices,
688
+ "world_size": world_size,
689
+ }
690
+ print("wandb_config:\n",wandb_config)
691
+ if False: # wandb_auto_resume
692
+ print("wandb_id:",model_name)
693
+ wandb.init(
694
+ id = model_name,
695
+ project=wandb_project,
696
+ name=wandb_run,
697
+ config=wandb_config,
698
+ notes=wandb_notes,
699
+ resume="allow",
700
+ )
701
+ else:
702
+ wandb.init(
703
+ project=wandb_project,
704
+ name=wandb_run,
705
+ config=wandb_config,
706
+ notes=wandb_notes,
707
+ )
708
+ else:
709
+ wandb_log = False
710
+
711
+
712
+ # # More custom functions
713
+
714
+ # In[33]:
715
+
716
+
717
+ # using the same preprocessing as was used in MindEye + BrainDiffuser
718
+ pixcorr_preprocess = transforms.Compose([
719
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
720
+ ])
721
+ def pixcorr(images,brains):
722
+ # Flatten images while keeping the batch dimension
723
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
724
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
725
+ corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
726
+ return corrmean
727
+
728
+
729
+ # # Main
730
+
731
+ # In[34]:
732
+
733
+
734
+ epoch = 0
735
+ losses, test_losses, lrs = [], [], []
736
+ best_test_loss = 1e9
737
+ soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))
738
+
739
+ # Optionally resume from checkpoint #
740
+ if resume_from_ckpt:
741
+ print("\n---resuming from last.pth ckpt---\n")
742
+ try:
743
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
744
+ except:
745
+ print('last.pth failed... trying last_backup.pth')
746
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
747
+ epoch = checkpoint['epoch']
748
+ print("Epoch",epoch)
749
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
750
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
751
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
752
+ del checkpoint
753
+ elif wandb_log:
754
+ if wandb.run.resumed:
755
+ print("\n---resuming from last.pth ckpt---\n")
756
+ try:
757
+ checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
758
+ except:
759
+ print('last.pth failed... trying last_backup.pth')
760
+ checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')
761
+ epoch = checkpoint['epoch']
762
+ print("Epoch",epoch)
763
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
764
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
765
+ diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])
766
+ del checkpoint
767
+ torch.cuda.empty_cache()
768
+
769
+
770
+ # In[35]:
771
+
772
+
773
+ model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(
774
+ model, optimizer, train_dl, test_dl, lr_scheduler
775
+ )
776
+
777
+
778
+ # In[36]:
779
+
780
+
781
+ """transform = transforms.Compose(
782
+ [
783
+ transforms.Resize(
784
+ (224, 224),
785
+ ),
786
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
787
+ ]
788
+ )
789
+
790
+ def tensor_2_embed(image):
791
+ image_for_blip2 = transform(image)
792
+
793
+ #Generate embeddings
794
+ with blip2_model.maybe_autocast():
795
+ blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))
796
+
797
+ return blip2_target
798
+
799
+ def embed_2_caption(image_embeds, model):
800
+ image_embeds = image_embeds.float()
801
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
802
+ image.device)
803
+
804
+ query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
805
+ query_output = model.Qformer.bert(
806
+ query_embeds=query_tokens,
807
+ encoder_hidden_states=image_embeds,
808
+ encoder_attention_mask=image_atts,
809
+ return_dict=True)
810
+
811
+ inputs_t5 = model.t5_proj(query_output.last_hidden_state)
812
+ atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
813
+ prompt = model.prompt
814
+ input_tokens = model.t5_tokenizer(
815
+ prompt, padding="longest", return_tensors="pt"
816
+ ).to(image.device)
817
+ encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
818
+
819
+ with model.maybe_autocast(dtype=torch.bfloat16):
820
+ inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids)
821
+ inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)
822
+
823
+ outputs = model.t5_model.generate(
824
+ inputs_embeds=inputs_embeds,
825
+ attention_mask=encoder_atts)
826
+ output_text = model.t5_tokenizer.batch_decode(
827
+ outputs, skip_special_tokens=True)
828
+
829
+ return output_text"""
830
+
831
+
832
+ # In[37]:
833
+
834
+
835
+ wandb_log = True
836
+
837
+
838
+ # In[ ]:
839
+
840
+
841
+ print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
842
+ progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
843
+ test_image, test_voxel = None, None
844
+ mse = nn.MSELoss()
845
+ for epoch in progress_bar:
846
+ model.train()
847
+
848
+ fwd_percent_correct = 0.
849
+ bwd_percent_correct = 0.
850
+ test_fwd_percent_correct = 0.
851
+ test_bwd_percent_correct = 0.
852
+
853
+ loss_clip_total = 0.
854
+ loss_blurry_total = 0.
855
+ test_loss_clip_total = 0.
856
+ test_loss_blurry_total = 0.
857
+
858
+ blurry_pixcorr = 0.
859
+ test_blurry_pixcorr = 0. # needs >.456 to beat low-level subj01 results in mindeye v1
860
+
861
+ for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):
862
+ if epoch == 0:
863
+ lrs.append(0)
864
+ break
865
+ with torch.cuda.amp.autocast():
866
+ optimizer.zero_grad()
867
+
868
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
869
+
870
+ image = images[behav[:,0,0].cpu().long()].to(device).float()
871
+
872
+ # blurry_image_enc = autoenc.encode(image).latent_dist.mode()
873
+
874
+ if use_image_aug: image = img_augment(image)
875
+ # clip_target = clip_model.embed_image(image)
876
+ clip_target = embed_images_b2(image)[0].to(device) #####CHANGED
877
+ assert not torch.any(torch.isnan(clip_target))
878
+
879
+ if epoch < int(mixup_pct * num_epochs):
880
+ voxel, perm, betas, select = utils.mixco(voxel)
881
+
882
+ voxel_ridge = model.ridge(voxel)
883
+
884
+ # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
885
+ clip_voxels = model.backbone(voxel_ridge)
886
+
887
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
888
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
889
+
890
+ if epoch < int(mixup_pct * num_epochs):
891
+ loss_clip = utils.mixco_nce(
892
+ clip_voxels_norm,
893
+ clip_target_norm,
894
+ temp=.006,
895
+ perm=perm, betas=betas, select=select)
896
+ else:
897
+ epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
898
+ loss_clip = utils.soft_clip_loss(
899
+ clip_voxels_norm,
900
+ clip_target_norm,
901
+ temp=epoch_temp)
902
+
903
+ loss_mse= mse(clip_voxels, clip_target)
904
+
905
+ # loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
906
+
907
+ loss_clip_total += loss_clip.item()
908
+ # loss_blurry_total += loss_blurry.item()
909
+
910
+ # loss = loss_blurry + loss_clip
911
+ loss = (clip_mse_ratio * loss_clip) + ((1 - clip_mse_ratio) * loss_mse)
912
+ if (train_i % 10 == 0):
913
+ print(train_i, loss)
914
+ # print(batch_size)
915
+ utils.check_loss(loss)
916
+ accelerator.backward(loss)
917
+ optimizer.step()
918
+
919
+ losses.append(loss.item())
920
+ lrs.append(optimizer.param_groups[0]['lr'])
921
+
922
+ # forward and backward top 1 accuracy
923
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
924
+ fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
925
+ bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
926
+
927
+ # with torch.no_grad():
928
+ # # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()
929
+ # random_samps = np.random.choice(np.arange(len(voxel)), size=8, replace=False)
930
+ # blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)
931
+ # blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)
932
+
933
+ if lr_scheduler_type is not None:
934
+ lr_scheduler.step()
935
+
936
+ model.eval()
937
+ for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):
938
+ with torch.cuda.amp.autocast():
939
+ with torch.no_grad():
940
+ # all test samples should be loaded per batch such that test_i should never exceed 0
941
+ if len(behav) != num_test: print("!",len(behav),num_test)
942
+
943
+ ## Average same-image repeats ##
944
+ if test_image is None:
945
+ voxel = voxels[behav[:,0,5].cpu().long()].to(device)
946
+
947
+ image = behav[:,0,0].cpu().long()
948
+
949
+ unique_image, sort_indices = torch.unique(image, return_inverse=True)
950
+ for im in unique_image:
951
+ locs = torch.where(im == image)[0]
952
+ if test_image is None:
953
+ test_image = images[im][None]
954
+ test_voxel = torch.mean(voxel[locs],axis=0)[None]
955
+ else:
956
+ test_image = torch.vstack((test_image, images[im][None]))
957
+ test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))
958
+
959
+ # sample of batch_size
960
+ random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]
961
+ voxel = test_voxel[random_indices].to(device)
962
+ image = test_image[random_indices].to(device)
963
+ assert len(image) == batch_size
964
+
965
+ # blurry_image_enc = autoenc.encode(image).latent_dist.mode()
966
+
967
+ # clip_target = clip_model.embed_image(image.float())
968
+ clip_target = embed_images_b2(image)[0].to(device) #####CHANGED
969
+
970
+ voxel_ridge = model.ridge(voxel)
971
+
972
+ # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)
973
+ clip_voxels = model.backbone(voxel_ridge)
974
+
975
+ clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)
976
+ clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)
977
+
978
+ loss_clip = utils.soft_clip_loss(
979
+ clip_voxels_norm,
980
+ clip_target_norm,
981
+ temp=.006)
982
+
983
+ loss_mse = mse(clip_voxels, clip_target)
984
+
985
+ # loss_blurry = mse(blurry_image_enc_, blurry_image_enc)
986
+
987
+ # loss = loss_blurry + loss_clip
988
+ loss = (clip_mse_ratio * loss_clip) + ((1 - clip_mse_ratio) * loss_mse)
989
+
990
+ utils.check_loss(loss)
991
+
992
+ test_losses.append(loss.item())
993
+
994
+ # forward and backward top 1 accuracy
995
+ labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device)
996
+ test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)
997
+ test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)
998
+
999
+ # # halving the batch size because the decoder is computationally heavy
1000
+ # blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)
1001
+ # blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))
1002
+ # test_blurry_pixcorr += pixcorr(image, blurry_recon_images)
1003
+
1004
+ #Find captions and print next to images
1005
+ #caption1 = embed_2_caption(clip_voxels[[0]], blip2_model)
1006
+ #caption2 = embed_2_caption(clip_voxels[[1]], blip2_model)
1007
+
1008
+ #true_embed1 = tensor_2_embed(image[[0]])
1009
+ #true_embed2 = tensor_2_embed(image[[1]])
1010
+
1011
+ # print(clip_voxels[[0]].shape)
1012
+ # print(true_embed1.shape)
1013
+
1014
+ #true_caption1 = embed_2_caption(true_embed1, blip2_model)
1015
+ #true_caption2 = embed_2_caption(true_embed2, blip2_model)
1016
+
1017
+ # transform blurry recon latents to images and plot it
1018
+ #fig, axes = plt.subplots(2, 2, figsize=(8, 4))
1019
+ #axes[0,0].imshow(utils.torch_to_Image(image[[0]]))
1020
+ #axes[0,1].imshow(utils.torch_to_Image(image[[1]]))
1021
+ #axes[0,0].axis('off'); axes[0,1].axis('off'); axes[1,0].axis('off'); axes[1,1].axis('off')
1022
+ #axes[0,0].set_title(caption1)
1023
+ #axes[0,1].set_title(caption2)
1024
+ #axes[1,0].set_title(true_caption1)
1025
+ #axes[1,1].set_title(true_caption2)
1026
+
1027
+ #plt.show()
1028
+
1029
+ # # transform blurry recon latents to images and plot it
1030
+ # fig, axes = plt.subplots(1, 4, figsize=(8, 4))
1031
+ # axes[0].imshow(utils.torch_to_Image(image[[0]]))
1032
+ # axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))
1033
+ # axes[2].imshow(utils.torch_to_Image(image[[1]]))
1034
+ # axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))
1035
+ # axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')
1036
+ # axes[0].set_title(caption1)
1037
+ # axes[3].set_title(caption2)
1038
+ # plt.show()
1039
+
1040
+
1041
+ if local_rank==0:
1042
+ # if utils.is_interactive(): clear_output(wait=True)
1043
+ assert (test_i+1) == 1
1044
+ logs = {"train/loss": np.mean(losses[-(train_i+1):]),
1045
+ "test/loss": np.mean(test_losses[-(test_i+1):]),
1046
+ "train/lr": lrs[-1],
1047
+ "train/num_steps": len(losses),
1048
+ "test/num_steps": len(test_losses),
1049
+ "train/fwd_pct_correct": fwd_percent_correct / (train_i + 1),
1050
+ "train/bwd_pct_correct": bwd_percent_correct / (train_i + 1),
1051
+ "test/test_fwd_pct_correct": test_fwd_percent_correct / (test_i + 1),
1052
+ "test/test_bwd_pct_correct": test_bwd_percent_correct / (test_i + 1),
1053
+ "train/loss_clip_total": loss_clip_total / (train_i + 1),
1054
+ "train/loss_blurry_total": loss_blurry_total / (train_i + 1),
1055
+ "test/loss_clip_total": test_loss_clip_total / (test_i + 1),
1056
+ "test/loss_blurry_total": test_loss_blurry_total / (test_i + 1),
1057
+ "train/blurry_pixcorr": blurry_pixcorr / (train_i + 1),
1058
+ "test/blurry_pixcorr": test_blurry_pixcorr / (test_i + 1),
1059
+ }
1060
+ progress_bar.set_postfix(**logs)
1061
+
1062
+ fig, axes = plt.subplots(1, 8, figsize=(10, 4))
1063
+ jj=-1
1064
+ for j in [0,1,2,3,4,5,6,7]:
1065
+ jj+=1
1066
+ axes[jj].imshow(utils.torch_to_Image(image[j]))
1067
+ axes[jj].axis('off')
1068
+
1069
+ if wandb_log:
1070
+ generated_captions = embeds_to_captions_b2(clip_voxels[0:8])
1071
+ print(generated_captions[1])
1072
+ logs[f"test/recons"] = wandb.Image(fig, caption=f"epoch{epoch:03d}" + "\n".join(generated_captions[1]))
1073
+ plt.close()
1074
+ # Save model checkpoint and reconstruct
1075
+ if epoch % ckpt_interval == 0:
1076
+ if not utils.is_interactive():
1077
+ save_ckpt(f'last')
1078
+
1079
+ if wandb_log: wandb.log(logs)
1080
+
1081
+ # wait for other GPUs to catch up if needed
1082
+ accelerator.wait_for_everyone()
1083
+ torch.cuda.empty_cache()
1084
+ gc.collect()
1085
+
1086
+ print("\n===Finished!===\n")
1087
+ if ckpt_saving:
1088
+ save_ckpt(f'last')
1089
+ if not utils.is_interactive():
1090
+ sys.exit(0)
1091
+
1092
+
1093
+ # In[ ]:
1094
+
1095
+
1096
+ plt.plot(losses)
1097
+ plt.show()
1098
+ plt.plot(test_losses)
1099
+ plt.show()
1100
+
1101
+
1102
+ # In[ ]:
1103
+
1104
+
1105
+ print('test')
1106
+
1107
+
1108
+ # In[ ]:
1109
+
1110
+
1111
+ def tensor_2_embed_old(tensor):
1112
+ embed_array = torch.zeros((tensor.shape[0],257, 1024))
1113
+ to_pil = ToPILImage()
1114
+ for sample in range(tensor.shape[0]):
1115
+ PIL_image = to_pil(tensor[sample])
1116
+ image_for_blip2 = vis_processors["eval"](PIL_image).unsqueeze(0).to(device)
1117
+ #Generate embeddings
1118
+ with blip2_model.maybe_autocast():
1119
+ blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))
1120
+ embed_array[sample] = blip2_target
1121
+
1122
+ return embed_array
1123
+
1124
+
1125
+ # In[ ]:
1126
+
1127
+
1128
+
1129
+
1130
+
1131
+ # In[ ]:
1132
+
1133
+
1134
+
1135
+
1136
+
1137
+ # In[ ]:
1138
+
1139
+
1140
+
1141
+
src/utils.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torchvision import transforms
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import PIL
7
+ import random
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+ import pandas as pd
11
+ import math
12
+ import webdataset as wds
13
+ import tempfile
14
+ from torchvision.utils import make_grid
15
+ # from diffusers.utils import randn_tensor
16
+
17
+ import json
18
+ from torchmetrics.image.fid import FrechetInceptionDistance
19
+ from PIL import Image
20
+ import requests
21
+ import io
22
+ import time
23
+
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+
26
+ def is_interactive():
27
+ import __main__ as main
28
+ return not hasattr(main, '__file__')
29
+
30
+ def seed_everything(seed=0, cudnn_deterministic=True):
31
+ random.seed(seed)
32
+ os.environ['PYTHONHASHSEED'] = str(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed)
35
+ torch.cuda.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ if cudnn_deterministic:
38
+ torch.backends.cudnn.deterministic = True
39
+ else:
40
+ ## needs to be False to use conv3D
41
+ print('Note: not using cudnn.deterministic')
42
+
43
+ def np_to_Image(x):
44
+ if x.ndim==4:
45
+ x=x[0]
46
+ return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
47
+
48
+ def torch_to_Image(x):
49
+ if x.ndim==4:
50
+ x=x[0]
51
+ return transforms.ToPILImage()(x)
52
+
53
+ def Image_to_torch(x):
54
+ try:
55
+ x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5
56
+ except:
57
+ x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
58
+ return x
59
+
60
+ def torch_to_matplotlib(x,device=device):
61
+ if torch.mean(x)>10:
62
+ x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
63
+ else:
64
+ x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
65
+ if device=='cpu':
66
+ return x[0]
67
+ else:
68
+ return x.cpu().numpy()[0]
69
+
70
+ def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
71
+ #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
72
+ numerator = A @ B.T
73
+ A_l2 = torch.mul(A, A).sum(axis=dim)
74
+ B_l2 = torch.mul(B, B).sum(axis=dim)
75
+ denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
76
+ return torch.div(numerator, denominator)
77
+
78
+ def batchwise_pearson_correlation(Z, B):
79
+ # Calculate means
80
+ Z_mean = torch.mean(Z, dim=1, keepdim=True)
81
+ B_mean = torch.mean(B, dim=1, keepdim=True)
82
+
83
+ # Subtract means
84
+ Z_centered = Z - Z_mean
85
+ B_centered = B - B_mean
86
+
87
+ # Calculate Pearson correlation coefficient
88
+ numerator = Z_centered @ B_centered.T
89
+ Z_centered_norm = torch.linalg.norm(Z_centered, dim=1, keepdim=True)
90
+ B_centered_norm = torch.linalg.norm(B_centered, dim=1, keepdim=True)
91
+ denominator = Z_centered_norm @ B_centered_norm.T
92
+
93
+ pearson_correlation = (numerator / denominator)
94
+ return pearson_correlation
95
+
96
+ def batchwise_cosine_similarity(Z,B):
97
+ # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
98
+ B = B.T
99
+ Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1).
100
+ B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b).
101
+ cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
102
+ return cosine_similarity
103
+
104
+ def topk(similarities,labels,k=5):
105
+ if k > similarities.shape[0]:
106
+ k = similarities.shape[0]
107
+ topsum=0
108
+ for i in range(k):
109
+ topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
110
+ return topsum
111
+
112
+ def get_non_diagonals(a):
113
+ a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
114
+ # make diagonals -1
115
+ a=a.fill_diagonal_(-1)
116
+ return a
117
+
118
+ def gather_features(image_features, voxel_features, accelerator):
119
+ all_image_features = accelerator.gather(image_features.contiguous())
120
+ if voxel_features is not None:
121
+ all_voxel_features = accelerator.gather(voxel_features.contiguous())
122
+ return all_image_features, all_voxel_features
123
+ return all_image_features
124
+
125
+ def soft_clip_loss(preds, targs, temp=0.125): #, distributed=False, accelerator=None):
126
+ # if not distributed:
127
+ clip_clip = (targs @ targs.T)/temp
128
+ brain_clip = (preds @ targs.T)/temp
129
+ # else:
130
+ # all_targs = gather_features(targs, None, accelerator)
131
+ # clip_clip = (targs @ all_targs.T)/temp
132
+ # brain_clip = (preds @ all_targs.T)/temp
133
+
134
+ loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
135
+ loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
136
+
137
+ loss = (loss1 + loss2)/2
138
+ return loss
139
+
140
+ def soft_siglip_loss(preds, targs, temp, bias):
141
+ temp = torch.exp(temp)
142
+
143
+ logits = (preds @ targs.T) * temp + bias
144
+ # diagonals (aka paired samples) should be >0 and off-diagonals <0
145
+ labels = (targs @ targs.T) - 1 + (torch.eye(len(targs)).to(targs.dtype).to(targs.device))
146
+
147
+ loss1 = -torch.sum(nn.functional.logsigmoid(logits * labels[:len(preds)])) / len(preds)
148
+ loss2 = -torch.sum(nn.functional.logsigmoid(logits.T * labels[:,:len(preds)])) / len(preds)
149
+ loss = (loss1 + loss2)/2
150
+ return loss
151
+
152
+ def mixco_hard_siglip_loss(preds, targs, temp, bias, perm, betas):
153
+ temp = torch.exp(temp)
154
+
155
+ probs = torch.diag(betas)
156
+ probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
157
+
158
+ logits = (preds @ targs.T) * temp + bias
159
+ labels = probs * 2 - 1
160
+ #labels = torch.eye(len(targs)).to(targs.dtype).to(targs.device) * 2 - 1
161
+
162
+ loss1 = -torch.sum(nn.functional.logsigmoid(logits * labels)) / len(preds)
163
+ loss2 = -torch.sum(nn.functional.logsigmoid(logits.T * labels)) / len(preds)
164
+ loss = (loss1 + loss2)/2
165
+ return loss
166
+
167
+ def mixco(voxels, beta=0.15, s_thresh=0.5, perm=None, betas=None, select=None):
168
+ if perm is None:
169
+ perm = torch.randperm(voxels.shape[0])
170
+ voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype)
171
+ if betas is None:
172
+ betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype)
173
+ if select is None:
174
+ select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device)
175
+ betas_shape = [-1] + [1]*(len(voxels.shape)-1)
176
+ voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \
177
+ voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape)
178
+ betas[~select] = 1
179
+ return voxels, perm, betas, select
180
+
181
+ def mixco_clip_target(clip_target, perm, select, betas):
182
+ clip_target_shuffle = clip_target[perm]
183
+ clip_target[select] = clip_target[select] * betas[select].reshape(-1, 1) + \
184
+ clip_target_shuffle[select] * (1 - betas[select]).reshape(-1, 1)
185
+ return clip_target
186
+
187
+ def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False,
188
+ accelerator=None, local_rank=None, bidirectional=True):
189
+ brain_clip = (preds @ targs.T)/temp
190
+
191
+ if perm is not None and betas is not None and select is not None:
192
+ probs = torch.diag(betas)
193
+ probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
194
+
195
+ loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()
196
+ if bidirectional:
197
+ loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean()
198
+ loss = (loss + loss2)/2
199
+ return loss
200
+ else:
201
+ loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
202
+ if bidirectional:
203
+ loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
204
+ loss = (loss + loss2)/2
205
+ return loss
206
+
207
+ def count_params(model):
208
+ total = sum(p.numel() for p in model.parameters())
209
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
210
+ print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable))
211
+
212
+ def image_grid(imgs, rows, cols):
213
+ w, h = imgs[0].size
214
+ grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
215
+ for i, img in enumerate(imgs):
216
+ grid.paste(img, box=(i%cols*w, i//cols*h))
217
+ return grid
218
+
219
+ def check_loss(loss):
220
+ if loss.isnan().any():
221
+ raise ValueError('NaN loss')
222
+
223
+ def cosine_anneal(start, end, steps):
224
+ return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1)))
225
+
226
+ def resize(img, img_size=128):
227
+ if img.ndim == 3: img = img[None]
228
+ return nn.functional.interpolate(img, size=(img_size, img_size), mode='nearest')
229
+
230
+ def patchify(img, patch_size=16):
231
+ B, C, H, W = img.size()
232
+ patches = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
233
+ patches = patches.contiguous().view(B, C, -1, patch_size, patch_size)
234
+ return patches.permute(0, 2, 1, 3, 4)
235
+
236
+ def unpatchify(patches):
237
+ B, N, C, H, W = patches.shape # B=Batch size, N=Number of patches, C=Channels, H=Height, W=Width
238
+ patches = patches.view(B, int(N**0.5), int(N**0.5), C, H, W)
239
+ patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
240
+ return patches.view(B, C, H*int(N**0.5), W*int(N**0.5))
241
+
242
+ import braceexpand
243
+ def get_dataloaders(
244
+ batch_size,
245
+ image_var='images',
246
+ num_devices=None,
247
+ num_workers=None,
248
+ train_url=None,
249
+ val_url=None,
250
+ meta_url=None,
251
+ num_train=None,
252
+ num_val=None,
253
+ cache_dir="/scratch/tmp/wds-cache",
254
+ seed=0,
255
+ voxels_key="nsdgeneral.npy",
256
+ val_batch_size=None,
257
+ to_tuple=["voxels", "images", "trial"],
258
+ local_rank=0,
259
+ world_size=1,
260
+ ):
261
+ print("Getting dataloaders...")
262
+ assert image_var == 'images'
263
+
264
+ def my_split_by_node(urls):
265
+ return urls
266
+
267
+ train_url = list(braceexpand.braceexpand(train_url))
268
+ val_url = list(braceexpand.braceexpand(val_url))
269
+
270
+ if num_devices is None:
271
+ num_devices = torch.cuda.device_count()
272
+
273
+ if num_workers is None:
274
+ num_workers = num_devices
275
+
276
+ if num_train is None:
277
+ metadata = json.load(open(meta_url))
278
+ num_train = metadata['totals']['train']
279
+ if num_val is None:
280
+ metadata = json.load(open(meta_url))
281
+ num_val = metadata['totals']['val']
282
+
283
+ if val_batch_size is None:
284
+ val_batch_size = batch_size
285
+
286
+ global_batch_size = batch_size * num_devices
287
+ num_batches = math.floor(num_train / global_batch_size)
288
+ num_worker_batches = math.floor(num_batches / num_workers)
289
+ if num_worker_batches == 0: num_worker_batches = 1
290
+
291
+ print("\nnum_train",num_train)
292
+ print("global_batch_size",global_batch_size)
293
+ print("batch_size",batch_size)
294
+ print("num_workers",num_workers)
295
+ print("num_batches",num_batches)
296
+ print("num_worker_batches", num_worker_batches)
297
+
298
+ # train_url = train_url[local_rank:world_size]
299
+ train_data = wds.WebDataset(train_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
300
+ .shuffle(500, initial=500, rng=random.Random(42))\
301
+ .decode("torch")\
302
+ .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
303
+ .to_tuple(*to_tuple)#\
304
+ # .batched(batch_size, partial=True)#\
305
+ # .with_epoch(num_worker_batches)
306
+
307
+ # BATCH SIZE SHOULD BE NONE!!! FOR TRAIN AND VAL | resampled=True for train | .batched(val_batch_size, partial=False)
308
+ train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=False)
309
+
310
+ # Validation
311
+ print("val_batch_size",val_batch_size)
312
+ val_data = wds.WebDataset(val_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
313
+ .shuffle(500, initial=500, rng=random.Random(42))\
314
+ .decode("torch")\
315
+ .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
316
+ .to_tuple(*to_tuple)#\
317
+ # .batched(val_batch_size, partial=True)
318
+ val_dl = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, num_workers=1, shuffle=False, drop_last=True)
319
+
320
+ return train_dl, val_dl, num_train, num_val
321
+
322
+ pixcorr_preprocess = transforms.Compose([
323
+ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
324
+ ])
325
+ def pixcorr(images,brains):
326
+ all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
327
+ all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
328
+ corrmean = torch.diag(batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
329
+ return corrmean
330
+
331
+ pixcorr_origsize_nanmean_preprocess = transforms.Compose([
332
+ transforms.Resize(128, interpolation=transforms.InterpolationMode.BILINEAR),
333
+ ])
334
+ def pixcorr_origsize_nanmean(images,brains):
335
+ all_images_flattened = pixcorr_origsize_nanmean_preprocess(images).reshape(len(images), -1)
336
+ all_brain_recons_flattened = brains.view(len(brains), -1) # assuming it's already 128 size
337
+ corrmean = torch.nanmean(torch.diag(batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)))
338
+ return corrmean
339
+
340
+ def select_annotations(annots, random=False):
341
+ """
342
+ There are 5 annotations per image. Select one of them for each image.
343
+ """
344
+ for i, b in enumerate(annots):
345
+ t = ''
346
+ if random:
347
+ # select random non-empty annotation
348
+ while t == '':
349
+ rand = torch.randint(5, (1,1))[0][0]
350
+ t = b[rand]
351
+ else:
352
+ # select first non-empty annotation
353
+ for j in range(5):
354
+ if b[j] != '':
355
+ t = b[j]
356
+ break
357
+ if i == 0:
358
+ txt = np.array(t)
359
+ else:
360
+ txt = np.vstack((txt, t))
361
+ txt = txt.flatten()
362
+ return txt
363
+
364
+ def add_saturation(image, alpha=2):
365
+ gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]
366
+ gray_image = gray_image.unsqueeze(1).expand_as(image)
367
+ saturated_image = alpha * image + (1 - alpha) * gray_image
368
+ return torch.clamp(saturated_image, 0, 1)
train_mem_logs/error.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c5b11d49c3a54be008689ed7a339164c4740c698ccf562bb069ae274ee9f834
3
+ size 8108517524
train_mem_logs/error_tensors.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ef06b797eafcaa4038c81293ce5a4154707134f813b82976d81726e021169fd
3
+ size 3534675888
train_mem_logs/test/last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5457bfd8fb50a44c250c8688aa2846ff9bf1ca8c7c03c9cfaff90b565f327b3e
3
+ size 16742654526
train_mem_logs/test_mem/last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e82bf58463b16ffe2049b7788c4de136829f0d6c847335e345630d11e84e2ee3
3
+ size 8066548705