{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f4d95fac-ac1d-473c-ab96-650f76e6aaf5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# # Code to convert this notebook to .py if you want to run it via command line or with Slurm\n", "# from subprocess import call\n", "# command = \"jupyter nbconvert Train.ipynb --to python\"\n", "# call(command,shell=True)" ] }, { "cell_type": "markdown", "id": "b0f0f4f3", "metadata": {}, "source": [ "# Import packages & functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bad764b-45c1-45ce-a716-8d055e09821a", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2023-11-19 16:32:39,711] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "import os\n", "import sys\n", "import json\n", "import argparse\n", "import numpy as np\n", "import math\n", "from einops import rearrange\n", "import time\n", "import random\n", "import h5py\n", "from tqdm import tqdm\n", "\n", "import webdataset as wds\n", "import gc\n", "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from torchvision import transforms\n", "from torchvision.transforms import ToPILImage #CHANGED (added)\n", "\n", "from accelerate import Accelerator, DeepSpeedPlugin\n", "\n", "# tf32 data type is faster than standard float32\n", "torch.backends.cuda.matmul.allow_tf32 = True\n", "\n", "# custom functions #\n", "import utils\n", "\n", "global_batch_size = 128 #128" ] }, { "cell_type": "code", "execution_count": 3, "id": "cc5d2e32-6027-4a19-bef4-5ca068db35bb", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LOCAL RANK 0\n" ] } ], "source": [ "### Multi-GPU config ###\n", "local_rank = os.getenv('RANK')\n", "if local_rank is None: \n", " local_rank = 0\n", "else:\n", " local_rank = int(local_rank)\n", "print(\"LOCAL RANK \", local_rank) \n", "\n", "num_devices = torch.cuda.device_count()\n", "if num_devices==0: num_devices = 1\n", "\n", "accelerator = Accelerator(split_batches=False)\n", "\n", "### UNCOMMENT BELOW STUFF TO USE DEEPSPEED (also comment out the immediately above \"accelerator = \" line) ###\n", "\n", "# if num_devices <= 1 and utils.is_interactive():\n", "# # can emulate a distributed environment for deepspeed to work in jupyter notebook\n", "# os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", "# os.environ[\"MASTER_PORT\"] = str(np.random.randint(10000)+9000)\n", "# os.environ[\"RANK\"] = \"0\"\n", "# os.environ[\"LOCAL_RANK\"] = \"0\"\n", "# os.environ[\"WORLD_SIZE\"] = \"1\"\n", "# os.environ[\"GLOBAL_BATCH_SIZE\"] = str(global_batch_size) # set this to your batch size!\n", "# global_batch_size = os.environ[\"GLOBAL_BATCH_SIZE\"]\n", "\n", "# # alter the deepspeed config according to your global and local batch size\n", "# if local_rank == 0:\n", "# with open('deepspeed_config_stage2.json', 'r') as file:\n", "# config = json.load(file)\n", "# config['train_batch_size'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"])\n", "# config['train_micro_batch_size_per_gpu'] = int(os.environ[\"GLOBAL_BATCH_SIZE\"]) // num_devices\n", "# with open('deepspeed_config_stage2.json', 'w') as file:\n", "# json.dump(config, file)\n", "# else:\n", "# # give some time for the local_rank=0 gpu to prep new deepspeed config file\n", "# time.sleep(10)\n", "# deepspeed_plugin = DeepSpeedPlugin(\"deepspeed_config_stage2.json\")\n", "# accelerator = Accelerator(split_batches=False, deepspeed_plugin=deepspeed_plugin)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b767ab6f-d4a9-47a5-b3bf-f56bf6760c0c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PID of this process = 2370606\n", "device: cuda\n", "Distributed environment: NO\n", "Num processes: 1\n", "Process index: 0\n", "Local process index: 0\n", "Device: cuda\n", "\n", "Mixed precision type: no\n", "\n", "distributed = False num_devices = 1 local rank = 0 world size = 1\n" ] } ], "source": [ "print(\"PID of this process =\",os.getpid())\n", "device = accelerator.device\n", "print(\"device:\",device)\n", "num_workers = num_devices\n", "print(accelerator.state)\n", "world_size = accelerator.state.num_processes\n", "distributed = not accelerator.state.distributed_type == 'NO'\n", "print(\"distributed =\",distributed, \"num_devices =\", num_devices, \"local rank =\", local_rank, \"world size =\", world_size)\n", "print = accelerator.print # only print if local_rank=0" ] }, { "cell_type": "markdown", "id": "9018b82b-c054-4463-9527-4b0c2a75bda6", "metadata": { "tags": [] }, "source": [ "# Configurations" ] }, { "cell_type": "code", "execution_count": 5, "id": "2b61fec7-72a0-4b67-86da-1375f1d9fbd3", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset', '--model_name=captions', '--subj=1', '--batch_size=128', '--n_samples_save=0', '--max_lr=3e-1', '--mixup_pct=.66', '--num_epochs=30', '--ckpt_interval=999', '--no-use_image_aug']\n" ] } ], "source": [ "# if running this interactively, can specify jupyter_args here for argparser to use\n", "if utils.is_interactive():\n", " # Example use\n", " jupyter_args = f\"--data_path=/fsx/proj-fmri/shared/mindeyev2_dataset \\\n", " --model_name=captions \\\n", " --subj=1 --batch_size={global_batch_size} --n_samples_save=0 \\\n", " --max_lr=3e-1 --mixup_pct=.66 --num_epochs=30 --ckpt_interval=999 --no-use_image_aug\"\n", " #max_lr=3e-5 originally\n", " jupyter_args = jupyter_args.split()\n", " print(jupyter_args)\n", " \n", " from IPython.display import clear_output # function to clear print outputs in cell\n", " %load_ext autoreload \n", " # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions\n", " %autoreload 2 " ] }, { "cell_type": "code", "execution_count": 6, "id": "2028bdf0-2f41-46d9-b6e7-86b870dbf16c", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global batch_size 128\n", "batch_size 128\n" ] } ], "source": [ "parser = argparse.ArgumentParser(description=\"Model Training Configuration\")\n", "parser.add_argument(\n", " \"--model_name\", type=str, default=\"testing\",\n", " help=\"name of model, used for ckpt saving and wandb logging (if enabled)\",\n", ")\n", "parser.add_argument(\n", " \"--data_path\", type=str, default=\"/fsx/proj-fmri/shared/natural-scenes-dataset\",\n", " help=\"Path to where NSD data is stored / where to download it to\",\n", ")\n", "parser.add_argument(\n", " \"--subj\",type=int, default=1, choices=[1,2,5,7],\n", ")\n", "parser.add_argument(\n", " \"--batch_size\", type=int, default=32,\n", " help=\"Batch size can be increased by 10x if only training v2c and not diffusion diffuser\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_log\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"whether to log to wandb\",\n", ")\n", "parser.add_argument(\n", " \"--resume_from_ckpt\",action=argparse.BooleanOptionalAction,default=False,\n", " help=\"if not using wandb and want to resume from a ckpt\",\n", ")\n", "parser.add_argument(\n", " \"--wandb_project\",type=str,default=\"stability\",\n", " help=\"wandb project name\",\n", ")\n", "parser.add_argument(\n", " \"--mixup_pct\",type=float,default=.33,\n", " help=\"proportion of way through training when to switch from BiMixCo to SoftCLIP\",\n", ")\n", "parser.add_argument(\n", " \"--use_image_aug\",action=argparse.BooleanOptionalAction,default=True,\n", " help=\"whether to use image augmentation\",\n", ")\n", "parser.add_argument(\n", " \"--num_epochs\",type=int,default=240,\n", " help=\"number of epochs of training\",\n", ")\n", "parser.add_argument(\n", " \"--lr_scheduler_type\",type=str,default='cycle',choices=['cycle','linear'],\n", ")\n", "parser.add_argument(\n", " \"--ckpt_saving\",action=argparse.BooleanOptionalAction,default=True,\n", ")\n", "parser.add_argument(\n", " \"--ckpt_interval\",type=int,default=5,\n", " help=\"save backup ckpt and reconstruct every x epochs\",\n", ")\n", "parser.add_argument(\n", " \"--seed\",type=int,default=42,\n", ")\n", "parser.add_argument(\n", " \"--max_lr\",type=float,default=3e-4,\n", ")\n", "parser.add_argument(\n", " \"--n_samples_save\",type=int,default=0,choices=[0,1],\n", " help=\"Number of reconstructions for monitoring progress, 0 will speed up training\",\n", ")\n", "\n", "if utils.is_interactive():\n", " args = parser.parse_args(jupyter_args)\n", "else:\n", " args = parser.parse_args()\n", "\n", "# create global variables without the args prefix\n", "for attribute_name in vars(args).keys():\n", " globals()[attribute_name] = getattr(args, attribute_name)\n", "\n", "print(\"global batch_size\", batch_size)\n", "batch_size = int(batch_size / num_devices)\n", "print(\"batch_size\", batch_size)" ] }, { "cell_type": "code", "execution_count": 7, "id": "60cd7f2c-37fd-426b-a0c6-633e51bc4c4d", "metadata": { "tags": [] }, "outputs": [], "source": [ "outdir = os.path.abspath(f'../train_logs/{model_name}')\n", "if not os.path.exists(outdir):\n", " os.makedirs(outdir,exist_ok=True)\n", "if use_image_aug:\n", " import kornia\n", " from kornia.augmentation.container import AugmentationSequential\n", " img_augment = AugmentationSequential(\n", " kornia.augmentation.RandomResizedCrop((224,224), (0.6,1), p=0.3),\n", " kornia.augmentation.Resize((224, 224)),\n", " kornia.augmentation.RandomHorizontalFlip(p=0.3),\n", " kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.3),\n", " kornia.augmentation.RandomGrayscale(p=0.3),\n", " same_on_batch=False,\n", " data_keys=[\"input\"],\n", " )" ] }, { "cell_type": "code", "execution_count": 8, "id": "e7807ba9-02b6-4bc0-873c-69869abe4091", "metadata": {}, "outputs": [], "source": [ "wandb_log = False" ] }, { "cell_type": "markdown", "id": "42d13c25-1369-4c49-81d4-83d713586096", "metadata": { "tags": [] }, "source": [ "# Prep data, models, and dataloaders" ] }, { "cell_type": "markdown", "id": "1c023f24-5233-4a15-a2f5-78487b3a8546", "metadata": {}, "source": [ "## Dataloader" ] }, { "cell_type": "code", "execution_count": 9, "id": "81084834-035f-4465-ad59-59e6b806a2f5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/train/{0..36}.tar\n", "/fsx/proj-fmri/shared/mindeyev2_dataset/wds/subj01/test/0.tar\n" ] } ], "source": [ "if subj==1:\n", " num_train = 24958\n", " num_test = 2770\n", "test_batch_size = num_test\n", "\n", "def my_split_by_node(urls): return urls\n", " \n", "train_url = f\"{data_path}/wds/subj0{subj}/train/\" + \"{0..36}.tar\"\n", "print(train_url)\n", "\n", "train_data = wds.WebDataset(train_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)\n", "\n", "test_url = f\"{data_path}/wds/subj0{subj}/test/\" + \"0.tar\"\n", "print(test_url)\n", "\n", "test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\\\n", " .shuffle(750, initial=1500, rng=random.Random(42))\\\n", " .decode(\"torch\")\\\n", " .rename(behav=\"behav.npy\", past_behav=\"past_behav.npy\", future_behav=\"future_behav.npy\", olds_behav=\"olds_behav.npy\")\\\n", " .to_tuple(*[\"behav\", \"past_behav\", \"future_behav\", \"olds_behav\"])\n", "test_dl = torch.utils.data.DataLoader(test_data, batch_size=test_batch_size, shuffle=False, drop_last=False, pin_memory=True)" ] }, { "cell_type": "markdown", "id": "203b060a-2dd2-4c35-929b-c576be82eb52", "metadata": {}, "source": [ "### check dataloaders are working" ] }, { "cell_type": "code", "execution_count": 10, "id": "e7a9c68c-c3c9-4080-bd99-067c4486dc37", "metadata": { "tags": [] }, "outputs": [], "source": [ "# test_indices = []\n", "# test_images = []\n", "# for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", "# test_indices = np.append(test_indices, behav[:,0,5].numpy())\n", "# test_images = np.append(test_images, behav[:,0,0].numpy())\n", "# test_indices = test_indices.astype(np.int16)\n", "# print(test_i, (test_i+1) * test_batch_size, len(test_indices))\n", "# print(\"---\\n\")\n", "\n", "# train_indices = []\n", "# train_images = []\n", "# for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", "# train_indices = np.append(train_indices, behav[:,0,5].long().numpy())\n", "# train_images = np.append(train_images, behav[:,0,0].numpy())\n", "# train_indices = train_indices.astype(np.int16)\n", "# print(train_i, (train_i+1) * batch_size, len(train_indices))\n", "\n", "# # train_images = np.hstack((train_images, test_images))\n", "# # print(\"WARNING: ADDED TEST IMAGES TO TRAIN IMAGES\")" ] }, { "cell_type": "markdown", "id": "45fad12c-f9fb-4408-8fd4-9bca324ad634", "metadata": {}, "source": [ "## Load data and images" ] }, { "cell_type": "code", "execution_count": 11, "id": "039dd330-7339-4f88-8f00-45f95e47baa0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "subj01 betas loaded into memory\n", "voxels torch.Size([27750, 15729])\n", "images torch.Size([73000, 3, 224, 224])\n" ] } ], "source": [ "# load betas\n", "f = h5py.File(f'{data_path}/betas_all_subj0{subj}.hdf5', 'r')\n", "voxels = f['betas'][:]\n", "print(f\"subj0{subj} betas loaded into memory\")\n", "voxels = torch.Tensor(voxels).to(\"cpu\").half()\n", "if subj==1:\n", " voxels = torch.hstack((voxels, torch.zeros((len(voxels), 5))))\n", "print(\"voxels\", voxels.shape)\n", "num_voxels = voxels.shape[-1]\n", "\n", "# load orig images\n", "f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')\n", "images = f['images'][:]\n", "images = torch.Tensor(images).to(\"cpu\").half()\n", "print(\"images\", images.shape)" ] }, { "cell_type": "markdown", "id": "10ec4517-dbdf-4ece-98f6-4714d5de4e15", "metadata": {}, "source": [ "## Load models" ] }, { "cell_type": "markdown", "id": "48d6160e-1ee8-4da7-a755-9dbb452a6fa5", "metadata": {}, "source": [ "### CLIP image embeddings model" ] }, { "cell_type": "code", "execution_count": 12, "id": "795e2885-bd07-4e27-bed7-181473c06df9", "metadata": { "tags": [] }, "outputs": [], "source": [ "import transformers\n", "from transformers import Blip2Processor, Blip2ForConditionalGeneration\n", "\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 13, "id": "b0420dc0-199e-4c1a-857d-b1747058b467", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ViT-L/14 cuda:0\n" ] } ], "source": [ "from models import Clipper\n", "clip_model = Clipper(\"ViT-L/14\", device=torch.device(f\"cuda:{local_rank}\"), hidden_state=True, norm_embs=True)" ] }, { "cell_type": "code", "execution_count": 14, "id": "23428fb7-2955-4295-bea1-447cebf9f72e", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:08<00:00, 34.47s/it]\n" ] }, { "data": { "text/plain": [ "'from lavis.models import load_model_and_preprocess\\nfrom lavis.models import model_zoo\\nblip2_model, vis_processors, _ = load_model_and_preprocess(\\n name=\"blip2_t5\", model_type=\"pretrain_flant5xl_vitL\", is_eval=True, device=device)\\n\\nclip_seq_dim = 257\\nclip_emb_dim = 1024\\nhidden_dim = 4096'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cache_blip2 = \"/fsx/proj-fmri/shared/cache/models--Salesforce--blip2-opt-2.7b/snapshots/6e723d92ee91ebcee4ba74d7017632f11ff4217b\"\n", "\n", "b2_processor = Blip2Processor.from_pretrained(cache_blip2)\n", "b2_model = Blip2ForConditionalGeneration.from_pretrained(cache_blip2, torch_dtype=torch.float16, device_map=\"auto\")\n", "\n", "#Load in blip2 as well\n", "\"\"\"from lavis.models import load_model_and_preprocess\n", "from lavis.models import model_zoo\n", "blip2_model, vis_processors, _ = load_model_and_preprocess(\n", " name=\"blip2_t5\", model_type=\"pretrain_flant5xl_vitL\", is_eval=True, device=device)\n", "\n", "clip_seq_dim = 257\n", "clip_emb_dim = 1024\n", "hidden_dim = 4096\"\"\"" ] }, { "cell_type": "code", "execution_count": 74, "id": "b06f3de2-a8da-4ba0-94f0-99096f738d55", "metadata": { "tags": [] }, "outputs": [], "source": [ "def embed_images_b2(images):\n", " images = (images * 255).type(torch.uint8)\n", " with torch.no_grad():\n", " inputs_processed = b2_processor(images, return_tensors=\"pt\").to(\"cuda\", torch.float16)\n", " enc_imgs = b2_model.vision_model.forward(inputs_processed['pixel_values'])\n", " return enc_imgs.last_hidden_state.detach(), inputs_processed\n", "\n", "def embeds_to_captions_b2(embeds, sample = False, temp = 0.9):\n", " with torch.no_grad():\n", " input_ids = None #inputs['input_ids']\n", " attention_mask = None\n", " batch_size = embeds.shape[0]\n", " image_embeds = embeds\n", " image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)\n", "\n", " query_tokens = b2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)\n", " query_outputs = b2_model.qformer(\n", " query_embeds=query_tokens,\n", " encoder_hidden_states=image_embeds,\n", " encoder_attention_mask=image_attention_mask,\n", " return_dict=True,\n", " )\n", " query_output = query_outputs.last_hidden_state\n", "\n", " language_model_inputs = b2_model.language_projection(query_output)\n", " language_attention_mask = torch.ones(\n", " language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device\n", " )\n", " if input_ids is None:\n", " input_ids = (\n", " torch.LongTensor([[b2_model.config.text_config.bos_token_id]])\n", " .repeat(batch_size, 1)\n", " .to(image_embeds.device)\n", " )\n", " if attention_mask is None:\n", " attention_mask = torch.ones_like(input_ids)\n", " attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)\n", "\n", " # concatenate query embeddings with prompt embeddings\n", " inputs_embeds = b2_model.get_input_embeddings()(input_ids)\n", " inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)\n", "\n", " outputs = b2_model.language_model.generate(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=attention_mask,\n", " temperature=temp,\n", " do_sample = sample\n", " )\n", " text = b2_processor.batch_decode(outputs, skip_special_tokens=True)\n", " \n", " return outputs, text\n" ] }, { "cell_type": "code", "execution_count": 73, "id": "51b29638-2c81-4e9f-b06d-525fdbac44b1", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 2, 6209, 14, 10, 205, 425, 13, 10, 7297, 1280,\n", " 9, 418, 116, 1437, 38, 10728, 33, 117, 1114, 99]],\n", " device='cuda:0')" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b2_model.language_model.generate(do_sample = True, temperature=1)" ] }, { "cell_type": "code", "execution_count": 16, "id": "ec0a34d3-76e0-4a47-a9ab-6131ab2ccecd", "metadata": { "tags": [] }, "outputs": [], "source": [ "image_test = images[1:20].permute(0,2,3,1)\n", "#raw_image = Image.open('/fsx/proj-fmri/shared/controlNetData/target/img_t1.jpg').convert('RGB')\n", "# Convert the image to a NumPy array\n", "#image_test = np.array(raw_image)\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "e04876a4-45c7-4015-8255-8574c8f50f14", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "\"import matplotlib.pyplot as plt\\n# Plotting one of the images (taking the first image as an example)\\nimg_to_plot = inputs_rec['pixel_values'][-1]\\n\\n# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\\nimg_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\\nprint(img_to_plot.shape)\\n\\nplt.imshow(img_to_plot)\\nplt.show()\"" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"import matplotlib.pyplot as plt\n", "# Plotting one of the images (taking the first image as an example)\n", "img_to_plot = inputs_rec['pixel_values'][-1]\n", "\n", "# Transpose the image for correct display (PyTorch: [C, H, W], Matplotlib: [H, W, C])\n", "img_to_plot = img_to_plot.permute(1, 2, 0).to(torch.float32).to('cpu')\n", "print(img_to_plot.shape)\n", "\n", "plt.imshow(img_to_plot)\n", "plt.show()\"\"\"" ] }, { "cell_type": "code", "execution_count": 18, "id": "328a17d0-593b-4d1e-812a-10a3b6efea6a", "metadata": { "tags": [] }, "outputs": [], "source": [ "embeds_test, inputs_rec = embed_images_b2(image_test)" ] }, { "cell_type": "code", "execution_count": 19, "id": "abe5f8a8-fca9-4083-8596-a913bdb57de7", "metadata": { "tags": [] }, "outputs": [], "source": [ "#inputs_rec['pixel_values'].shape" ] }, { "cell_type": "code", "execution_count": 20, "id": "c5f3ca7e-b880-421e-b354-7b6c3df565e9", "metadata": { "tags": [] }, "outputs": [], "source": [ "#out = b2_model.generate(**inputs_rec)\n", "#print(b2_processor.decode(out[0], skip_special_tokens=True).strip())" ] }, { "cell_type": "code", "execution_count": 21, "id": "fb462016-78d7-46ea-8058-0d608f17ea65", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation/utils.py:1260: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n", " warnings.warn(\n" ] } ], "source": [ "outputs_test, text_test = embeds_to_captions_b2(embeds_test)" ] }, { "cell_type": "code", "execution_count": 22, "id": "6a95fcdf-db87-4c02-9728-09f85605fb1c", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "['a cat sitting on a toilet seat\\n',\n", " 'a person cutting a pizza on a cutting board\\n',\n", " 'a sandwich and a drink on a table\\n',\n", " 'a man crossing the street in front of a truck\\n',\n", " 'a giraffe standing in front of trees\\n',\n", " 'three men standing together\\n',\n", " 'a bird standing on a rock next to a body of water\\n',\n", " 'two men sitting on a street corner in asia\\n',\n", " 'a woman and two children playing tennis on a court\\n',\n", " 'a tall brick building with a clock on the side\\n',\n", " 'a train is on the tracks\\n',\n", " 'a man and woman in the water with a surfboard\\n',\n", " 'a living room with a desk and a chair\\n',\n", " 'a group of men on a basketball court\\n',\n", " 'a man holding an umbrella\\n',\n", " 'a man in a red shirt\\n',\n", " 'a group of people holding cell phones and wine glasses\\n',\n", " 'a laptop computer sitting on a table in front of a television\\n',\n", " 'a baseball player is swinging a bat on a field\\n']" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text_test" ] }, { "cell_type": "code", "execution_count": 23, "id": "9ac69fbd-55db-435b-bed6-5ae9186450e3", "metadata": { "tags": [] }, "outputs": [], "source": [ "#inputss['pixel_values'].shape" ] }, { "cell_type": "code", "execution_count": 24, "id": "0524f498-c8da-4e8a-8970-d75d2d0f6b8b", "metadata": { "tags": [] }, "outputs": [], "source": [ "#image_test.shape" ] }, { "cell_type": "code", "execution_count": 25, "id": "5417541b-49eb-4e43-a3e2-d937d9653e04", "metadata": { "tags": [] }, "outputs": [], "source": [ "max_lr = 1e-4" ] }, { "cell_type": "code", "execution_count": 26, "id": "da0ce190-1b3e-4c12-9e9f-91cbc076d044", "metadata": { "tags": [] }, "outputs": [], "source": [ "clip_seq_dim = 257 #blip2 image encoder shapes\n", "clip_emb_dim = 1408 #blip2 image encoder shapes\n", "hidden_dim = 2048" ] }, { "cell_type": "markdown", "id": "5b79bd38-6990-4504-8d45-4a68d57d8885", "metadata": {}, "source": [ "### SD VAE (blurry images)" ] }, { "cell_type": "code", "execution_count": 40, "id": "01baff79-8114-482b-b115-6f05aa8ad691", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "83,653,863 total\n", "0 trainable\n" ] } ], "source": [ "from diffusers import AutoencoderKL\n", "autoenc = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16, cache_dir=\"/fsx/proj-fmri/shared/cache\")\n", "# autoenc.load_state_dict(torch.load('../train_logs/sdxl_vae_normed/best.pth')[\"model_state_dict\"])\n", "autoenc.eval()\n", "autoenc.requires_grad_(False)\n", "autoenc.to(device)\n", "utils.count_params(autoenc)" ] }, { "cell_type": "markdown", "id": "260e5e4a-f697-4b2c-88fc-01f6a54886c0", "metadata": {}, "source": [ "### MindEye modules" ] }, { "cell_type": "code", "execution_count": 41, "id": "c44c271b-173f-472e-b059-a2eda0f4c4c5", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MindEyeModule()" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class MindEyeModule(nn.Module):\n", " def __init__(self):\n", " super(MindEyeModule, self).__init__()\n", " def forward(self, x):\n", " return x\n", " \n", "model = MindEyeModule()\n", "model" ] }, { "cell_type": "code", "execution_count": 42, "id": "038a5d61-4769-40b9-a004-f4e7b5b38bb0", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "32,215,040 total\n", "32,215,040 trainable\n", "param counts:\n", "32,215,040 total\n", "32,215,040 trainable\n", "torch.Size([2, 1, 15729]) torch.Size([2, 1, 2048])\n" ] } ], "source": [ "class RidgeRegression(torch.nn.Module):\n", " # make sure to add weight_decay when initializing optimizer\n", " def __init__(self, input_size, out_features): \n", " super(RidgeRegression, self).__init__()\n", " self.out_features = out_features\n", " self.linear = torch.nn.Linear(input_size, out_features)\n", " def forward(self, x):\n", " return self.linear(x)\n", " \n", "model.ridge = RidgeRegression(voxels.shape[1], out_features=hidden_dim)\n", "utils.count_params(model.ridge)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((2,1,voxels.shape[1]))\n", "print(b.shape, model.ridge(b).shape)" ] }, { "cell_type": "code", "execution_count": 43, "id": "3602c333-d029-465c-8fb4-c3ccffdba6fd", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "param counts:\n", "772,419,072 total\n", "772,419,072 trainable\n", "param counts:\n", "804,634,112 total\n", "804,634,112 trainable\n", "torch.Size([4, 2048])\n", "torch.Size([4, 257, 1408])\n" ] } ], "source": [ "from functools import partial\n", "from diffusers.models.vae import Decoder\n", "class BrainNetwork(nn.Module):\n", " def __init__(self, out_dim=768, in_dim=15724, clip_size=768, h=4096, n_blocks=4, norm_type='ln', act_first=False, drop=.15, blurry_dim=16):\n", " super().__init__()\n", " self.blurry_dim = blurry_dim\n", " norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h)\n", " act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU\n", " act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn)\n", " self.lin0 = nn.Linear(in_dim, h)\n", " self.mlp = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(h, h),\n", " *[item() for item in act_and_norm],\n", " nn.Dropout(drop)\n", " ) for _ in range(n_blocks)\n", " ])\n", " self.lin1 = nn.Linear(h, out_dim, bias=True)\n", " # self.blin1 = nn.Linear(out_dim, blurry_dim, bias=True)\n", " self.n_blocks = n_blocks\n", " self.clip_size = clip_size\n", " self.clip_proj = nn.Sequential(\n", " nn.LayerNorm(clip_size),\n", " nn.GELU(),\n", " nn.Linear(clip_size, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, 2048),\n", " nn.LayerNorm(2048),\n", " nn.GELU(),\n", " nn.Linear(2048, clip_size)\n", " )\n", " # self.upsampler = Decoder(\n", " # in_channels=64,\n", " # out_channels=4,\n", " # up_block_types=[\"UpDecoderBlock2D\",\"UpDecoderBlock2D\",\"UpDecoderBlock2D\"],\n", " # block_out_channels=[64, 128, 256],\n", " # layers_per_block=1,\n", " # )\n", " \n", " def forward(self, x):\n", " x = self.lin0(x)\n", " residual = x\n", " for res_block in range(self.n_blocks):\n", " x = self.mlp[res_block](x)\n", " x += residual\n", " residual = x\n", " x = x.reshape(len(x), -1)\n", " x = self.lin1(x)\n", " # b = self.blin1(x)\n", " # b = self.upsampler(b.reshape(len(b), -1, 7, 7))\n", " c = self.clip_proj(x.reshape(len(x), -1, self.clip_size))\n", " # return c, b\n", " return c\n", "\n", "model.backbone = BrainNetwork(h=2048, in_dim=hidden_dim, clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, blurry_dim=64*7*7) \n", "utils.count_params(model.backbone)\n", "utils.count_params(model)\n", "\n", "b = torch.randn((4,hidden_dim))\n", "print(b.shape)\n", "clip_ = model.backbone(b)\n", "print(clip_.shape)" ] }, { "cell_type": "code", "execution_count": 44, "id": "e14d0482-dc42-43b9-9ce1-953c32f2c9c1", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Done with model preparations!\n", "param counts:\n", "804,634,112 total\n", "804,634,112 trainable\n" ] } ], "source": [ "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n", "opt_grouped_parameters = [\n", " {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},\n", " {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},\n", "]\n", "\n", "optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr, betas=(0.9, 0.95))\n", "\n", "if lr_scheduler_type == 'linear':\n", " lr_scheduler = torch.optim.lr_scheduler.LinearLR(\n", " optimizer,\n", " total_iters=int(num_epochs*(num_train*num_devices//batch_size)),\n", " last_epoch=-1\n", " )\n", "elif lr_scheduler_type == 'cycle':\n", " total_steps=int(num_epochs*(num_train*num_devices//batch_size))\n", " lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(\n", " optimizer, \n", " max_lr=max_lr,\n", " total_steps=total_steps,\n", " final_div_factor=1000,\n", " last_epoch=-1, pct_start=2/num_epochs\n", " )\n", " \n", "def save_ckpt(tag): \n", " ckpt_path = outdir+f'/{tag}.pth'\n", " print(f'saving {ckpt_path}',flush=True)\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " try:\n", " torch.save({\n", " 'epoch': epoch,\n", " 'model_state_dict': unwrapped_model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'lr_scheduler': lr_scheduler.state_dict(),\n", " 'train_losses': losses,\n", " 'test_losses': test_losses,\n", " 'lrs': lrs,\n", " }, ckpt_path)\n", " except:\n", " print(\"Couldn't save... moving on to prevent crashing.\")\n", " del unwrapped_model\n", " \n", "print(\"\\nDone with model preparations!\")\n", "utils.count_params(model)" ] }, { "cell_type": "markdown", "id": "983f458b-35b8-49f2-b6db-80296cece730", "metadata": {}, "source": [ "# Weights and Biases" ] }, { "cell_type": "code", "execution_count": 32, "id": "0a25a662-daa8-4de9-9233-8364800fcb6b", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "wandb mindeyev2 run captions\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mckadirt\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "wandb_config:\n", " {'model_name': 'captions', 'batch_size': 128, 'num_epochs': 30, 'use_image_aug': False, 'max_lr': 0.0001, 'lr_scheduler_type': 'cycle', 'mixup_pct': 0.66, 'num_train': 24958, 'num_test': 2770, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1}\n" ] }, { "data": { "text/html": [ "wandb version 0.16.0 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /fsx/proj-fmri/ckadirt/MindEyeV2/src/wandb/run-20231119_163615-o1xwsqre" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run captions to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://stability.wandb.io/ckadirt/mindeyev2" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://stability.wandb.io/ckadirt/mindeyev2/runs/o1xwsqre" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# params for wandb\n", "if local_rank==0 and True: # only use main process for wandb logging\n", " import wandb\n", " \n", " wandb_project = 'mindeyev2'\n", " wandb_run = model_name\n", " wandb_notes = ''\n", " \n", " print(f\"wandb {wandb_project} run {wandb_run}\")\n", " wandb.login(host='https://stability.wandb.io')#, relogin=True)\n", " wandb_config = {\n", " \"model_name\": model_name,\n", " \"batch_size\": batch_size,\n", " \"num_epochs\": num_epochs,\n", " \"use_image_aug\": use_image_aug,\n", " \"max_lr\": max_lr,\n", " \"lr_scheduler_type\": lr_scheduler_type,\n", " \"mixup_pct\": mixup_pct,\n", " \"num_train\": num_train,\n", " \"num_test\": num_test,\n", " \"seed\": seed,\n", " \"distributed\": distributed,\n", " \"num_devices\": num_devices,\n", " \"world_size\": world_size,\n", " }\n", " print(\"wandb_config:\\n\",wandb_config)\n", " if False: # wandb_auto_resume\n", " print(\"wandb_id:\",model_name)\n", " wandb.init(\n", " id = model_name,\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " resume=\"allow\",\n", " )\n", " else:\n", " wandb.init(\n", " project=wandb_project,\n", " name=wandb_run,\n", " config=wandb_config,\n", " notes=wandb_notes,\n", " )\n", "else:\n", " wandb_log = False" ] }, { "cell_type": "code", "execution_count": 33, "id": "4e5de216-5318-4b45-ac02-113f03105adc", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭──────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
       " n++                                                                                              \n",
       "                                                                                                 \n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "SyntaxError: invalid syntax\n",
       "
\n" ], "text/plain": [ "\u001b[91m╭──────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", "\u001b[91m│\u001b[0m n++ \u001b[91m│\u001b[0m\n", "\u001b[91m│\u001b[0m \u001b[1;91m▲\u001b[0m \u001b[91m│\u001b[0m\n", "\u001b[91m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mSyntaxError: \u001b[0minvalid syntax\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [] }, { "cell_type": "markdown", "id": "5b0ae095-3203-4eb8-8606-acc2db6ccf20", "metadata": {}, "source": [ "# More custom functions" ] }, { "cell_type": "code", "execution_count": 34, "id": "827ead88-7eb3-47cc-82da-31565063b927", "metadata": { "tags": [] }, "outputs": [], "source": [ "# using the same preprocessing as was used in MindEye + BrainDiffuser\n", "pixcorr_preprocess = transforms.Compose([\n", " transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),\n", "])\n", "def pixcorr(images,brains):\n", " # Flatten images while keeping the batch dimension\n", " all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)\n", " all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)\n", " corrmean = torch.diag(utils.batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()\n", " return corrmean" ] }, { "cell_type": "markdown", "id": "d5690151-2131-4918-b750-e869cbd1a8a8", "metadata": {}, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 51, "id": "12de6387-6e18-4e4b-b5ce-a847d625330a", "metadata": { "tags": [] }, "outputs": [], "source": [ "epoch = 0\n", "losses, test_losses, lrs = [], [], []\n", "best_test_loss = 1e9\n", "soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))\n", "\n", "# Optionally resume from checkpoint #\n", "if resume_from_ckpt:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "elif wandb_log:\n", " if wandb.run.resumed:\n", " print(\"\\n---resuming from last.pth ckpt---\\n\")\n", " try:\n", " checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')\n", " except:\n", " print('last.pth failed... trying last_backup.pth')\n", " checkpoint = torch.load(outdir+'/last_backup.pth', map_location='cpu')\n", " epoch = checkpoint['epoch']\n", " print(\"Epoch\",epoch)\n", " optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", " lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])\n", " diffusion_diffuser.load_state_dict(checkpoint['model_state_dict'])\n", " del checkpoint\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 36, "id": "b4755749-2d99-4e98-ad98-3df661746058", "metadata": { "tags": [] }, "outputs": [], "source": [ "checkpoint = torch.load('/fsx/proj-fmri/ckadirt/MindEyeV2/train_logs/caption_clip_0.5_bz/last.pth', map_location='cpu')" ] }, { "cell_type": "code", "execution_count": 45, "id": "cd3dc793-5a20-4b48-959c-bc64430c8c02", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.load_state_dict(checkpoint['model_state_dict'])" ] }, { "cell_type": "code", "execution_count": 46, "id": "0faa2c6a-00da-4b66-b5e5-8c4864768805", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "MindEyeModule(\n", " (ridge): RidgeRegression(\n", " (linear): Linear(in_features=15729, out_features=2048, bias=True)\n", " )\n", " (backbone): BrainNetwork(\n", " (lin0): Linear(in_features=2048, out_features=2048, bias=True)\n", " (mlp): ModuleList(\n", " (0-3): 4 x Sequential(\n", " (0): Linear(in_features=2048, out_features=2048, bias=True)\n", " (1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.15, inplace=False)\n", " )\n", " )\n", " (lin1): Linear(in_features=2048, out_features=361856, bias=True)\n", " (clip_proj): Sequential(\n", " (0): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)\n", " (1): GELU(approximate='none')\n", " (2): Linear(in_features=1408, out_features=2048, bias=True)\n", " (3): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", " (4): GELU(approximate='none')\n", " (5): Linear(in_features=2048, out_features=2048, bias=True)\n", " (6): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n", " (7): GELU(approximate='none')\n", " (8): Linear(in_features=2048, out_features=1408, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "code", "execution_count": 47, "id": "99f09f76-4481-4133-b09a-a22b10dbc0c4", "metadata": { "tags": [] }, "outputs": [], "source": [ "model, optimizer, train_dl, test_dl, lr_scheduler = accelerator.prepare(\n", "model, optimizer, train_dl, test_dl, lr_scheduler\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "bfeeda32-82ca-4364-bce1-eaa41b4f3e25", "metadata": { "tags": [] }, "outputs": [], "source": [ "\"\"\"transform = transforms.Compose(\n", " [\n", " transforms.Resize(\n", " (224, 224),\n", " ),\n", " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n", " ]\n", " )\n", "\n", "def tensor_2_embed(image): \n", " image_for_blip2 = transform(image)\n", " \n", " #Generate embeddings\n", " with blip2_model.maybe_autocast():\n", " blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))\n", " \n", " return blip2_target\n", "\n", "def embed_2_caption(image_embeds, model):\n", " image_embeds = image_embeds.float()\n", " image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(\n", " image.device)\n", "\n", " query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)\n", " query_output = model.Qformer.bert(\n", " query_embeds=query_tokens,\n", " encoder_hidden_states=image_embeds,\n", " encoder_attention_mask=image_atts,\n", " return_dict=True)\n", "\n", " inputs_t5 = model.t5_proj(query_output.last_hidden_state)\n", " atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)\n", " prompt = model.prompt\n", " input_tokens = model.t5_tokenizer(\n", " prompt, padding=\"longest\", return_tensors=\"pt\"\n", " ).to(image.device)\n", " encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)\n", " \n", " with model.maybe_autocast(dtype=torch.bfloat16):\n", " inputs_embeds = model.t5_model.encoder.embed_tokens(input_tokens.input_ids)\n", " inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)\n", "\n", " outputs = model.t5_model.generate(\n", " inputs_embeds=inputs_embeds,\n", " attention_mask=encoder_atts)\n", " output_text = model.t5_tokenizer.batch_decode(\n", " outputs, skip_special_tokens=True)\n", " \n", " return output_text\"\"\"" ] }, { "cell_type": "code", "execution_count": 48, "id": "636b4684-df9a-4e29-8683-86fb035ba690", "metadata": { "tags": [] }, "outputs": [], "source": [ "wandb_log = False" ] }, { "cell_type": "code", "execution_count": 49, "id": "0847b380-2edb-4a56-9b33-fdc4c0c3f8d3", "metadata": { "tags": [] }, "outputs": [], "source": [ "predicted_embeddings = None" ] }, { "cell_type": "code", "execution_count": 52, "id": "60be0d5f-3e94-4612-9373-61b53d836393", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "captions starting with epoch 0 / 30\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/30 [00:17.456 to beat low-level subj01 results in mindeye v1\n", " \n", " \"\"\"for train_i, (behav, past_behav, future_behav, old_behav) in enumerate(train_dl):\n", " if epoch == 0:\n", " lrs.append(0)\n", " break\n", " with torch.cuda.amp.autocast():\n", " optimizer.zero_grad()\n", "\n", " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " \n", " image = images[behav[:,0,0].cpu().long()].to(device).float()\n", "\n", " # blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n", " \n", " if use_image_aug: image = img_augment(image)\n", " # clip_target = clip_model.embed_image(image)\n", " clip_target = embed_images_b2(image)[0].to(device) #####CHANGED\n", " assert not torch.any(torch.isnan(clip_target))\n", " \n", " if epoch < int(mixup_pct * num_epochs):\n", " voxel, perm, betas, select = utils.mixco(voxel)\n", "\n", " voxel_ridge = model.ridge(voxel)\n", " \n", " # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n", " clip_voxels = model.backbone(voxel_ridge)\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", "\n", " if epoch < int(mixup_pct * num_epochs): \n", " loss_clip = utils.mixco_nce(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=.006, \n", " perm=perm, betas=betas, select=select)\n", " else:\n", " epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]\n", " loss_clip = utils.soft_clip_loss(\n", " clip_voxels_norm,\n", " clip_target_norm,\n", " temp=epoch_temp)\n", " \n", " loss_mse= mse(clip_voxels, clip_target)\n", "\n", " # loss_blurry = mse(blurry_image_enc_, blurry_image_enc) \n", "\n", " loss_clip_total += loss_clip.item()\n", " # loss_blurry_total += loss_blurry.item()\n", "\n", " # loss = loss_blurry + loss_clip\n", " loss = 0.7 * loss_clip + 0.3 * loss_mse\n", " if (train_i % 10 == 0):\n", " print(train_i, loss)\n", " # print(batch_size)\n", " utils.check_loss(loss)\n", " accelerator.backward(loss)\n", " optimizer.step()\n", " \n", " losses.append(loss.item())\n", " lrs.append(optimizer.param_groups[0]['lr'])\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n", " bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n", "\n", " # with torch.no_grad():\n", " # # only doing pixcorr eval on a subset (8) of the samples per batch because its costly & slow to compute autoenc.decode()\n", " # random_samps = np.random.choice(np.arange(len(voxel)), size=8, replace=False)\n", " # blurry_recon_images = autoenc.decode(blurry_image_enc_[random_samps]).sample.clamp(0,1)\n", " # blurry_pixcorr += pixcorr(image[random_samps], blurry_recon_images)\n", "\n", " if lr_scheduler_type is not None:\n", " lr_scheduler.step()\"\"\"\n", " \n", " model.eval()\n", " for test_i, (behav, past_behav, future_behav, old_behav) in enumerate(test_dl):\n", " with torch.cuda.amp.autocast():\n", " with torch.no_grad(): \n", " # all test samples should be loaded per batch such that test_i should never exceed 0\n", " if len(behav) != num_test: print(\"!\",len(behav),num_test)\n", " \n", " ## Average same-image repeats ##\n", " if test_image is None:\n", " voxel = voxels[behav[:,0,5].cpu().long()].to(device)\n", " \n", " image = behav[:,0,0].cpu().long()\n", " \n", " unique_image, sort_indices = torch.unique(image, return_inverse=True)\n", " for im in unique_image:\n", " locs = torch.where(im == image)[0]\n", " if test_image is None:\n", " test_image = images[im][None]\n", " test_voxel = torch.mean(voxel[locs],axis=0)[None]\n", " else:\n", " test_image = torch.vstack((test_image, images[im][None]))\n", " test_voxel = torch.vstack((test_voxel, torch.mean(voxel[locs],axis=0)[None]))\n", " \n", " # sample of batch_size\n", " random_indices = torch.arange(len(test_voxel))[:batch_size] #torch.randperm(len(test_voxel))[:300]\n", " voxel = test_voxel[random_indices].to(device)\n", " image = test_image[random_indices].to(device)\n", " assert len(image) == batch_size\n", " \n", " # blurry_image_enc = autoenc.encode(image).latent_dist.mode()\n", " \n", " # clip_target = clip_model.embed_image(image.float())\n", " clip_target = embed_images_b2(image)[0].to(device) #####CHANGED\n", " \n", " voxel_ridge = model.ridge(voxel)\n", " \n", " # clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge)\n", " clip_voxels = model.backbone(voxel_ridge)\n", " \n", " predicted_embeddings = clip_voxels\n", " break\n", " \n", " clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1)\n", " clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1)\n", " \n", " # loss_clip = utils.soft_clip_loss(\n", " # clip_voxels_norm,\n", " # clip_target_norm,\n", " # temp=.006)\n", " \n", " loss_clip = mse(clip_voxels, clip_target)\n", "\n", " # loss_blurry = mse(blurry_image_enc_, blurry_image_enc)\n", " \n", " # loss = loss_blurry + loss_clip\n", " loss = loss_clip\n", " \n", " utils.check_loss(loss)\n", " \n", " test_losses.append(loss.item())\n", " \n", " # forward and backward top 1 accuracy \n", " labels = torch.arange(len(clip_target_norm)).to(clip_voxels_norm.device) \n", " test_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1)\n", " test_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1)\n", "\n", " # # halving the batch size because the decoder is computationally heavy\n", " # blurry_recon_images = autoenc.decode(blurry_image_enc_[:len(voxel)//2]).sample.clamp(0,1)\n", " # blurry_recon_images = torch.vstack((blurry_recon_images, autoenc.decode(blurry_image_enc_[len(voxel)//2:]).sample.clamp(0,1)))\n", " # test_blurry_pixcorr += pixcorr(image, blurry_recon_images)\n", "\n", " #Find captions and print next to images\n", " #caption1 = embed_2_caption(clip_voxels[[0]], blip2_model)\n", " #caption2 = embed_2_caption(clip_voxels[[1]], blip2_model)\n", "\n", " #true_embed1 = tensor_2_embed(image[[0]])\n", " #true_embed2 = tensor_2_embed(image[[1]])\n", "\n", " # print(clip_voxels[[0]].shape)\n", " # print(true_embed1.shape)\n", " \n", " #true_caption1 = embed_2_caption(true_embed1, blip2_model)\n", " #true_caption2 = embed_2_caption(true_embed2, blip2_model)\n", " \n", " # transform blurry recon latents to images and plot it\n", " #fig, axes = plt.subplots(2, 2, figsize=(8, 4))\n", " #axes[0,0].imshow(utils.torch_to_Image(image[[0]]))\n", " #axes[0,1].imshow(utils.torch_to_Image(image[[1]]))\n", " #axes[0,0].axis('off'); axes[0,1].axis('off'); axes[1,0].axis('off'); axes[1,1].axis('off')\n", " #axes[0,0].set_title(caption1)\n", " #axes[0,1].set_title(caption2)\n", " #axes[1,0].set_title(true_caption1)\n", " #axes[1,1].set_title(true_caption2)\n", "\n", " #plt.show()\n", " \n", " # # transform blurry recon latents to images and plot it\n", " # fig, axes = plt.subplots(1, 4, figsize=(8, 4))\n", " # axes[0].imshow(utils.torch_to_Image(image[[0]]))\n", " # axes[1].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[0]]).sample.clamp(0,1)))\n", " # axes[2].imshow(utils.torch_to_Image(image[[1]]))\n", " # axes[3].imshow(utils.torch_to_Image(autoenc.decode(blurry_image_enc_[[1]]).sample.clamp(0,1)))\n", " # axes[0].axis('off'); axes[1].axis('off'); axes[2].axis('off'); axes[3].axis('off')\n", " # axes[0].set_title(caption1)\n", " # axes[3].set_title(caption2)\n", " # plt.show()\n", " \n", " break\n", " if local_rank==0: \n", " # if utils.is_interactive(): clear_output(wait=True)\n", " assert (test_i+1) == 1\n", " logs = {\"train/loss\": np.mean(losses[-(train_i+1):]),\n", " \"test/loss\": np.mean(test_losses[-(test_i+1):]),\n", " \"train/lr\": lrs[-1],\n", " \"train/num_steps\": len(losses),\n", " \"test/num_steps\": len(test_losses),\n", " \"train/fwd_pct_correct\": fwd_percent_correct / (train_i + 1),\n", " \"train/bwd_pct_correct\": bwd_percent_correct / (train_i + 1),\n", " \"test/test_fwd_pct_correct\": test_fwd_percent_correct / (test_i + 1),\n", " \"test/test_bwd_pct_correct\": test_bwd_percent_correct / (test_i + 1),\n", " \"train/loss_clip_total\": loss_clip_total / (train_i + 1),\n", " \"train/loss_blurry_total\": loss_blurry_total / (train_i + 1),\n", " \"test/loss_clip_total\": test_loss_clip_total / (test_i + 1),\n", " \"test/loss_blurry_total\": test_loss_blurry_total / (test_i + 1),\n", " \"train/blurry_pixcorr\": blurry_pixcorr / (train_i + 1),\n", " \"test/blurry_pixcorr\": test_blurry_pixcorr / (test_i + 1),\n", " }\n", " progress_bar.set_postfix(**logs)\n", " \n", " fig, axes = plt.subplots(1, 8, figsize=(10, 4))\n", " jj=-1\n", " for j in [0,1,2,3,4,5,6,7]:\n", " jj+=1\n", " axes[jj].imshow(utils.torch_to_Image(image[j]))\n", " axes[jj].axis('off')\n", "\n", " if wandb_log:\n", " generated_captions = embeds_to_captions_b2(clip_voxels[0:8])\n", " print(generated_captions[1])\n", " logs[f\"test/recons\"] = wandb.Image(fig, caption=f\"epoch{epoch:03d}\" + \"\\n\".join(generated_captions[1]))\n", " plt.close()\n", " # Save model checkpoint and reconstruct\n", " if epoch % ckpt_interval == 0:\n", " if not utils.is_interactive():\n", " save_ckpt(f'last')\n", " \n", " if wandb_log: wandb.log(logs)\n", "\n", " # wait for other GPUs to catch up if needed\n", " accelerator.wait_for_everyone()\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", "\n", "print(\"\\n===Finished!===\\n\")\n", "if ckpt_saving:\n", " save_ckpt(f'last')\n", "if not utils.is_interactive():\n", " sys.exit(0)" ] }, { "cell_type": "code", "execution_count": 54, "id": "f5b47c76-a97a-48ee-b4b3-051c17aebac4", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "torch.Size([128, 257, 1408])" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted_embeddings.shape" ] }, { "cell_type": "code", "execution_count": 55, "id": "92d0029f-079f-4710-bf43-bc9e3fd08d5e", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/admin/home-ckadirt/miniconda3/envs/mindeye/lib/python3.10/site-packages/transformers/generation/utils.py:1260: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "['a group of people are sitting around a table\\n', 'a man is holding a glass of water in front of a television\\n', 'a man is riding a skateboard on a hill\\n', 'a group of people standing around a bike\\n', 'a building with a sign that says \"the house\"\\n', 'a plate of food with vegetables and meat\\n', 'a white cup with a small bottle of wine\\n', 'a group of people playing baseball and one is holding a ball\\n']\n" ] } ], "source": [ "generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8])\n", "print(generated_captions[1])" ] }, { "cell_type": "code", "execution_count": 75, "id": "88750a6d-0b61-4943-a7e5-1d675bbb4f8f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['a group of people are sitting at a table with food and drinks\\n', 'a man in a kitchen with a large screen\\n', 'a man on a surfboard with his legs in the air\\n', 'a group of people are standing on the beach in front of a boat\\n', 'a building with a sign that says \"home of the person\"\\n', 'a vegetable salad with a variety of vegetables and other ingredients\\n', 'a white cup with a small amount of coffee and a bottle of wine\\n', 'a group of people playing baseball and soccer\\n']\n" ] } ], "source": [ "generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8], sample = True, temp = 0.3)\n", "print(generated_captions[1])" ] }, { "cell_type": "code", "execution_count": 95, "id": "d99e7583-0f26-41c1-8035-a1aa3b1c2d55", "metadata": { "tags": [] }, "outputs": [], "source": [ "def concatenate_lists_any_depth(list1, list2):\n", " \"\"\"\n", " Concatenates two lists of potentially varying depths, forming a new list of lists.\n", "\n", " Args:\n", " list1 (list): The first list to concatenate. Elements can be of any type.\n", " list2 (list): The second list to concatenate. Elements can be of any type.\n", "\n", " Returns:\n", " list: A new list containing lists of elements from the original lists.\n", " \"\"\"\n", " # Ensure that both lists have the same length\n", " if len(list1) != len(list2):\n", " raise ValueError(\"Lists must be of the same length\")\n", "\n", " concatenated_list = []\n", "\n", " for a, b in zip(list1, list2):\n", " # If the elements are not lists, convert them to lists\n", " if not isinstance(a, list):\n", " a = [a]\n", " if not isinstance(b, list):\n", " b = [b]\n", "\n", " # Concatenate the lists\n", " concatenated_list.append(a + b)\n", "\n", " return concatenated_list" ] }, { "cell_type": "code", "execution_count": 96, "id": "ed8167ea-a3ab-438a-aa85-f1309047199c", "metadata": { "tags": [] }, "outputs": [], "source": [ "def sample_several(embeddings, num=10, temp=0.3):\n", " # embeddings shape = batch, 257, 1408\n", " results = None # Initialize results as None\n", "\n", " for i in range(num): # Iterate from 0 to num-1\n", " if results is None:\n", " # For the first iteration, assign the results directly\n", " results = embeds_to_captions_b2(embeddings, sample=True, temp=temp)[1]\n", " else:\n", " # For subsequent iterations, combine the new results with the existing ones\n", " new_results = embeds_to_captions_b2(embeddings, sample=True, temp=temp)[1]\n", " results = concatenate_lists_any_depth(results, new_results)\n", "\n", " return results # Return the combined results\n" ] }, { "cell_type": "code", "execution_count": 77, "id": "6700e130-8ae4-4475-a5b4-972fd8b9717a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['a group of people sitting on a bench in front of a building\\n', 'a woman is using a computer to make a video\\n', 'a man in a black shirt is sitting on a surfboard\\n', 'a group of people on the beach with a bike and some other things\\n', 'a large building with a sign that says \"the old farmhouse\"\\n', 'a plate with many different types of vegetables\\n', 'a white cup with a bottle of wine and a small bottle of wine\\n', 'a group of people are playing baseball in a field\\n']\n" ] } ], "source": [ "generated_captions = embeds_to_captions_b2(predicted_embeddings[0:8], sample = True, temp = 0.3)\n", "print(generated_captions[1])" ] }, { "cell_type": "code", "execution_count": 99, "id": "f0e111e3-6134-4a63-a6d7-17b3441be8c8", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[['people are sitting at a table with a bunch of chairs\\n',\n", " 'several people in the yard with some food\\n',\n", " 'people sitting on a bench near a water fountain\\n',\n", " 'a group of people are sitting around a table\\n',\n", " 'a group of people in a room with several people in the foreground\\n',\n", " 'a group of people sitting around a table with food\\n',\n", " 'the people in the background are sitting on the edge of a table\\n',\n", " 'beverages and food are served at a family picnic\\n',\n", " 'a group of people eating in a restaurant\\n',\n", " 'a group of people sitting around a table\\n',\n", " 'people are sitting at a table next to a tree\\n',\n", " 'people are sitting around a table with a lot of food\\n'],\n", " ['a person is holding a newspaper in a restaurant\\n',\n", " 'the man is holding a cup of coffee in front of a television\\n',\n", " 'a woman is preparing to cook in a kitchen\\n',\n", " 'a man working in an office setting with a computer and a man in a chair\\n',\n", " 'a person is using a smartphone in a restaurant\\n',\n", " 'a man is holding a glass of water in front of a television\\n',\n", " 'a man in a kitchen with a knife and a cup of coffee\\n',\n", " 'the kitchen at the new york times\\n',\n", " 'a man is holding a knife and cutting a piece of pizza\\n',\n", " 'a man is reading a book while another is working on a computer\\n',\n", " 'a person is using a computer to make a presentation\\n',\n", " 'a man is holding up a box of food\\n'],\n", " ['a man in a suit and a woman wearing a helmet on a surfboard\\n',\n", " 'a person is on the ground while holding onto a skateboard\\n',\n", " 'a man in a beach chair riding a skateboard\\n',\n", " 'a woman is standing on a surfboard in the ocean while holding a skateboard\\n',\n", " 'a man is riding on a surfboard\\n',\n", " 'a man on his knees in a surfboard with his leg up\\n',\n", " 'a man is doing a trick on a skateboard\\n',\n", " 'a man is riding a skateboard on a wave\\n',\n", " 'a person is sitting on a surfboard while another person is riding on it\\n',\n", " 'a man in a jumpsuit is holding onto a surfboard\\n',\n", " 'a man is jumping on a surfboard while another is sitting on it\\n',\n", " 'a man is sitting on a surfboard while he is riding\\n'],\n", " ['a picture of a man riding a bike next to a bike\\n',\n", " 'a group of people standing on a street with a bike\\n',\n", " 'people are sitting around a picnic table and a bike is being ridden\\n',\n", " 'a group of people are on a beach with a bike and a car\\n',\n", " \"the world's largest boat race is underway in the bay of britain\\n\",\n", " 'a man and his bike standing on the side of a road\\n',\n", " 'a motorcycle is sitting on top of a hill with a boat and a bicycle\\n',\n", " 'a bunch of people are standing around a large park\\n',\n", " 'a group of people standing around a table with bicycles\\n',\n", " 'a man with his bike and helmet in the air\\n',\n", " 'the sun is shining brightly and there are people walking around\\n',\n", " 'a group of people standing on a beach next to a boat\\n'],\n", " ['the home has a large yellow sign\\n',\n", " 'the building has two small windows and a sign\\n',\n", " 'a view of a home with a building in the background\\n',\n", " 'the house has been built in the style of a traditional english cottage\\n',\n", " 'the house is on the corner of a street\\n',\n", " 'the home is an old style building with a white door\\n',\n", " 'the house is in a residential area with many buildings\\n',\n", " 'the building is white and has a red roof\\n',\n", " 'the old building is now a park and recreation center\\n',\n", " 'a large building with a lot of windows and a lot of people\\n',\n", " 'a large house with a white door and a blue sign\\n',\n", " 'the house is in a residential area with a front and back door\\n'],\n", " ['the vegetables are arranged in a square shape on the table\\n',\n", " 'a plate full of vegetables and fruit with a knife and fork\\n',\n", " 'a plate of various vegetables with a knife\\n',\n", " 'a plate with several different types of food\\n',\n", " 'a plate of food with various vegetables and meat\\n',\n", " 'a picture of some vegetables and a plate of food\\n',\n", " 'a close up of several types of food on a table\\n',\n", " 'a plate of food with a variety of vegetables\\n',\n", " 'a large plate with many different types of food\\n',\n", " 'a plate of vegetables and meat on a table\\n',\n", " 'a plate with lots of different types of vegetables\\n',\n", " 'a close up of some food with a knife\\n'],\n", " ['a white cup with a green tea bag and a small bottle of alcohol\\n',\n", " 'a bottle of wine with two glasses and a spoon\\n',\n", " 'the chocolate bar is sitting next to the bottle of wine\\n',\n", " 'a white and black cup and a bottle of wine\\n',\n", " 'a white cup sitting next to some drinks\\n',\n", " 'a bottle of wine and a bottle of champagne on a table\\n',\n", " 'a white cup with two pills and a small bottle of wine\\n',\n", " 'a bottle of wine and a cup of coffee next to a bottle of wine\\n',\n", " 'a bottle of wine and a bottle of beer in a glass\\n',\n", " 'a bottle of wine, a bottle of beer and a wine bottle\\n',\n", " 'a bottle of wine and a cup with some food\\n',\n", " 'a glass of wine and a pair of glasses on a table\\n'],\n", " ['a group of people in white and blue uniforms playing baseball\\n',\n", " 'a group of people playing baseball in a field\\n',\n", " 'a group of people playing a game of baseball\\n',\n", " 'a group of people standing on a field and one is holding a tennis ball\\n',\n", " 'a group of people in uniform playing baseball\\n',\n", " 'two men and a woman in the middle of a game\\n',\n", " 'a group of people playing baseball in the grass\\n',\n", " 'a group of men and women are playing baseball\\n',\n", " 'july 15th, 2011 - june 20th, 2012 - june 17th,',\n", " 'the team is playing soccer and one is holding a ball\\n',\n", " 'people are playing baseball with each other and one is holding a ball\\n',\n", " 'the women are laughing and the man is running\\n']]" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "several = sample_several(predicted_embeddings[0:8], num = 12, temp = 0.5)\n", "several" ] }, { "cell_type": "code", "execution_count": 100, "id": "7ced031a-f259-4797-afd7-876fa62cdcfd", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[['a group of people are sitting around a table\\n',\n", " 'a group of people are sitting around a table with food\\n',\n", " 'a group of people sitting at a table with food\\n',\n", " 'a group of people are sitting on the ground in front of a table\\n',\n", " 'a group of people sitting around a table with a person and a dog\\n',\n", " 'a group of people are sitting on the ground and eating\\n',\n", " 'the group is sitting around a table with food\\n',\n", " 'people are sitting around a table with food\\n',\n", " 'a group of people sitting around a table with food\\n',\n", " 'the people are eating in front of a table\\n',\n", " 'a group of people are sitting on a bench in a field\\n',\n", " 'a group of people are sitting on a bench\\n'],\n", " ['a man is using a computer and a phone\\n',\n", " 'a person in a kitchen with a large screen\\n',\n", " 'a man is preparing food in a kitchen\\n',\n", " 'a man is standing in front of a computer and a woman is sitting behind him\\n',\n", " 'a man is using a computer to play a game\\n',\n", " 'a man is using a computer to play a game\\n',\n", " 'a man in a kitchen with a large television\\n',\n", " 'a man is holding a glass of water in front of a television\\n',\n", " 'the man is holding a bottle of water and a glass\\n',\n", " 'a man is using a computer to make a video\\n',\n", " 'a man is serving food at a restaurant\\n',\n", " 'a man is holding a drink in his hand\\n'],\n", " ['a man with a skateboard is riding on a wave\\n',\n", " 'a man is riding a skateboard on a hill\\n',\n", " 'a man is riding a skateboard on a hill\\n',\n", " 'a person is sitting on a surfboard while another person is riding on it\\n',\n", " 'a man is riding a surfboard on a wave\\n',\n", " 'a man with a skateboard is on top of a hill\\n',\n", " 'a person in a surfboard is riding a wave\\n',\n", " 'a man on a surfboard is riding on a wave\\n',\n", " 'a man in a suit and a woman in a bikini are playing on a surf board\\n',\n", " 'a man is riding a skateboard while wearing a helmet\\n',\n", " 'a man on the surf board with his legs in the air\\n',\n", " 'a man in a suit is playing a game with a skateboard\\n'],\n", " ['a group of people standing on a beach with a bike\\n',\n", " 'a group of people standing on a beach with a bike\\n',\n", " 'a group of people standing on a road with a bike and a car\\n',\n", " 'a group of people in the water with two bikes\\n',\n", " 'the bike is in the middle of the road and there are two people on the side of the',\n", " 'a group of people standing around a car with a bike\\n',\n", " 'a man is standing on a bike with a skateboard\\n',\n", " 'a group of people riding bicycles on a road\\n',\n", " 'a bicycle is in the middle of a field with a person on it\\n',\n", " 'a man is standing on a bicycle with a helmet and a skateboard\\n',\n", " 'a photo of a bicycle with a man on it\\n',\n", " 'a group of people riding bicycles on a road\\n'],\n", " ['a building with a sign that says \"the old man\"\\n',\n", " 'a house with a sign that says \"the house that james bond built\"\\n',\n", " 'a building with a sign that says \"the house\"\\n',\n", " 'a house with a sign that says \"museum\"\\n',\n", " 'a building with a sign that says \"the home of the person\"\\n',\n", " 'a building with a sign that says \"the museum of american history\"\\n',\n", " 'a white building with a sign on the side\\n',\n", " 'a brown house with a white roof and a green sign\\n',\n", " 'a house with a large sign on the side\\n',\n", " 'a building with a sign that says \"the building\"\\n',\n", " 'the building is in the middle of the street\\n',\n", " 'the front of an old building with a sign\\n'],\n", " ['a plate of different types of vegetables and meat\\n',\n", " 'a close up of some vegetables and meat\\n',\n", " 'a plate with a variety of different foods on it\\n',\n", " 'a plate of vegetables and meat with a green border\\n',\n", " 'a plate of vegetables with a variety of toppings\\n',\n", " 'a plate of food with different types of vegetables\\n',\n", " 'a plate of food with various vegetables and meat\\n',\n", " 'a plate of vegetables with some green leaves on it\\n',\n", " 'a bunch of vegetables and mushrooms on a plate\\n',\n", " 'a bunch of vegetables and fruit on a table\\n',\n", " 'a plate of vegetables and other items on a table\\n',\n", " 'a close up of some vegetables and meat\\n'],\n", " ['a white cup with a spoon and a spoon\\n',\n", " 'a bottle of wine and a bottle of champagne\\n',\n", " 'a white cup with a small bottle and a small bottle of wine\\n',\n", " 'a white cup with a small bottle of wine and a small bottle of water\\n',\n", " 'the bottle is open and the bottle is next to a cup\\n',\n", " 'the white cup with a small bottle of wine and a small bottle of wine\\n',\n", " 'a white cup with a black handle and a pair of scissors\\n',\n", " 'a bottle of wine and a bottle of wine glasses\\n',\n", " 'a bottle of wine and a bottle of champagne\\n',\n", " 'a white and black cup with a small spoon next to it\\n',\n", " 'a white cup with a small bottle of wine\\n',\n", " 'a white cup with a spoon and a bottle of wine\\n'],\n", " ['a group of people playing baseball and soccer\\n',\n", " 'a group of people are playing baseball in the grass\\n',\n", " 'a group of people playing baseball and running\\n',\n", " 'a group of people playing baseball and soccer\\n',\n", " 'a group of people playing soccer on a field\\n',\n", " 'a group of people are playing baseball in the grass\\n',\n", " 'a group of people playing baseball with a man in the background\\n',\n", " 'a group of people playing baseball and one is holding a ball\\n',\n", " 'a group of people playing baseball in front of a field\\n',\n", " 'a group of people playing baseball on a field\\n',\n", " 'a group of people playing baseball with one person in the background\\n',\n", " 'a group of people are playing baseball and one is holding a ball\\n']]" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "several = sample_several(predicted_embeddings[0:8], num = 12, temp = 0.3)\n", "several" ] }, { "cell_type": "code", "execution_count": null, "id": "93e87fde-815d-4452-9915-f5f5dacf7c2a", "metadata": { "tags": [] }, "outputs": [], "source": [ "plt.plot(losses)\n", "plt.show()\n", "plt.plot(test_losses)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "ccfccd4f-764d-4624-842c-f931676eb43b", "metadata": {}, "outputs": [], "source": [ "print('test')" ] }, { "cell_type": "code", "execution_count": null, "id": "f1a60e19-c440-4c9c-a634-30186209012f", "metadata": {}, "outputs": [], "source": [ "def tensor_2_embed_old(tensor):\n", " embed_array = torch.zeros((tensor.shape[0],257, 1024)) \n", " to_pil = ToPILImage()\n", " for sample in range(tensor.shape[0]):\n", " PIL_image = to_pil(tensor[sample])\n", " image_for_blip2 = vis_processors[\"eval\"](PIL_image).unsqueeze(0).to(device)\n", " #Generate embeddings\n", " with blip2_model.maybe_autocast():\n", " blip2_target = blip2_model.ln_vision(blip2_model.visual_encoder(image_for_blip2))\n", " embed_array[sample] = blip2_target\n", " \n", " return embed_array" ] }, { "cell_type": "code", "execution_count": null, "id": "d39ddada-47f7-4111-92fa-0dd98e8a83d6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ec8ed96a-61fa-4c20-8da2-fcd9d0a2ed38", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "6228eb1a-e8e7-4500-b7bc-d0c57bcac4c6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": true }, "toc-autonumbering": true, "vscode": { "interpreter": { "hash": "62aae01ef0cf7b6af841ab1c8ce59175c4332e693ab3d00bc32ceffb78a35376" } } }, "nbformat": 4, "nbformat_minor": 5 }