diff --git "a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py" "b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py"
--- "a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py"
+++ "b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py"
@@ -1,4 +1,4 @@
-# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,149 +12,336 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
-import gc
-import html
 import inspect
-import re
-import urllib.parse as ul
-from dataclasses import dataclass
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 import torch
 import torch.nn.functional as F
-from diffusers import DiffusionPipeline, ImagePipelineOutput
+from dataclasses import dataclass
+from diffusers import DiffusionPipeline
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
 from diffusers.image_processor import VaeImageProcessor
-from diffusers.models import AutoencoderKL
-from diffusers.schedulers import DPMSolverMultistepScheduler
+from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
+from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
+                                         get_3d_rotary_pos_embed)
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import \
+    StableDiffusionSafetyChecker
+from diffusers.schedulers import DDIMScheduler, FlowMatchEulerDiscreteScheduler
 from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
-                             is_bs4_available, is_ftfy_available, logging,
+                             is_bs4_available, is_ftfy_available,
+                             is_torch_xla_available, logging,
                              replace_example_docstring)
 from diffusers.utils.torch_utils import randn_tensor
 from einops import rearrange
 from PIL import Image
 from tqdm import tqdm
-from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
+from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
+                          Qwen2Tokenizer, Qwen2VLForConditionalGeneration, CLIPVisionModelWithProjection,
                           T5EncoderModel, T5Tokenizer)
 
-from ..models.transformer3d import Transformer3DModel
+from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
 
-logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
 
-if is_bs4_available():
-    from bs4 import BeautifulSoup
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
 
-if is_ftfy_available():
-    import ftfy
 
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 
 EXAMPLE_DOC_STRING = """
     Examples:
         ```py
         >>> import torch
-        >>> from diffusers import EasyAnimatePipeline
-
-        >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
-        >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
-        >>> # Enable memory optimizations.
-        >>> pipe.enable_model_cpu_offload()
-
-        >>> prompt = "A small cactus with a happy face in the Sahara desert."
-        >>> image = pipe(prompt).images[0]
+        >>> from diffusers import EasyAnimateInpaintPipeline
+        >>> from diffusers.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent
+        >>> from diffusers.utils import export_to_video, load_image
+
+        >>> pipe = EasyAnimateInpaintPipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh-InP", torch_dtype=torch.bfloat16)
+        >>> pipe.to("cuda")
+
+        >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+        >>> validation_image_start = load_image(
+        ...     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+        ... )
+        >>> validation_image_end = None
+        >>> sample_size = (576, 448)
+        >>> video_length = 49
+        >>> input_video, input_video_mask, _ = get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size)
+        >>> video = pipe(prompt, video_length=video_length, negative_prompt="bad detailed", height=sample_size[0], width=sample_size[1], input_video=input_video, mask_video=input_video_mask)
+        >>> export_to_video(video.sample[0], "output.mp4", fps=8)
         ```
 """
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(encoder_output, generator):
-    if hasattr(encoder_output, "latent_dist"):
-        return encoder_output.latent_dist.sample(generator)
-    elif hasattr(encoder_output, "latents"):
-        return encoder_output.latents
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+    tw = tgt_width
+    th = tgt_height
+    h, w = src
+    r = h / w
+    if r > (th / tw):
+        resize_height = th
+        resize_width = int(round(th / h * w))
+    else:
+        resize_width = tw
+        resize_height = int(round(tw / w * h))
+
+    crop_top = int(round((th - resize_height) / 2.0))
+    crop_left = int(round((tw - resize_width) / 2.0))
+
+    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+# Resize mask information in magvit
+def resize_mask(mask, latent, process_first_frame_only=True):
+    latent_size = latent.size()
+
+    if process_first_frame_only:
+        target_size = list(latent_size[2:])
+        target_size[0] = 1
+        first_frame_resized = F.interpolate(
+            mask[:, :, 0:1, :, :],
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+        
+        target_size = list(latent_size[2:])
+        target_size[0] = target_size[0] - 1
+        if target_size[0] != 0:
+            remaining_frames_resized = F.interpolate(
+                mask[:, :, 1:, :, :],
+                size=target_size,
+                mode='trilinear',
+                align_corners=False
+            )
+            resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+        else:
+            resized_mask = first_frame_resized
     else:
-        raise AttributeError("Could not access latents of provided encoder_output")
+        target_size = list(latent_size[2:])
+        resized_mask = F.interpolate(
+            mask,
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+    return resized_mask
+
+
+## Add noise to reference video
+def add_noise_to_reference_video(image, ratio=None):
+    if ratio is None:
+        sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
+        sigma = torch.exp(sigma).to(image.dtype)
+    else:
+        sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
     
+    image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
+    image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
+    image = image + image_noise
+    return image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
 @dataclass
 class EasyAnimatePipelineOutput(BaseOutput):
-    videos: Union[torch.Tensor, np.ndarray]
+    r"""
+    Output class for EasyAnimate pipelines.
+
+    Args:
+        frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+            List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+            denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+            `(batch_size, num_frames, channels, height, width)`.
+    """
+
+    frames: torch.Tensor
+
 
 class EasyAnimateInpaintPipeline(DiffusionPipeline):
     r"""
-    Pipeline for text-to-image generation using PixArt-Alpha.
+    Pipeline for text-to-video generation using EasyAnimate.
 
     This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
     library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
 
+    EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+    EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
+    HunyuanDiT team) in V5.
+
     Args:
-        vae ([`AutoencoderKL`]):
-            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
-        text_encoder ([`T5EncoderModel`]):
-            Frozen text-encoder. PixArt-Alpha uses
-            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
-            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
-        tokenizer (`T5Tokenizer`):
-            Tokenizer of class
-            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
-        transformer ([`Transformer3DModel`]):
-            A text conditioned `Transformer3DModel` to denoise the encoded image latents.
-        scheduler ([`SchedulerMixin`]):
-            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+        vae ([`AutoencoderKLMagvit`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. 
+        text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
+            EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
+            EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
+        tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
+            A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
+        transformer ([`EasyAnimateTransformer3DModel`]):
+            The EasyAnimate model designed by EasyAnimate Team.
+        text_encoder_2 (`T5EncoderModel`):
+            EasyAnimate does not use text_encoder_2 in V5.1.
+            EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
+        tokenizer_2 (`T5Tokenizer`):
+            The tokenizer for the mT5 embedder.
+        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+            A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
+        clip_image_processor (`CLIPImageProcessor`):
+            The CLIP image embedder. 
+        clip_image_encoder (`CLIPVisionModelWithProjection`):
+            The image processor for the CLIP image embedder.
     """
-    bad_punct_regex = re.compile(
-        r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
-    )  # noqa
 
-    _optional_components = ["tokenizer", "text_encoder"]
-    model_cpu_offload_seq = "text_encoder->transformer->vae"
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae"
+    _optional_components = [
+        "text_encoder_2",
+        "tokenizer_2",
+        "text_encoder",
+        "tokenizer",
+        "clip_image_encoder",
+    ]
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+        "prompt_embeds_2",
+        "negative_prompt_embeds_2",
+    ]
 
     def __init__(
         self,
-        tokenizer: T5Tokenizer,
-        text_encoder: T5EncoderModel,
-        vae: AutoencoderKL,
-        transformer: Transformer3DModel,
-        scheduler: DPMSolverMultistepScheduler,
-        clip_image_processor:CLIPImageProcessor = None,
-        clip_image_encoder:CLIPVisionModelWithProjection = None,
+        vae: AutoencoderKLMagvit,
+        text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
+        tokenizer: Union[Qwen2Tokenizer, BertTokenizer], 
+        text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
+        tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
+        transformer: EasyAnimateTransformer3DModel,
+        scheduler: FlowMatchEulerDiscreteScheduler,
+        clip_image_processor: CLIPImageProcessor = None,
+        clip_image_encoder: CLIPVisionModelWithProjection = None,
     ):
         super().__init__()
 
         self.register_modules(
-            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, 
+            vae=vae,
+            text_encoder=text_encoder,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            transformer=transformer,
             scheduler=scheduler,
-            clip_image_processor=clip_image_processor, clip_image_encoder=clip_image_encoder,
+            text_encoder_2=text_encoder_2,
+            clip_image_processor=clip_image_processor, 
+            clip_image_encoder=clip_image_encoder,
         )
 
         self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
-        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=True)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         self.mask_processor = VaeImageProcessor(
             vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
         )
