import torch
import numpy as np
import os
import pickle
from ldm.util import default
import glob
import PIL
import matplotlib.pyplot as plt

def load_file(filename):
    with open(filename , 'rb') as file:
        x = pickle.load(file)
    return x

def save_file(filename, x, mode="wb"):
        with open(filename, mode) as file:
            pickle.dump(x, file)

def normalize_np(img):
    """ Normalize img in arbitrary range to [0, 1] """
    img -= np.min(img)
    img /= np.max(img)
    return img

def clear_color(x):
    if torch.is_complex(x):
        x = torch.abs(x)
    x = x.detach().cpu().squeeze().numpy()
    return normalize_np(np.transpose(x, (1, 2, 0)))

def to_img(sample):
    return (sample.detach().cpu().numpy().transpose(0,2,3,1) * 127.5 + 128).clip(0, 255)

def save_plot(dir_name, tensors, labels, file_name="loss.png"):
    t = np.linspace(0, len(tensors[0]), len(tensors[0]))
    colours = ["r", "b", "g"]
    plt.figure()
    for j in range(len(tensors)):
        plt.plot(t, tensors[j],color =  colours[j], label = labels[j])
    plt.legend()
    plt.savefig(os.path.join(dir_name, file_name))
    #plt.show()
    
def save_samples(dir_name, sample, k=None, num_to_save = 5, file_name = None):
    if type(sample) is not np.ndarray: sample_np = to_img(sample).astype(np.uint8)
    else: sample_np = sample.astype(np.uint8)
    
    for j in range(num_to_save): 
        if file_name is None: 
            if k is not None: file_name_img =  f'sample_{k+1}'f'{j}.png'
            else:  file_name_img = f'{j}.png'
        else: file_name_img = file_name
        image_path = os.path.join(dir_name,file_name_img)
        image_np = sample_np[j]
        PIL.Image.fromarray(image_np, 'RGB').save(image_path)
        file_name_img = None

def save_inpaintings(dir_name, sample, y, mask_pixel, k=None, num_to_save = 5, file_name = None):
    recon_in = y*(mask_pixel) + ( 1-mask_pixel)*sample
    recon_in = to_img(recon_in)
    for j in range(num_to_save): 
        if file_name is None: 
            if k is not None: file_name_img =  f'sample_{k+1}'f'{j}.png'
            else:  file_name_img = f'{j}.png'
        else: file_name_img = file_name
        image_path = os.path.join(dir_name, file_name_img)
        image_np = recon_in.astype(np.uint8)[j]
        PIL.Image.fromarray(image_np, 'RGB').save(image_path)
        file_name_img = None

def save_params(dir_name, mu_pos, logvar_pos, gamma,k):
    params_to_fit = params_untrain([mu_pos.detach().cpu(), logvar_pos.detach().cpu(), gamma.detach().cpu()])
    params_path = os.path.join(dir_name, f'{k+1}.pt')
    torch.save(params_to_fit, params_path)

def custom_to_np(img): 
    sample = img.detach().cpu()
    #sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    #sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample

def encoder_kl(diff, img):
    _, params = diff.encode_first_stage(img, return_all = True)
    params = diff.scale_factor * params
    mean, logvar = torch.chunk(params, 2, dim=1)
    noise = default(None, lambda: torch.randn_like(mean))
    mean = mean + diff.scale_factor*noise 
    return mean, logvar
      
def encoder_vq(diff, img):
    quant = diff.encode_first_stage(img) #, diff, (_,_,ind)
    quant = diff.scale_factor * quant
    #mean, logvar = torch.chunk(params, 2, dim=1)
    noise = default(None, lambda: torch.randn_like(quant))
    mean = quant + diff.scale_factor*noise #
    return mean

def clean_directory(dir_name):
    files = glob.glob(dir_name)
    for f in files:
        os.remove(f)

def params_train( params ): 
    for item in params: 
        item.requires_grad = True
    return params

def params_untrain(params):
    for item in params: 
        item.requires_grad = False
    return params

def time_descretization(sigma_min=0.002, sigma_max = 80, rho = 7, num_t_steps = 18):
    step_indices = torch.arange(num_t_steps, dtype=torch.float64).cuda()
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_t_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    inv_idx = torch.arange(num_t_steps -1, -1, -1).long()
    t_steps_fwd = t_steps[inv_idx]
    #t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
    return t_steps_fwd

def get_optimizers(means, variances, gamma_param, lr_init_gamma=0.01) :
    [lr, step_size, gamma] = [0.1, 10, 0.99]  #was 0.999  for right-half: [0.01, 10, 0.99] 
    optimizer = torch.optim.Adam([means], lr=lr, betas=(0.9, 0.99))
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    optimizer_2 = torch.optim.Adam([variances], lr=0.001, betas=(0.9, 0.99)) #0.001 for lsun
    optimizer_3 = torch.optim.Adam([gamma_param], lr=lr_init_gamma, betas=(0.9, 0.99)) #0.01

    scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=step_size, gamma=gamma) ##added this
    scheduler_3 = torch.optim.lr_scheduler.StepLR(optimizer_3, step_size=step_size, gamma=gamma)

    return [optimizer, optimizer_2, optimizer_3 ], [scheduler, scheduler_2,  scheduler_3]

def check_directory(filename_list):
    for filename in filename_list:
        if not os.path.exists(filename):
            os.mkdir(filename)

def s_file(filename, x, mode="wb"):
    with open(filename, mode) as file:
        pickle.dump(x, file)

