import numpy as np
from PIL import Image
import io

def image_to_byte_array(image: Image) -> bytes:
    # BytesIO is a fake file stored in memory
    imgByteArr = io.BytesIO()
    # image.save expects a file as a argument, passing a bytes io ins
    image.save(imgByteArr, format='png')  # image.format
    # Turn the BytesIO object back into a bytes object
    imgByteArr = imgByteArr.getvalue()
    return imgByteArr


def get_mask(image_mask: np.ndarray) -> np.ndarray:
    """Get the mask from the segmentation mask.
    Args:
        image_mask (np.ndarray): segmentation mask
    Returns:
        np.ndarray: mask
    """
    # average the colors of the segmentation masks
    average_color = np.mean(image_mask, axis=(2))
    mask = average_color[:, :] > 0
    if mask.sum() > 0:
        mask = mask * 1
    return mask