-        self.enable_autocast_float8_transformer_flag = False
 
-    # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
-    def mask_text_embeddings(self, emb, mask):
-        if emb.shape[0] == 1:
-            keep_index = mask.sum().item()
-            return emb[:, :, :keep_index, :], keep_index
-        else:
-            masked_feature = emb * mask[:, None, :, None]
-            return masked_feature, emb.shape[2]
+    def enable_sequential_cpu_offload(self, *args, **kwargs):
+        super().enable_sequential_cpu_offload(*args, **kwargs)
+        if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
+            import accelerate
+            accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
+            self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
 
-    # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
     def encode_prompt(
         self,
-        prompt: Union[str, List[str]],
-        do_classifier_free_guidance: bool = True,
-        negative_prompt: str = "",
+        prompt: str,
+        device: torch.device,
+        dtype: torch.dtype,
         num_images_per_prompt: int = 1,
-        device: Optional[torch.device] = None,
-        prompt_embeds: Optional[torch.FloatTensor] = None,
-        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
-        prompt_attention_mask: Optional[torch.FloatTensor] = None,
-        negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
-        clean_caption: bool = False,
-        max_sequence_length: int = 120,
-        **kwargs,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt: Optional[str] = None,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        prompt_attention_mask: Optional[torch.Tensor] = None,
+        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+        max_sequence_length: Optional[int] = None,
+        text_encoder_index: int = 0,
+        actual_max_sequence_length: int = 256
     ):
         r"""
         Encodes the prompt into text encoder hidden states.
@@ -162,33 +349,46 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         Args:
             prompt (`str` or `List[str]`, *optional*):
                 prompt to be encoded
-            negative_prompt (`str` or `List[str]`, *optional*):
-                The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
-                instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
-                PixArt-Alpha, this should be "".
-            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
-                whether to use classifier free guidance or not
-            num_images_per_prompt (`int`, *optional*, defaults to 1):
+            device: (`torch.device`):
+                torch device
+            dtype (`torch.dtype`):
+                torch dtype
+            num_images_per_prompt (`int`):
                 number of images that should be generated per prompt
-            device: (`torch.device`, *optional*):
-                torch device to place the resulting embeddings on
-            prompt_embeds (`torch.FloatTensor`, *optional*):
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            prompt_embeds (`torch.Tensor`, *optional*):
                 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                 provided, text embeddings will be generated from `prompt` input argument.
-            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
-                Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
-                string.
-            clean_caption (`bool`, defaults to `False`):
-                If `True`, the function will preprocess and clean the provided caption before encoding.
-            max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            prompt_attention_mask (`torch.Tensor`, *optional*):
+                Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
+            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+                Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
+            max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
+            text_encoder_index (`int`, *optional*):
+                Index of the text encoder to use. `0` for clip and `1` for T5.
         """
+        tokenizers = [self.tokenizer, self.tokenizer_2]
+        text_encoders = [self.text_encoder, self.text_encoder_2]
 
-        if "mask_feature" in kwargs:
-            deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
-            deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
+        tokenizer = tokenizers[text_encoder_index]
+        text_encoder = text_encoders[text_encoder_index]
 
-        if device is None:
-            device = self._execution_device
+        if max_sequence_length is None:
+            if text_encoder_index == 0:
+                max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
+            if text_encoder_index == 1:
+                max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
+        else:
+            max_length = max_sequence_length
 
         if prompt is not None and isinstance(prompt, str):
             batch_size = 1
@@ -197,74 +397,199 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         else:
             batch_size = prompt_embeds.shape[0]
 
-        # See Section 3.1. of the paper.
-        max_length = max_sequence_length
-
         if prompt_embeds is None:
-            prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
-            text_inputs = self.tokenizer(
-                prompt,
-                padding="max_length",
-                max_length=max_length,
-                truncation=True,
-                add_special_tokens=True,
-                return_tensors="pt",
-            )
-            text_input_ids = text_inputs.input_ids
-            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
-            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
-                text_input_ids, untruncated_ids
-            ):
-                removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
-                logger.warning(
-                    "The following part of your input was truncated because CLIP can only handle sequences up to"
-                    f" {max_length} tokens: {removed_text}"
+            if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_attention_mask=True,
+                    return_tensors="pt",
+                )
+                text_input_ids = text_inputs.input_ids
+                if text_input_ids.shape[-1] > actual_max_sequence_length:
+                    reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
+                    text_inputs = tokenizer(
+                        reprompt,
+                        padding="max_length",
+                        max_length=max_length,
+                        truncation=True,
+                        return_attention_mask=True,
+                        return_tensors="pt",
+                    )
+                    text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {_actual_max_sequence_length} tokens: {removed_text}"
+                    )
+
+                prompt_attention_mask = text_inputs.attention_mask.to(device)
+
+                if self.transformer.config.enable_text_attention_mask:
+                    prompt_embeds = text_encoder(
+                        text_input_ids.to(device),
+                        attention_mask=prompt_attention_mask,
+                    )
+                else:
+                    prompt_embeds = text_encoder(
+                        text_input_ids.to(device)
+                    )
+                prompt_embeds = prompt_embeds[0]
+                prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+            else:
+                if prompt is not None and isinstance(prompt, str):
+                    messages = [
+                        {
+                            "role": "user",
+                            "content": [{"type": "text", "text": prompt}],
+                        }
+                    ]
+                else:
+                    messages = [
+                        {
+                            "role": "user",
+                            "content": [{"type": "text", "text": _prompt}],
+                        } for _prompt in prompt
+                    ]
+                text = tokenizer.apply_chat_template(
+                    messages, tokenize=False, add_generation_prompt=True
                 )
 
-            prompt_attention_mask = text_inputs.attention_mask
-            prompt_attention_mask = prompt_attention_mask.to(device)
-
-            prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
-            prompt_embeds = prompt_embeds[0]
-
-        if self.text_encoder is not None:
-            dtype = self.text_encoder.dtype
-        elif self.transformer is not None:
-            dtype = self.transformer.dtype
-        else:
-            dtype = None
-
+                text_inputs = tokenizer(
+                    text=[text],
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_attention_mask=True,
+                    padding_side="right",
+                    return_tensors="pt",
+                )
+                text_inputs = text_inputs.to(text_encoder.device)
+
+                text_input_ids = text_inputs.input_ids
+                prompt_attention_mask = text_inputs.attention_mask
+                if self.transformer.config.enable_text_attention_mask:
+                    # Inference: Generation of the output
+                    prompt_embeds = text_encoder(
+                        input_ids=text_input_ids,
+                        attention_mask=prompt_attention_mask,
+                        output_hidden_states=True).hidden_states[-2]
+                else:
+                    raise ValueError("LLM needs attention_mask")
+                prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+        
         prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
 
         bs_embed, seq_len, _ = prompt_embeds.shape
