import PIL
import torch
import requests
import torchvision
from math import ceil
from io import BytesIO
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import math
from tqdm import tqdm
def download_image(url):
    return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB")


def resize_image(image, size=768):
    tensor_image = F.to_tensor(image)
    resized_image = F.resize(tensor_image, size, antialias=True)
    return resized_image


def downscale_images(images, factor=3/4):
    scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32)
    scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
    return scaled_image



def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
    resolution_multiple = 42.67
    latent_height = ceil(height / compression_factor_b)
    latent_width = ceil(width / compression_factor_b)
    stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
    
    latent_height = ceil(height / compression_factor_a)
    latent_width = ceil(width / compression_factor_a)
    stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
    
    return stage_c_latent_shape, stage_b_latent_shape


def get_views(H, W, window_size=64, stride=16):
    '''
    - H, W: height and width of the latent
    '''
    num_blocks_height = (H - window_size) // stride + 1
    num_blocks_width = (W - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

       

def show_images(images, rows=None, cols=None, **kwargs):
    if images.size(1) == 1:
        images = images.repeat(1, 3, 1, 1)
    elif images.size(1) > 3:
        images = images[:, :3]
    
    if rows is None:
        rows = 1
    if cols is None:
        cols = images.size(0) // rows

    _, _, h, w = images.shape

    imgs = []
    for i, img in enumerate(images):
        imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)))
    
    return imgs
    


def decode_b(conditions_b, unconditions_b, models_b, bshape,  extras_b, device, \
    stage_a_tiled=False, num_instance=4, patch_size=256, stride=24):
   
    
    sampling_b = extras_b.gdf.sample(
        models_b.generator.half(), conditions_b,  bshape,
            unconditions_b, device=device,
            **extras_b.sampling_configs,
        )
    models_b.generator.cuda()
    for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
        sampled_b = sampled_b
    models_b.generator.cpu()
    torch.cuda.empty_cache()
    if stage_a_tiled:
        with torch.cuda.amp.autocast(dtype=torch.float16):
            padding = (stride*2, stride*2, stride*2, stride*2)
            sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect')
            count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
            sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
            views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride)
           
            for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))):
            
                sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float()   
                count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1
            sampled /= count    
            sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2]
    else:
    
        sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled)

    return sampled.float()


def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None):
    if conditions is None:
        conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
    if unconditions is None:
        unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)    
    sampling_c = extras.gdf.sample(
                models.generator,  conditions, stage_c_latent_shape, stage_c_latent_shape_lr, 
                unconditions, device=device, **extras.sampling_configs, 
            )
    for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])):
                sampled_c = sampled_c
    return sampled_c
    
def get_target_lr_size(ratio, std_size=24):
        w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) 
        return (h * 32 , w *32 )