import numpy as np from torchvision import transforms import torch import torch.nn as nn import torch.nn.functional as F import PIL import random import os import matplotlib.pyplot as plt import pandas as pd import math import webdataset as wds import tempfile from torchvision.utils import make_grid # from diffusers.utils import randn_tensor import json from torchmetrics.image.fid import FrechetInceptionDistance from PIL import Image import requests import io import time device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def is_interactive(): import __main__ as main return not hasattr(main, '__file__') def seed_everything(seed=0, cudnn_deterministic=True): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) if cudnn_deterministic: torch.backends.cudnn.deterministic = True else: ## needs to be False to use conv3D print('Note: not using cudnn.deterministic') def np_to_Image(x): if x.ndim==4: x=x[0] return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8')) def torch_to_Image(x): if x.ndim==4: x=x[0] return transforms.ToPILImage()(x) def Image_to_torch(x): try: x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5 except: x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5 return x def torch_to_matplotlib(x,device=device): if torch.mean(x)>10: x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8) else: x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8) if device=='cpu': return x[0] else: return x.cpu().numpy()[0] def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8): #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements numerator = A @ B.T A_l2 = torch.mul(A, A).sum(axis=dim) B_l2 = torch.mul(B, B).sum(axis=dim) denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps)) return torch.div(numerator, denominator) def batchwise_pearson_correlation(Z, B): # Calculate means Z_mean = torch.mean(Z, dim=1, keepdim=True) B_mean = torch.mean(B, dim=1, keepdim=True) # Subtract means Z_centered = Z - Z_mean B_centered = B - B_mean # Calculate Pearson correlation coefficient numerator = Z_centered @ B_centered.T Z_centered_norm = torch.linalg.norm(Z_centered, dim=1, keepdim=True) B_centered_norm = torch.linalg.norm(B_centered, dim=1, keepdim=True) denominator = Z_centered_norm @ B_centered_norm.T pearson_correlation = (numerator / denominator) return pearson_correlation def batchwise_cosine_similarity(Z,B): # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc B = B.T Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1). B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b). cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T return cosine_similarity def topk(similarities,labels,k=5): if k > similarities.shape[0]: k = similarities.shape[0] topsum=0 for i in range(k): topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels) return topsum def get_non_diagonals(a): a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1) # make diagonals -1 a=a.fill_diagonal_(-1) return a def gather_features(image_features, voxel_features, accelerator): all_image_features = accelerator.gather(image_features.contiguous()) if voxel_features is not None: all_voxel_features = accelerator.gather(voxel_features.contiguous()) return all_image_features, all_voxel_features return all_image_features def soft_clip_loss(preds, targs, temp=0.125): #, distributed=False, accelerator=None): # if not distributed: clip_clip = (targs @ targs.T)/temp brain_clip = (preds @ targs.T)/temp # else: # all_targs = gather_features(targs, None, accelerator) # clip_clip = (targs @ all_targs.T)/temp # brain_clip = (preds @ all_targs.T)/temp loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean() loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean() loss = (loss1 + loss2)/2 return loss def soft_siglip_loss(preds, targs, temp, bias): temp = torch.exp(temp) logits = (preds @ targs.T) * temp + bias # diagonals (aka paired samples) should be >0 and off-diagonals <0 labels = (targs @ targs.T) - 1 + (torch.eye(len(targs)).to(targs.dtype).to(targs.device)) loss1 = -torch.sum(nn.functional.logsigmoid(logits * labels[:len(preds)])) / len(preds) loss2 = -torch.sum(nn.functional.logsigmoid(logits.T * labels[:,:len(preds)])) / len(preds) loss = (loss1 + loss2)/2 return loss def mixco_hard_siglip_loss(preds, targs, temp, bias, perm, betas): temp = torch.exp(temp) probs = torch.diag(betas) probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas logits = (preds @ targs.T) * temp + bias labels = probs * 2 - 1 #labels = torch.eye(len(targs)).to(targs.dtype).to(targs.device) * 2 - 1 loss1 = -torch.sum(nn.functional.logsigmoid(logits * labels)) / len(preds) loss2 = -torch.sum(nn.functional.logsigmoid(logits.T * labels)) / len(preds) loss = (loss1 + loss2)/2 return loss def mixco(voxels, beta=0.15, s_thresh=0.5, perm=None, betas=None, select=None): if perm is None: perm = torch.randperm(voxels.shape[0]) voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype) if betas is None: betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype) if select is None: select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device) betas_shape = [-1] + [1]*(len(voxels.shape)-1) voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \ voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape) betas[~select] = 1 return voxels, perm, betas, select def mixco_clip_target(clip_target, perm, select, betas): clip_target_shuffle = clip_target[perm] clip_target[select] = clip_target[select] * betas[select].reshape(-1, 1) + \ clip_target_shuffle[select] * (1 - betas[select]).reshape(-1, 1) return clip_target def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False, accelerator=None, local_rank=None, bidirectional=True): brain_clip = (preds @ targs.T)/temp if perm is not None and betas is not None and select is not None: probs = torch.diag(betas) probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean() if bidirectional: loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean() loss = (loss + loss2)/2 return loss else: loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device)) if bidirectional: loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device)) loss = (loss + loss2)/2 return loss def count_params(model): total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable)) def image_grid(imgs, rows, cols): w, h = imgs[0].size grid = PIL.Image.new('RGB', size=(cols*w, rows*h)) for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) return grid def check_loss(loss): if loss.isnan().any(): raise ValueError('NaN loss') def cosine_anneal(start, end, steps): return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1))) def resize(img, img_size=128): if img.ndim == 3: img = img[None] return nn.functional.interpolate(img, size=(img_size, img_size), mode='nearest') def patchify(img, patch_size=16): B, C, H, W = img.size() patches = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size) patches = patches.contiguous().view(B, C, -1, patch_size, patch_size) return patches.permute(0, 2, 1, 3, 4) def unpatchify(patches): B, N, C, H, W = patches.shape # B=Batch size, N=Number of patches, C=Channels, H=Height, W=Width patches = patches.view(B, int(N**0.5), int(N**0.5), C, H, W) patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous() return patches.view(B, C, H*int(N**0.5), W*int(N**0.5)) import braceexpand def get_dataloaders( batch_size, image_var='images', num_devices=None, num_workers=None, train_url=None, val_url=None, meta_url=None, num_train=None, num_val=None, cache_dir="/scratch/tmp/wds-cache", seed=0, voxels_key="nsdgeneral.npy", val_batch_size=None, to_tuple=["voxels", "images", "trial"], local_rank=0, world_size=1, ): print("Getting dataloaders...") assert image_var == 'images' def my_split_by_node(urls): return urls train_url = list(braceexpand.braceexpand(train_url)) val_url = list(braceexpand.braceexpand(val_url)) if num_devices is None: num_devices = torch.cuda.device_count() if num_workers is None: num_workers = num_devices if num_train is None: metadata = json.load(open(meta_url)) num_train = metadata['totals']['train'] if num_val is None: metadata = json.load(open(meta_url)) num_val = metadata['totals']['val'] if val_batch_size is None: val_batch_size = batch_size global_batch_size = batch_size * num_devices num_batches = math.floor(num_train / global_batch_size) num_worker_batches = math.floor(num_batches / num_workers) if num_worker_batches == 0: num_worker_batches = 1 print("\nnum_train",num_train) print("global_batch_size",global_batch_size) print("batch_size",batch_size) print("num_workers",num_workers) print("num_batches",num_batches) print("num_worker_batches", num_worker_batches) # train_url = train_url[local_rank:world_size] train_data = wds.WebDataset(train_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\ .shuffle(500, initial=500, rng=random.Random(42))\ .decode("torch")\ .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\ .to_tuple(*to_tuple)#\ # .batched(batch_size, partial=True)#\ # .with_epoch(num_worker_batches) # BATCH SIZE SHOULD BE NONE!!! FOR TRAIN AND VAL | resampled=True for train | .batched(val_batch_size, partial=False) train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=False) # Validation print("val_batch_size",val_batch_size) val_data = wds.WebDataset(val_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\ .shuffle(500, initial=500, rng=random.Random(42))\ .decode("torch")\ .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\ .to_tuple(*to_tuple)#\ # .batched(val_batch_size, partial=True) val_dl = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, num_workers=1, shuffle=False, drop_last=True) return train_dl, val_dl, num_train, num_val pixcorr_preprocess = transforms.Compose([ transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), ]) def pixcorr(images,brains): all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1) all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1) corrmean = torch.diag(batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean() return corrmean pixcorr_origsize_nanmean_preprocess = transforms.Compose([ transforms.Resize(128, interpolation=transforms.InterpolationMode.BILINEAR), ]) def pixcorr_origsize_nanmean(images,brains): all_images_flattened = pixcorr_origsize_nanmean_preprocess(images).reshape(len(images), -1) all_brain_recons_flattened = brains.view(len(brains), -1) # assuming it's already 128 size corrmean = torch.nanmean(torch.diag(batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened))) return corrmean def select_annotations(annots, random=False): """ There are 5 annotations per image. Select one of them for each image. """ for i, b in enumerate(annots): t = '' if random: # select random non-empty annotation while t == '': rand = torch.randint(5, (1,1))[0][0] t = b[rand] else: # select first non-empty annotation for j in range(5): if b[j] != '': t = b[j] break if i == 0: txt = np.array(t) else: txt = np.vstack((txt, t)) txt = txt.flatten() return txt def add_saturation(image, alpha=2): gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :] gray_image = gray_image.unsqueeze(1).expand_as(image) saturated_image = alpha * image + (1 - alpha) * gray_image return torch.clamp(saturated_image, 0, 1)