|
import os |
|
import tempfile |
|
from typing import Any, Callable, List, Optional, Tuple, Union |
|
from urllib.parse import unquote, urlparse |
|
|
|
import PIL.Image |
|
import PIL.ImageOps |
|
import requests |
|
|
|
from .import_utils import BACKENDS_MAPPING, is_imageio_available |
|
|
|
|
|
def load_image( |
|
image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None |
|
) -> PIL.Image.Image: |
|
""" |
|
Loads `image` to a PIL Image. |
|
|
|
Args: |
|
image (`str` or `PIL.Image.Image`): |
|
The image to convert to the PIL Image format. |
|
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): |
|
A conversion method to apply to the image after loading it. When set to `None` the image will be converted |
|
"RGB". |
|
|
|
Returns: |
|
`PIL.Image.Image`: |
|
A PIL Image. |
|
""" |
|
if isinstance(image, str): |
|
if image.startswith("http://") or image.startswith("https://"): |
|
image = PIL.Image.open(requests.get(image, stream=True).raw) |
|
elif os.path.isfile(image): |
|
image = PIL.Image.open(image) |
|
else: |
|
raise ValueError( |
|
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." |
|
) |
|
elif isinstance(image, PIL.Image.Image): |
|
image = image |
|
else: |
|
raise ValueError( |
|
"Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." |
|
) |
|
|
|
image = PIL.ImageOps.exif_transpose(image) |
|
|
|
if convert_method is not None: |
|
image = convert_method(image) |
|
else: |
|
image = image.convert("RGB") |
|
|
|
return image |
|
|
|
|
|
def load_video( |
|
video: str, |
|
convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None, |
|
) -> List[PIL.Image.Image]: |
|
""" |
|
Loads `video` to a list of PIL Image. |
|
|
|
Args: |
|
video (`str`): |
|
A URL or Path to a video to convert to a list of PIL Image format. |
|
convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): |
|
A conversion method to apply to the video after loading it. When set to `None` the images will be converted |
|
to "RGB". |
|
|
|
Returns: |
|
`List[PIL.Image.Image]`: |
|
The video as a list of PIL images. |
|
""" |
|
is_url = video.startswith("http://") or video.startswith("https://") |
|
is_file = os.path.isfile(video) |
|
was_tempfile_created = False |
|
|
|
if not (is_url or is_file): |
|
raise ValueError( |
|
f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." |
|
) |
|
|
|
if is_url: |
|
response = requests.get(video, stream=True) |
|
if response.status_code != 200: |
|
raise ValueError(f"Failed to download video. Status code: {response.status_code}") |
|
|
|
parsed_url = urlparse(video) |
|
file_name = os.path.basename(unquote(parsed_url.path)) |
|
|
|
suffix = os.path.splitext(file_name)[1] or ".mp4" |
|
video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name |
|
|
|
was_tempfile_created = True |
|
|
|
video_data = response.iter_content(chunk_size=8192) |
|
with open(video_path, "wb") as f: |
|
for chunk in video_data: |
|
f.write(chunk) |
|
|
|
video = video_path |
|
|
|
pil_images = [] |
|
if video.endswith(".gif"): |
|
gif = PIL.Image.open(video) |
|
try: |
|
while True: |
|
pil_images.append(gif.copy()) |
|
gif.seek(gif.tell() + 1) |
|
except EOFError: |
|
pass |
|
|
|
else: |
|
if is_imageio_available(): |
|
import imageio |
|
else: |
|
raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video")) |
|
|
|
try: |
|
imageio.plugins.ffmpeg.get_exe() |
|
except AttributeError: |
|
raise AttributeError( |
|
"`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" |
|
) |
|
|
|
with imageio.get_reader(video) as reader: |
|
|
|
for frame in reader: |
|
pil_images.append(PIL.Image.fromarray(frame)) |
|
|
|
if was_tempfile_created: |
|
os.remove(video_path) |
|
|
|
if convert_method is not None: |
|
pil_images = convert_method(pil_images) |
|
|
|
return pil_images |
|
|
|
|
|
|
|
def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: |
|
if "." in tensor_name: |
|
splits = tensor_name.split(".") |
|
for split in splits[:-1]: |
|
new_module = getattr(module, split) |
|
if new_module is None: |
|
raise ValueError(f"{module} has no attribute {split}.") |
|
module = new_module |
|
tensor_name = splits[-1] |
|
return module, tensor_name |
|
|