mindeyev2old2 / src /utils.py
ckadirt's picture
Upload folder using huggingface_hub
b8ea2b2 verified
import numpy as np
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import PIL
import random
import os
import matplotlib.pyplot as plt
import pandas as pd
import math
import webdataset as wds
import tempfile
from torchvision.utils import make_grid
# from diffusers.utils import randn_tensor
import json
from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import Image
import requests
import io
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def is_interactive():
import __main__ as main
return not hasattr(main, '__file__')
def seed_everything(seed=0, cudnn_deterministic=True):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if cudnn_deterministic:
torch.backends.cudnn.deterministic = True
else:
## needs to be False to use conv3D
print('Note: not using cudnn.deterministic')
def np_to_Image(x):
if x.ndim==4:
x=x[0]
return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
def torch_to_Image(x):
if x.ndim==4:
x=x[0]
return transforms.ToPILImage()(x)
def Image_to_torch(x):
try:
x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5
except:
x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
return x
def torch_to_matplotlib(x,device=device):
if torch.mean(x)>10:
x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
else:
x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
if device=='cpu':
return x[0]
else:
return x.cpu().numpy()[0]
def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
#https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
numerator = A @ B.T
A_l2 = torch.mul(A, A).sum(axis=dim)
B_l2 = torch.mul(B, B).sum(axis=dim)
denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
return torch.div(numerator, denominator)
def batchwise_pearson_correlation(Z, B):
# Calculate means
Z_mean = torch.mean(Z, dim=1, keepdim=True)
B_mean = torch.mean(B, dim=1, keepdim=True)
# Subtract means
Z_centered = Z - Z_mean
B_centered = B - B_mean
# Calculate Pearson correlation coefficient
numerator = Z_centered @ B_centered.T
Z_centered_norm = torch.linalg.norm(Z_centered, dim=1, keepdim=True)
B_centered_norm = torch.linalg.norm(B_centered, dim=1, keepdim=True)
denominator = Z_centered_norm @ B_centered_norm.T
pearson_correlation = (numerator / denominator)
return pearson_correlation
def batchwise_cosine_similarity(Z,B):
# https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
B = B.T
Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1).
B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b).
cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
return cosine_similarity
def topk(similarities,labels,k=5):
if k > similarities.shape[0]:
k = similarities.shape[0]
topsum=0
for i in range(k):
topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
return topsum
def get_non_diagonals(a):
a = torch.triu(a,diagonal=1)+torch.tril(a,diagonal=-1)
# make diagonals -1
a=a.fill_diagonal_(-1)
return a
def gather_features(image_features, voxel_features, accelerator):
all_image_features = accelerator.gather(image_features.contiguous())
if voxel_features is not None:
all_voxel_features = accelerator.gather(voxel_features.contiguous())
return all_image_features, all_voxel_features
return all_image_features
def soft_clip_loss(preds, targs, temp=0.125): #, distributed=False, accelerator=None):
# if not distributed:
clip_clip = (targs @ targs.T)/temp
brain_clip = (preds @ targs.T)/temp
# else:
# all_targs = gather_features(targs, None, accelerator)
# clip_clip = (targs @ all_targs.T)/temp
# brain_clip = (preds @ all_targs.T)/temp
loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
loss = (loss1 + loss2)/2
return loss
def soft_siglip_loss(preds, targs, temp, bias):
temp = torch.exp(temp)
logits = (preds @ targs.T) * temp + bias
# diagonals (aka paired samples) should be >0 and off-diagonals <0
labels = (targs @ targs.T) - 1 + (torch.eye(len(targs)).to(targs.dtype).to(targs.device))
loss1 = -torch.sum(nn.functional.logsigmoid(logits * labels[:len(preds)])) / len(preds)
loss2 = -torch.sum(nn.functional.logsigmoid(logits.T * labels[:,:len(preds)])) / len(preds)
loss = (loss1 + loss2)/2
return loss
def mixco_hard_siglip_loss(preds, targs, temp, bias, perm, betas):
temp = torch.exp(temp)
probs = torch.diag(betas)
probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
logits = (preds @ targs.T) * temp + bias
labels = probs * 2 - 1
#labels = torch.eye(len(targs)).to(targs.dtype).to(targs.device) * 2 - 1
loss1 = -torch.sum(nn.functional.logsigmoid(logits * labels)) / len(preds)
loss2 = -torch.sum(nn.functional.logsigmoid(logits.T * labels)) / len(preds)
loss = (loss1 + loss2)/2
return loss
def mixco(voxels, beta=0.15, s_thresh=0.5, perm=None, betas=None, select=None):
if perm is None:
perm = torch.randperm(voxels.shape[0])
voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype)
if betas is None:
betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype)
if select is None:
select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device)
betas_shape = [-1] + [1]*(len(voxels.shape)-1)
voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \
voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape)
betas[~select] = 1
return voxels, perm, betas, select
def mixco_clip_target(clip_target, perm, select, betas):
clip_target_shuffle = clip_target[perm]
clip_target[select] = clip_target[select] * betas[select].reshape(-1, 1) + \
clip_target_shuffle[select] * (1 - betas[select]).reshape(-1, 1)
return clip_target
def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False,
accelerator=None, local_rank=None, bidirectional=True):
brain_clip = (preds @ targs.T)/temp
if perm is not None and betas is not None and select is not None:
probs = torch.diag(betas)
probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas
loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean()
if bidirectional:
loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean()
loss = (loss + loss2)/2
return loss
else:
loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
if bidirectional:
loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device))
loss = (loss + loss2)/2
return loss
def count_params(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable))
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def check_loss(loss):
if loss.isnan().any():
raise ValueError('NaN loss')
def cosine_anneal(start, end, steps):
return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1)))
def resize(img, img_size=128):
if img.ndim == 3: img = img[None]
return nn.functional.interpolate(img, size=(img_size, img_size), mode='nearest')
def patchify(img, patch_size=16):
B, C, H, W = img.size()
patches = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
patches = patches.contiguous().view(B, C, -1, patch_size, patch_size)
return patches.permute(0, 2, 1, 3, 4)
def unpatchify(patches):
B, N, C, H, W = patches.shape # B=Batch size, N=Number of patches, C=Channels, H=Height, W=Width
patches = patches.view(B, int(N**0.5), int(N**0.5), C, H, W)
patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
return patches.view(B, C, H*int(N**0.5), W*int(N**0.5))
import braceexpand
def get_dataloaders(
batch_size,
image_var='images',
num_devices=None,
num_workers=None,
train_url=None,
val_url=None,
meta_url=None,
num_train=None,
num_val=None,
cache_dir="/scratch/tmp/wds-cache",
seed=0,
voxels_key="nsdgeneral.npy",
val_batch_size=None,
to_tuple=["voxels", "images", "trial"],
local_rank=0,
world_size=1,
):
print("Getting dataloaders...")
assert image_var == 'images'
def my_split_by_node(urls):
return urls
train_url = list(braceexpand.braceexpand(train_url))
val_url = list(braceexpand.braceexpand(val_url))
if num_devices is None:
num_devices = torch.cuda.device_count()
if num_workers is None:
num_workers = num_devices
if num_train is None:
metadata = json.load(open(meta_url))
num_train = metadata['totals']['train']
if num_val is None:
metadata = json.load(open(meta_url))
num_val = metadata['totals']['val']
if val_batch_size is None:
val_batch_size = batch_size
global_batch_size = batch_size * num_devices
num_batches = math.floor(num_train / global_batch_size)
num_worker_batches = math.floor(num_batches / num_workers)
if num_worker_batches == 0: num_worker_batches = 1
print("\nnum_train",num_train)
print("global_batch_size",global_batch_size)
print("batch_size",batch_size)
print("num_workers",num_workers)
print("num_batches",num_batches)
print("num_worker_batches", num_worker_batches)
# train_url = train_url[local_rank:world_size]
train_data = wds.WebDataset(train_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
.shuffle(500, initial=500, rng=random.Random(42))\
.decode("torch")\
.rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
.to_tuple(*to_tuple)#\
# .batched(batch_size, partial=True)#\
# .with_epoch(num_worker_batches)
# BATCH SIZE SHOULD BE NONE!!! FOR TRAIN AND VAL | resampled=True for train | .batched(val_batch_size, partial=False)
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=1, shuffle=False)
# Validation
print("val_batch_size",val_batch_size)
val_data = wds.WebDataset(val_url, resampled=False, cache_dir=cache_dir, nodesplitter=my_split_by_node)\
.shuffle(500, initial=500, rng=random.Random(42))\
.decode("torch")\
.rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
.to_tuple(*to_tuple)#\
# .batched(val_batch_size, partial=True)
val_dl = torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, num_workers=1, shuffle=False, drop_last=True)
return train_dl, val_dl, num_train, num_val
pixcorr_preprocess = transforms.Compose([
transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
])
def pixcorr(images,brains):
all_images_flattened = pixcorr_preprocess(images).reshape(len(images), -1)
all_brain_recons_flattened = pixcorr_preprocess(brains).view(len(brains), -1)
corrmean = torch.diag(batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)).mean()
return corrmean
pixcorr_origsize_nanmean_preprocess = transforms.Compose([
transforms.Resize(128, interpolation=transforms.InterpolationMode.BILINEAR),
])
def pixcorr_origsize_nanmean(images,brains):
all_images_flattened = pixcorr_origsize_nanmean_preprocess(images).reshape(len(images), -1)
all_brain_recons_flattened = brains.view(len(brains), -1) # assuming it's already 128 size
corrmean = torch.nanmean(torch.diag(batchwise_pearson_correlation(all_images_flattened, all_brain_recons_flattened)))
return corrmean
def select_annotations(annots, random=False):
"""
There are 5 annotations per image. Select one of them for each image.
"""
for i, b in enumerate(annots):
t = ''
if random:
# select random non-empty annotation
while t == '':
rand = torch.randint(5, (1,1))[0][0]
t = b[rand]
else:
# select first non-empty annotation
for j in range(5):
if b[j] != '':
t = b[j]
break
if i == 0:
txt = np.array(t)
else:
txt = np.vstack((txt, t))
txt = txt.flatten()
return txt
def add_saturation(image, alpha=2):
gray_image = 0.2989 * image[:, 0, :, :] + 0.5870 * image[:, 1, :, :] + 0.1140 * image[:, 2, :, :]
gray_image = gray_image.unsqueeze(1).expand_as(image)
saturated_image = alpha * image + (1 - alpha) * gray_image
return torch.clamp(saturated_image, 0, 1)