-        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
         prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
-        prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
-        prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+        prompt_attention_mask = prompt_attention_mask.to(device=device)
 
         # get unconditional embeddings for classifier free guidance
         if do_classifier_free_guidance and negative_prompt_embeds is None:
-            uncond_tokens = [negative_prompt] * batch_size
-            uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
-            max_length = prompt_embeds.shape[1]
-            uncond_input = self.tokenizer(
-                uncond_tokens,
-                padding="max_length",
-                max_length=max_length,
-                truncation=True,
-                return_attention_mask=True,
-                add_special_tokens=True,
-                return_tensors="pt",
-            )
-            negative_prompt_attention_mask = uncond_input.attention_mask
-            negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
+            if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
+                uncond_tokens: List[str]
+                if negative_prompt is None:
+                    uncond_tokens = [""] * batch_size
+                elif prompt is not None and type(prompt) is not type(negative_prompt):
+                    raise TypeError(
+                        f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                        f" {type(prompt)}."
+                    )
+                elif isinstance(negative_prompt, str):
+                    uncond_tokens = [negative_prompt]
+                elif batch_size != len(negative_prompt):
+                    raise ValueError(
+                        f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                        f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                        " the batch size of `prompt`."
+                    )
+                else:
+                    uncond_tokens = negative_prompt
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    uncond_tokens,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+                uncond_input_ids = uncond_input.input_ids
+                if uncond_input_ids.shape[-1] > actual_max_sequence_length:
+                    reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
+                    uncond_input = tokenizer(
+                        reuncond_tokens,
+                        padding="max_length",
+                        max_length=max_length,
+                        truncation=True,
+                        return_attention_mask=True,
+                        return_tensors="pt",
+                    )
+                    uncond_input_ids = uncond_input.input_ids
+
+                negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
+                if self.transformer.config.enable_text_attention_mask:
+                    negative_prompt_embeds = text_encoder(
+                        uncond_input.input_ids.to(device),
+                        attention_mask=negative_prompt_attention_mask,
+                    )
+                else:
+                    negative_prompt_embeds = text_encoder(
+                        uncond_input.input_ids.to(device)
+                    )
+                negative_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+            else:
+                if negative_prompt is not None and isinstance(negative_prompt, str):
+                    messages = [
+                        {
+                            "role": "user",
+                            "content": [{"type": "text", "text": negative_prompt}],
+                        }
+                    ]
+                else:
+                    messages = [
+                        {
+                            "role": "user",
+                            "content": [{"type": "text", "text": _negative_prompt}],
+                        } for _negative_prompt in negative_prompt
+                    ]
+                text = tokenizer.apply_chat_template(
+                    messages, tokenize=False, add_generation_prompt=True
+                )
 
-            negative_prompt_embeds = self.text_encoder(
-                uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
-            )
-            negative_prompt_embeds = negative_prompt_embeds[0]
+                text_inputs = tokenizer(
+                    text=[text],
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_attention_mask=True,
+                    padding_side="right",
+                    return_tensors="pt",
+                )
+                text_inputs = text_inputs.to(text_encoder.device)
+
+                text_input_ids = text_inputs.input_ids
+                negative_prompt_attention_mask = text_inputs.attention_mask
+                if self.transformer.config.enable_text_attention_mask:
+                    # Inference: Generation of the output
+                    negative_prompt_embeds = text_encoder(
+                        input_ids=text_input_ids,
+                        attention_mask=negative_prompt_attention_mask,
+                        output_hidden_states=True).hidden_states[-2]
+                else:
+                    raise ValueError("LLM needs attention_mask")
+                negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
 
         if do_classifier_free_guidance:
             # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -274,14 +599,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
 
             negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
             negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
-
-            negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
-            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
-        else:
-            negative_prompt_embeds = None
-            negative_prompt_attention_mask = None
-
-        return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+            negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
+            
+        return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
 
     # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
     def prepare_extra_step_kwargs(self, generator, eta):
@@ -306,20 +626,25 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         prompt,
         height,
         width,
-        negative_prompt,
-        callback_steps,
+        negative_prompt=None,
         prompt_embeds=None,
         negative_prompt_embeds=None,
+        prompt_attention_mask=None,
+        negative_prompt_attention_mask=None,
+        prompt_embeds_2=None,
+        negative_prompt_embeds_2=None,
+        prompt_attention_mask_2=None,
+        negative_prompt_attention_mask_2=None,
+        callback_on_step_end_tensor_inputs=None,
     ):
-        if height % 8 != 0 or width % 8 != 0:
-            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+        if height % 16 != 0 or width % 16 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
 
-        if (callback_steps is None) or (
-            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
         ):
             raise ValueError(
-                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
-                f" {type(callback_steps)}."
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
             )
 
         if prompt is not None and prompt_embeds is not None:
@@ -331,14 +656,18 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
             raise ValueError(
                 "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
             )
+        elif prompt is None and prompt_embeds_2 is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
+            )
         elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
             raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
-        if prompt is not None and negative_prompt_embeds is not None:
-            raise ValueError(
-                f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
-                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
-            )
+        if prompt_embeds is not None and prompt_attention_mask is None:
+            raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+        if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
+            raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
 
         if negative_prompt is not None and negative_prompt_embeds is not None:
             raise ValueError(
@@ -346,6 +675,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
                 f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
             )
 
+        if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+            raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+        if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
+            raise ValueError(
+                "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
+            )
         if prompt_embeds is not None and negative_prompt_embeds is not None:
             if prompt_embeds.shape != negative_prompt_embeds.shape:
                 raise ValueError(
@@ -353,201 +689,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
                     f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                     f" {negative_prompt_embeds.shape}."
                 )
+        if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
+            if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
+                raise ValueError(
+                    "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
+                    f" {negative_prompt_embeds_2.shape}."
+                )
 
