Spaces:
Running
on
Zero
Running
on
Zero
| import hashlib | |
| import json | |
| import os | |
| import re | |
| from pathlib import Path | |
| import folder_paths | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageOps | |
| from PIL.PngImagePlugin import PngInfo | |
| from ..log import log | |
| class MTB_LoadImageSequence: | |
| """Load an image sequence from a folder. The current frame is used to determine which image to load. | |
| Usually used in conjunction with the `Primitive` node set to increment to load a sequence of images from a folder. | |
| Use -1 to load all matching frames as a batch. | |
| """ | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "path": ("STRING", {"default": "videos/####.png"}), | |
| "current_frame": ( | |
| "INT", | |
| {"default": 0, "min": -1, "max": 9999999}, | |
| ), | |
| }, | |
| "optional": { | |
| "range": ("STRING", {"default": ""}), | |
| }, | |
| } | |
| CATEGORY = "mtb/IO" | |
| FUNCTION = "load_image" | |
| RETURN_TYPES = ( | |
| "IMAGE", | |
| "MASK", | |
| "INT", | |
| "INT", | |
| ) | |
| RETURN_NAMES = ( | |
| "image", | |
| "mask", | |
| "current_frame", | |
| "total_frames", | |
| ) | |
| def load_image(self, path=None, current_frame=0, range=""): | |
| load_all = current_frame == -1 | |
| total_frames = 1 | |
| if range: | |
| frames = self.get_frames_from_range(path, range) | |
| imgs, masks = zip(*(img_from_path(frame) for frame in frames)) | |
| out_img = torch.cat(imgs, dim=0) | |
| out_mask = torch.cat(masks, dim=0) | |
| total_frames = len(imgs) | |
| return (out_img, out_mask, -1, total_frames) | |
| elif load_all: | |
| log.debug(f"Loading all frames from {path}") | |
| frames = resolve_all_frames(path) | |
| log.debug(f"Found {len(frames)} frames") | |
| imgs = [] | |
| masks = [] | |
| imgs, masks = zip(*(img_from_path(frame) for frame in frames)) | |
| out_img = torch.cat(imgs, dim=0) | |
| out_mask = torch.cat(masks, dim=0) | |
| total_frames = len(imgs) | |
| return (out_img, out_mask, -1, total_frames) | |
| log.debug(f"Loading image: {path}, {current_frame}") | |
| resolved_path = resolve_path(path, current_frame) | |
| image_path = folder_paths.get_annotated_filepath(resolved_path) | |
| image, mask = img_from_path(image_path) | |
| return (image, mask, current_frame, total_frames) | |
| def get_frames_from_range(self, path, range_str): | |
| try: | |
| start, end = map(int, range_str.split("-")) | |
| except ValueError: | |
| raise ValueError( | |
| f"Invalid range format: {range_str}. Expected format is 'start-end'." | |
| ) | |
| frames = resolve_all_frames(path) | |
| total_frames = len(frames) | |
| if start < 0 or end >= total_frames: | |
| raise ValueError( | |
| f"Range {range_str} is out of bounds. Total frames available: {total_frames}" | |
| ) | |
| if "#" in path: | |
| frame_regex = re.escape(path).replace(r"\#", r"(\d+)") | |
| frame_number_regex = re.compile(frame_regex) | |
| matching_frames = [] | |
| for frame in frames: | |
| match = frame_number_regex.search(frame) | |
| if match: | |
| frame_number = int(match.group(1)) | |
| if start <= frame_number <= end: | |
| matching_frames.append(frame) | |
| return matching_frames | |
| else: | |
| log.warning( | |
| f"Wildcard pattern or directory will use indexes instead of frame numbers for : {path}" | |
| ) | |
| selected_frames = frames[start : end + 1] | |
| return selected_frames | |
| def IS_CHANGED(path="", current_frame=0, range=""): | |
| print(f"Checking if changed: {path}, {current_frame}") | |
| if range or current_frame == -1: | |
| resolved_paths = resolve_all_frames(path) | |
| timestamps = [ | |
| os.path.getmtime(folder_paths.get_annotated_filepath(p)) | |
| for p in resolved_paths | |
| ] | |
| combined_hash = hashlib.sha256( | |
| "".join(map(str, timestamps)).encode() | |
| ) | |
| return combined_hash.hexdigest() | |
| resolved_path = resolve_path(path, current_frame) | |
| image_path = folder_paths.get_annotated_filepath(resolved_path) | |
| if os.path.exists(image_path): | |
| m = hashlib.sha256() | |
| with open(image_path, "rb") as f: | |
| m.update(f.read()) | |
| return m.digest().hex() | |
| return "NONE" | |
| # @staticmethod | |
| # def VALIDATE_INPUTS(path="", current_frame=0): | |
| # print(f"Validating inputs: {path}, {current_frame}") | |
| # resolved_path = resolve_path(path, current_frame) | |
| # if not folder_paths.exists_annotated_filepath(resolved_path): | |
| # return f"Invalid image file: {resolved_path}" | |
| # return True | |
| import glob | |
| def img_from_path(path): | |
| img = Image.open(path) | |
| img = ImageOps.exif_transpose(img) | |
| image = img.convert("RGB") | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = torch.from_numpy(image)[None,] | |
| if "A" in img.getbands(): | |
| mask = np.array(img.getchannel("A")).astype(np.float32) / 255.0 | |
| mask = 1.0 - torch.from_numpy(mask) | |
| else: | |
| mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") | |
| return ( | |
| image, | |
| mask, | |
| ) | |
| def resolve_all_frames(path: str): | |
| frames: list[str] = [] | |
| if "#" not in path: | |
| pth = Path(path) | |
| if pth.is_dir(): | |
| for f in pth.iterdir(): | |
| if f.suffix in [".jpg", ".png"]: | |
| frames.append(f.as_posix()) | |
| elif "*" in path: | |
| frames = glob.glob(path) | |
| else: | |
| raise ValueError( | |
| "The path doesn't contain a # or a * or is not a directory" | |
| ) | |
| frames.sort() | |
| return frames | |
| pattern = path | |
| folder_path, file_pattern = os.path.split(pattern) | |
| log.debug(f"Resolving all frames in {folder_path}") | |
| hash_count = file_pattern.count("#") | |
| frame_pattern = re.sub(r"#+", "*", file_pattern) | |
| log.debug(f"Found pattern: {frame_pattern}") | |
| matching_files = glob.glob(os.path.join(folder_path, frame_pattern)) | |
| log.debug(f"Found {len(matching_files)} matching files") | |
| frame_regex = re.escape(file_pattern).replace(r"\#", r"(\d+)") | |
| frame_number_regex = re.compile(frame_regex) | |
| for file in matching_files: | |
| match = frame_number_regex.search(file) | |
| if match: | |
| frame_number = match.group(1) | |
| log.debug(f"Found frame number: {frame_number}") | |
| # resolved_file = pattern.replace("*" * frame_number.count("#"), frame_number) | |
| frames.append(file) | |
| frames.sort() # Sort frames alphabetically | |
| return frames | |
| def resolve_path(path, frame): | |
| hashes = path.count("#") | |
| padded_number = str(frame).zfill(hashes) | |
| return re.sub("#+", padded_number, path) | |
| class MTB_SaveImageSequence: | |
| """Save an image sequence to a folder. The current frame is used to determine which image to save. | |
| This is merely a wrapper around the `save_images` function with formatting for the output folder and filename. | |
| """ | |
| def __init__(self): | |
| self.output_dir = folder_paths.get_output_directory() | |
| self.type = "output" | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "images": ("IMAGE",), | |
| "filename_prefix": ("STRING", {"default": "Sequence"}), | |
| "current_frame": ( | |
| "INT", | |
| {"default": 0, "min": 0, "max": 9999999}, | |
| ), | |
| }, | |
| "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, | |
| } | |
| RETURN_TYPES = () | |
| FUNCTION = "save_images" | |
| OUTPUT_NODE = True | |
| CATEGORY = "mtb/IO" | |
| def save_images( | |
| self, | |
| images, | |
| filename_prefix="Sequence", | |
| current_frame=0, | |
| prompt=None, | |
| extra_pnginfo=None, | |
| ): | |
| # full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) | |
| # results = list() | |
| # for image in images: | |
| # i = 255. * image.cpu().numpy() | |
| # img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) | |
| # metadata = PngInfo() | |
| # if prompt is not None: | |
| # metadata.add_text("prompt", json.dumps(prompt)) | |
| # if extra_pnginfo is not None: | |
| # for x in extra_pnginfo: | |
| # metadata.add_text(x, json.dumps(extra_pnginfo[x])) | |
| # file = f"{filename}_{counter:05}_.png" | |
| # img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) | |
| # results.append({ | |
| # "filename": file, | |
| # "subfolder": subfolder, | |
| # "type": self.type | |
| # }) | |
| # counter += 1 | |
| if len(images) > 1: | |
| raise ValueError("Can only save one image at a time") | |
| resolved_path = Path(self.output_dir) / filename_prefix | |
| resolved_path.mkdir(parents=True, exist_ok=True) | |
| resolved_img = ( | |
| resolved_path / f"{filename_prefix}_{current_frame:05}.png" | |
| ) | |
| output_image = images[0].cpu().numpy() | |
| img = Image.fromarray( | |
| np.clip(output_image * 255.0, 0, 255).astype(np.uint8) | |
| ) | |
| metadata = PngInfo() | |
| if prompt is not None: | |
| metadata.add_text("prompt", json.dumps(prompt)) | |
| if extra_pnginfo is not None: | |
| for x in extra_pnginfo: | |
| metadata.add_text(x, json.dumps(extra_pnginfo[x])) | |
| img.save(resolved_img, pnginfo=metadata, compress_level=4) | |
| return { | |
| "ui": { | |
| "images": [ | |
| { | |
| "filename": resolved_img.name, | |
| "subfolder": resolved_path.name, | |
| "type": self.type, | |
| } | |
| ] | |
| } | |
| } | |
| __nodes__ = [ | |
| MTB_LoadImageSequence, | |
| MTB_SaveImageSequence, | |
| ] | |