JiminHeo's picture
util
c429825
import torch
import numpy as np
import os
import pickle
from ldm.util import default
import glob
import PIL
import matplotlib.pyplot as plt
def load_file(filename):
with open(filename , 'rb') as file:
x = pickle.load(file)
return x
def save_file(filename, x, mode="wb"):
with open(filename, mode) as file:
pickle.dump(x, file)
def normalize_np(img):
""" Normalize img in arbitrary range to [0, 1] """
img -= np.min(img)
img /= np.max(img)
return img
def clear_color(x):
if torch.is_complex(x):
x = torch.abs(x)
x = x.detach().cpu().squeeze().numpy()
return normalize_np(np.transpose(x, (1, 2, 0)))
def to_img(sample):
return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255)
def save_plot(dir_name, tensors, labels, file_name="loss.png"):
t = np.linspace(0, len(tensors[0]), len(tensors[0]))
colours = ["r", "b", "g"]
plt.figure()
for j in range(len(tensors)):
plt.plot(t, tensors[j],color = colours[j], label = labels[j])
plt.legend()
plt.savefig(os.path.join(dir_name, file_name))
#plt.show()
def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None):
if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8)
else: sample_np = sample.astype(np.uint8)
for j in range(num_to_save):
if file_name is None:
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
else: file_name_img = f'{j}.png'
else: file_name_img = file_name
image_path = os.path.join(dir_name,file_name_img)
image_np = sample_np[j]
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
file_name_img = None
def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None):
recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample
recon_in = to_img(recon_in)
for j in range(num_to_save):
if file_name is None:
if k is not None: file_name_img = f'sample_{k+1}'f'{j}.png'
else: file_name_img = f'{j}.png'
else: file_name_img = file_name
image_path = os.path.join(dir_name, file_name_img)
image_np = recon_in.astype(np.uint8)[j]
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
file_name_img = None
def save_params(dir_name, mu_pos, logvar_pos, gamma,k):
params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()])
params_path = os.path.join(dir_name, f'{k+1}.pt')
torch.save(params_to_fit, params_path)
def custom_to_np(img):
sample = img.detach().cpu()
#sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
#sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
return sample
def encoder_kl(diff, img):
_, params = diff.encode_first_stage(img, return_all = True)
params = diff.scale_factor * params
mean, logvar = torch.chunk(params, 2, dim=1)
noise = default(None, lambda: torch.randn_like(mean))
mean = mean + diff.scale_factor*noise
return mean, logvar
def encoder_vq(diff, img):
quant = diff.encode_first_stage(img) #, diff, (_,_,ind)
quant = diff.scale_factor * quant
#mean, logvar = torch.chunk(params, 2, dim=1)
noise = default(None, lambda: torch.randn_like(quant))
mean = quant + diff.scale_factor*noise #
return mean
def clean_directory(dir_name):
files = glob.glob(dir_name)
for f in files:
os.remove(f)
def params_train( params ):
for item in params:
item.requires_grad = True
return params
def params_untrain(params):
for item in params:
item.requires_grad = False
return params
def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18):
step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda()
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
inv_idx = torch.arange(num_t_steps -1, -1, -1).long()
t_steps_fwd = t_steps[inv_idx]
#t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
return t_steps_fwd
def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) :
[lr, step_size, gamma] = [0.1, 10, 0.99] #was 0.999 for right-half: [0.01, 10, 0.99]
optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun
optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01
scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this
scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma)
return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2, scheduler_3]
def check_directory(filename_list):
for filename in filename_list:
if not os.path.exists(filename):
os.mkdir(filename)
def s_file(filename, x, mode="wb"):
with open(filename, mode) as file:
pickle.dump(x, file)
def r_file(filename, mode="rb"):
with open(filename, mode) as file:
x = pickle.load(file)
return x
def sample_from_gaussian(mu, alpha, sigma):
noise = torch.randn_like(mu)
return alpha*mu + sigma * noise
'''
def make_batch(image, mask=None, device=None):
image = torch.permute(image, (0,3,1,2))
batch_size = image.shape[0]
if mask is None :
mask = torch.zeros_like(image)
mask[0, :, :256, :128] = 1
else :
mask = torch.tensor(mask)
masked_image = (mask)*image #+ mask*noise*0.2
mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3])
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device)
return batch
def get_sigma_t_steps(net, n_steps=3, kwargs=None):
sigma_min = kwargs["sigma_min"]
sigma_max = kwargs["sigma_max"]
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
##Get the time-steps based on iddpm discretization
num_steps = n_steps #11 # kwargs["num_steps"]
C_2 = kwargs["C_2"]
C_1 = kwargs["C_1"]
M = kwargs["M"]
step_indices = torch.arange(num_steps, dtype=torch.float64).cuda()
u = torch.zeros(M + 1, dtype=torch.float64).cuda()
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
#print(sigma_steps)
##get noise schedule
sigma = lambda t: t
sigma_deriv = lambda t: 1
sigma_inv = lambda sigma: sigma
##scaling schedule
s = lambda t: 1
s_deriv = lambda t: 0
##compute some final time steps based on the corresponding noise levels.
t_steps = sigma_inv(net.round_sigma(sigma_steps))
return t_steps, sigma_inv, sigma, s, sigma_deriv
def data_replicate(data, K):
if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1])
else: data_batch = torch.Tensor.repeat(data,[K,1,1,1])
return data_batch
'''
def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None):
'''
sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
T_max = 1000
beta_start = 1 # 0.0015*T_max
beta_end = 15 # 0.0155*T_max
def var(t):
return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
'''
t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda()
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)
os.makedirs("out_temp2/", exist_ok=True)
for i, t in enumerate(t_steps_hierarchy):
t_hat = torch.ones(10).cuda() * (t)
e_out = self.model.model(x_t, t_hat)
var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
#score_out = - e_out / torch.sqrt()
a_t = 1 - var_t
#beta_t = 1 - a_t/a_prev
#std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t)
pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()
if i != len(t_steps_hierarchy) - 1:
var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2
a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda()
sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t)
#x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t)
'''
def pred_mean(pred_x0, z_t):
posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t)
posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t)
return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t
x_t = torch.sqrt(a_prev) * pred_x0 # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t)
'''
recon = self.model.decode_first_stage(pred_x0)
image_path = os.path.join("out_temp2/", f'{i}.png')
image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
return