-    # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
-    def _text_preprocessing(self, text, clean_caption=False):
-        if clean_caption and not is_bs4_available():
-            logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
-            logger.warn("Setting `clean_caption` to False...")
-            clean_caption = False
-
-        if clean_caption and not is_ftfy_available():
-            logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
-            logger.warn("Setting `clean_caption` to False...")
-            clean_caption = False
-
-        if not isinstance(text, (tuple, list)):
-            text = [text]
-
-        def process(text: str):
-            if clean_caption:
-                text = self._clean_caption(text)
-                text = self._clean_caption(text)
-            else:
-                text = text.lower().strip()
-            return text
-
-        return [process(t) for t in text]
-
-    # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
-    def _clean_caption(self, caption):
-        caption = str(caption)
-        caption = ul.unquote_plus(caption)
-        caption = caption.strip().lower()
-        caption = re.sub("<person>", "person", caption)
-        # urls:
-        caption = re.sub(
-            r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",  # noqa
-            "",
-            caption,
-        )  # regex for urls
-        caption = re.sub(
-            r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",  # noqa
-            "",
-            caption,
-        )  # regex for urls
-        # html:
-        caption = BeautifulSoup(caption, features="html.parser").text
-
-        # @<nickname>
-        caption = re.sub(r"@[\w\d]+\b", "", caption)
-
-        # 31C0—31EF CJK Strokes
-        # 31F0—31FF Katakana Phonetic Extensions
-        # 3200—32FF Enclosed CJK Letters and Months
-        # 3300—33FF CJK Compatibility
-        # 3400—4DBF CJK Unified Ideographs Extension A
-        # 4DC0—4DFF Yijing Hexagram Symbols
-        # 4E00—9FFF CJK Unified Ideographs
-        caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
-        caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
-        caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
-        caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
-        caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
-        caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
-        caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
-        #######################################################
-
-        # все виды тире / all types of dash --> "-"
-        caption = re.sub(
-            r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",  # noqa
-            "-",
-            caption,
-        )
-
-        # кавычки к одному стандарту
-        caption = re.sub(r"[`´«»“”¨]", '"', caption)
-        caption = re.sub(r"[‘’]", "'", caption)
-
-        # &quot;
-        caption = re.sub(r"&quot;?", "", caption)
-        # &amp
-        caption = re.sub(r"&amp", "", caption)
-
-        # ip adresses:
-        caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
-
-        # article ids:
-        caption = re.sub(r"\d:\d\d\s+$", "", caption)
-
-        # \n
-        caption = re.sub(r"\\n", " ", caption)
-
-        # "#123"
-        caption = re.sub(r"#\d{1,3}\b", "", caption)
-        # "#12345.."
-        caption = re.sub(r"#\d{5,}\b", "", caption)
-        # "123456.."
-        caption = re.sub(r"\b\d{6,}\b", "", caption)
-        # filenames:
-        caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
-
-        #
-        caption = re.sub(r"[\"\']{2,}", r'"', caption)  # """AUSVERKAUFT"""
-        caption = re.sub(r"[\.]{2,}", r" ", caption)  # """AUSVERKAUFT"""
-
-        caption = re.sub(self.bad_punct_regex, r" ", caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT
-        caption = re.sub(r"\s+\.\s+", r" ", caption)  # " . "
-
-        # this-is-my-cute-cat / this_is_my_cute_cat
-        regex2 = re.compile(r"(?:\-|\_)")
-        if len(re.findall(regex2, caption)) > 3:
-            caption = re.sub(regex2, " ", caption)
-
-        caption = ftfy.fix_text(caption)
-        caption = html.unescape(html.unescape(caption))
-
-        caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption)  # jc6640
-        caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption)  # jc6640vc
-        caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption)  # 6640vc231
-
-        caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
-        caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
-        caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
-        caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
-        caption = re.sub(r"\bpage\s+\d+\b", "", caption)
-
-        caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption)  # j2d1a2a...
-
-        caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
-
-        caption = re.sub(r"\b\s+\:\s+", r": ", caption)
-        caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
-        caption = re.sub(r"\s+", " ", caption)
-
-        caption.strip()
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device):
+        # get the original timestep using init_timestep
+        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
 
-        caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
-        caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
-        caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
-        caption = re.sub(r"^\.\S+$", "", caption)
+        t_start = max(num_inference_steps - init_timestep, 0)
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
 
-        return caption.strip()
+        return timesteps, num_inference_steps - t_start
 
     def prepare_mask_latents(
-        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
     ):
         # resize the mask to latents shape as we concatenate the mask to the latents
         # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         # and half precision
-        video_length = mask.shape[2]
-
-        mask = mask.to(device=device, dtype=self.vae.dtype)
-        if self.vae.quant_conv.weight.ndim==5:
-            bs = 1
-            new_mask = []
-            for i in range(0, mask.shape[0], bs):
-                mask_bs = mask[i : i + bs]
-                mask_bs = self.vae.encode(mask_bs)[0]
-                mask_bs = mask_bs.sample()
-                new_mask.append(mask_bs)
-            mask = torch.cat(new_mask, dim = 0)
-            mask = mask * self.vae.config.scaling_factor
+        if mask is not None:
+            mask = mask.to(device=device, dtype=dtype)
+            if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
+                bs = 1
+                new_mask = []
+                for i in range(0, mask.shape[0], bs):
+                    mask_bs = mask[i : i + bs]
+                    mask_bs = self.vae.encode(mask_bs)[0]
+                    mask_bs = mask_bs.mode()
+                    new_mask.append(mask_bs)
+                mask = torch.cat(new_mask, dim = 0)
+                mask = mask * self.vae.config.scaling_factor
 
-        else:
-            if mask.shape[1] == 4:
-                mask = mask
             else:
-                video_length = mask.shape[2]
-                mask = rearrange(mask, "b c f h w -> (b f) c h w")
-                mask = self._encode_vae_image(mask, generator=generator)
-                mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
-
-        masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
-        if self.vae.quant_conv.weight.ndim==5:
-            bs = 1
-            new_mask_pixel_values = []
-            for i in range(0, masked_image.shape[0], bs):
-                mask_pixel_values_bs = masked_image[i : i + bs]
-                mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
-                mask_pixel_values_bs = mask_pixel_values_bs.sample()
-                new_mask_pixel_values.append(mask_pixel_values_bs)
-            masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
-            masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+                if mask.shape[1] == 4:
+                    mask = mask
+                else:
+                    video_length = mask.shape[2]
+                    mask = rearrange(mask, "b c f h w -> (b f) c h w")
+                    mask = self._encode_vae_image(mask, generator=generator)
+                    mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
+
+        if masked_image is not None:
+            masked_image = masked_image.to(device=device, dtype=dtype)
+            if self.transformer.config.add_noise_in_inpaint_model:
+                masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
+            if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
+                bs = 1
+                new_mask_pixel_values = []
+                for i in range(0, masked_image.shape[0], bs):
+                    mask_pixel_values_bs = masked_image[i : i + bs]
+                    mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+                    mask_pixel_values_bs = mask_pixel_values_bs.mode()
+                    new_mask_pixel_values.append(mask_pixel_values_bs)
+                masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+                masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
 
-        else:
-            if masked_image.shape[1] == 4:
-                masked_image_latents = masked_image
             else:
-                video_length = mask.shape[2]
-                masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
-                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
-                masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
+                if masked_image.shape[1] == 4:
+                    masked_image_latents = masked_image
+                else:
+                    video_length = masked_image.shape[2]
+                    masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
+                    masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+                    masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
+
+            # aligning device to prevent device errors when concating it with the latent model input
+            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+        else:
+            masked_image_latents = None
 
-        # aligning device to prevent device errors when concating it with the latent model input
-        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
         return mask, masked_image_latents
-    
+
     def prepare_latents(
         self, 
         batch_size,
@@ -565,10 +783,15 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         return_noise=False,
         return_video_latents=False,
     ):
-        if self.vae.quant_conv.weight.ndim==5:
-            mini_batch_encoder = self.vae.mini_batch_encoder
-            mini_batch_decoder = self.vae.mini_batch_decoder
-            shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
+            if self.vae.cache_mag_vae:
+                mini_batch_encoder = self.vae.mini_batch_encoder
+                mini_batch_decoder = self.vae.mini_batch_decoder
+                shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
+            else:
+                mini_batch_encoder = self.vae.mini_batch_encoder
+                mini_batch_decoder = self.vae.mini_batch_decoder
+                shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         else:
             shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
@@ -579,10 +802,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
             )
 
         if return_video_latents or (latents is None and not is_strength_max):
