import torch from typing import Union, List from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd from itertools import repeat import collections.abc def _ntuple(n): """ Creates a helper function to convert inputs to tuples of specified length. Converts iterable inputs (excluding strings) to tuples of length n, or repeats single values n times to form a tuple. Useful for handling multi-dimensional parameters like sizes and strides. Args: n (int): Target length of the tuple Returns: function: Parser function that converts inputs to n-length tuples """ def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): x = tuple(x) if len(x) == 1: x = tuple(repeat(x[0], n)) return x return tuple(repeat(x, n)) return parse # Create common tuple conversion functions for 1-4 dimensions to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) def get_rope_freq_from_size( latents_size, ndim, target_ndim, args, rope_theta_rescale_factor: Union[float, List[float]] = 1.0, rope_interpolation_factor: Union[float, List[float]] = 1.0, concat_dict={} ): """ Calculates RoPE (Rotary Position Embedding) frequencies based on latent dimensions. Converts latent space dimensions to rope-compatible sizes by accounting for patch size, then generates the appropriate frequency embeddings for each dimension. Args: latents_size: Dimensions of the latent space tensor ndim (int): Number of dimensions in the latent space target_ndim (int): Target number of dimensions for the embeddings args: Configuration arguments containing model parameters (patch_size, rope_theta, etc.) rope_theta_rescale_factor: Rescaling factor(s) for theta parameter (per dimension) rope_interpolation_factor: Interpolation factor(s) for position embeddings (per dimension) concat_dict: Dictionary for special concatenation modes (e.g., time-based extensions) Returns: tuple: Cosine and sine frequency embeddings (freqs_cos, freqs_sin) """ # Calculate rope sizes by dividing latent dimensions by patch size if isinstance(args.patch_size, int): # Validate all latent dimensions are divisible by patch size assert all(s % args.patch_size == 0 for s in latents_size), \ f"Latent size (last {ndim} dimensions) must be divisible by patch size ({args.patch_size}), " \ f"but got {latents_size}." rope_sizes = [s // args.patch_size for s in latents_size] elif isinstance(args.patch_size, list): # Validate with per-dimension patch sizes assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ f"Latent size (last {ndim} dimensions) must be divisible by patch size ({args.patch_size}), " \ f"but got {latents_size}." rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)] # Add singleton dimensions if needed to match target_ndim (typically for time axis) if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # Calculate head dimension and validate rope dimensions head_dim = args.hidden_size // args.num_heads rope_dim_list = args.rope_dim_list # Default: split head dimension equally across target dimensions if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] # Ensure rope dimensions sum to head dimension assert sum(rope_dim_list) == head_dim, \ "Sum of rope_dim_list must equal attention head dimension (hidden_size // num_heads)" # Generate rotary position embeddings freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new( rope_dim_list, rope_sizes, theta=args.rope_theta, use_real=True, theta_rescale_factor=rope_theta_rescale_factor, interpolation_factor=rope_interpolation_factor, concat_dict=concat_dict ) return freqs_cos, freqs_sin def get_nd_rotary_pos_embed_new( rope_dim_list, start, *args, theta=10000., use_real=False, theta_rescale_factor: Union[float, List[float]] = 1.0, interpolation_factor: Union[float, List[float]] = 1.0, concat_dict={} ): """ Generates multi-dimensional Rotary Position Embeddings (RoPE). Creates position embeddings for n-dimensional spaces by generating a meshgrid of positions and applying 1D rotary embeddings to each dimension, then combining them. Args: rope_dim_list (list): List of embedding dimensions for each axis start: Starting dimensions for generating the meshgrid *args: Additional arguments for meshgrid generation theta (float): Base theta parameter for RoPE frequency calculation use_real (bool): If True, returns separate cosine and sine embeddings theta_rescale_factor: Rescaling factor(s) for theta (per dimension) interpolation_factor: Interpolation factor(s) for position scaling (per dimension) concat_dict: Dictionary for special concatenation modes (e.g., time-based extensions) Returns: tuple or tensor: Cosine and sine embeddings if use_real=True, combined embedding otherwise """ # Generate n-dimensional meshgrid of positions (shape: [dim, *sizes]) grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # Handle special concatenation modes (e.g., adding time-based bias) if concat_dict: if concat_dict['mode'] == 'timecat': # Add bias as first element in first dimension bias = grid[:, :1].clone() bias[0] = concat_dict['bias'] * torch.ones_like(bias[0]) grid = torch.cat([bias, grid], dim=1) elif concat_dict['mode'] == 'timecat-w': # Add biased first element with spatial offset bias = grid[:, :1].clone() bias[0] = concat_dict['bias'] * torch.ones_like(bias[0]) bias[2] += start[-1] # Spatial offset reference: OminiControl implementation grid = torch.cat([bias, grid], dim=1) # Normalize theta rescale factors to list format (per dimension) if isinstance(theta_rescale_factor, (int, float)): theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) assert len(theta_rescale_factor) == len(rope_dim_list), \ "Length of theta_rescale_factor must match number of dimensions" # Normalize interpolation factors to list format (per dimension) if isinstance(interpolation_factor, (int, float)): interpolation_factor = [interpolation_factor] * len(rope_dim_list) elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) assert len(interpolation_factor) == len(rope_dim_list), \ "Length of interpolation_factor must match number of dimensions" # Generate 1D rotary embeddings for each dimension and combine embs = [] for i in range(len(rope_dim_list)): # Flatten grid dimension and generate embeddings emb = get_1d_rotary_pos_embed( rope_dim_list[i], grid[i].reshape(-1), # Flatten to 1D positions theta, use_real=use_real, theta_rescale_factor=theta_rescale_factor[i], interpolation_factor=interpolation_factor[i] ) embs.append(emb) # Combine embeddings from all dimensions if use_real: # Return separate cosine and sine components cos = torch.cat([emb[0] for emb in embs], dim=1) sin = torch.cat([emb[1] for emb in embs], dim=1) return cos, sin else: # Return combined embedding return torch.cat(embs, dim=1)