def r_file(filename, mode="rb"): 
    with open(filename, mode) as file:
        x = pickle.load(file)   
    return x 

def sample_from_gaussian(mu, alpha, sigma):
    noise = torch.randn_like(mu)
    return alpha*mu + sigma * noise

'''
def make_batch(image, mask=None, device=None):
    image = torch.permute(image, (0,3,1,2)) 
    batch_size = image.shape[0]
    if mask is None : 
        mask = torch.zeros_like(image)
        mask[0, :, :256, :128] = 1
    else : 
        mask = torch.tensor(mask)
    masked_image = (mask)*image #+ mask*noise*0.2
    mask = mask[:,0,:,:].reshape(batch_size,1,image.shape[2], image.shape[3])
    batch = {"image": image, "mask": mask, "masked_image": masked_image}
    for k in batch:
        batch[k] = batch[k].to(device)
    return batch

def get_sigma_t_steps(net, n_steps=3, kwargs=None):
    sigma_min = kwargs["sigma_min"]
    sigma_max = kwargs["sigma_max"]
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    ##Get the time-steps based on iddpm discretization
    num_steps = n_steps #11 # kwargs["num_steps"]
    C_2 = kwargs["C_2"]
    C_1 = kwargs["C_1"]
    M = kwargs["M"]
    step_indices = torch.arange(num_steps, dtype=torch.float64).cuda()
    u = torch.zeros(M + 1, dtype=torch.float64).cuda()
    alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
    for j in torch.arange(M, 0, -1, device=step_indices.device): # M, ..., 1
        u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
    u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
    sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
    #print(sigma_steps)

    ##get noise schedule
    sigma = lambda t: t
    sigma_deriv = lambda t: 1
    sigma_inv = lambda sigma: sigma

    ##scaling schedule 
    s = lambda t: 1
    s_deriv = lambda t: 0

    ##compute some final time steps based on the corresponding noise levels.
    t_steps = sigma_inv(net.round_sigma(sigma_steps))

    return t_steps, sigma_inv, sigma, s, sigma_deriv

def data_replicate(data, K):
    if len(data.shape)==2: data_batch = torch.Tensor.repeat(data,[K,1]) 
    else: data_batch = torch.Tensor.repeat(data,[K,1,1,1]) 
    return data_batch

'''


def sample_T(self, x0, eta=0.4, t_steps_hierarchy=None):
    '''
    sigma_discretization_edm = time_descretization(sigma_min=0.002, sigma_max = 999, rho = 7, num_t_steps = 10)/1000
    T_max = 1000
    beta_start  = 1 # 0.0015*T_max
    beta_end = 15 # 0.0155*T_max
    def var(t):
        return 1.0 - (1.0) * torch.exp(- beta_start * t - 0.5 * (beta_end - beta_start) * t * t)
    '''
    t_steps_hierarchy = torch.tensor(t_steps_hierarchy).cuda()
    var_t =  (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[0]].reshape(1, 1 ,1 ,1))**2 # self.var(t_steps_hierarchy[0])
    x_t = torch.sqrt(1 - var_t) * x0 + torch.sqrt(var_t) * torch.randn_like(x0)

    os.makedirs("out_temp2/", exist_ok=True)
    for i, t in enumerate(t_steps_hierarchy): 
        t_hat = torch.ones(10).cuda() * (t)
        e_out = self.model.model(x_t, t_hat)
        var_t = (self.model.sqrt_one_minus_alphas_cumprod[t].reshape(1, 1 ,1 ,1))**2
        #score_out = - e_out / torch.sqrt()
        a_t = 1 - var_t 
        #beta_t = 1 - a_t/a_prev
        #std_pos = ((1 - a_prev)/(1 - a_t)).sqrt()*torch.sqrt(beta_t)
        pred_x0 = (x_t - torch.sqrt(1 - a_t) * e_out) / a_t.sqrt()

        if i != len(t_steps_hierarchy) - 1: 
            var_t1 = (self.model.sqrt_one_minus_alphas_cumprod[t_steps_hierarchy[i+1]].reshape(1, 1 ,1 ,1))**2
            a_prev = 1 - var_t1 # var(t_steps_hierarchy[i+1]/1000) # torch.full((10, 1, 1, 1), alphas[t_steps_hierarchy[i+1]]).cuda()
            sigma_t = eta * torch.sqrt((1 - a_prev) / (1 - a_t) * (1 - a_t / a_prev))
            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_out
            x_t = a_prev.sqrt() * pred_x0 + dir_xt + torch.randn_like(x_t) * sigma_t + sigma_t*torch.randn_like(x_t)
            
        #x_t= (x_t - torch.sqrt( 1 - a_t/a_prev) * e_out ) / (a_t/a_prev).sqrt() + std_pos*torch.randn_like(x_t)

        '''
        def pred_mean(pred_x0, z_t):
            posterior_mean_coef1 = beta_t * torch.sqrt(a_prev) / (1. - a_t)
            posterior_mean_coef2 = (1. - a_prev) * torch.sqrt(a_t/a_prev) / (1. - a_t)
            return posterior_mean_coef1*pred_x0 + posterior_mean_coef2*z_t
        
        x_t = torch.sqrt(a_prev) * pred_x0  # pred_mean(pred_x0, x_t) #+ 0.4*torch.sqrt(beta_t) *torch.randn_like(x_t)
        '''
        recon = self.model.decode_first_stage(pred_x0)
        image_path = os.path.join("out_temp2/", f'{i}.png')
        image_np = (recon.detach() * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()[0]
        PIL.Image.fromarray(image_np, 'RGB').save(image_path)

    return