from math import ceil, floor
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
pil_to_torch = transforms.Compose([
    transforms.PILToTensor()
])
from typing import Tuple

def get_padding_for_aspect_ratio(img: Image.Image, target_aspect_ratio: float = 16/9) -> list[int]:
    aspect_ratio = img.width / img.height

    if aspect_ratio != target_aspect_ratio:
        w_target = ceil(target_aspect_ratio*img.height) # r = w /h = w_i / h_i 
        h_target = floor(img.width * (1/target_aspect_ratio))

        if w_target >= img.width:
            w_scale = w_target / img.width
        else:
            w_scale = np.inf

        if h_target >= img.height:
            h_scale = h_target / img.height
        else:
            h_scale = np.inf

        if min([h_scale, w_scale]) == h_scale:
            scale_axis = 1
            target_size = h_target
        else:
            scale_axis = 0
            target_size = w_target

        pad_size = [0, 0, 0, 0]
        img_size = img.size
        pad_size[2+scale_axis] = int(target_size - img_size[scale_axis])
        return pad_size
    else:
        return None


def get_padding_for_aspect_ratio(img: Image, target_aspect_ratio: float = 16/9):
    aspect_ratio = img.width / img.height

    if aspect_ratio != target_aspect_ratio:
        w_target = ceil(target_aspect_ratio*img.height) # r = w /h = w_i / h_i 
        h_target = floor(img.width * (1/target_aspect_ratio))

        if w_target >= img.width:
            w_scale = w_target / img.width
        else:
            w_scale = np.inf

        if h_target >= img.height:
            h_scale = h_target / img.height
        else:
            h_scale = np.inf

        if min([h_scale, w_scale]) == h_scale:
            scale_axis = 1
            target_size = h_target
        else:
            scale_axis = 0
            target_size = w_target

        pad_size = [0, 0, 0, 0]
        img_size = img.size
        pad_size[2+scale_axis] = int(target_size - img_size[scale_axis])
        return pad_size
    else:
        return None


def add_margin(pil_img, top, right, bottom, left, color):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    result.paste(pil_img, (left, top))
    return result


def resize_to_fit(image, size):
    W, H = size
    w, h = image.size
    if H / h > W / w:
        H_ = int(h * W / w)
        W_ = W
    else:
        W_ = int(w * H / h)
        H_ = H
    return image.resize((W_, H_))


def pad_to_fit(image, size):
    W, H = size
    w, h = image.size
    pad_h = (H - h) // 2
    pad_w = (W - w) // 2
    return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0))


def resize_and_keep(pil_img):
    expanded_size = [pil_img.width, pil_img.height]
    myheight = 576
    hpercent = (myheight/float(pil_img.size[1]))
    wsize = int((float(pil_img.size[0])*float(hpercent)))
    pil_img = pil_img.resize((wsize, myheight))

    return pil_img, expanded_size


def resize_and_crop(pil_img: Image.Image) -> Tuple[Image.Image, Tuple[int, int]]:
    img, expanded_size = resize_and_keep(pil_img)
    assert img.width >= 1024 and img.height >= 576,f"Got {img.width} and {img.height}"
    return img.crop((0, 0, 1024, 576)), expanded_size


def center_crop(pil_img):
    width, height = pil_img.size
    new_width = 576
    new_height = 576

    left = (width - new_width)/2
    top = (height - new_height)/2
    right = (width + new_width)/2
    bottom = (height + new_height)/2

    # Crop the center of the image
    pil_img = pil_img.crop((left, top, right, bottom))
    return pil_img