-            video = video.to(device=device, dtype=self.vae.dtype)
-            if self.vae.quant_conv.weight.ndim==5:
+            video = video.to(device=device, dtype=dtype)
+            if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
                 bs = 1
-                mini_batch_encoder = self.vae.mini_batch_encoder
                 new_video = []
                 for i in range(0, video.shape[0], bs):
                     video_bs = video[i : i + bs]
@@ -601,16 +823,24 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
                     video = self._encode_vae_image(video, generator=generator)
                     video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
             video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
+            video_latents = video_latents.to(device=device, dtype=dtype)
 
         if latents is None:
             noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
             # if strength is 1. then initialise the latents to noise, else initial to image + noise
-            latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
+            if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+                latents = noise if is_strength_max else self.scheduler.scale_noise(video_latents, timestep, noise)
+            else:
+                latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
             # if pure noise then scale the initial latents by the  Scheduler's init sigma
-            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+            if hasattr(self.scheduler, "init_noise_sigma"):
+                latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
         else:
-            noise = latents.to(device)
-            latents = noise * self.scheduler.init_noise_sigma
+            if hasattr(self.scheduler, "init_noise_sigma"):
+                noise = latents.to(device)
+                latents = noise * self.scheduler.init_noise_sigma
+            else:
+                latents = latents.to(device)
 
         # scale the initial noise by the standard deviation required by the scheduler
         outputs = (latents,)
@@ -632,22 +862,23 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
 
         # Encode middle videos
         latents = self.vae.encode(pixel_values)[0]
-        latents = latents.sample()
+        latents = latents.mode()
         # Decode middle videos
         middle_video = self.vae.decode(latents)[0]
 
         video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
         return video
-    
+
     def decode_latents(self, latents):
         video_length = latents.shape[2]
         latents = 1 / self.vae.config.scaling_factor * latents
-        if self.vae.quant_conv.weight.ndim==5:
+        if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
             mini_batch_encoder = self.vae.mini_batch_encoder
             mini_batch_decoder = self.vae.mini_batch_decoder
             video = self.vae.decode(latents)[0]
             video = video.clamp(-1, 1)
-            video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
+            if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
+                video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
         else:
             latents = rearrange(latents, "b c f h w -> (b f) c h w")
             video = []
@@ -660,32 +891,28 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         video = video.cpu().float().numpy()
         return video
 
-    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
-        if isinstance(generator, list):
-            image_latents = [
-                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
-                for i in range(image.shape[0])
-            ]
-            image_latents = torch.cat(image_latents, dim=0)
-        else:
-            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
 
-        image_latents = self.vae.config.scaling_factor * image_latents
+    @property
+    def guidance_rescale(self):
+        return self._guidance_rescale
 
-        return image_latents
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
-    def get_timesteps(self, num_inference_steps, strength, device):
-        # get the original timestep using init_timestep
-        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
 
-        t_start = max(num_inference_steps - init_timestep, 0)
-        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
-
-        return timesteps, num_inference_steps - t_start
-
-    def enable_autocast_float8_transformer(self):
-        self.enable_autocast_float8_transformer_flag = True
+    @property
+    def interrupt(self):
+        return self._interrupt
 
     @torch.no_grad()
     @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -696,109 +923,167 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         video: Union[torch.FloatTensor] = None,
         mask_video: Union[torch.FloatTensor] = None,
         masked_video_latents: Union[torch.FloatTensor] = None,
-        negative_prompt: str = "",
-        num_inference_steps: int = 20,
-        timesteps: List[int] = None,
-        guidance_scale: float = 4.5,
-        num_images_per_prompt: Optional[int] = 1,
         height: Optional[int] = None,
         width: Optional[int] = None,
-        strength: float = 1.0,
-        eta: float = 0.0,
+        num_inference_steps: Optional[int] = 50,
+        guidance_scale: Optional[float] = 5.0,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: Optional[float] = 0.0,
         generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
-        latents: Optional[torch.FloatTensor] = None,
-        prompt_embeds: Optional[torch.FloatTensor] = None,
-        prompt_attention_mask: Optional[torch.FloatTensor] = None,
-        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
-        negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
+        latents: Optional[torch.Tensor] = None,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        prompt_embeds_2: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds_2: Optional[torch.Tensor] = None,
+        prompt_attention_mask: Optional[torch.Tensor] = None,
+        prompt_attention_mask_2: Optional[torch.Tensor] = None,
+        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+        negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
         output_type: Optional[str] = "latent",
         return_dict: bool = True,
-        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
-        callback_steps: int = 1,
-        clean_caption: bool = True,
-        mask_feature: bool = True,
-        max_sequence_length: int = 120,
+        callback_on_step_end: Optional[
+            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+        ] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        guidance_rescale: float = 0.0,
+        original_size: Optional[Tuple[int, int]] = (1024, 1024),
+        target_size: Optional[Tuple[int, int]] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
         clip_image: Image = None,
-        clip_apply_ratio: float = 0.50,
+        clip_apply_ratio: float = 0.40,
+        strength: float = 1.0,
+        noise_aug_strength: float = 0.0563,
         comfyui_progressbar: bool = False,
-        **kwargs,
-    ) -> Union[EasyAnimatePipelineOutput, Tuple]:
-        """
-        Function invoked when calling the pipeline for generation.
+        timesteps: Optional[List[int]] = None,
+    ):
+        r"""
+        The call function to the pipeline for generation with HunyuanDiT.
 
-        Args:
+        Examples:
             prompt (`str` or `List[str]`, *optional*):
