Spaces:
Runtime error
Runtime error
# Copyright 2023 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import warnings | |
from typing import List, Optional, Union | |
import numpy as np | |
import PIL | |
import torch | |
from PIL import Image | |
from .configuration_utils import ConfigMixin, register_to_config | |
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate | |
PipelineImageInput = Union[ | |
PIL.Image.Image, | |
np.ndarray, | |
torch.FloatTensor, | |
List[PIL.Image.Image], | |
List[np.ndarray], | |
List[torch.FloatTensor], | |
] | |
class VaeImageProcessor(ConfigMixin): | |
""" | |
Image processor for VAE. | |
Args: | |
do_resize (`bool`, *optional*, defaults to `True`): | |
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept | |
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. | |
vae_scale_factor (`int`, *optional*, defaults to `8`): | |
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. | |
resample (`str`, *optional*, defaults to `lanczos`): | |
Resampling filter to use when resizing the image. | |
do_normalize (`bool`, *optional*, defaults to `True`): | |
Whether to normalize the image to [-1,1]. | |
do_binarize (`bool`, *optional*, defaults to `True`): | |
Whether to binarize the image to 0/1. | |
do_convert_rgb (`bool`, *optional*, defaults to be `False`): | |
Whether to convert the images to RGB format. | |
do_convert_grayscale (`bool`, *optional*, defaults to be `False`): | |
Whether to convert the images to grayscale format. | |
""" | |
config_name = CONFIG_NAME | |
def __init__( | |
self, | |
do_resize: bool = True, | |
vae_scale_factor: int = 8, | |
resample: str = "lanczos", | |
do_normalize: bool = True, | |
do_binarize: bool = False, | |
do_convert_rgb: bool = False, | |
do_convert_grayscale: bool = False, | |
): | |
super().__init__() | |
if do_convert_rgb and do_convert_grayscale: | |
raise ValueError( | |
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," | |
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", | |
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", | |
) | |
self.config.do_convert_rgb = False | |
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: | |
""" | |
Convert a numpy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: | |
""" | |
Convert a PIL image or a list of PIL images to NumPy arrays. | |
""" | |
if not isinstance(images, list): | |
images = [images] | |
images = [np.array(image).astype(np.float32) / 255.0 for image in images] | |
images = np.stack(images, axis=0) | |
return images | |
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: | |
""" | |
Convert a NumPy image to a PyTorch tensor. | |
""" | |
if images.ndim == 3: | |
images = images[..., None] | |
images = torch.from_numpy(images.transpose(0, 3, 1, 2)) | |
return images | |
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: | |
""" | |
Convert a PyTorch tensor to a NumPy image. | |
""" | |
images = images.cpu().permute(0, 2, 3, 1).float().numpy() | |
return images | |
def normalize(images): | |
""" | |
Normalize an image array to [-1,1]. | |
""" | |
return 2.0 * images - 1.0 | |
def denormalize(images): | |
""" | |
Denormalize an image array to [0,1]. | |
""" | |
return (images / 2 + 0.5).clamp(0, 1) | |
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: | |
""" | |
Converts a PIL image to RGB format. | |
""" | |
image = image.convert("RGB") | |
return image | |
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: | |
""" | |
Converts a PIL image to grayscale format. | |
""" | |
image = image.convert("L") | |
return image | |
def get_default_height_width( | |
self, | |
image: [PIL.Image.Image, np.ndarray, torch.Tensor], | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
): | |
""" | |
This function return the height and width that are downscaled to the next integer multiple of | |
`vae_scale_factor`. | |
Args: | |
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): | |
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have | |
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should | |
have shape `[batch, channel, height, width]`. | |
height (`int`, *optional*, defaults to `None`): | |
The height in preprocessed image. If `None`, will use the height of `image` input. | |
width (`int`, *optional*`, defaults to `None`): | |
The width in preprocessed. If `None`, will use the width of the `image` input. | |
""" | |
if height is None: | |
if isinstance(image, PIL.Image.Image): | |
height = image.height | |
elif isinstance(image, torch.Tensor): | |
height = image.shape[2] | |
else: | |
height = image.shape[1] | |
if width is None: | |
if isinstance(image, PIL.Image.Image): | |
width = image.width | |
elif isinstance(image, torch.Tensor): | |
width = image.shape[3] | |
else: | |
width = image.shape[2] | |
width, height = ( | |
x - x % self.config.vae_scale_factor for x in (width, height) | |
) # resize to integer multiple of vae_scale_factor | |
return height, width | |
def resize( | |
self, | |
image: [PIL.Image.Image, np.ndarray, torch.Tensor], | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]: | |
""" | |
Resize image. | |
""" | |
if isinstance(image, PIL.Image.Image): | |
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) | |
elif isinstance(image, torch.Tensor): | |
image = torch.nn.functional.interpolate( | |
image, | |
size=(height, width), | |
) | |
elif isinstance(image, np.ndarray): | |
image = self.numpy_to_pt(image) | |
image = torch.nn.functional.interpolate( | |
image, | |
size=(height, width), | |
) | |
image = self.pt_to_numpy(image) | |
return image | |
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: | |
""" | |
create a mask | |
""" | |
image[image < 0.5] = 0 | |
image[image >= 0.5] = 1 | |
return image | |
def preprocess( | |
self, | |
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
) -> torch.Tensor: | |
""" | |
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. | |
""" | |
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) | |
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image | |
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: | |
if isinstance(image, torch.Tensor): | |
# if image is a pytorch tensor could have 2 possible shapes: | |
# 1. batch x height x width: we should insert the channel dimension at position 1 | |
# 2. channnel x height x width: we should insert batch dimension at position 0, | |
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1 | |
# for simplicity, we insert a dimension of size 1 at position 1 for both cases | |
image = image.unsqueeze(1) | |
else: | |
# if it is a numpy array, it could have 2 possible shapes: | |
# 1. batch x height x width: insert channel dimension on last position | |
# 2. height x width x channel: insert batch dimension on first position | |
if image.shape[-1] == 1: | |
image = np.expand_dims(image, axis=0) | |
else: | |
image = np.expand_dims(image, axis=-1) | |
if isinstance(image, supported_formats): | |
image = [image] | |
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): | |
raise ValueError( | |
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" | |
) | |
if isinstance(image[0], PIL.Image.Image): | |
if self.config.do_convert_rgb: | |
image = [self.convert_to_rgb(i) for i in image] | |
elif self.config.do_convert_grayscale: | |
image = [self.convert_to_grayscale(i) for i in image] | |
if self.config.do_resize: | |
height, width = self.get_default_height_width(image[0], height, width) | |
image = [self.resize(i, height, width) for i in image] | |
image = self.pil_to_numpy(image) # to np | |
image = self.numpy_to_pt(image) # to pt | |
elif isinstance(image[0], np.ndarray): | |
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) | |
image = self.numpy_to_pt(image) | |
height, width = self.get_default_height_width(image, height, width) | |
if self.config.do_resize: | |
image = self.resize(image, height, width) | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) | |
if self.config.do_convert_grayscale and image.ndim == 3: | |
image = image.unsqueeze(1) | |
channel = image.shape[1] | |
# don't need any preprocess if the image is latents | |
if channel == 4: | |
return image | |
height, width = self.get_default_height_width(image, height, width) | |
if self.config.do_resize: | |
image = self.resize(image, height, width) | |
# expected range [0,1], normalize to [-1,1] | |
do_normalize = self.config.do_normalize | |
if image.min() < 0 and do_normalize: | |
warnings.warn( | |
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | |
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | |
FutureWarning, | |
) | |
do_normalize = False | |
if do_normalize: | |
image = self.normalize(image) | |
if self.config.do_binarize: | |
image = self.binarize(image) | |
return image | |
def postprocess( | |
self, | |
image: torch.FloatTensor, | |
output_type: str = "pil", | |
do_denormalize: Optional[List[bool]] = None, | |
): | |
if not isinstance(image, torch.Tensor): | |
raise ValueError( | |
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" | |
) | |
if output_type not in ["latent", "pt", "np", "pil"]: | |
deprecation_message = ( | |
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " | |
"`pil`, `np`, `pt`, `latent`" | |
) | |
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) | |
output_type = "np" | |
if output_type == "latent": | |
return image | |
if do_denormalize is None: | |
do_denormalize = [self.config.do_normalize] * image.shape[0] | |
image = torch.stack( | |
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] | |
) | |
if output_type == "pt": | |
return image | |
image = self.pt_to_numpy(image) | |
if output_type == "np": | |
return image | |
if output_type == "pil": | |
return self.numpy_to_pil(image) | |
class VaeImageProcessorLDM3D(VaeImageProcessor): | |
""" | |
Image processor for VAE LDM3D. | |
Args: | |
do_resize (`bool`, *optional*, defaults to `True`): | |
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. | |
vae_scale_factor (`int`, *optional*, defaults to `8`): | |
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. | |
resample (`str`, *optional*, defaults to `lanczos`): | |
Resampling filter to use when resizing the image. | |
do_normalize (`bool`, *optional*, defaults to `True`): | |
Whether to normalize the image to [-1,1]. | |
""" | |
config_name = CONFIG_NAME | |
def __init__( | |
self, | |
do_resize: bool = True, | |
vae_scale_factor: int = 8, | |
resample: str = "lanczos", | |
do_normalize: bool = True, | |
): | |
super().__init__() | |
def get_default_height_width( | |
self, | |
image: [PIL.Image.Image, np.ndarray, torch.Tensor], | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
): | |
""" | |
This function return the height and width that are downscaled to the next integer multiple of | |
`vae_scale_factor`. | |
Args: | |
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): | |
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have | |
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should | |
have shape `[batch, channel, height, width]`. | |
height (`int`, *optional*, defaults to `None`): | |
The height in preprocessed image. If `None`, will use the height of `image` input. | |
width (`int`, *optional*`, defaults to `None`): | |
The width in preprocessed. If `None`, will use the width of the `image` input. | |
""" | |
if height is None: | |
if isinstance(image, PIL.Image.Image): | |
height = image.height | |
elif isinstance(image, torch.Tensor): | |
height = image.shape[2] | |
else: | |
height = image.shape[1] | |
if width is None: | |
if isinstance(image, PIL.Image.Image): | |
width = image.width | |
elif isinstance(image, torch.Tensor): | |
width = image.shape[3] | |
else: | |
width = image.shape[2] | |
width, height = ( | |
x - x % self.config.vae_scale_factor for x in (width, height) | |
) # resize to integer multiple of vae_scale_factor | |
return height, width | |
def numpy_to_pil(images): | |
""" | |
Convert a NumPy image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images = (images * 255).round().astype("uint8") | |
if images.shape[-1] == 1: | |
# special case for grayscale (single channel) images | |
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
else: | |
pil_images = [Image.fromarray(image[:, :, :3]) for image in images] | |
return pil_images | |
def rgblike_to_depthmap(image): | |
""" | |
Args: | |
image: RGB-like depth image | |
Returns: depth map | |
""" | |
return image[:, :, 1] * 2**8 + image[:, :, 2] | |
def depthmap_to_rgblike(depthmap): | |
depthmap = depthmap.astype(np.uint16) | |
r = np.zeros_like(depthmap, dtype=np.uint8) | |
g = (depthmap // 2**8) % 2**8 | |
b = depthmap % 2**8 | |
return np.stack([r, g, b], axis=-1).astype(np.uint8) | |
def numpy_to_depth(self, images): | |
""" | |
Convert a NumPy depth image or a batch of images to a PIL image. | |
""" | |
if images.ndim == 3: | |
images = images[None, ...] | |
images_depth = images[:, :, :, 3:] | |
if images.shape[-1] == 6: | |
images_depth = (images_depth * 255).round().astype("uint8") | |
pil_images = [ | |
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth | |
] | |
elif images.shape[-1] == 4: | |
images_depth = (images_depth * 65535.0).astype(np.uint16) | |
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth] | |
else: | |
raise Exception("Not supported") | |
return pil_images | |
def preprocess_depth(self, image, height=None, width=None): | |
image = np.array(image) | |
image = image / 65535.0 | |
image = image[None, ...] | |
if self.config.do_resize: | |
height, width = self.get_default_height_width(image, height, width) | |
image = self.resize(image, height, width) | |
image = 2* (image - 0.5 ) | |
image = torch.from_numpy(image.transpose(0, 3, 1, 2)) | |
return image | |
def postprocess( | |
self, | |
image: torch.FloatTensor, | |
output_type: str = "pil", | |
do_denormalize: Optional[List[bool]] = None, | |
): | |
if not isinstance(image, torch.Tensor): | |
raise ValueError( | |
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" | |
) | |
if output_type not in ["latent", "pt", "np", "pil"]: | |
deprecation_message = ( | |
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " | |
"`pil`, `np`, `pt`, `latent`" | |
) | |
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) | |
output_type = "np" | |
if do_denormalize is None: | |
do_denormalize = [self.config.do_normalize] * image.shape[0] | |
image = torch.stack( | |
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] | |
) | |
image = self.pt_to_numpy(image) | |
if output_type == "np": | |
if image.shape[-1] == 6: | |
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0) | |
else: | |
image_depth = image[:, :, :, 3:] | |
return image[:, :, :, :3], image_depth | |
if output_type == "pil": | |
return self.numpy_to_pil(image), self.numpy_to_depth(image) | |
else: | |
raise Exception(f"This type {output_type} is not supported") | |