|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | from einops import rearrange | 
					
						
						|  |  | 
					
						
						|  | from torch import randint | 
					
						
						|  |  | 
					
						
						|  | def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: | 
					
						
						|  | min_value = min(min_value, value) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | divisors = [i for i in range(min_value, value + 1) if value % i == 0] | 
					
						
						|  |  | 
					
						
						|  | ns = [value // i for i in divisors[:max_options]] | 
					
						
						|  |  | 
					
						
						|  | if len(ns) - 1 > 0: | 
					
						
						|  | idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() | 
					
						
						|  | else: | 
					
						
						|  | idx = 0 | 
					
						
						|  |  | 
					
						
						|  | return ns[idx] | 
					
						
						|  |  | 
					
						
						|  | class HyperTile: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return {"required": { "model": ("MODEL",), | 
					
						
						|  | "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), | 
					
						
						|  | "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), | 
					
						
						|  | "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), | 
					
						
						|  | "scale_depth": ("BOOLEAN", {"default": False}), | 
					
						
						|  | }} | 
					
						
						|  | RETURN_TYPES = ("MODEL",) | 
					
						
						|  | FUNCTION = "patch" | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "model_patches/unet" | 
					
						
						|  |  | 
					
						
						|  | def patch(self, model, tile_size, swap_size, max_depth, scale_depth): | 
					
						
						|  | latent_tile_size = max(32, tile_size) // 8 | 
					
						
						|  | self.temp = None | 
					
						
						|  |  | 
					
						
						|  | def hypertile_in(q, k, v, extra_options): | 
					
						
						|  | model_chans = q.shape[-2] | 
					
						
						|  | orig_shape = extra_options['original_shape'] | 
					
						
						|  | apply_to = [] | 
					
						
						|  | for i in range(max_depth + 1): | 
					
						
						|  | apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) | 
					
						
						|  |  | 
					
						
						|  | if model_chans in apply_to: | 
					
						
						|  | shape = extra_options["original_shape"] | 
					
						
						|  | aspect_ratio = shape[-1] / shape[-2] | 
					
						
						|  |  | 
					
						
						|  | hw = q.size(1) | 
					
						
						|  | h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) | 
					
						
						|  |  | 
					
						
						|  | factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 | 
					
						
						|  | nh = random_divisor(h, latent_tile_size * factor, swap_size) | 
					
						
						|  | nw = random_divisor(w, latent_tile_size * factor, swap_size) | 
					
						
						|  |  | 
					
						
						|  | if nh * nw > 1: | 
					
						
						|  | q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) | 
					
						
						|  | self.temp = (nh, nw, h, w) | 
					
						
						|  | return q, k, v | 
					
						
						|  |  | 
					
						
						|  | return q, k, v | 
					
						
						|  | def hypertile_out(out, extra_options): | 
					
						
						|  | if self.temp is not None: | 
					
						
						|  | nh, nw, h, w = self.temp | 
					
						
						|  | self.temp = None | 
					
						
						|  | out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) | 
					
						
						|  | out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | m = model.clone() | 
					
						
						|  | m.set_model_attn1_patch(hypertile_in) | 
					
						
						|  | m.set_model_attn1_output_patch(hypertile_out) | 
					
						
						|  | return (m, ) | 
					
						
						|  |  | 
					
						
						|  | NODE_CLASS_MAPPINGS = { | 
					
						
						|  | "HyperTile": HyperTile, | 
					
						
						|  | } | 
					
						
						|  |  |