-                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
-                instead.
+                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+            video_length (`int`, *optional*):
+                Length of the video to be generated in seconds. This parameter influences the number of frames and
+                continuity of generated content.
+            video (`torch.FloatTensor`, *optional*):
+                A tensor representing an input video, which can be modified depending on the prompts provided.
+            mask_video (`torch.FloatTensor`, *optional*):
+                A tensor to specify areas of the video to be masked (omitted from generation).
+            masked_video_latents (`torch.FloatTensor`, *optional*):
+                Latents from masked portions of the video, utilized during image generation.
+            height (`int`, *optional*):
+                The height in pixels of the generated image or video frames.
+            width (`int`, *optional*):
+                The width in pixels of the generated image or video frames.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
+                inference time. This parameter is modulated by `strength`.
+            guidance_scale (`float`, *optional*, defaults to 5.0):
+                A higher guidance scale value encourages the model to generate images closely linked to the text 
+                `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
             negative_prompt (`str` or `List[str]`, *optional*):
-                The prompt or prompts not to guide the image generation. If not defined, one has to pass
-                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
-                less than `1`).
-            num_inference_steps (`int`, *optional*, defaults to 100):
-                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
-                expense of slower inference.
-            timesteps (`List[int]`, *optional*):
-                Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
-                timesteps are used. Must be in descending order.
-            guidance_scale (`float`, *optional*, defaults to 7.0):
-                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
-                `guidance_scale` is defined as `w` of equation 2. of [Imagen
-                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
-                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
-                usually at the expense of lower image quality.
+                The prompt or prompts to guide what to exclude in image generation. If not defined, you need to
+                provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
             num_images_per_prompt (`int`, *optional*, defaults to 1):
                 The number of images to generate per prompt.
-            height (`int`, *optional*, defaults to self.unet.config.sample_size):
-                The height in pixels of the generated image.
-            width (`int`, *optional*, defaults to self.unet.config.sample_size):
-                The width in pixels of the generated image.
             eta (`float`, *optional*, defaults to 0.0):
-                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
-                [`schedulers.DDIMScheduler`], will be ignored for others.
+                A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
+                [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the 
+                inference process.
             generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
-                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
-                to make generation deterministic.
-            latents (`torch.FloatTensor`, *optional*):
-                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
-                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
-                tensor will ge generated by sampling using the supplied random `generator`.
-            prompt_embeds (`torch.FloatTensor`, *optional*):
-                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
-                provided, text embeddings will be generated from `prompt` input argument.
-            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
-                Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
-                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
-            output_type (`str`, *optional*, defaults to `"pil"`):
-                The output format of the generate image. Choose between
-                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
+                random seeds which helps in making generation deterministic.
+            latents (`torch.Tensor`, *optional*):
+                A pre-computed latent representation which can be used to guide the generation process.
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+                provided, embeddings are generated from the `prompt` input argument.
+            prompt_embeds_2 (`torch.Tensor`, *optional*):
+                Secondary set of pre-generated text embeddings, useful for advanced prompt weighting.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs.
+                If not provided, embeddings are generated from the `negative_prompt` argument.
+            negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
+                Secondary set of pre-generated negative text embeddings for further control.
+            prompt_attention_mask (`torch.Tensor`, *optional*):
+                Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
+                `prompt_embeds`.
+            prompt_attention_mask_2 (`torch.Tensor`, *optional*):
+                Attention mask for the secondary prompt embedding.
+            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+                Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
+            negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
+                Attention mask for the secondary negative prompt embedding.
+            output_type (`str`, *optional*, defaults to `"latent"`):
+                The output format of the generated image. Choose between `PIL.Image` and `np.array` to define
+                how you want the results to be formatted.
             return_dict (`bool`, *optional*, defaults to `True`):
-                Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
-            callback (`Callable`, *optional*):
-                A function that will be called every `callback_steps` steps during inference. The function will be
-                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
-            callback_steps (`int`, *optional*, defaults to 1):
-                The frequency at which the `callback` function will be called. If not specified, the callback will be
-                called at every step.
-            clean_caption (`bool`, *optional*, defaults to `True`):
-                Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
-                be installed. If the dependencies are not installed, the embeddings will be created from the raw
-                prompt.
-            mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
+                If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
+                otherwise, a tuple containing the generated images and safety flags will be returned.
+            callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+                A callback function (or a list of them) that will be executed at the end of each denoising step,
+                allowing for custom processing during generation.
+            callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
+                Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
+                inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
+            guidance_rescale (`float`, *optional*, defaults to 0.0):
+                Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
+                [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+            original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
+                The original dimensions of the image. Used to compute time ids during the generation process.
+            target_size (`Tuple[int, int]`, *optional*):
+                The targeted dimensions of the generated image, also utilized in the time id calculations.
+            crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
+                Coordinates defining the top left corner of any cropping, utilized while calculating the time ids.
+            clip_image (`Image`, *optional*):
+                An optional image to assist in the generation process. It may be used as an additional visual cue.
+            clip_apply_ratio (`float`, *optional*, defaults to 0.40):
+                Ratio indicating how much influence the clip image should exert over the generated content.
+            strength (`float`, *optional*, defaults to 1.0):
+                Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct 
+                adherence to prompts.
+            comfyui_progressbar (`bool`, *optional*, defaults to `False`):
+                Enables a progress bar in ComfyUI, providing visual feedback during the generation process.
 
         Examples:
-
+            # Example usage of the function for generating images based on prompts.
+        
         Returns:
-            [`~pipelines.ImagePipelineOutput`] or `tuple`:
-                If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
-                returned where the first element is a list with the generated images
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+                Returns either a structured output containing generated images and their metadata when `return_dict` is
+                `True`, or a simpler tuple, where the first element is a list of generated images and the second
+                element indicates if any of them contain "not-safe-for-work" (NSFW) content.
         """
-        # 1. Check inputs. Raise error if not correct
-        height = height or self.transformer.config.sample_size * self.vae_scale_factor
-        width = width or self.transformer.config.sample_size * self.vae_scale_factor
+
+        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+        # 0. default height and width
         height = int(height // 16 * 16)
         width = int(width // 16 * 16)
 
-        # 2. Default height and width to transformer
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            prompt_embeds,
+            negative_prompt_embeds,
+            prompt_attention_mask,
+            negative_prompt_attention_mask,
+            prompt_embeds_2,
+            negative_prompt_embeds_2,
+            prompt_attention_mask_2,
+            negative_prompt_attention_mask_2,
+            callback_on_step_end_tensor_inputs,
+        )
+        self._guidance_scale = guidance_scale
+        self._guidance_rescale = guidance_rescale
+        self._interrupt = False
+
+        # 2. Define call parameters
         if prompt is not None and isinstance(prompt, str):
             batch_size = 1
         elif prompt is not None and isinstance(prompt, list):
@@ -807,40 +1092,68 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
             batch_size = prompt_embeds.shape[0]
 
         device = self._execution_device
-
-        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
-        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
-        # corresponds to doing no classifier free guidance.
-        do_classifier_free_guidance = guidance_scale > 1.0
-
+        if self.text_encoder is not None:
+            dtype = self.text_encoder.dtype
+        elif self.text_encoder_2 is not None:
+            dtype = self.text_encoder_2.dtype
+        else:
+            dtype = self.transformer.dtype
+            
         # 3. Encode input prompt
         (
             prompt_embeds,
-            prompt_attention_mask,
             negative_prompt_embeds,
+            prompt_attention_mask,
             negative_prompt_attention_mask,
         ) = self.encode_prompt(
-            prompt,
-            do_classifier_free_guidance,
-            negative_prompt=negative_prompt,
-            num_images_per_prompt=num_images_per_prompt,
+            prompt=prompt,
             device=device,
+            dtype=dtype,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=self.do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
             prompt_embeds=prompt_embeds,
             negative_prompt_embeds=negative_prompt_embeds,
             prompt_attention_mask=prompt_attention_mask,
             negative_prompt_attention_mask=negative_prompt_attention_mask,
-            clean_caption=clean_caption,
-            max_sequence_length=max_sequence_length,
+            text_encoder_index=0,
         )
-        if do_classifier_free_guidance:
-            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
-            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+        if self.tokenizer_2 is not None:
+            (
+                prompt_embeds_2,
+                negative_prompt_embeds_2,
+                prompt_attention_mask_2,
+                negative_prompt_attention_mask_2,
+            ) = self.encode_prompt(
+                prompt=prompt,
+                device=device,
+                dtype=dtype,
+                num_images_per_prompt=num_images_per_prompt,
+                do_classifier_free_guidance=self.do_classifier_free_guidance,
+                negative_prompt=negative_prompt,
+                prompt_embeds=prompt_embeds_2,
+                negative_prompt_embeds=negative_prompt_embeds_2,
+                prompt_attention_mask=prompt_attention_mask_2,
+                negative_prompt_attention_mask=negative_prompt_attention_mask_2,
+                text_encoder_index=1,
+            )
+        else:
+            prompt_embeds_2 = None
+            negative_prompt_embeds_2 = None
+            prompt_attention_mask_2 = None
+            negative_prompt_attention_mask_2 = None
 
         # 4. set timesteps
-        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+            timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
+        else:
+            timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
         timesteps, num_inference_steps = self.get_timesteps(
             num_inference_steps=num_inference_steps, strength=strength, device=device
         )
+        if comfyui_progressbar:
+            from comfy.utils import ProgressBar
+            pbar = ProgressBar(num_inference_steps + 3)
         # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
         latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
@@ -857,7 +1170,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         # Prepare latent variables
         num_channels_latents = self.vae.config.latent_channels
         num_channels_transformer = self.transformer.config.in_channels
-        return_image_latents = True # num_channels_transformer == 4
+        return_image_latents = num_channels_transformer == num_channels_latents
 
         # 5. Prepare latents.
         latents_outputs = self.prepare_latents(
@@ -866,7 +1179,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
             height,
             width,
             video_length,
-            prompt_embeds.dtype,
+            dtype,
             device,
             generator,
             latents,
@@ -880,91 +1193,153 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
             latents, noise, image_latents = latents_outputs
         else:
             latents, noise = latents_outputs
-        latents_dtype = latents.dtype
 
+        if comfyui_progressbar:
+            pbar.update(1)
+
+        # 6. Prepare clip latents if it needs.
+        if clip_image is not None and self.transformer.enable_clip_in_inpaint:
+            inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
+            inputs["pixel_values"] = inputs["pixel_values"].to(device, dtype=dtype)
+            if self.transformer.config.get("position_of_clip_embedding", "full") == "full":
+                clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
+                clip_encoder_hidden_states_neg = torch.zeros(
+                    [
+                        batch_size, 
+                        int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, 
+                        int(self.clip_image_encoder.config.hidden_size)
+                    ]
+                ).to(device, dtype=dtype)
+
+            else:
+                clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds
+                clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(device, dtype=dtype)
+
+            clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(device, dtype=dtype)
+            clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(device, dtype=dtype)
+
+            clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
+            clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
+
+        elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint:
+            if self.transformer.config.get("position_of_clip_embedding", "full") == "full":
+                clip_encoder_hidden_states = torch.zeros(
+                    [
+                        batch_size, 
+                        int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, 
+                        int(self.clip_image_encoder.config.hidden_size)
+                    ]
+                ).to(device, dtype=dtype)
+            else:
+                clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(device, dtype=dtype)
+
+            clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
+            clip_attention_mask = clip_attention_mask.to(device, dtype=dtype)
+
+            clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
+            clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
+
+        else:
+            clip_encoder_hidden_states_input = None
+            clip_attention_mask_input = None
+        if comfyui_progressbar:
+            pbar.update(1)
+
+        # 7. Prepare inpaint latents if it needs.
         if mask_video is not None:
-            # Prepare mask latent variables
-            video_length = video.shape[2]
-            mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
-            mask_condition = mask_condition.to(dtype=torch.float32)
-            mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
-
-            if num_channels_transformer == 12:
-                mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
-                if masked_video_latents is None:
-                    masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
+            if self.transformer.config.get("enable_zero_in_inpaint", True) and (mask_video == 255).all():
+                # Use zero latents if we want to t2v.
+                mask = torch.zeros_like(latents).to(device, dtype)
+                if self.transformer.resize_inpaint_mask_directly:
+                    mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype)
                 else:
-                    masked_video = masked_video_latents
-
-                mask_latents, masked_video_latents = self.prepare_mask_latents(
-                    mask_condition_tile,
-                    masked_video,
-                    batch_size,
-                    height,
-                    width,
-                    prompt_embeds.dtype,
-                    device,
-                    generator,
-                    do_classifier_free_guidance,
-                )
-                mask = torch.tile(mask_condition, [1, num_channels_transformer // 3, 1, 1, 1])
-                mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
-                
-                mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+                    mask_latents = torch.zeros_like(latents).to(device, dtype)
+                masked_video_latents = torch.zeros_like(latents).to(device, dtype)
+
+                mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
                 masked_video_latents_input = (
-                    torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+                    torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
                 )
-                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
+                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
             else:
-                mask = torch.tile(mask_condition, [1, num_channels_transformer, 1, 1, 1])
-                mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
-                
-                inpaint_latents = None
+                # Prepare mask latent variables
+                video_length = video.shape[2]
+                mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
+                mask_condition = mask_condition.to(dtype=torch.float32)
+                mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
+
+                if num_channels_transformer != num_channels_latents:
+                    mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
+                    if masked_video_latents is None:
+                        masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
+                    else:
+                        masked_video = masked_video_latents
+                    
+                    if self.transformer.resize_inpaint_mask_directly:
+                        _, masked_video_latents = self.prepare_mask_latents(
+                            None,
+                            masked_video,
+                            batch_size,
+                            height,
+                            width,
+                            dtype,
+                            device,
+                            generator,
+                            self.do_classifier_free_guidance,
+                            noise_aug_strength=noise_aug_strength,
+                        )
+                        mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae)
+                        mask_latents = mask_latents.to(device, dtype) * self.vae.config.scaling_factor
+                    else:
+                        mask_latents, masked_video_latents = self.prepare_mask_latents(
+                            mask_condition_tile,
+                            masked_video,
+                            batch_size,
+                            height,
+                            width,
+                            dtype,
+                            device,
+                            generator,
+                            self.do_classifier_free_guidance,
+                            noise_aug_strength=noise_aug_strength,
+                        )
+                    
+                    mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
+                    masked_video_latents_input = (
+                        torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
+                    )
+                    inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
+                else:
+                    inpaint_latents = None
+
+                mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
+                mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, dtype)
         else:
-            if num_channels_transformer == 12:
-                mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
-                masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
+            if num_channels_transformer != num_channels_latents:
+                mask = torch.zeros_like(latents).to(device, dtype)
+                if self.transformer.resize_inpaint_mask_directly:
+                    mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype)
+                else:
+                    mask_latents = torch.zeros_like(latents).to(device, dtype)
+                masked_video_latents = torch.zeros_like(latents).to(device, dtype)
 
-                mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+                mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
                 masked_video_latents_input = (
-                    torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+                    torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
                 )
-                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
+                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype)
             else:
                 mask = torch.zeros_like(init_video[:, :1])
-                mask = torch.tile(mask, [1, num_channels_transformer, 1, 1, 1])
-                mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+                mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
+                mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, dtype)
 
                 inpaint_latents = None
-    
-        if clip_image is not None:
-            inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
-            inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
-            clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds
-            clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
-
-            clip_attention_mask = torch.ones([batch_size, 8]).to(latents.device, dtype=latents.dtype)
-            clip_attention_mask_neg = torch.zeros([batch_size, 8]).to(latents.device, dtype=latents.dtype)
 
-            clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if do_classifier_free_guidance else clip_encoder_hidden_states
-            clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if do_classifier_free_guidance else clip_attention_mask
-
-        elif clip_image is None and num_channels_transformer == 12:
-            clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
-
-            clip_attention_mask = torch.zeros([batch_size, 8])
-            clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
-
-            clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if do_classifier_free_guidance else clip_encoder_hidden_states
-            clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if do_classifier_free_guidance else clip_attention_mask
-
-        else:
-            clip_encoder_hidden_states_input = None
-            clip_attention_mask_input = None
+        if comfyui_progressbar:
+            pbar.update(1)
 
         # Check that sizes of mask, masked image and latents match
-        if num_channels_transformer == 12:
-            # default case for runwayml/stable-diffusion-inpainting
+        if num_channels_transformer != num_channels_latents:
             num_channels_mask = mask_latents.shape[1]
             num_channels_masked_image = masked_video_latents.shape[1]
             if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
@@ -975,45 +1350,89 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
                     f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
                     " `pipeline.transformer` or your `mask_image` or `image` input."
                 )
-        elif num_channels_transformer != 4:
-            raise ValueError(
-                f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
-            )
-        
-        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+
+        # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
-        # 6.1 Prepare micro-conditions.
+        # 9 create image_rotary_emb, style embedding & time ids
+        grid_height = height // 8 // self.transformer.config.patch_size
+        grid_width = width // 8 // self.transformer.config.patch_size
+        if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
+            base_size_width = 720 // 8 // self.transformer.config.patch_size
+            base_size_height = 480 // 8 // self.transformer.config.patch_size
+
+            grid_crops_coords = get_resize_crop_region_for_grid(
+                (grid_height, grid_width), base_size_width, base_size_height
+            )
+            image_rotary_emb = get_3d_rotary_pos_embed(
+                self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
+                temporal_size=latents.size(2), use_real=True,
+            )
+        else:
+            base_size = 512 // 8 // self.transformer.config.patch_size
+            grid_crops_coords = get_resize_crop_region_for_grid(
+                (grid_height, grid_width), base_size, base_size
+            )
+            image_rotary_emb = get_2d_rotary_pos_embed(
+                self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
+            )
+
+        # Get other hunyuan params
+        target_size = target_size or (height, width)
+        add_time_ids = list(original_size + target_size + crops_coords_top_left)
+        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+        style = torch.tensor([0], device=device)
+
+        if self.do_classifier_free_guidance:
+            add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
+            style = torch.cat([style] * 2, dim=0)
+
+        # To latents.device
+        add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
+            batch_size * num_images_per_prompt, 1
+        )
+        style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
+
+        # Get other pixart params
         added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
-        if self.transformer.config.sample_size == 128:
+        if self.transformer.config.get("sample_size", 64) == 128:
             resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
             aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
-            resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
-            aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
+            resolution = resolution.to(dtype=dtype, device=device)
+            aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
 
-            if do_classifier_free_guidance:
+            if self.do_classifier_free_guidance:
                 resolution = torch.cat([resolution, resolution], dim=0)
                 aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
-            
+
             added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
 
-        gc.collect()
-        torch.cuda.empty_cache()
-        torch.cuda.ipc_collect()
-        if self.enable_autocast_float8_transformer_flag:
-            origin_weight_dtype = self.transformer.dtype
-            self.transformer = self.transformer.to(torch.float8_e4m3fn)
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+            prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
+            if prompt_embeds_2 is not None:
+                prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
+                prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
+
+        # To latents.device
+        prompt_embeds = prompt_embeds.to(device=device)
+        prompt_attention_mask = prompt_attention_mask.to(device=device)
+        if prompt_embeds_2 is not None:
+            prompt_embeds_2 = prompt_embeds_2.to(device=device)
+            prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
 
         # 10. Denoising loop
         num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         self._num_timesteps = len(timesteps)
-        if comfyui_progressbar:
-            from comfy.utils import ProgressBar
-            pbar = ProgressBar(num_inference_steps)
         with self.progress_bar(total=num_inference_steps) as progress_bar:
             for i, t in enumerate(timesteps):
-                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
-                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+                if self.interrupt:
+                    continue
+
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                if hasattr(self.scheduler, "scale_model_input"):
+                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
                 if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
                     clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
@@ -1021,74 +1440,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
                 else:
                     clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
                     clip_attention_mask_actual_input = clip_attention_mask_input
+                
+                # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
+                t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
+                    dtype=latent_model_input.dtype
+                )
 
-                current_timestep = t
-                if not torch.is_tensor(current_timestep):
-                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
-                    # This would be a good case for the `match` statement (Python 3.10+)
-                    is_mps = latent_model_input.device.type == "mps"
-                    if isinstance(current_timestep, float):
-                        dtype = torch.float32 if is_mps else torch.float64
-                    else:
-                        dtype = torch.int32 if is_mps else torch.int64
-                    current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
-                elif len(current_timestep.shape) == 0:
-                    current_timestep = current_timestep[None].to(latent_model_input.device)
-                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
-                current_timestep = current_timestep.expand(latent_model_input.shape[0])
-
-                # predict noise model_output
+                # predict the noise residual
                 noise_pred = self.transformer(
                     latent_model_input,
+                    t_expand,
                     encoder_hidden_states=prompt_embeds,
-                    encoder_attention_mask=prompt_attention_mask,
-                    timestep=current_timestep,
-                    added_cond_kwargs=added_cond_kwargs,
+                    text_embedding_mask=prompt_attention_mask,
+                    encoder_hidden_states_t5=prompt_embeds_2,
+                    text_embedding_mask_t5=prompt_attention_mask_2,
+                    image_meta_size=add_time_ids,
+                    style=style,
+                    image_rotary_emb=image_rotary_emb,
                     inpaint_latents=inpaint_latents,
                     clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
                     clip_attention_mask=clip_attention_mask_actual_input,
+                    added_cond_kwargs=added_cond_kwargs,
                     return_dict=False,
                 )[0]
+                if noise_pred.size()[1] != self.vae.config.latent_channels:
+                    noise_pred, _ = noise_pred.chunk(2, dim=1)
 
                 # perform guidance
-                if do_classifier_free_guidance:
+                if self.do_classifier_free_guidance:
                     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                     noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
-                # learned sigma
-                noise_pred = noise_pred.chunk(2, dim=1)[0]
+                if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
 
-                # compute previous image: x_t -> x_t-1
+                # compute the previous noisy sample x_t -> x_t-1
                 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
-                if num_channels_transformer == 4:
+                if num_channels_transformer == num_channels_latents:
                     init_latents_proper = image_latents
                     init_mask = mask
                     if i < len(timesteps) - 1:
                         noise_timestep = timesteps[i + 1]
-                        init_latents_proper = self.scheduler.add_noise(
-                            init_latents_proper, noise, torch.tensor([noise_timestep])
-                        )
-                    
+                        if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+                            init_latents_proper = self.scheduler.scale_noise(
+                                init_latents_proper, torch.tensor([noise_timestep], noise)
+                            )
+                        else:
+                            init_latents_proper = self.scheduler.add_noise(
+                                init_latents_proper, noise, torch.tensor([noise_timestep])
+                            )
+                        
                     latents = (1 - init_mask) * init_latents_proper + init_mask * latents
 
-                # call the callback, if provided
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+                    prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
+                    negative_prompt_embeds_2 = callback_outputs.pop(
+                        "negative_prompt_embeds_2", negative_prompt_embeds_2
+                    )
+
                 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                     progress_bar.update()
-                    if callback is not None and i % callback_steps == 0:
-                        step_idx = i // getattr(self.scheduler, "order", 1)
-                        callback(step_idx, t, latents)
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
 
                 if comfyui_progressbar:
                     pbar.update(1)
 
-        if self.enable_autocast_float8_transformer_flag:
-            self.transformer = self.transformer.to("cpu", origin_weight_dtype)
-
-        gc.collect()
-        torch.cuda.empty_cache()
-        torch.cuda.ipc_collect()
-
         # Post-processing
         video = self.decode_latents(latents)
 
@@ -1096,7 +1524,10 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
         if output_type == "latent":
             video = torch.from_numpy(video)
 
+        # Offload all models
+        self.maybe_free_model_hooks()
+
         if not return_dict:
             return video
 
-        return EasyAnimatePipelineOutput(videos=video)
\ No newline at end of file
+        return EasyAnimatePipelineOutput(frames=video)
\ No newline at end of file