diff --git a/Dockerfile b/Dockerfile
index 3fcf866aa103015e06435cb6180633f9ea896544..aa608ed779d78f900852135b6703f336a921cdd4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -42,4 +42,4 @@ COPY ./datasets /content/datasets
 COPY ./reports /content/reports
 COPY ./requirements.txt /content/requirements.txt
 RUN pip install -r /content/requirements.txt
-WORKDIR /content
\ No newline at end of file
+WORKDIR /content
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..273bba3b0f85b937b18f37dab5d4eea78e2ff9b8
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,3 @@
+from .comfyui.comfyui_nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
+
+__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7206fd4ae080330ab65e6f5316ed2ffef48abd0
--- /dev/null
+++ b/app.py
@@ -0,0 +1,49 @@
+import time 
+import torch
+
+from cogvideox.api.api import infer_forward_api, update_diffusion_transformer_api, update_edition_api
+from cogvideox.ui.ui import ui_modelscope, ui_eas, ui
+
+if __name__ == "__main__":
+    # Choose the ui mode  
+    ui_mode = "normal"
+    
+    # Low gpu memory mode, this is used when the GPU memory is under 16GB
+    low_gpu_memory_mode = False
+    # Use torch.float16 if GPU does not support torch.bfloat16
+    # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
+    weight_dtype = torch.bfloat16
+
+    # Server ip
+    server_name = "0.0.0.0"
+    server_port = 7860
+
+    # Params below is used when ui_mode = "modelscope"
+    model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
+    # "Inpaint" or "Control"
+    model_type = "Inpaint"
+    # Save dir of this model
+    savedir_sample = "samples"
+
+    if ui_mode == "modelscope":
+        demo, controller = ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
+    elif ui_mode == "eas":
+        demo, controller = ui_eas(model_name, savedir_sample)
+    else:
+        demo, controller = ui(low_gpu_memory_mode, weight_dtype)
+
+    # launch gradio
+    app, _, _ = demo.queue(status_update_rate=1).launch(
+        server_name=server_name,
+        server_port=server_port,
+        prevent_thread_lock=True
+    )
+    
+    # launch api
+    infer_forward_api(None, app, controller)
+    update_diffusion_transformer_api(None, app, controller)
+    update_edition_api(None, app, controller)
+    
+    # not close the python
+    while True:
+        time.sleep(5)
\ No newline at end of file
diff --git a/asset/2.png b/asset/2.png
new file mode 100644
index 0000000000000000000000000000000000000000..5b1ff678fb51e9f2c38778a2b30c328b77332002
Binary files /dev/null and b/asset/2.png differ
diff --git a/asset/3.png b/asset/3.png
new file mode 100644
index 0000000000000000000000000000000000000000..faa84a0a0a0567f6dcbe0e451b426fe7765c10f8
Binary files /dev/null and b/asset/3.png differ
diff --git a/asset/4.png b/asset/4.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c7e5bb0fd130ccdae59d29341d72255bf03c9b3
Binary files /dev/null and b/asset/4.png differ
diff --git a/asset/5.png b/asset/5.png
new file mode 100644
index 0000000000000000000000000000000000000000..6a7af9420ad00587b484fca0b2607ad77d025e0c
Binary files /dev/null and b/asset/5.png differ
diff --git a/cogvideox/__init__.py b/cogvideox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cogvideox/api/api.py b/cogvideox/api/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..89405592a32909a1e3e8386b2b6e79111d024efd
--- /dev/null
+++ b/cogvideox/api/api.py
@@ -0,0 +1,173 @@
+import io
+import gc
+import base64
+import torch
+import gradio as gr
+import tempfile
+import hashlib
+import os
+
+from fastapi import FastAPI
+from io import BytesIO
+from PIL import Image
+
+# Function to encode a file to Base64
+def encode_file_to_base64(file_path):
+    with open(file_path, "rb") as file:
+        # Encode the data to Base64
+        file_base64 = base64.b64encode(file.read())
+        return file_base64
+
+def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
+    @app.post("/cogvideox_fun/update_edition")
+    def _update_edition_api(
+        datas: dict,
+    ):
+        edition = datas.get('edition', 'v2')
+
+        try:
+            controller.update_edition(
+                edition
+            )
+            comment = "Success"
+        except Exception as e:
+            torch.cuda.empty_cache()
+            comment = f"Error. error information is {str(e)}"
+
+        return {"message": comment}
+
+def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
+    @app.post("/cogvideox_fun/update_diffusion_transformer")
+    def _update_diffusion_transformer_api(
+        datas: dict,
+    ):
+        diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
+
+        try:
+            controller.update_diffusion_transformer(
+                diffusion_transformer_path
+            )
+            comment = "Success"
+        except Exception as e:
+            torch.cuda.empty_cache()
+            comment = f"Error. error information is {str(e)}"
+
+        return {"message": comment}
+
+def save_base64_video(base64_string):
+    video_data = base64.b64decode(base64_string)
+
+    md5_hash = hashlib.md5(video_data).hexdigest()
+    filename = f"{md5_hash}.mp4"  
+    
+    temp_dir = tempfile.gettempdir()
+    file_path = os.path.join(temp_dir, filename)
+
+    with open(file_path, 'wb') as video_file:
+        video_file.write(video_data)
+
+    return file_path
+
+def save_base64_image(base64_string):
+    video_data = base64.b64decode(base64_string)
+
+    md5_hash = hashlib.md5(video_data).hexdigest()
+    filename = f"{md5_hash}.jpg"  
+    
+    temp_dir = tempfile.gettempdir()
+    file_path = os.path.join(temp_dir, filename)
+
+    with open(file_path, 'wb') as video_file:
+        video_file.write(video_data)
+
+    return file_path
+
+def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
+    @app.post("/cogvideox_fun/infer_forward")
+    def _infer_forward_api(
+        datas: dict,
+    ):
+        base_model_path = datas.get('base_model_path', 'none')
+        lora_model_path = datas.get('lora_model_path', 'none')
+        lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
+        prompt_textbox = datas.get('prompt_textbox', None)
+        negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
+        sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
+        sample_step_slider = datas.get('sample_step_slider', 30)
+        resize_method = datas.get('resize_method', "Generate by")
+        width_slider = datas.get('width_slider', 672)
+        height_slider = datas.get('height_slider', 384)
+        base_resolution = datas.get('base_resolution', 512)
+        is_image = datas.get('is_image', False)
+        generation_method = datas.get('generation_method', False)
+        length_slider = datas.get('length_slider', 49)
+        overlap_video_length = datas.get('overlap_video_length', 4)
+        partial_video_length = datas.get('partial_video_length', 72)
+        cfg_scale_slider = datas.get('cfg_scale_slider', 6)
+        start_image = datas.get('start_image', None)
+        end_image = datas.get('end_image', None)
+        validation_video = datas.get('validation_video', None)
+        validation_video_mask = datas.get('validation_video_mask', None)
+        control_video = datas.get('control_video', None)
+        denoise_strength = datas.get('denoise_strength', 0.70)
+        seed_textbox = datas.get("seed_textbox", 43)
+
+        generation_method = "Image Generation" if is_image else generation_method
+
+        if start_image is not None:
+            start_image = base64.b64decode(start_image)
+            start_image = [Image.open(BytesIO(start_image))]
+        
+        if end_image is not None:
+            end_image = base64.b64decode(end_image)
+            end_image = [Image.open(BytesIO(end_image))]
+
+        if validation_video is not None:
+            validation_video = save_base64_video(validation_video)
+
+        if validation_video_mask is not None:
+            validation_video_mask = save_base64_image(validation_video_mask)
+
+        if control_video is not None:
+            control_video = save_base64_video(control_video)
+        
+        try:
+            save_sample_path, comment = controller.generate(
+                "",
+                base_model_path,
+                lora_model_path, 
+                lora_alpha_slider,
+                prompt_textbox, 
+                negative_prompt_textbox, 
+                sampler_dropdown, 
+                sample_step_slider, 
+                resize_method,
+                width_slider, 
+                height_slider, 
+                base_resolution,
+                generation_method,
+                length_slider, 
+                overlap_video_length, 
+                partial_video_length, 
+                cfg_scale_slider, 
+                start_image,
+                end_image,
+                validation_video,
+                validation_video_mask, 
+                control_video, 
+                denoise_strength,
+                seed_textbox,
+                is_api = True,
+            )
+        except Exception as e:
+            gc.collect()
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+            save_sample_path = ""
+            comment = f"Error. error information is {str(e)}"
+            return {"message": comment}
+        
+        if save_sample_path != "":
+            return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
+        else:
+            return {"message": comment, "save_sample_path": save_sample_path}
\ No newline at end of file
diff --git a/cogvideox/api/post_infer.py b/cogvideox/api/post_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..57f6ffe16c8ea09fa5d3aefcb5630e0da1bdcf4b
--- /dev/null
+++ b/cogvideox/api/post_infer.py
@@ -0,0 +1,89 @@
+import base64
+import json
+import sys
+import time
+from datetime import datetime
+from io import BytesIO
+
+import cv2
+import requests
+import base64
+
+
+def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
+    datas = json.dumps({
+        "diffusion_transformer_path": diffusion_transformer_path
+    })
+    r = requests.post(f'{url}/cogvideox_fun/update_diffusion_transformer', data=datas, timeout=1500)
+    data = r.content.decode('utf-8')
+    return data
+
+def post_update_edition(edition, url='http://0.0.0.0:7860'):
+    datas = json.dumps({
+        "edition": edition
+    })
+    r = requests.post(f'{url}/cogvideox_fun/update_edition', data=datas, timeout=1500)
+    data = r.content.decode('utf-8')
+    return data
+
+def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
+    datas = json.dumps({
+        "base_model_path": "none",
+        "motion_module_path": "none",
+        "lora_model_path": "none", 
+        "lora_alpha_slider": 0.55, 
+        "prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+        "negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ", 
+        "sampler_dropdown": "Euler", 
+        "sample_step_slider": 50, 
+        "width_slider": 672, 
+        "height_slider": 384, 
+        "generation_method": "Video Generation",
+        "length_slider": length_slider,
+        "cfg_scale_slider": 6,
+        "seed_textbox": 43,
+    })
+    r = requests.post(f'{url}/cogvideox_fun/infer_forward', data=datas, timeout=1500)
+    data = r.content.decode('utf-8')
+    return data
+
+if __name__ == '__main__':
+    # initiate time
+    now_date    = datetime.now()
+    time_start  = time.time()  
+    
+    # -------------------------- #
+    #  Step 1: update edition
+    # -------------------------- #
+    diffusion_transformer_path = "models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+    outputs = post_diffusion_transformer(diffusion_transformer_path)
+    print('Output update edition: ', outputs)
+
+    # -------------------------- #
+    #  Step 2: infer
+    # -------------------------- #
+    # "Video Generation" and "Image Generation"
+    generation_method = "Video Generation"
+    length_slider = 49
+    outputs = post_infer(generation_method, length_slider)
+    
+    # Get decoded data
+    outputs = json.loads(outputs)
+    base64_encoding = outputs["base64_encoding"]
+    decoded_data = base64.b64decode(base64_encoding)
+
+    is_image = True if generation_method == "Image Generation" else False
+    if is_image or length_slider == 1:
+        file_path = "1.png"
+    else:
+        file_path = "1.mp4"
+    with open(file_path, "wb") as file:
+        file.write(decoded_data)
+        
+    # End of record time
+    # The calculated time difference is the execution time of the program, expressed in seconds / s
+    time_end = time.time()  
+    time_sum = (time_end - time_start) % 60 
+    print('# --------------------------------------------------------- #')
+    print(f'#   Total expenditure: {time_sum}s')
+    print('# --------------------------------------------------------- #')
\ No newline at end of file
diff --git a/cogvideox/data/bucket_sampler.py b/cogvideox/data/bucket_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c5fded15beeb7f53bbf310a571f776cba932c52
--- /dev/null
+++ b/cogvideox/data/bucket_sampler.py
@@ -0,0 +1,379 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
+                    Sized, TypeVar, Union)
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from torch.utils.data import BatchSampler, Dataset, Sampler
+
+ASPECT_RATIO_512 = {
+    '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
+    '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
+    '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
+    '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
+    '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
+    '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
+    '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
+    '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
+    '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
+    '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
+}
+ASPECT_RATIO_RANDOM_CROP_512 = {
+    '0.42': [320.0, 768.0], '0.5': [352.0, 704.0], 
+    '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0], 
+    '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], 
+    '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0], 
+    '2.0': [704.0, 352.0],  '2.4': [768.0, 320.0]
+}
+ASPECT_RATIO_RANDOM_CROP_PROB = [
+    1, 2,
+    4, 4, 4, 4,
+    8, 8, 8,
+    4, 4, 4, 4,
+    2, 1
+]
+ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
+
+def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
+    aspect_ratio = height / width
+    closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
+    return ratios[closest_ratio], float(closest_ratio)
+
+def get_image_size_without_loading(path):
+    with Image.open(path) as img:
+        return img.size  # (width, height)
+
+class RandomSampler(Sampler[int]):
+    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+
+    If with replacement, then user can specify :attr:`num_samples` to draw.
+
+    Args:
+        data_source (Dataset): dataset to sample from
+        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
+        num_samples (int): number of samples to draw, default=`len(dataset)`.
+        generator (Generator): Generator used in sampling.
+    """
+
+    data_source: Sized
+    replacement: bool
+
+    def __init__(self, data_source: Sized, replacement: bool = False,
+                 num_samples: Optional[int] = None, generator=None) -> None:
+        self.data_source = data_source
+        self.replacement = replacement
+        self._num_samples = num_samples
+        self.generator = generator
+        self._pos_start = 0
+
+        if not isinstance(self.replacement, bool):
+            raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
+
+        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+            raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
+
+    @property
+    def num_samples(self) -> int:
+        # dataset size might change at runtime
+        if self._num_samples is None:
+            return len(self.data_source)
+        return self._num_samples
+
+    def __iter__(self) -> Iterator[int]:
+        n = len(self.data_source)
+        if self.generator is None:
+            seed = int(torch.empty((), dtype=torch.int64).random_().item())
+            generator = torch.Generator()
+            generator.manual_seed(seed)
+        else:
+            generator = self.generator
+
+        if self.replacement:
+            for _ in range(self.num_samples // 32):
+                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
+            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
+        else:
+            for _ in range(self.num_samples // n):
+                xx = torch.randperm(n, generator=generator).tolist()
+                if self._pos_start >= n:
+                    self._pos_start = 0
+                print("xx top 10", xx[:10], self._pos_start)
+                for idx in range(self._pos_start, n):
+                    yield xx[idx]
+                    self._pos_start = (self._pos_start + 1) % n
+                self._pos_start = 0
+            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
+
+    def __len__(self) -> int:
+        return self.num_samples
+
+class AspectRatioBatchImageSampler(BatchSampler):
+    """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+    Args:
+        sampler (Sampler): Base sampler.
+        dataset (Dataset): Dataset providing data information.
+        batch_size (int): Size of mini-batch.
+        drop_last (bool): If ``True``, the sampler will drop the last batch if
+            its size would be less than ``batch_size``.
+        aspect_ratios (dict): The predefined aspect ratios.
+    """
+    def __init__(
+        self,
+        sampler: Sampler,
+        dataset: Dataset,
+        batch_size: int,
+        train_folder: str = None,
+        aspect_ratios: dict = ASPECT_RATIO_512,
+        drop_last: bool = False,
+        config=None,
+        **kwargs
+    ) -> None:
+        if not isinstance(sampler, Sampler):
+            raise TypeError('sampler should be an instance of ``Sampler``, '
+                            f'but got {sampler}')
+        if not isinstance(batch_size, int) or batch_size <= 0:
+            raise ValueError('batch_size should be a positive integer value, '
+                             f'but got batch_size={batch_size}')
+        self.sampler = sampler
+        self.dataset = dataset
+        self.train_folder = train_folder
+        self.batch_size = batch_size
+        self.aspect_ratios = aspect_ratios
+        self.drop_last = drop_last
+        self.config = config
+        # buckets for each aspect ratio 
+        self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
+        # [str(k) for k, v in aspect_ratios] 
+        self.current_available_bucket_keys = list(aspect_ratios.keys())
+
+    def __iter__(self):
+        for idx in self.sampler:
+            try:
+                image_dict = self.dataset[idx]
+
+                width, height = image_dict.get("width", None), image_dict.get("height", None)
+                if width is None or height is None:
+                    image_id, name = image_dict['file_path'], image_dict['text']
+                    if self.train_folder is None:
+                        image_dir = image_id
+                    else:
+                        image_dir = os.path.join(self.train_folder, image_id)
+
+                    width, height = get_image_size_without_loading(image_dir)
+
+                    ratio = height / width # self.dataset[idx]
+                else:
+                    height = int(height)
+                    width = int(width)
+                    ratio = height / width # self.dataset[idx]
+            except Exception as e:
+                print(e)
+                continue
+            # find the closest aspect ratio
+            closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+            if closest_ratio not in self.current_available_bucket_keys:
+                continue
+            bucket = self._aspect_ratio_buckets[closest_ratio]
+            bucket.append(idx)
+            # yield a batch of indices in the same aspect ratio group
+            if len(bucket) == self.batch_size:
+                yield bucket[:]
+                del bucket[:]
+
+class AspectRatioBatchSampler(BatchSampler):
+    """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+    Args:
+        sampler (Sampler): Base sampler.
+        dataset (Dataset): Dataset providing data information.
+        batch_size (int): Size of mini-batch.
+        drop_last (bool): If ``True``, the sampler will drop the last batch if
+            its size would be less than ``batch_size``.
+        aspect_ratios (dict): The predefined aspect ratios.
+    """
+    def __init__(
+        self,
+        sampler: Sampler,
+        dataset: Dataset,
+        batch_size: int,
+        video_folder: str = None,
+        train_data_format: str = "webvid",
+        aspect_ratios: dict = ASPECT_RATIO_512,
+        drop_last: bool = False,
+        config=None,
+        **kwargs
+    ) -> None:
+        if not isinstance(sampler, Sampler):
+            raise TypeError('sampler should be an instance of ``Sampler``, '
+                            f'but got {sampler}')
+        if not isinstance(batch_size, int) or batch_size <= 0:
+            raise ValueError('batch_size should be a positive integer value, '
+                             f'but got batch_size={batch_size}')
+        self.sampler = sampler
+        self.dataset = dataset
+        self.video_folder = video_folder
+        self.train_data_format = train_data_format
+        self.batch_size = batch_size
+        self.aspect_ratios = aspect_ratios
+        self.drop_last = drop_last
+        self.config = config
+        # buckets for each aspect ratio 
+        self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
+        # [str(k) for k, v in aspect_ratios] 
+        self.current_available_bucket_keys = list(aspect_ratios.keys())
+
+    def __iter__(self):
+        for idx in self.sampler:
+            try:
+                video_dict = self.dataset[idx]
+                width, more = video_dict.get("width", None), video_dict.get("height", None)
+
+                if width is None or height is None:
+                    if self.train_data_format == "normal":
+                        video_id, name = video_dict['file_path'], video_dict['text']
+                        if self.video_folder is None:
+                            video_dir = video_id
+                        else:
+                            video_dir = os.path.join(self.video_folder, video_id)
+                    else:
+                        videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+                        video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
+                    cap = cv2.VideoCapture(video_dir)
+
+                    # 获取视频尺寸
+                    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))   # 浮点数转换为整数
+                    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))  # 浮点数转换为整数
+                    
+                    ratio = height / width # self.dataset[idx]
+                else:
+                    height = int(height)
+                    width = int(width)
+                    ratio = height / width # self.dataset[idx]
+            except Exception as e:
+                print(e)
+                continue
+            # find the closest aspect ratio
+            closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+            if closest_ratio not in self.current_available_bucket_keys:
+                continue
+            bucket = self._aspect_ratio_buckets[closest_ratio]
+            bucket.append(idx)
+            # yield a batch of indices in the same aspect ratio group
+            if len(bucket) == self.batch_size:
+                yield bucket[:]
+                del bucket[:]
+
+class AspectRatioBatchImageVideoSampler(BatchSampler):
+    """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+    Args:
+        sampler (Sampler): Base sampler.
+        dataset (Dataset): Dataset providing data information.
+        batch_size (int): Size of mini-batch.
+        drop_last (bool): If ``True``, the sampler will drop the last batch if
+            its size would be less than ``batch_size``.
+        aspect_ratios (dict): The predefined aspect ratios.
+    """
+
+    def __init__(self,
+                 sampler: Sampler,
+                 dataset: Dataset,
+                 batch_size: int,
+                 train_folder: str = None,
+                 aspect_ratios: dict = ASPECT_RATIO_512,
+                 drop_last: bool = False
+                ) -> None:
+        if not isinstance(sampler, Sampler):
+            raise TypeError('sampler should be an instance of ``Sampler``, '
+                            f'but got {sampler}')
+        if not isinstance(batch_size, int) or batch_size <= 0:
+            raise ValueError('batch_size should be a positive integer value, '
+                             f'but got batch_size={batch_size}')
+        self.sampler = sampler
+        self.dataset = dataset
+        self.train_folder = train_folder
+        self.batch_size = batch_size
+        self.aspect_ratios = aspect_ratios
+        self.drop_last = drop_last
+
+        # buckets for each aspect ratio
+        self.current_available_bucket_keys = list(aspect_ratios.keys())
+        self.bucket = {
+            'image':{ratio: [] for ratio in aspect_ratios}, 
+            'video':{ratio: [] for ratio in aspect_ratios}
+        }
+
+    def __iter__(self):
+        for idx in self.sampler:
+            content_type = self.dataset[idx].get('type', 'image')
+            if content_type == 'image':
+                try:
+                    image_dict = self.dataset[idx]
+
+                    width, height = image_dict.get("width", None), image_dict.get("height", None)
+                    if width is None or height is None:
+                        image_id, name = image_dict['file_path'], image_dict['text']
+                        if self.train_folder is None:
+                            image_dir = image_id
+                        else:
+                            image_dir = os.path.join(self.train_folder, image_id)
+
+                        width, height = get_image_size_without_loading(image_dir)
+
+                        ratio = height / width # self.dataset[idx]
+                    else:
+                        height = int(height)
+                        width = int(width)
+                        ratio = height / width # self.dataset[idx]
+                except Exception as e:
+                    print(e)
+                    continue
+                # find the closest aspect ratio
+                closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+                if closest_ratio not in self.current_available_bucket_keys:
+                    continue
+                bucket = self.bucket['image'][closest_ratio]
+                bucket.append(idx)
+                # yield a batch of indices in the same aspect ratio group
+                if len(bucket) == self.batch_size:
+                    yield bucket[:]
+                    del bucket[:]
+            else:
+                try:
+                    video_dict = self.dataset[idx]
+                    width, height = video_dict.get("width", None), video_dict.get("height", None)
+
+                    if width is None or height is None:
+                        video_id, name = video_dict['file_path'], video_dict['text']
+                        if self.train_folder is None:
+                            video_dir = video_id
+                        else:
+                            video_dir = os.path.join(self.train_folder, video_id)
+                        cap = cv2.VideoCapture(video_dir)
+
+                        # 获取视频尺寸
+                        width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))   # 浮点数转换为整数
+                        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))  # 浮点数转换为整数
+                        
+                        ratio = height / width # self.dataset[idx]
+                    else:
+                        height = int(height)
+                        width = int(width)
+                        ratio = height / width # self.dataset[idx]
+                except Exception as e:
+                    print(e)
+                    continue
+                # find the closest aspect ratio
+                closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
+                if closest_ratio not in self.current_available_bucket_keys:
+                    continue
+                bucket = self.bucket['video'][closest_ratio]
+                bucket.append(idx)
+                # yield a batch of indices in the same aspect ratio group
+                if len(bucket) == self.batch_size:
+                    yield bucket[:]
+                    del bucket[:]
\ No newline at end of file
diff --git a/cogvideox/data/dataset_image.py b/cogvideox/data/dataset_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..098d49a4044f8daa351cd01b4cb1ec5415412e80
--- /dev/null
+++ b/cogvideox/data/dataset_image.py
@@ -0,0 +1,76 @@
+import json
+import os
+import random
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+
+
+class CC15M(Dataset):
+    def __init__(
+            self,
+            json_path, 
+            video_folder=None,
+            resolution=512,
+            enable_bucket=False,
+        ):
+        print(f"loading annotations from {json_path} ...")
+        self.dataset = json.load(open(json_path, 'r'))
+        self.length = len(self.dataset)
+        print(f"data scale: {self.length}")
+        
+        self.enable_bucket = enable_bucket
+        self.video_folder = video_folder
+
+        resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
+        self.pixel_transforms = transforms.Compose([
+            transforms.Resize(resolution[0]),
+            transforms.CenterCrop(resolution),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+        ])
+    
+    def get_batch(self, idx):
+        video_dict = self.dataset[idx]
+        video_id, name = video_dict['file_path'], video_dict['text']
+
+        if self.video_folder is None:
+            video_dir = video_id
+        else:
+            video_dir = os.path.join(self.video_folder, video_id)
+
+        pixel_values = Image.open(video_dir).convert("RGB")
+        return pixel_values, name
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        while True:
+            try:
+                pixel_values, name = self.get_batch(idx)
+                break
+            except Exception as e:
+                print(e)
+                idx = random.randint(0, self.length-1)
+
+        if not self.enable_bucket:
+            pixel_values = self.pixel_transforms(pixel_values)
+        else:
+            pixel_values = np.array(pixel_values)
+
+        sample = dict(pixel_values=pixel_values, text=name)
+        return sample
+
+if __name__ == "__main__":
+    dataset = CC15M(
+        csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
+        resolution=512,
+    )
+    
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
+    for idx, batch in enumerate(dataloader):
+        print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/cogvideox/data/dataset_image_video.py b/cogvideox/data/dataset_image_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..f714d694aa0a1f13b1d1ee30a98cff5336f027ec
--- /dev/null
+++ b/cogvideox/data/dataset_image_video.py
@@ -0,0 +1,550 @@
+import csv
+import io
+import json
+import math
+import os
+import random
+from threading import Thread
+
+import albumentations
+import cv2
+import gc
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+
+from func_timeout import func_timeout, FunctionTimedOut
+from decord import VideoReader
+from PIL import Image
+from torch.utils.data import BatchSampler, Sampler
+from torch.utils.data.dataset import Dataset
+from contextlib import contextmanager
+
+VIDEO_READER_TIMEOUT = 20
+
+def get_random_mask(shape):
+    f, c, h, w = shape
+    
+    if f != 1:
+        mask_index = np.random.choice([0, 1, 2, 3, 4], p = [0.05, 0.3, 0.3, 0.3, 0.05]) # np.random.randint(0, 5)
+    else:
+        mask_index = np.random.choice([0, 1], p = [0.2, 0.8]) # np.random.randint(0, 2)
+    mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
+
+    if mask_index == 0:
+        center_x = torch.randint(0, w, (1,)).item()
+        center_y = torch.randint(0, h, (1,)).item()
+        block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()  # 方块的宽度范围
+        block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()  # 方块的高度范围
+
+        start_x = max(center_x - block_size_x // 2, 0)
+        end_x = min(center_x + block_size_x // 2, w)
+        start_y = max(center_y - block_size_y // 2, 0)
+        end_y = min(center_y + block_size_y // 2, h)
+        mask[:, :, start_y:end_y, start_x:end_x] = 1
+    elif mask_index == 1:
+        mask[:, :, :, :] = 1
+    elif mask_index == 2:
+        mask_frame_index = np.random.randint(1, 5)
+        mask[mask_frame_index:, :, :, :] = 1
+    elif mask_index == 3:
+        mask_frame_index = np.random.randint(1, 5)
+        mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
+    elif mask_index == 4:
+        center_x = torch.randint(0, w, (1,)).item()
+        center_y = torch.randint(0, h, (1,)).item()
+        block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()  # 方块的宽度范围
+        block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()  # 方块的高度范围
+
+        start_x = max(center_x - block_size_x // 2, 0)
+        end_x = min(center_x + block_size_x // 2, w)
+        start_y = max(center_y - block_size_y // 2, 0)
+        end_y = min(center_y + block_size_y // 2, h)
+
+        mask_frame_before = np.random.randint(0, f // 2)
+        mask_frame_after = np.random.randint(f // 2, f)
+        mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
+    else:
+        raise ValueError(f"The mask_index {mask_index} is not define")
+    return mask
+
+class ImageVideoSampler(BatchSampler):
+    """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
+
+    Args:
+        sampler (Sampler): Base sampler.
+        dataset (Dataset): Dataset providing data information.
+        batch_size (int): Size of mini-batch.
+        drop_last (bool): If ``True``, the sampler will drop the last batch if
+            its size would be less than ``batch_size``.
+        aspect_ratios (dict): The predefined aspect ratios.
+    """
+
+    def __init__(self,
+                 sampler: Sampler,
+                 dataset: Dataset,
+                 batch_size: int,
+                 drop_last: bool = False
+                ) -> None:
+        if not isinstance(sampler, Sampler):
+            raise TypeError('sampler should be an instance of ``Sampler``, '
+                            f'but got {sampler}')
+        if not isinstance(batch_size, int) or batch_size <= 0:
+            raise ValueError('batch_size should be a positive integer value, '
+                             f'but got batch_size={batch_size}')
+        self.sampler = sampler
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+
+        # buckets for each aspect ratio
+        self.bucket = {'image':[], 'video':[]}
+
+    def __iter__(self):
+        for idx in self.sampler:
+            content_type = self.dataset.dataset[idx].get('type', 'image')
+            self.bucket[content_type].append(idx)
+
+            # yield a batch of indices in the same aspect ratio group
+            if len(self.bucket['video']) == self.batch_size:
+                bucket = self.bucket['video']
+                yield bucket[:]
+                del bucket[:]
+            elif len(self.bucket['image']) == self.batch_size:
+                bucket = self.bucket['image']
+                yield bucket[:]
+                del bucket[:]
+
+@contextmanager
+def VideoReader_contextmanager(*args, **kwargs):
+    vr = VideoReader(*args, **kwargs)
+    try:
+        yield vr
+    finally:
+        del vr
+        gc.collect()
+
+def get_video_reader_batch(video_reader, batch_index):
+    frames = video_reader.get_batch(batch_index).asnumpy()
+    return frames
+
+def resize_frame(frame, target_short_side):
+    h, w, _ = frame.shape
+    if h < w:
+        if target_short_side > h:
+            return frame
+        new_h = target_short_side
+        new_w = int(target_short_side * w / h)
+    else:
+        if target_short_side > w:
+            return frame
+        new_w = target_short_side
+        new_h = int(target_short_side * h / w)
+    
+    resized_frame = cv2.resize(frame, (new_w, new_h))
+    return resized_frame
+
+class ImageVideoDataset(Dataset):
+    def __init__(
+            self,
+            ann_path, data_root=None,
+            video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
+            image_sample_size=512,
+            video_repeat=0,
+            text_drop_ratio=-1,
+            enable_bucket=False,
+            video_length_drop_start=0.1, 
+            video_length_drop_end=0.9,
+            enable_inpaint=False,
+        ):
+        # Loading annotations from files
+        print(f"loading annotations from {ann_path} ...")
+        if ann_path.endswith('.csv'):
+            with open(ann_path, 'r') as csvfile:
+                dataset = list(csv.DictReader(csvfile))
+        elif ann_path.endswith('.json'):
+            dataset = json.load(open(ann_path))
+    
+        self.data_root = data_root
+
+        # It's used to balance num of images and videos.
+        self.dataset = []
+        for data in dataset:
+            if data.get('type', 'image') != 'video':
+                self.dataset.append(data)
+        if video_repeat > 0:
+            for _ in range(video_repeat):
+                for data in dataset:
+                    if data.get('type', 'image') == 'video':
+                        self.dataset.append(data)
+        del dataset
+
+        self.length = len(self.dataset)
+        print(f"data scale: {self.length}")
+        # TODO: enable bucket training
+        self.enable_bucket = enable_bucket
+        self.text_drop_ratio = text_drop_ratio
+        self.enable_inpaint  = enable_inpaint
+
+        self.video_length_drop_start = video_length_drop_start
+        self.video_length_drop_end = video_length_drop_end
+
+        # Video params
+        self.video_sample_stride    = video_sample_stride
+        self.video_sample_n_frames  = video_sample_n_frames
+        self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
+        self.video_transforms = transforms.Compose(
+            [
+                transforms.Resize(min(self.video_sample_size)),
+                transforms.CenterCrop(self.video_sample_size),
+                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+            ]
+        )
+
+        # Image params
+        self.image_sample_size  = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
+        self.image_transforms   = transforms.Compose([
+            transforms.Resize(min(self.image_sample_size)),
+            transforms.CenterCrop(self.image_sample_size),
+            transforms.ToTensor(),
+            transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
+        ])
+
+        self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
+
+    def get_batch(self, idx):
+        data_info = self.dataset[idx % len(self.dataset)]
+        
+        if data_info.get('type', 'image')=='video':
+            video_id, text = data_info['file_path'], data_info['text']
+
+            if self.data_root is None:
+                video_dir = video_id
+            else:
+                video_dir = os.path.join(self.data_root, video_id)
+
+            with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+                min_sample_n_frames = min(
+                    self.video_sample_n_frames, 
+                    int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
+                )
+                if min_sample_n_frames == 0:
+                    raise ValueError(f"No Frames in video.")
+
+                video_length = int(self.video_length_drop_end * len(video_reader))
+                clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
+                start_idx   = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
+                batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
+
+                try:
+                    sample_args = (video_reader, batch_index)
+                    pixel_values = func_timeout(
+                        VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+                    )
+                    resized_frames = []
+                    for i in range(len(pixel_values)):
+                        frame = pixel_values[i]
+                        resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
+                        resized_frames.append(resized_frame)
+                    pixel_values = np.array(resized_frames)
+                except FunctionTimedOut:
+                    raise ValueError(f"Read {idx} timeout.")
+                except Exception as e:
+                    raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+                if not self.enable_bucket:
+                    pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+                    del video_reader
+                else:
+                    pixel_values = pixel_values
+
+                if not self.enable_bucket:
+                    pixel_values = self.video_transforms(pixel_values)
+                
+                # Random use no text generation
+                if random.random() < self.text_drop_ratio:
+                    text = ''
+            return pixel_values, text, 'video'
+        else:
+            image_path, text = data_info['file_path'], data_info['text']
+            if self.data_root is not None:
+                image_path = os.path.join(self.data_root, image_path)
+            image = Image.open(image_path).convert('RGB')
+            if not self.enable_bucket:
+                image = self.image_transforms(image).unsqueeze(0)
+            else:
+                image = np.expand_dims(np.array(image), 0)
+            if random.random() < self.text_drop_ratio:
+                text = ''
+            return image, text, 'image'
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        data_info = self.dataset[idx % len(self.dataset)]
+        data_type = data_info.get('type', 'image')
+        while True:
+            sample = {}
+            try:
+                data_info_local = self.dataset[idx % len(self.dataset)]
+                data_type_local = data_info_local.get('type', 'image')
+                if data_type_local != data_type:
+                    raise ValueError("data_type_local != data_type")
+
+                pixel_values, name, data_type = self.get_batch(idx)
+                sample["pixel_values"] = pixel_values
+                sample["text"] = name
+                sample["data_type"] = data_type
+                sample["idx"] = idx
+                
+                if len(sample) > 0:
+                    break
+            except Exception as e:
+                print(e, self.dataset[idx % len(self.dataset)])
+                idx = random.randint(0, self.length-1)
+
+        if self.enable_inpaint and not self.enable_bucket:
+            mask = get_random_mask(pixel_values.size())
+            mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+            sample["mask_pixel_values"] = mask_pixel_values
+            sample["mask"] = mask
+
+            clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
+            clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
+            sample["clip_pixel_values"] = clip_pixel_values
+
+            ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
+            if (mask == 1).all():
+                ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
+            sample["ref_pixel_values"] = ref_pixel_values
+
+        return sample
+
+
+class ImageVideoControlDataset(Dataset):
+    def __init__(
+            self,
+            ann_path, data_root=None,
+            video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
+            image_sample_size=512,
+            video_repeat=0,
+            text_drop_ratio=-1,
+            enable_bucket=False,
+            video_length_drop_start=0.1, 
+            video_length_drop_end=0.9,
+            enable_inpaint=False,
+    ):
+        # Loading annotations from files
+        print(f"loading annotations from {ann_path} ...")
+        if ann_path.endswith('.csv'):
+            with open(ann_path, 'r') as csvfile:
+                dataset = list(csv.DictReader(csvfile))
+        elif ann_path.endswith('.json'):
+            dataset = json.load(open(ann_path))
+    
+        self.data_root = data_root
+
+        # It's used to balance num of images and videos.
+        self.dataset = []
+        for data in dataset:
+            if data.get('type', 'image') != 'video':
+                self.dataset.append(data)
+        if video_repeat > 0:
+            for _ in range(video_repeat):
+                for data in dataset:
+                    if data.get('type', 'image') == 'video':
+                        self.dataset.append(data)
+        del dataset
+
+        self.length = len(self.dataset)
+        print(f"data scale: {self.length}")
+        # TODO: enable bucket training
+        self.enable_bucket = enable_bucket
+        self.text_drop_ratio = text_drop_ratio
+        self.enable_inpaint  = enable_inpaint
+
+        self.video_length_drop_start = video_length_drop_start
+        self.video_length_drop_end = video_length_drop_end
+
+        # Video params
+        self.video_sample_stride    = video_sample_stride
+        self.video_sample_n_frames  = video_sample_n_frames
+        self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
+        self.video_transforms = transforms.Compose(
+            [
+                transforms.Resize(min(self.video_sample_size)),
+                transforms.CenterCrop(self.video_sample_size),
+                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+            ]
+        )
+
+        # Image params
+        self.image_sample_size  = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
+        self.image_transforms   = transforms.Compose([
+            transforms.Resize(min(self.image_sample_size)),
+            transforms.CenterCrop(self.image_sample_size),
+            transforms.ToTensor(),
+            transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
+        ])
+
+        self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
+    
+    def get_batch(self, idx):
+        data_info = self.dataset[idx % len(self.dataset)]
+        video_id, text = data_info['file_path'], data_info['text']
+
+        if data_info.get('type', 'image')=='video':
+            if self.data_root is None:
+                video_dir = video_id
+            else:
+                video_dir = os.path.join(self.data_root, video_id)
+
+            with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+                min_sample_n_frames = min(
+                    self.video_sample_n_frames, 
+                    int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
+                )
+                if min_sample_n_frames == 0:
+                    raise ValueError(f"No Frames in video.")
+
+                video_length = int(self.video_length_drop_end * len(video_reader))
+                clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
+                start_idx   = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
+                batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
+
+                try:
+                    sample_args = (video_reader, batch_index)
+                    pixel_values = func_timeout(
+                        VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+                    )
+                    resized_frames = []
+                    for i in range(len(pixel_values)):
+                        frame = pixel_values[i]
+                        resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
+                        resized_frames.append(resized_frame)
+                    pixel_values = np.array(resized_frames)
+                except FunctionTimedOut:
+                    raise ValueError(f"Read {idx} timeout.")
+                except Exception as e:
+                    raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+                if not self.enable_bucket:
+                    pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+                    del video_reader
+                else:
+                    pixel_values = pixel_values
+
+                if not self.enable_bucket:
+                    pixel_values = self.video_transforms(pixel_values)
+                
+                # Random use no text generation
+                if random.random() < self.text_drop_ratio:
+                    text = ''
+
+            control_video_id = data_info['control_file_path']
+
+            if self.data_root is None:
+                control_video_id = control_video_id
+            else:
+                control_video_id = os.path.join(self.data_root, control_video_id)
+            
+            with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
+                try:
+                    sample_args = (control_video_reader, batch_index)
+                    control_pixel_values = func_timeout(
+                        VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+                    )
+                    resized_frames = []
+                    for i in range(len(control_pixel_values)):
+                        frame = control_pixel_values[i]
+                        resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
+                        resized_frames.append(resized_frame)
+                    control_pixel_values = np.array(resized_frames)
+                except FunctionTimedOut:
+                    raise ValueError(f"Read {idx} timeout.")
+                except Exception as e:
+                    raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+                if not self.enable_bucket:
+                    control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
+                    control_pixel_values = control_pixel_values / 255.
+                    del control_video_reader
+                else:
+                    control_pixel_values = control_pixel_values
+
+                if not self.enable_bucket:
+                    control_pixel_values = self.video_transforms(control_pixel_values)
+            return pixel_values, control_pixel_values, text, "video"
+        else:
+            image_path, text = data_info['file_path'], data_info['text']
+            if self.data_root is not None:
+                image_path = os.path.join(self.data_root, image_path)
+            image = Image.open(image_path).convert('RGB')
+            if not self.enable_bucket:
+                image = self.image_transforms(image).unsqueeze(0)
+            else:
+                image = np.expand_dims(np.array(image), 0)
+
+            if random.random() < self.text_drop_ratio:
+                text = ''
+
+            control_image_id = data_info['control_file_path']
+
+            if self.data_root is None:
+                control_image_id = control_image_id
+            else:
+                control_image_id = os.path.join(self.data_root, control_image_id)
+
+            control_image = Image.open(control_image_id).convert('RGB')
+            if not self.enable_bucket:
+                control_image = self.image_transforms(control_image).unsqueeze(0)
+            else:
+                control_image = np.expand_dims(np.array(control_image), 0)
+            return image, control_image, text, 'image'
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        data_info = self.dataset[idx % len(self.dataset)]
+        data_type = data_info.get('type', 'image')
+        while True:
+            sample = {}
+            try:
+                data_info_local = self.dataset[idx % len(self.dataset)]
+                data_type_local = data_info_local.get('type', 'image')
+                if data_type_local != data_type:
+                    raise ValueError("data_type_local != data_type")
+
+                pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
+                sample["pixel_values"] = pixel_values
+                sample["control_pixel_values"] = control_pixel_values
+                sample["text"] = name
+                sample["data_type"] = data_type
+                sample["idx"] = idx
+                
+                if len(sample) > 0:
+                    break
+            except Exception as e:
+                print(e, self.dataset[idx % len(self.dataset)])
+                idx = random.randint(0, self.length-1)
+
+        if self.enable_inpaint and not self.enable_bucket:
+            mask = get_random_mask(pixel_values.size())
+            mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+            sample["mask_pixel_values"] = mask_pixel_values
+            sample["mask"] = mask
+
+            clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
+            clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
+            sample["clip_pixel_values"] = clip_pixel_values
+
+            ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
+            if (mask == 1).all():
+                ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
+            sample["ref_pixel_values"] = ref_pixel_values
+
+        return sample
diff --git a/cogvideox/data/dataset_video.py b/cogvideox/data/dataset_video.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78367d0973ceb1abdcd005947612d16e2480831
--- /dev/null
+++ b/cogvideox/data/dataset_video.py
@@ -0,0 +1,262 @@
+import csv
+import gc
+import io
+import json
+import math
+import os
+import random
+from contextlib import contextmanager
+from threading import Thread
+
+import albumentations
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+from decord import VideoReader
+from einops import rearrange
+from func_timeout import FunctionTimedOut, func_timeout
+from PIL import Image
+from torch.utils.data import BatchSampler, Sampler
+from torch.utils.data.dataset import Dataset
+
+VIDEO_READER_TIMEOUT = 20
+
+def get_random_mask(shape):
+    f, c, h, w = shape
+    
+    mask_index = np.random.randint(0, 4)
+    mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
+    if mask_index == 0:
+        mask[1:, :, :, :] = 1
+    elif mask_index == 1:
+        mask_frame_index = 1
+        mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
+    elif mask_index == 2:
+        center_x = torch.randint(0, w, (1,)).item()
+        center_y = torch.randint(0, h, (1,)).item()
+        block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()  # 方块的宽度范围
+        block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()  # 方块的高度范围
+
+        start_x = max(center_x - block_size_x // 2, 0)
+        end_x = min(center_x + block_size_x // 2, w)
+        start_y = max(center_y - block_size_y // 2, 0)
+        end_y = min(center_y + block_size_y // 2, h)
+        mask[:, :, start_y:end_y, start_x:end_x] = 1
+    elif mask_index == 3:
+        center_x = torch.randint(0, w, (1,)).item()
+        center_y = torch.randint(0, h, (1,)).item()
+        block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()  # 方块的宽度范围
+        block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()  # 方块的高度范围
+
+        start_x = max(center_x - block_size_x // 2, 0)
+        end_x = min(center_x + block_size_x // 2, w)
+        start_y = max(center_y - block_size_y // 2, 0)
+        end_y = min(center_y + block_size_y // 2, h)
+
+        mask_frame_before = np.random.randint(0, f // 2)
+        mask_frame_after = np.random.randint(f // 2, f)
+        mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
+    else:
+        raise ValueError(f"The mask_index {mask_index} is not define")
+    return mask
+
+
+@contextmanager
+def VideoReader_contextmanager(*args, **kwargs):
+    vr = VideoReader(*args, **kwargs)
+    try:
+        yield vr
+    finally:
+        del vr
+        gc.collect()
+
+
+def get_video_reader_batch(video_reader, batch_index):
+    frames = video_reader.get_batch(batch_index).asnumpy()
+    return frames
+
+
+class WebVid10M(Dataset):
+    def __init__(
+            self,
+            csv_path, video_folder,
+            sample_size=256, sample_stride=4, sample_n_frames=16,
+            enable_bucket=False, enable_inpaint=False, is_image=False,
+        ):
+        print(f"loading annotations from {csv_path} ...")
+        with open(csv_path, 'r') as csvfile:
+            self.dataset = list(csv.DictReader(csvfile))
+        self.length = len(self.dataset)
+        print(f"data scale: {self.length}")
+
+        self.video_folder    = video_folder
+        self.sample_stride   = sample_stride
+        self.sample_n_frames = sample_n_frames
+        self.enable_bucket   = enable_bucket
+        self.enable_inpaint  = enable_inpaint
+        self.is_image        = is_image
+        
+        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+        self.pixel_transforms = transforms.Compose([
+            transforms.Resize(sample_size[0]),
+            transforms.CenterCrop(sample_size),
+            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+        ])
+    
+    def get_batch(self, idx):
+        video_dict = self.dataset[idx]
+        videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
+        
+        video_dir    = os.path.join(self.video_folder, f"{videoid}.mp4")
+        video_reader = VideoReader(video_dir)
+        video_length = len(video_reader)
+        
+        if not self.is_image:
+            clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+            start_idx   = random.randint(0, video_length - clip_length)
+            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+        else:
+            batch_index = [random.randint(0, video_length - 1)]
+
+        if not self.enable_bucket:
+            pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
+            pixel_values = pixel_values / 255.
+            del video_reader
+        else:
+            pixel_values = video_reader.get_batch(batch_index).asnumpy()
+
+        if self.is_image:
+            pixel_values = pixel_values[0]
+        return pixel_values, name
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        while True:
+            try:
+                pixel_values, name = self.get_batch(idx)
+                break
+
+            except Exception as e:
+                print("Error info:", e)
+                idx = random.randint(0, self.length-1)
+
+        if not self.enable_bucket:
+            pixel_values = self.pixel_transforms(pixel_values)
+        if self.enable_inpaint:
+            mask = get_random_mask(pixel_values.size())
+            mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+            sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
+        else:
+            sample = dict(pixel_values=pixel_values, text=name)
+        return sample
+
+
+class VideoDataset(Dataset):
+    def __init__(
+        self,
+        json_path, video_folder=None,
+        sample_size=256, sample_stride=4, sample_n_frames=16,
+        enable_bucket=False, enable_inpaint=False
+    ):
+        print(f"loading annotations from {json_path} ...")
+        self.dataset = json.load(open(json_path, 'r'))
+        self.length = len(self.dataset)
+        print(f"data scale: {self.length}")
+
+        self.video_folder    = video_folder
+        self.sample_stride   = sample_stride
+        self.sample_n_frames = sample_n_frames
+        self.enable_bucket   = enable_bucket
+        self.enable_inpaint  = enable_inpaint
+        
+        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
+        self.pixel_transforms = transforms.Compose(
+            [
+                transforms.Resize(sample_size[0]),
+                transforms.CenterCrop(sample_size),
+                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+            ]
+        )
+    
+    def get_batch(self, idx):
+        video_dict = self.dataset[idx]
+        video_id, name = video_dict['file_path'], video_dict['text']
+
+        if self.video_folder is None:
+            video_dir = video_id
+        else:
+            video_dir = os.path.join(self.video_folder, video_id)
+
+        with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
+            video_length = len(video_reader)
+        
+            clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
+            start_idx   = random.randint(0, video_length - clip_length)
+            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
+
+            try:
+                sample_args = (video_reader, batch_index)
+                pixel_values = func_timeout(
+                    VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
+                )
+            except FunctionTimedOut:
+                raise ValueError(f"Read {idx} timeout.")
+            except Exception as e:
+                raise ValueError(f"Failed to extract frames from video. Error is {e}.")
+
+            if not self.enable_bucket:
+                pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
+                pixel_values = pixel_values / 255.
+                del video_reader
+            else:
+                pixel_values = pixel_values
+
+            return pixel_values, name
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, idx):
+        while True:
+            try:
+                pixel_values, name = self.get_batch(idx)
+                break
+
+            except Exception as e:
+                print("Error info:", e)
+                idx = random.randint(0, self.length-1)
+
+        if not self.enable_bucket:
+            pixel_values = self.pixel_transforms(pixel_values)
+        if self.enable_inpaint:
+            mask = get_random_mask(pixel_values.size())
+            mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
+            sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
+        else:
+            sample = dict(pixel_values=pixel_values, text=name)
+        return sample
+
+
+if __name__ == "__main__":
+    if 1:
+        dataset = VideoDataset(
+            json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
+            sample_size=256,
+            sample_stride=4, sample_n_frames=16,
+        )
+
+    if 0:
+        dataset = WebVid10M(
+            csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
+            video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
+            sample_size=256,
+            sample_stride=4, sample_n_frames=16,
+            is_image=False,
+        )
+
+    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
+    for idx, batch in enumerate(dataloader):
+        print(batch["pixel_values"].shape, len(batch["text"]))
\ No newline at end of file
diff --git a/cogvideox/models/autoencoder_magvit.py b/cogvideox/models/autoencoder_magvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c2b9063956e12e9504a09dc87bccc6611787508
--- /dev/null
+++ b/cogvideox/models/autoencoder_magvit.py
@@ -0,0 +1,1296 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.utils import logging
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from diffusers.models.activations import get_activation
+from diffusers.models.downsampling import CogVideoXDownsample3D
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.upsampling import CogVideoXUpsample3D
+from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+class CogVideoXSafeConv3d(nn.Conv3d):
+    r"""
+    A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
+    """
+
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+
+        # Set to 2GB, suitable for CuDNN
+        if memory_count > 2:
+            kernel_size = self.kernel_size[0]
+            part_num = int(memory_count / 2) + 1
+            input_chunks = torch.chunk(input, part_num, dim=2)
+
+            if kernel_size > 1:
+                input_chunks = [input_chunks[0]] + [
+                    torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
+                    for i in range(1, len(input_chunks))
+                ]
+
+            output_chunks = []
+            for input_chunk in input_chunks:
+                output_chunks.append(super().forward(input_chunk))
+            output = torch.cat(output_chunks, dim=2)
+            return output
+        else:
+            return super().forward(input)
+
+
+class CogVideoXCausalConv3d(nn.Module):
+    r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
+
+    Args:
+        in_channels (`int`): Number of channels in the input tensor.
+        out_channels (`int`): Number of output channels produced by the convolution.
+        kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
+        stride (`int`, defaults to `1`): Stride of the convolution.
+        dilation (`int`, defaults to `1`): Dilation rate of the convolution.
+        pad_mode (`str`, defaults to `"constant"`): Padding mode.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Union[int, Tuple[int, int, int]],
+        stride: int = 1,
+        dilation: int = 1,
+        pad_mode: str = "constant",
+    ):
+        super().__init__()
+
+        if isinstance(kernel_size, int):
+            kernel_size = (kernel_size,) * 3
+
+        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+
+        self.pad_mode = pad_mode
+        time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+        height_pad = height_kernel_size // 2
+        width_pad = width_kernel_size // 2
+
+        self.height_pad = height_pad
+        self.width_pad = width_pad
+        self.time_pad = time_pad
+        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
+
+        self.temporal_dim = 2
+        self.time_kernel_size = time_kernel_size
+
+        stride = (stride, 1, 1)
+        dilation = (dilation, 1, 1)
+        self.conv = CogVideoXSafeConv3d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            dilation=dilation,
+        )
+
+        self.conv_cache = None
+
+    def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        kernel_size = self.time_kernel_size
+        if kernel_size > 1:
+            cached_inputs = (
+                [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
+            )
+            inputs = torch.cat(cached_inputs + [inputs], dim=2)
+        return inputs
+
+    def _clear_fake_context_parallel_cache(self):
+        del self.conv_cache
+        self.conv_cache = None
+
+    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        inputs = self.fake_context_parallel_forward(inputs)
+
+        self._clear_fake_context_parallel_cache()
+        # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
+        # hundred megabytes and so let's not do it for now
+        self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
+
+        padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
+        inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
+
+        output = self.conv(inputs)
+        return output
+
+
+class CogVideoXSpatialNorm3D(nn.Module):
+    r"""
+    Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
+    to 3D-video like data.
+
+    CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
+
+    Args:
+        f_channels (`int`):
+            The number of channels for input to group normalization layer, and output of the spatial norm layer.
+        zq_channels (`int`):
+            The number of channels for the quantized vector as described in the paper.
+        groups (`int`):
+            Number of groups to separate the channels into for group normalization.
+    """
+
+    def __init__(
+        self,
+        f_channels: int,
+        zq_channels: int,
+        groups: int = 32,
+    ):
+        super().__init__()
+        self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
+        self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+        self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+
+    def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+        if f.shape[2] > 1 and f.shape[2] % 2 == 1:
+            f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+            f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
+            z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
+            z_first = F.interpolate(z_first, size=f_first_size)
+            z_rest = F.interpolate(z_rest, size=f_rest_size)
+            zq = torch.cat([z_first, z_rest], dim=2)
+        else:
+            zq = F.interpolate(zq, size=f.shape[-3:])
+
+        norm_f = self.norm_layer(f)
+        new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+        return new_f
+
+
+class CogVideoXResnetBlock3D(nn.Module):
+    r"""
+    A 3D ResNet block used in the CogVideoX model.
+
+    Args:
+        in_channels (`int`):
+            Number of input channels.
+        out_channels (`int`, *optional*):
+            Number of output channels. If None, defaults to `in_channels`.
+        dropout (`float`, defaults to `0.0`):
+            Dropout rate.
+        temb_channels (`int`, defaults to `512`):
+            Number of time embedding channels.
+        groups (`int`, defaults to `32`):
+            Number of groups to separate the channels into for group normalization.
+        eps (`float`, defaults to `1e-6`):
+            Epsilon value for normalization layers.
+        non_linearity (`str`, defaults to `"swish"`):
+            Activation function to use.
+        conv_shortcut (bool, defaults to `False`):
+            Whether or not to use a convolution shortcut.
+        spatial_norm_dim (`int`, *optional*):
+            The dimension to use for spatial norm if it is to be used instead of group norm.
+        pad_mode (str, defaults to `"first"`):
+            Padding mode.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: Optional[int] = None,
+        dropout: float = 0.0,
+        temb_channels: int = 512,
+        groups: int = 32,
+        eps: float = 1e-6,
+        non_linearity: str = "swish",
+        conv_shortcut: bool = False,
+        spatial_norm_dim: Optional[int] = None,
+        pad_mode: str = "first",
+    ):
+        super().__init__()
+
+        out_channels = out_channels or in_channels
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.nonlinearity = get_activation(non_linearity)
+        self.use_conv_shortcut = conv_shortcut
+
+        if spatial_norm_dim is None:
+            self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
+            self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
+        else:
+            self.norm1 = CogVideoXSpatialNorm3D(
+                f_channels=in_channels,
+                zq_channels=spatial_norm_dim,
+                groups=groups,
+            )
+            self.norm2 = CogVideoXSpatialNorm3D(
+                f_channels=out_channels,
+                zq_channels=spatial_norm_dim,
+                groups=groups,
+            )
+
+        self.conv1 = CogVideoXCausalConv3d(
+            in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+        )
+
+        if temb_channels > 0:
+            self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
+
+        self.dropout = nn.Dropout(dropout)
+        self.conv2 = CogVideoXCausalConv3d(
+            in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+        )
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = CogVideoXCausalConv3d(
+                    in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
+                )
+            else:
+                self.conv_shortcut = CogVideoXSafeConv3d(
+                    in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
+                )
+
+    def forward(
+        self,
+        inputs: torch.Tensor,
+        temb: Optional[torch.Tensor] = None,
+        zq: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        hidden_states = inputs
+
+        if zq is not None:
+            hidden_states = self.norm1(hidden_states, zq)
+        else:
+            hidden_states = self.norm1(hidden_states)
+
+        hidden_states = self.nonlinearity(hidden_states)
+        hidden_states = self.conv1(hidden_states)
+
+        if temb is not None:
+            hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+        if zq is not None:
+            hidden_states = self.norm2(hidden_states, zq)
+        else:
+            hidden_states = self.norm2(hidden_states)
+
+        hidden_states = self.nonlinearity(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.conv2(hidden_states)
+
+        if self.in_channels != self.out_channels:
+            inputs = self.conv_shortcut(inputs)
+
+        hidden_states = hidden_states + inputs
+        return hidden_states
+
+
+class CogVideoXDownBlock3D(nn.Module):
+    r"""
+    A downsampling block used in the CogVideoX model.
+
+    Args:
+        in_channels (`int`):
+            Number of input channels.
+        out_channels (`int`, *optional*):
+            Number of output channels. If None, defaults to `in_channels`.
+        temb_channels (`int`, defaults to `512`):
+            Number of time embedding channels.
+        num_layers (`int`, defaults to `1`):
+            Number of resnet layers.
+        dropout (`float`, defaults to `0.0`):
+            Dropout rate.
+        resnet_eps (`float`, defaults to `1e-6`):
+            Epsilon value for normalization layers.
+        resnet_act_fn (`str`, defaults to `"swish"`):
+            Activation function to use.
+        resnet_groups (`int`, defaults to `32`):
+            Number of groups to separate the channels into for group normalization.
+        add_downsample (`bool`, defaults to `True`):
+            Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+        compress_time (`bool`, defaults to `False`):
+            Whether or not to downsample across temporal dimension.
+        pad_mode (str, defaults to `"first"`):
+            Padding mode.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        add_downsample: bool = True,
+        downsample_padding: int = 0,
+        compress_time: bool = False,
+        pad_mode: str = "first",
+    ):
+        super().__init__()
+
+        resnets = []
+        for i in range(num_layers):
+            in_channel = in_channels if i == 0 else out_channels
+            resnets.append(
+                CogVideoXResnetBlock3D(
+                    in_channels=in_channel,
+                    out_channels=out_channels,
+                    dropout=dropout,
+                    temb_channels=temb_channels,
+                    groups=resnet_groups,
+                    eps=resnet_eps,
+                    non_linearity=resnet_act_fn,
+                    pad_mode=pad_mode,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.downsamplers = None
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    CogVideoXDownsample3D(
+                        out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
+                    )
+                ]
+            )
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        temb: Optional[torch.Tensor] = None,
+        zq: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def create_forward(*inputs):
+                        return module(*inputs)
+
+                    return create_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet), hidden_states, temb, zq
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, zq)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+        return hidden_states
+
+
+class CogVideoXMidBlock3D(nn.Module):
+    r"""
+    A middle block used in the CogVideoX model.
+
+    Args:
+        in_channels (`int`):
+            Number of input channels.
+        temb_channels (`int`, defaults to `512`):
+            Number of time embedding channels.
+        dropout (`float`, defaults to `0.0`):
+            Dropout rate.
+        num_layers (`int`, defaults to `1`):
+            Number of resnet layers.
+        resnet_eps (`float`, defaults to `1e-6`):
+            Epsilon value for normalization layers.
+        resnet_act_fn (`str`, defaults to `"swish"`):
+            Activation function to use.
+        resnet_groups (`int`, defaults to `32`):
+            Number of groups to separate the channels into for group normalization.
+        spatial_norm_dim (`int`, *optional*):
+            The dimension to use for spatial norm if it is to be used instead of group norm.
+        pad_mode (str, defaults to `"first"`):
+            Padding mode.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        spatial_norm_dim: Optional[int] = None,
+        pad_mode: str = "first",
+    ):
+        super().__init__()
+
+        resnets = []
+        for _ in range(num_layers):
+            resnets.append(
+                CogVideoXResnetBlock3D(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    dropout=dropout,
+                    temb_channels=temb_channels,
+                    groups=resnet_groups,
+                    eps=resnet_eps,
+                    spatial_norm_dim=spatial_norm_dim,
+                    non_linearity=resnet_act_fn,
+                    pad_mode=pad_mode,
+                )
+            )
+        self.resnets = nn.ModuleList(resnets)
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        temb: Optional[torch.Tensor] = None,
+        zq: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def create_forward(*inputs):
+                        return module(*inputs)
+
+                    return create_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet), hidden_states, temb, zq
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, zq)
+
+        return hidden_states
+
+
+class CogVideoXUpBlock3D(nn.Module):
+    r"""
+    An upsampling block used in the CogVideoX model.
+
+    Args:
+        in_channels (`int`):
+            Number of input channels.
+        out_channels (`int`, *optional*):
+            Number of output channels. If None, defaults to `in_channels`.
+        temb_channels (`int`, defaults to `512`):
+            Number of time embedding channels.
+        dropout (`float`, defaults to `0.0`):
+            Dropout rate.
+        num_layers (`int`, defaults to `1`):
+            Number of resnet layers.
+        resnet_eps (`float`, defaults to `1e-6`):
+            Epsilon value for normalization layers.
+        resnet_act_fn (`str`, defaults to `"swish"`):
+            Activation function to use.
+        resnet_groups (`int`, defaults to `32`):
+            Number of groups to separate the channels into for group normalization.
+        spatial_norm_dim (`int`, defaults to `16`):
+            The dimension to use for spatial norm if it is to be used instead of group norm.
+        add_upsample (`bool`, defaults to `True`):
+            Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
+        compress_time (`bool`, defaults to `False`):
+            Whether or not to downsample across temporal dimension.
+        pad_mode (str, defaults to `"first"`):
+            Padding mode.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        spatial_norm_dim: int = 16,
+        add_upsample: bool = True,
+        upsample_padding: int = 1,
+        compress_time: bool = False,
+        pad_mode: str = "first",
+    ):
+        super().__init__()
+
+        resnets = []
+        for i in range(num_layers):
+            in_channel = in_channels if i == 0 else out_channels
+            resnets.append(
+                CogVideoXResnetBlock3D(
+                    in_channels=in_channel,
+                    out_channels=out_channels,
+                    dropout=dropout,
+                    temb_channels=temb_channels,
+                    groups=resnet_groups,
+                    eps=resnet_eps,
+                    non_linearity=resnet_act_fn,
+                    spatial_norm_dim=spatial_norm_dim,
+                    pad_mode=pad_mode,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.upsamplers = None
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList(
+                [
+                    CogVideoXUpsample3D(
+                        out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
+                    )
+                ]
+            )
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        temb: Optional[torch.Tensor] = None,
+        zq: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        r"""Forward method of the `CogVideoXUpBlock3D` class."""
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def create_forward(*inputs):
+                        return module(*inputs)
+
+                    return create_forward
+
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet), hidden_states, temb, zq
+                )
+            else:
+                hidden_states = resnet(hidden_states, temb, zq)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
+
+
+class CogVideoXEncoder3D(nn.Module):
+    r"""
+    The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
+
+    Args:
+        in_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        out_channels (`int`, *optional*, defaults to 3):
+            The number of output channels.
+        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
+            options.
+        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+            The number of output channels for each block.
+        act_fn (`str`, *optional*, defaults to `"silu"`):
+            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+        layers_per_block (`int`, *optional*, defaults to 2):
+            The number of layers per block.
+        norm_num_groups (`int`, *optional*, defaults to 32):
+            The number of groups for normalization.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    def __init__(
+        self,
+        in_channels: int = 3,
+        out_channels: int = 16,
+        down_block_types: Tuple[str, ...] = (
+            "CogVideoXDownBlock3D",
+            "CogVideoXDownBlock3D",
+            "CogVideoXDownBlock3D",
+            "CogVideoXDownBlock3D",
+        ),
+        block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
+        layers_per_block: int = 3,
+        act_fn: str = "silu",
+        norm_eps: float = 1e-6,
+        norm_num_groups: int = 32,
+        dropout: float = 0.0,
+        pad_mode: str = "first",
+        temporal_compression_ratio: float = 4,
+    ):
+        super().__init__()
+
+        # log2 of temporal_compress_times
+        temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+        self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
+        self.down_blocks = nn.ModuleList([])
+
+        # down blocks
+        output_channel = block_out_channels[0]
+        for i, down_block_type in enumerate(down_block_types):
+            input_channel = output_channel
+            output_channel = block_out_channels[i]
+            is_final_block = i == len(block_out_channels) - 1
+            compress_time = i < temporal_compress_level
+
+            if down_block_type == "CogVideoXDownBlock3D":
+                down_block = CogVideoXDownBlock3D(
+                    in_channels=input_channel,
+                    out_channels=output_channel,
+                    temb_channels=0,
+                    dropout=dropout,
+                    num_layers=layers_per_block,
+                    resnet_eps=norm_eps,
+                    resnet_act_fn=act_fn,
+                    resnet_groups=norm_num_groups,
+                    add_downsample=not is_final_block,
+                    compress_time=compress_time,
+                )
+            else:
+                raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
+
+            self.down_blocks.append(down_block)
+
+        # mid block
+        self.mid_block = CogVideoXMidBlock3D(
+            in_channels=block_out_channels[-1],
+            temb_channels=0,
+            dropout=dropout,
+            num_layers=2,
+            resnet_eps=norm_eps,
+            resnet_act_fn=act_fn,
+            resnet_groups=norm_num_groups,
+            pad_mode=pad_mode,
+        )
+
+        self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
+        self.conv_act = nn.SiLU()
+        self.conv_out = CogVideoXCausalConv3d(
+            block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
+        )
+
+        self.gradient_checkpointing = False
+
+    def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+        r"""The forward method of the `CogVideoXEncoder3D` class."""
+        hidden_states = self.conv_in(sample)
+
+        if self.training and self.gradient_checkpointing:
+
+            def create_custom_forward(module):
+                def custom_forward(*inputs):
+                    return module(*inputs)
+
+                return custom_forward
+
+            # 1. Down
+            for down_block in self.down_blocks:
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(down_block), hidden_states, temb, None
+                )
+
+            # 2. Mid
+            hidden_states = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(self.mid_block), hidden_states, temb, None
+            )
+        else:
+            # 1. Down
+            for down_block in self.down_blocks:
+                hidden_states = down_block(hidden_states, temb, None)
+
+            # 2. Mid
+            hidden_states = self.mid_block(hidden_states, temb, None)
+
+        # 3. Post-process
+        hidden_states = self.norm_out(hidden_states)
+        hidden_states = self.conv_act(hidden_states)
+        hidden_states = self.conv_out(hidden_states)
+        return hidden_states
+
+
+class CogVideoXDecoder3D(nn.Module):
+    r"""
+    The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
+    sample.
+
+    Args:
+        in_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        out_channels (`int`, *optional*, defaults to 3):
+            The number of output channels.
+        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
+        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+            The number of output channels for each block.
+        act_fn (`str`, *optional*, defaults to `"silu"`):
+            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+        layers_per_block (`int`, *optional*, defaults to 2):
+            The number of layers per block.
+        norm_num_groups (`int`, *optional*, defaults to 32):
+            The number of groups for normalization.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    def __init__(
+        self,
+        in_channels: int = 16,
+        out_channels: int = 3,
+        up_block_types: Tuple[str, ...] = (
+            "CogVideoXUpBlock3D",
+            "CogVideoXUpBlock3D",
+            "CogVideoXUpBlock3D",
+            "CogVideoXUpBlock3D",
+        ),
+        block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
+        layers_per_block: int = 3,
+        act_fn: str = "silu",
+        norm_eps: float = 1e-6,
+        norm_num_groups: int = 32,
+        dropout: float = 0.0,
+        pad_mode: str = "first",
+        temporal_compression_ratio: float = 4,
+    ):
+        super().__init__()
+
+        reversed_block_out_channels = list(reversed(block_out_channels))
+
+        self.conv_in = CogVideoXCausalConv3d(
+            in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
+        )
+
+        # mid block
+        self.mid_block = CogVideoXMidBlock3D(
+            in_channels=reversed_block_out_channels[0],
+            temb_channels=0,
+            num_layers=2,
+            resnet_eps=norm_eps,
+            resnet_act_fn=act_fn,
+            resnet_groups=norm_num_groups,
+            spatial_norm_dim=in_channels,
+            pad_mode=pad_mode,
+        )
+
+        # up blocks
+        self.up_blocks = nn.ModuleList([])
+
+        output_channel = reversed_block_out_channels[0]
+        temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+        for i, up_block_type in enumerate(up_block_types):
+            prev_output_channel = output_channel
+            output_channel = reversed_block_out_channels[i]
+            is_final_block = i == len(block_out_channels) - 1
+            compress_time = i < temporal_compress_level
+
+            if up_block_type == "CogVideoXUpBlock3D":
+                up_block = CogVideoXUpBlock3D(
+                    in_channels=prev_output_channel,
+                    out_channels=output_channel,
+                    temb_channels=0,
+                    dropout=dropout,
+                    num_layers=layers_per_block + 1,
+                    resnet_eps=norm_eps,
+                    resnet_act_fn=act_fn,
+                    resnet_groups=norm_num_groups,
+                    spatial_norm_dim=in_channels,
+                    add_upsample=not is_final_block,
+                    compress_time=compress_time,
+                    pad_mode=pad_mode,
+                )
+                prev_output_channel = output_channel
+            else:
+                raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
+
+            self.up_blocks.append(up_block)
+
+        self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
+        self.conv_act = nn.SiLU()
+        self.conv_out = CogVideoXCausalConv3d(
+            reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
+        )
+
+        self.gradient_checkpointing = False
+
+    def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
+        r"""The forward method of the `CogVideoXDecoder3D` class."""
+        hidden_states = self.conv_in(sample)
+
+        if self.training and self.gradient_checkpointing:
+
+            def create_custom_forward(module):
+                def custom_forward(*inputs):
+                    return module(*inputs)
+
+                return custom_forward
+
+            # 1. Mid
+            hidden_states = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(self.mid_block), hidden_states, temb, sample
+            )
+
+            # 2. Up
+            for up_block in self.up_blocks:
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(up_block), hidden_states, temb, sample
+                )
+        else:
+            # 1. Mid
+            hidden_states = self.mid_block(hidden_states, temb, sample)
+
+            # 2. Up
+            for up_block in self.up_blocks:
+                hidden_states = up_block(hidden_states, temb, sample)
+
+        # 3. Post-process
+        hidden_states = self.norm_out(hidden_states, sample)
+        hidden_states = self.conv_act(hidden_states)
+        hidden_states = self.conv_out(hidden_states)
+        return hidden_states
+
+
+class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+    r"""
+    A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+    [CogVideoX](https://github.com/THUDM/CogVideo).
+
+    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+    for all models (such as downloading or saving).
+
+    Parameters:
+        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
+        down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+            Tuple of downsample block types.
+        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+            Tuple of upsample block types.
+        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+            Tuple of block output channels.
+        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+        sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+        scaling_factor (`float`, *optional*, defaults to `1.15258426`):
+            The component-wise standard deviation of the trained latent space computed using the first batch of the
+            training set. This is used to scale the latent space to have unit variance when training the diffusion
+            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+            Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+        force_upcast (`bool`, *optional*, default to `True`):
+            If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+            can be fine-tuned / trained to a lower range without loosing too much precision in which case
+            `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+    """
+
+    _supports_gradient_checkpointing = True
+    _no_split_modules = ["CogVideoXResnetBlock3D"]
+
+    @register_to_config
+    def __init__(
+        self,
+        in_channels: int = 3,
+        out_channels: int = 3,
+        down_block_types: Tuple[str] = (
+            "CogVideoXDownBlock3D",
+            "CogVideoXDownBlock3D",
+            "CogVideoXDownBlock3D",
+            "CogVideoXDownBlock3D",
+        ),
+        up_block_types: Tuple[str] = (
+            "CogVideoXUpBlock3D",
+            "CogVideoXUpBlock3D",
+            "CogVideoXUpBlock3D",
+            "CogVideoXUpBlock3D",
+        ),
+        block_out_channels: Tuple[int] = (128, 256, 256, 512),
+        latent_channels: int = 16,
+        layers_per_block: int = 3,
+        act_fn: str = "silu",
+        norm_eps: float = 1e-6,
+        norm_num_groups: int = 32,
+        temporal_compression_ratio: float = 4,
+        sample_height: int = 480,
+        sample_width: int = 720,
+        scaling_factor: float = 1.15258426,
+        shift_factor: Optional[float] = None,
+        latents_mean: Optional[Tuple[float]] = None,
+        latents_std: Optional[Tuple[float]] = None,
+        force_upcast: float = True,
+        use_quant_conv: bool = False,
+        use_post_quant_conv: bool = False,
+    ):
+        super().__init__()
+
+        self.encoder = CogVideoXEncoder3D(
+            in_channels=in_channels,
+            out_channels=latent_channels,
+            down_block_types=down_block_types,
+            block_out_channels=block_out_channels,
+            layers_per_block=layers_per_block,
+            act_fn=act_fn,
+            norm_eps=norm_eps,
+            norm_num_groups=norm_num_groups,
+            temporal_compression_ratio=temporal_compression_ratio,
+        )
+        self.decoder = CogVideoXDecoder3D(
+            in_channels=latent_channels,
+            out_channels=out_channels,
+            up_block_types=up_block_types,
+            block_out_channels=block_out_channels,
+            layers_per_block=layers_per_block,
+            act_fn=act_fn,
+            norm_eps=norm_eps,
+            norm_num_groups=norm_num_groups,
+            temporal_compression_ratio=temporal_compression_ratio,
+        )
+        self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
+        self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
+
+        self.use_slicing = False
+        self.use_tiling = False
+
+        # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
+        # recommended because the temporal parts of the VAE, here, are tricky to understand.
+        # If you decode X latent frames together, the number of output frames is:
+        #     (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
+        #
+        # Example with num_latent_frames_batch_size = 2:
+        #     - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
+        #         => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
+        #         => 6 * 8 = 48 frames
+        #     - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
+        #         => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
+        #            ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
+        #         => 1 * 9 + 5 * 8 = 49 frames
+        # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
+        # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
+        # number of temporal frames.
+        self.num_latent_frames_batch_size = 2
+
+        # We make the minimum height and width of sample for tiling half that of the generally supported
+        self.tile_sample_min_height = sample_height // 2
+        self.tile_sample_min_width = sample_width // 2
+        self.tile_latent_min_height = int(
+            self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
+        )
+        self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
+
+        # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
+        # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
+        # and so the tiling implementation has only been tested on those specific resolutions.
+        self.tile_overlap_factor_height = 1 / 6
+        self.tile_overlap_factor_width = 1 / 5
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
+            module.gradient_checkpointing = value
+
+    def _clear_fake_context_parallel_cache(self):
+        for name, module in self.named_modules():
+            if isinstance(module, CogVideoXCausalConv3d):
+                logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
+                module._clear_fake_context_parallel_cache()
+
+    def enable_tiling(
+        self,
+        tile_sample_min_height: Optional[int] = None,
+        tile_sample_min_width: Optional[int] = None,
+        tile_overlap_factor_height: Optional[float] = None,
+        tile_overlap_factor_width: Optional[float] = None,
+    ) -> None:
+        r"""
+        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+        processing larger images.
+
+        Args:
+            tile_sample_min_height (`int`, *optional*):
+                The minimum height required for a sample to be separated into tiles across the height dimension.
+            tile_sample_min_width (`int`, *optional*):
+                The minimum width required for a sample to be separated into tiles across the width dimension.
+            tile_overlap_factor_height (`int`, *optional*):
+                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+                no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
+                value might cause more tiles to be processed leading to slow down of the decoding process.
+            tile_overlap_factor_width (`int`, *optional*):
+                The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
+                are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
+                value might cause more tiles to be processed leading to slow down of the decoding process.
+        """
+        self.use_tiling = True
+        self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+        self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+        self.tile_latent_min_height = int(
+            self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
+        )
+        self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
+        self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
+        self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
+
+    def disable_tiling(self) -> None:
+        r"""
+        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+        decoding in one step.
+        """
+        self.use_tiling = False
+
+    def enable_slicing(self) -> None:
+        r"""
+        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.use_slicing = True
+
+    def disable_slicing(self) -> None:
+        r"""
+        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+        decoding in one step.
+        """
+        self.use_slicing = False
+
+    @apply_forward_hook
+    def encode(
+        self, x: torch.Tensor, return_dict: bool = True
+    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+        """
+        Encode a batch of images into latents.
+
+        Args:
+            x (`torch.Tensor`): Input batch of images.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+        Returns:
+                The latent representations of the encoded images. If `return_dict` is True, a
+                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+        """
+        batch_size, num_channels, num_frames, height, width = x.shape
+        if num_frames == 1:
+            h = self.encoder(x)
+            if self.quant_conv is not None:
+                h = self.quant_conv(h)
+            posterior = DiagonalGaussianDistribution(h)
+        else:
+            frame_batch_size = 4
+            h = []
+            for i in range(num_frames // frame_batch_size):
+                remaining_frames = num_frames % frame_batch_size
+                start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
+                end_frame = frame_batch_size * (i + 1) + remaining_frames
+                z_intermediate = x[:, :, start_frame:end_frame]
+                z_intermediate = self.encoder(z_intermediate)
+                if self.quant_conv is not None:
+                    z_intermediate = self.quant_conv(z_intermediate)
+                h.append(z_intermediate)
+            self._clear_fake_context_parallel_cache()
+            h = torch.cat(h, dim=2)
+            posterior = DiagonalGaussianDistribution(h)
+        if not return_dict:
+            return (posterior,)
+        return AutoencoderKLOutput(latent_dist=posterior)
+
+    def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+        batch_size, num_channels, num_frames, height, width = z.shape
+
+        if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
+            return self.tiled_decode(z, return_dict=return_dict)
+
+        if num_frames == 1:
+            dec = []
+            z_intermediate = z
+            if self.post_quant_conv is not None:
+                z_intermediate = self.post_quant_conv(z_intermediate)
+            z_intermediate = self.decoder(z_intermediate)
+            dec.append(z_intermediate)
+        else:
+            frame_batch_size = self.num_latent_frames_batch_size
+            dec = []
+            for i in range(num_frames // frame_batch_size):
+                remaining_frames = num_frames % frame_batch_size
+                start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
+                end_frame = frame_batch_size * (i + 1) + remaining_frames
+                z_intermediate = z[:, :, start_frame:end_frame]
+                if self.post_quant_conv is not None:
+                    z_intermediate = self.post_quant_conv(z_intermediate)
+                z_intermediate = self.decoder(z_intermediate)
+                dec.append(z_intermediate)
+
+        self._clear_fake_context_parallel_cache()
+        dec = torch.cat(dec, dim=2)
+
+        if not return_dict:
+            return (dec,)
+
+        return DecoderOutput(sample=dec)
+
+    @apply_forward_hook
+    def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+        """
+        Decode a batch of images.
+
+        Args:
+            z (`torch.Tensor`): Input batch of latent vectors.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+        Returns:
+            [`~models.vae.DecoderOutput`] or `tuple`:
+                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+                returned.
+        """
+        if self.use_slicing and z.shape[0] > 1:
+            decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+            decoded = torch.cat(decoded_slices)
+        else:
+            decoded = self._decode(z).sample
+
+        if not return_dict:
+            return (decoded,)
+        return DecoderOutput(sample=decoded)
+
+    def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+        blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+        for y in range(blend_extent):
+            b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+                y / blend_extent
+            )
+        return b
+
+    def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+        blend_extent = min(a.shape[4], b.shape[4], blend_extent)
+        for x in range(blend_extent):
+            b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+                x / blend_extent
+            )
+        return b
+
+    def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+        r"""
+        Decode a batch of images using a tiled decoder.
+
+        Args:
+            z (`torch.Tensor`): Input batch of latent vectors.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+        Returns:
+            [`~models.vae.DecoderOutput`] or `tuple`:
+                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+                returned.
+        """
+        # Rough memory assessment:
+        #   - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
+        #   - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
+        #   - Assume fp16 (2 bytes per value).
+        # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
+        #
+        # Memory assessment when using tiling:
+        #   - Assume everything as above but now HxW is 240x360 by tiling in half
+        # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
+
+        batch_size, num_channels, num_frames, height, width = z.shape
+
+        overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
+        overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
+        blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
+        blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
+        row_limit_height = self.tile_sample_min_height - blend_extent_height
+        row_limit_width = self.tile_sample_min_width - blend_extent_width
+        frame_batch_size = self.num_latent_frames_batch_size
+
+        # Split z into overlapping tiles and decode them separately.
+        # The tiles have an overlap to avoid seams between tiles.
+        rows = []
+        for i in range(0, height, overlap_height):
+            row = []
+            for j in range(0, width, overlap_width):
+                time = []
+                for k in range(num_frames // frame_batch_size):
+                    remaining_frames = num_frames % frame_batch_size
+                    start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
+                    end_frame = frame_batch_size * (k + 1) + remaining_frames
+                    tile = z[
+                        :,
+                        :,
+                        start_frame:end_frame,
+                        i : i + self.tile_latent_min_height,
+                        j : j + self.tile_latent_min_width,
+                    ]
+                    if self.post_quant_conv is not None:
+                        tile = self.post_quant_conv(tile)
+                    tile = self.decoder(tile)
+                    time.append(tile)
+                self._clear_fake_context_parallel_cache()
+                row.append(torch.cat(time, dim=2))
+            rows.append(row)
+
+        result_rows = []
+        for i, row in enumerate(rows):
+            result_row = []
+            for j, tile in enumerate(row):
+                # blend the above tile and the left tile
+                # to the current tile and add the current tile to the result row
+                if i > 0:
+                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
+                if j > 0:
+                    tile = self.blend_h(row[j - 1], tile, blend_extent_width)
+                result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+            result_rows.append(torch.cat(result_row, dim=4))
+
+        dec = torch.cat(result_rows, dim=3)
+
+        if not return_dict:
+            return (dec,)
+
+        return DecoderOutput(sample=dec)
+
+    def forward(
+        self,
+        sample: torch.Tensor,
+        sample_posterior: bool = False,
+        return_dict: bool = True,
+        generator: Optional[torch.Generator] = None,
+    ) -> Union[torch.Tensor, torch.Tensor]:
+        x = sample
+        posterior = self.encode(x).latent_dist
+        if sample_posterior:
+            z = posterior.sample(generator=generator)
+        else:
+            z = posterior.mode()
+        dec = self.decode(z)
+        if not return_dict:
+            return (dec,)
+        return dec
diff --git a/cogvideox/models/transformer3d.py b/cogvideox/models/transformer3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..88c8013fc7c904436a7f570ca985e64d06be3f64
--- /dev/null
+++ b/cogvideox/models/transformer3d.py
@@ -0,0 +1,609 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import os
+import json
+import torch
+import glob
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import is_torch_version, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+class CogVideoXPatchEmbed(nn.Module):
+    def __init__(
+        self,
+        patch_size: int = 2,
+        in_channels: int = 16,
+        embed_dim: int = 1920,
+        text_embed_dim: int = 4096,
+        bias: bool = True,
+    ) -> None:
+        super().__init__()
+        self.patch_size = patch_size
+
+        self.proj = nn.Conv2d(
+            in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+        )
+        self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+    def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+        r"""
+        Args:
+            text_embeds (`torch.Tensor`):
+                Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+            image_embeds (`torch.Tensor`):
+                Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+        """
+        text_embeds = self.text_proj(text_embeds)
+
+        batch, num_frames, channels, height, width = image_embeds.shape
+        image_embeds = image_embeds.reshape(-1, channels, height, width)
+        image_embeds = self.proj(image_embeds)
+        image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+        image_embeds = image_embeds.flatten(3).transpose(2, 3)  # [batch, num_frames, height x width, channels]
+        image_embeds = image_embeds.flatten(1, 2)  # [batch, num_frames x height x width, channels]
+
+        embeds = torch.cat(
+            [text_embeds, image_embeds], dim=1
+        ).contiguous()  # [batch, seq_length + num_frames x height x width, channels]
+        return embeds
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+    r"""
+    Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+    Parameters:
+        dim (`int`):
+            The number of channels in the input and output.
+        num_attention_heads (`int`):
+            The number of heads to use for multi-head attention.
+        attention_head_dim (`int`):
+            The number of channels in each head.
+        time_embed_dim (`int`):
+            The number of channels in timestep embedding.
+        dropout (`float`, defaults to `0.0`):
+            The dropout probability to use.
+        activation_fn (`str`, defaults to `"gelu-approximate"`):
+            Activation function to be used in feed-forward.
+        attention_bias (`bool`, defaults to `False`):
+            Whether or not to use bias in attention projection layers.
+        qk_norm (`bool`, defaults to `True`):
+            Whether or not to use normalization after query and key projections in Attention.
+        norm_elementwise_affine (`bool`, defaults to `True`):
+            Whether to use learnable elementwise affine parameters for normalization.
+        norm_eps (`float`, defaults to `1e-5`):
+            Epsilon value for normalization layers.
+        final_dropout (`bool` defaults to `False`):
+            Whether to apply a final dropout after the last feed-forward layer.
+        ff_inner_dim (`int`, *optional*, defaults to `None`):
+            Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+        ff_bias (`bool`, defaults to `True`):
+            Whether or not to use bias in Feed-forward layer.
+        attention_out_bias (`bool`, defaults to `True`):
+            Whether or not to use bias in Attention output projection layer.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        time_embed_dim: int,
+        dropout: float = 0.0,
+        activation_fn: str = "gelu-approximate",
+        attention_bias: bool = False,
+        qk_norm: bool = True,
+        norm_elementwise_affine: bool = True,
+        norm_eps: float = 1e-5,
+        final_dropout: bool = True,
+        ff_inner_dim: Optional[int] = None,
+        ff_bias: bool = True,
+        attention_out_bias: bool = True,
+    ):
+        super().__init__()
+
+        # 1. Self Attention
+        self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+        self.attn1 = Attention(
+            query_dim=dim,
+            dim_head=attention_head_dim,
+            heads=num_attention_heads,
+            qk_norm="layer_norm" if qk_norm else None,
+            eps=1e-6,
+            bias=attention_bias,
+            out_bias=attention_out_bias,
+            processor=CogVideoXAttnProcessor2_0(),
+        )
+
+        # 2. Feed Forward
+        self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+        self.ff = FeedForward(
+            dim,
+            dropout=dropout,
+            activation_fn=activation_fn,
+            final_dropout=final_dropout,
+            inner_dim=ff_inner_dim,
+            bias=ff_bias,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: torch.Tensor,
+        temb: torch.Tensor,
+        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+    ) -> torch.Tensor:
+        text_seq_length = encoder_hidden_states.size(1)
+
+        # norm & modulate
+        norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+            hidden_states, encoder_hidden_states, temb
+        )
+
+        # attention
+        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+            hidden_states=norm_hidden_states,
+            encoder_hidden_states=norm_encoder_hidden_states,
+            image_rotary_emb=image_rotary_emb,
+        )
+
+        hidden_states = hidden_states + gate_msa * attn_hidden_states
+        encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+        # norm & modulate
+        norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+            hidden_states, encoder_hidden_states, temb
+        )
+
+        # feed-forward
+        norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+        ff_output = self.ff(norm_hidden_states)
+
+        hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+        encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+        return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
+    """
+    A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+    Parameters:
+        num_attention_heads (`int`, defaults to `30`):
+            The number of heads to use for multi-head attention.
+        attention_head_dim (`int`, defaults to `64`):
+            The number of channels in each head.
+        in_channels (`int`, defaults to `16`):
+            The number of channels in the input.
+        out_channels (`int`, *optional*, defaults to `16`):
+            The number of channels in the output.
+        flip_sin_to_cos (`bool`, defaults to `True`):
+            Whether to flip the sin to cos in the time embedding.
+        time_embed_dim (`int`, defaults to `512`):
+            Output dimension of timestep embeddings.
+        text_embed_dim (`int`, defaults to `4096`):
+            Input dimension of text embeddings from the text encoder.
+        num_layers (`int`, defaults to `30`):
+            The number of layers of Transformer blocks to use.
+        dropout (`float`, defaults to `0.0`):
+            The dropout probability to use.
+        attention_bias (`bool`, defaults to `True`):
+            Whether or not to use bias in the attention projection layers.
+        sample_width (`int`, defaults to `90`):
+            The width of the input latents.
+        sample_height (`int`, defaults to `60`):
+            The height of the input latents.
+        sample_frames (`int`, defaults to `49`):
+            The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+            instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+            but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+            K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+        patch_size (`int`, defaults to `2`):
+            The size of the patches to use in the patch embedding layer.
+        temporal_compression_ratio (`int`, defaults to `4`):
+            The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+        max_text_seq_length (`int`, defaults to `226`):
+            The maximum sequence length of the input text embeddings.
+        activation_fn (`str`, defaults to `"gelu-approximate"`):
+            Activation function to use in feed-forward.
+        timestep_activation_fn (`str`, defaults to `"silu"`):
+            Activation function to use when generating the timestep embeddings.
+        norm_elementwise_affine (`bool`, defaults to `True`):
+            Whether or not to use elementwise affine in normalization layers.
+        norm_eps (`float`, defaults to `1e-5`):
+            The epsilon value to use in normalization layers.
+        spatial_interpolation_scale (`float`, defaults to `1.875`):
+            Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+        temporal_interpolation_scale (`float`, defaults to `1.0`):
+            Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    @register_to_config
+    def __init__(
+        self,
+        num_attention_heads: int = 30,
+        attention_head_dim: int = 64,
+        in_channels: int = 16,
+        out_channels: Optional[int] = 16,
+        flip_sin_to_cos: bool = True,
+        freq_shift: int = 0,
+        time_embed_dim: int = 512,
+        text_embed_dim: int = 4096,
+        num_layers: int = 30,
+        dropout: float = 0.0,
+        attention_bias: bool = True,
+        sample_width: int = 90,
+        sample_height: int = 60,
+        sample_frames: int = 49,
+        patch_size: int = 2,
+        temporal_compression_ratio: int = 4,
+        max_text_seq_length: int = 226,
+        activation_fn: str = "gelu-approximate",
+        timestep_activation_fn: str = "silu",
+        norm_elementwise_affine: bool = True,
+        norm_eps: float = 1e-5,
+        spatial_interpolation_scale: float = 1.875,
+        temporal_interpolation_scale: float = 1.0,
+        use_rotary_positional_embeddings: bool = False,
+        add_noise_in_inpaint_model: bool = False,
+    ):
+        super().__init__()
+        inner_dim = num_attention_heads * attention_head_dim
+
+        post_patch_height = sample_height // patch_size
+        post_patch_width = sample_width // patch_size
+        post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
+        self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+        self.post_patch_height = post_patch_height
+        self.post_patch_width = post_patch_width
+        self.post_time_compression_frames = post_time_compression_frames
+        self.patch_size = patch_size
+
+        # 1. Patch embedding
+        self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
+        self.embedding_dropout = nn.Dropout(dropout)
+
+        # 2. 3D positional embeddings
+        spatial_pos_embedding = get_3d_sincos_pos_embed(
+            inner_dim,
+            (post_patch_width, post_patch_height),
+            post_time_compression_frames,
+            spatial_interpolation_scale,
+            temporal_interpolation_scale,
+        )
+        spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
+        pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
+        pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
+        self.register_buffer("pos_embedding", pos_embedding, persistent=False)
+
+        # 3. Time embeddings
+        self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+        self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+        # 4. Define spatio-temporal transformers blocks
+        self.transformer_blocks = nn.ModuleList(
+            [
+                CogVideoXBlock(
+                    dim=inner_dim,
+                    num_attention_heads=num_attention_heads,
+                    attention_head_dim=attention_head_dim,
+                    time_embed_dim=time_embed_dim,
+                    dropout=dropout,
+                    activation_fn=activation_fn,
+                    attention_bias=attention_bias,
+                    norm_elementwise_affine=norm_elementwise_affine,
+                    norm_eps=norm_eps,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+        self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+        # 5. Output blocks
+        self.norm_out = AdaLayerNorm(
+            embedding_dim=time_embed_dim,
+            output_dim=2 * inner_dim,
+            norm_elementwise_affine=norm_elementwise_affine,
+            norm_eps=norm_eps,
+            chunk_dim=1,
+        )
+        self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+        self.gradient_checkpointing = False
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        self.gradient_checkpointing = value
+
+    @property
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+    def attn_processors(self) -> Dict[str, AttentionProcessor]:
+        r"""
+        Returns:
+            `dict` of attention processors: A dictionary containing all attention processors used in the model with
+            indexed by its weight name.
+        """
+        # set recursively
+        processors = {}
+
+        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+            if hasattr(module, "get_processor"):
+                processors[f"{name}.processor"] = module.get_processor()
+
+            for sub_name, child in module.named_children():
+                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+            return processors
+
+        for name, module in self.named_children():
+            fn_recursive_add_processors(name, module, processors)
+
+        return processors
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+        r"""
+        Sets the attention processor to use to compute attention.
+
+        Parameters:
+            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+                The instantiated processor class or a dictionary of processor classes that will be set as the processor
+                for **all** `Attention` layers.
+
+                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+                processor. This is strongly recommended when setting trainable attention processors.
+
+        """
+        count = len(self.attn_processors.keys())
+
+        if isinstance(processor, dict) and len(processor) != count:
+            raise ValueError(
+                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+            )
+
+        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+            if hasattr(module, "set_processor"):
+                if not isinstance(processor, dict):
+                    module.set_processor(processor)
+                else:
+                    module.set_processor(processor.pop(f"{name}.processor"))
+
+            for sub_name, child in module.named_children():
+                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+        for name, module in self.named_children():
+            fn_recursive_attn_processor(name, module, processor)
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+    def fuse_qkv_projections(self):
+        """
+        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+        are fused. For cross-attention modules, key and value projection matrices are fused.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+        """
+        self.original_attn_processors = None
+
+        for _, attn_processor in self.attn_processors.items():
+            if "Added" in str(attn_processor.__class__.__name__):
+                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+        self.original_attn_processors = self.attn_processors
+
+        for module in self.modules():
+            if isinstance(module, Attention):
+                module.fuse_projections(fuse=True)
+
+        self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+    def unfuse_qkv_projections(self):
+        """Disables the fused QKV projection if enabled.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+
+        """
+        if self.original_attn_processors is not None:
+            self.set_attn_processor(self.original_attn_processors)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: torch.Tensor,
+        timestep: Union[int, float, torch.LongTensor],
+        timestep_cond: Optional[torch.Tensor] = None,
+        inpaint_latents: Optional[torch.Tensor] = None,
+        control_latents: Optional[torch.Tensor] = None,
+        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        return_dict: bool = True,
+    ):
+        batch_size, num_frames, channels, height, width = hidden_states.shape
+
+        # 1. Time embedding
+        timesteps = timestep
+        t_emb = self.time_proj(timesteps)
+
+        # timesteps does not contain any weights and will always return f32 tensors
+        # but time_embedding might actually be running in fp16. so we need to cast here.
+        # there might be better ways to encapsulate this.
+        t_emb = t_emb.to(dtype=hidden_states.dtype)
+        emb = self.time_embedding(t_emb, timestep_cond)
+
+        # 2. Patch embedding
+        if inpaint_latents is not None:
+            hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
+        if control_latents is not None:
+            hidden_states = torch.concat([hidden_states, control_latents], 2)
+        hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+
+        # 3. Position embedding
+        text_seq_length = encoder_hidden_states.shape[1]
+        if not self.config.use_rotary_positional_embeddings:
+            seq_length = height * width * num_frames // (self.config.patch_size**2)
+            # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
+            pos_embeds = self.pos_embedding
+            emb_size = hidden_states.size()[-1]
+            pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
+            pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
+            pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.config.patch_size, width // self.config.patch_size],mode='trilinear',align_corners=False)
+            pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
+            pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
+            pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
+            hidden_states = hidden_states + pos_embeds
+            hidden_states = self.embedding_dropout(hidden_states)
+
+        encoder_hidden_states = hidden_states[:, :text_seq_length]
+        hidden_states = hidden_states[:, text_seq_length:]
+
+        # 4. Transformer blocks
+        for i, block in enumerate(self.transformer_blocks):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(block),
+                    hidden_states,
+                    encoder_hidden_states,
+                    emb,
+                    image_rotary_emb,
+                    **ckpt_kwargs,
+                )
+            else:
+                hidden_states, encoder_hidden_states = block(
+                    hidden_states=hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    temb=emb,
+                    image_rotary_emb=image_rotary_emb,
+                )
+
+        if not self.config.use_rotary_positional_embeddings:
+            # CogVideoX-2B
+            hidden_states = self.norm_final(hidden_states)
+        else:
+            # CogVideoX-5B
+            hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+            hidden_states = self.norm_final(hidden_states)
+            hidden_states = hidden_states[:, text_seq_length:]
+
+        # 5. Final block
+        hidden_states = self.norm_out(hidden_states, temb=emb)
+        hidden_states = self.proj_out(hidden_states)
+
+        # 6. Unpatchify
+        p = self.config.patch_size
+        output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
+        output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+        if not return_dict:
+            return (output,)
+        return Transformer2DModelOutput(sample=output)
+
+    @classmethod
+    def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}):
+        if subfolder is not None:
+            pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+        print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
+
+        config_file = os.path.join(pretrained_model_path, 'config.json')
+        if not os.path.isfile(config_file):
+            raise RuntimeError(f"{config_file} does not exist")
+        with open(config_file, "r") as f:
+            config = json.load(f)
+
+        from diffusers.utils import WEIGHTS_NAME
+        model = cls.from_config(config, **transformer_additional_kwargs)
+        model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+        model_file_safetensors = model_file.replace(".bin", ".safetensors")
+        if os.path.exists(model_file):
+            state_dict = torch.load(model_file, map_location="cpu")
+        elif os.path.exists(model_file_safetensors):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(model_file_safetensors)
+        else:
+            from safetensors.torch import load_file, safe_open
+            model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
+            state_dict = {}
+            for model_file_safetensors in model_files_safetensors:
+                _state_dict = load_file(model_file_safetensors)
+                for key in _state_dict:
+                    state_dict[key] = _state_dict[key]
+        
+        if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
+            new_shape   = model.state_dict()['patch_embed.proj.weight'].size()
+            if len(new_shape) == 5:
+                state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
+                state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
+            else:
+                if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
+                    model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
+                    model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
+                    state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+                else:
+                    model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
+                    state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
+
+        tmp_state_dict = {} 
+        for key in state_dict:
+            if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
+                tmp_state_dict[key] = state_dict[key]
+            else:
+                print(key, "Size don't match, skip")
+        state_dict = tmp_state_dict
+
+        m, u = model.load_state_dict(state_dict, strict=False)
+        print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
+        print(m)
+        
+        params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
+        print(f"### Mamba Parameters: {sum(params) / 1e6} M")
+
+        params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
+        print(f"### attn1 Parameters: {sum(params) / 1e6} M")
+        
+        return model
\ No newline at end of file
diff --git a/cogvideox/pipeline/pipeline_cogvideox.py b/cogvideox/pipeline/pipeline_cogvideox.py
new file mode 100644
index 0000000000000000000000000000000000000000..81da03c19b450d5b6c5fc99417ca26023acfbae8
--- /dev/null
+++ b/cogvideox/pipeline/pipeline_cogvideox.py
@@ -0,0 +1,751 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```python
+        >>> import torch
+        >>> from diffusers import CogVideoX_Fun_Pipeline
+        >>> from diffusers.utils import export_to_video
+
+        >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+        >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+        >>> prompt = (
+        ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+        ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+        ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+        ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+        ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+        ...     "atmosphere of this unique musical performance."
+        ... )
+        >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+        >>> export_to_video(video, "output.mp4", fps=8)
+        ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+    tw = tgt_width
+    th = tgt_height
+    h, w = src
+    r = h / w
+    if r > (th / tw):
+        resize_height = th
+        resize_width = int(round(th / h * w))
+    else:
+        resize_width = tw
+        resize_height = int(round(tw / w * h))
+
+    crop_top = int(round((th - resize_height) / 2.0))
+    crop_left = int(round((tw - resize_width) / 2.0))
+
+    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+@dataclass
+class CogVideoX_Fun_PipelineOutput(BaseOutput):
+    r"""
+    Output class for CogVideo pipelines.
+
+    Args:
+        video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+            List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+            denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+            `(batch_size, num_frames, channels, height, width)`.
+    """
+
+    videos: torch.Tensor
+
+
+class CogVideoX_Fun_Pipeline(DiffusionPipeline):
+    r"""
+    Pipeline for text-to-video generation using CogVideoX_Fun.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+        text_encoder ([`T5EncoderModel`]):
+            Frozen text-encoder. CogVideoX uses
+            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+        tokenizer (`T5Tokenizer`):
+            Tokenizer of class
+            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+        transformer ([`CogVideoXTransformer3DModel`]):
+            A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+    """
+
+    _optional_components = []
+    model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+    ]
+
+    def __init__(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        vae: AutoencoderKLCogVideoX,
+        transformer: CogVideoXTransformer3DModel,
+        scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+    ):
+        super().__init__()
+
+        self.register_modules(
+            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+        )
+        self.vae_scale_factor_spatial = (
+            2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+        )
+        self.vae_scale_factor_temporal = (
+            self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+        )
+
+        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+    def _get_t5_prompt_embeds(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_videos_per_prompt: int = 1,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        device = device or self._execution_device
+        dtype = dtype or self.text_encoder.dtype
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        batch_size = len(prompt)
+
+        text_inputs = self.tokenizer(
+            prompt,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            add_special_tokens=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+            logger.warning(
+                "The following part of your input was truncated because `max_sequence_length` is set to "
+                f" {max_sequence_length} tokens: {removed_text}"
+            )
+
+        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        _, seq_len, _ = prompt_embeds.shape
+        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        do_classifier_free_guidance: bool = True,
+        num_videos_per_prompt: int = 1,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+                Whether to use classifier free guidance or not.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            device: (`torch.device`, *optional*):
+                torch device
+            dtype: (`torch.dtype`, *optional*):
+                torch dtype
+        """
+        device = device or self._execution_device
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+
+            negative_prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=negative_prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_latents(
+        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+    ):
+        shape = (
+            batch_size,
+            (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+            num_channels_latents,
+            height // self.vae_scale_factor_spatial,
+            width // self.vae_scale_factor_spatial,
+        )
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
+        latents = 1 / self.vae.config.scaling_factor * latents
+
+        frames = self.vae.decode(latents).sample
+        frames = (frames / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+        frames = frames.cpu().float().numpy()
+        return frames
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        negative_prompt,
+        callback_on_step_end_tensor_inputs,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+    def fuse_qkv_projections(self) -> None:
+        r"""Enables fused QKV projections."""
+        self.fusing_transformer = True
+        self.transformer.fuse_qkv_projections()
+
+    def unfuse_qkv_projections(self) -> None:
+        r"""Disable QKV projection fusion if enabled."""
+        if not self.fusing_transformer:
+            logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+        else:
+            self.transformer.unfuse_qkv_projections()
+            self.fusing_transformer = False
+
+    def _prepare_rotary_positional_embeddings(
+        self,
+        height: int,
+        width: int,
+        num_frames: int,
+        device: torch.device,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+        grid_crops_coords = get_resize_crop_region_for_grid(
+            (grid_height, grid_width), base_size_width, base_size_height
+        )
+        freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+            embed_dim=self.transformer.config.attention_head_dim,
+            crops_coords=grid_crops_coords,
+            grid_size=(grid_height, grid_width),
+            temporal_size=num_frames,
+            use_real=True,
+        )
+
+        freqs_cos = freqs_cos.to(device=device)
+        freqs_sin = freqs_sin.to(device=device)
+        return freqs_cos, freqs_sin
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: int = 480,
+        width: int = 720,
+        num_frames: int = 49,
+        num_inference_steps: int = 50,
+        timesteps: Optional[List[int]] = None,
+        guidance_scale: float = 6,
+        use_dynamic_cfg: bool = False,
+        num_videos_per_prompt: int = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: str = "numpy",
+        return_dict: bool = False,
+        callback_on_step_end: Optional[
+            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+        ] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        max_sequence_length: int = 226,
+    ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
+        """
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+            num_frames (`int`, defaults to `48`):
+                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+                contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+                num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+                needs to be satisfied is that of divisibility mentioned above.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            guidance_scale (`float`, *optional*, defaults to 7.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                The number of videos to generate per prompt.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+            max_sequence_length (`int`, defaults to `226`):
+                Maximum sequence length in encoded prompt. Must be consistent with
+                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        if num_frames > 49:
+            raise ValueError(
+                "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+            )
+
+        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+        height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+        width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            callback_on_step_end_tensor_inputs,
+            prompt_embeds,
+            negative_prompt_embeds,
+        )
+        self._guidance_scale = guidance_scale
+        self._interrupt = False
+
+        # 2. Default call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            negative_prompt,
+            do_classifier_free_guidance,
+            num_videos_per_prompt=num_videos_per_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            max_sequence_length=max_sequence_length,
+            device=device,
+        )
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+        # 4. Prepare timesteps
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        self._num_timesteps = len(timesteps)
+
+        # 5. Prepare latents.
+        latent_channels = self.transformer.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_videos_per_prompt,
+            latent_channels,
+            num_frames,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Create rotary embeds if required
+        image_rotary_emb = (
+            self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+            if self.transformer.config.use_rotary_positional_embeddings
+            else None
+        )
+
+        # 8. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            # for DPM-solver++
+            old_pred_original_sample = None
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+                timestep = t.expand(latent_model_input.shape[0])
+
+                # predict noise model_output
+                noise_pred = self.transformer(
+                    hidden_states=latent_model_input,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timestep,
+                    image_rotary_emb=image_rotary_emb,
+                    return_dict=False,
+                )[0]
+                noise_pred = noise_pred.float()
+
+                # perform guidance
+                if use_dynamic_cfg:
+                    self._guidance_scale = 1 + guidance_scale * (
+                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+                    )
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                else:
+                    latents, old_pred_original_sample = self.scheduler.step(
+                        noise_pred,
+                        old_pred_original_sample,
+                        t,
+                        timesteps[i - 1] if i > 0 else None,
+                        latents,
+                        **extra_step_kwargs,
+                        return_dict=False,
+                    )
+                latents = latents.to(prompt_embeds.dtype)
+
+                # call the callback, if provided
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+
+        if output_type == "numpy":
+            video = self.decode_latents(latents)
+        elif not output_type == "latent":
+            video = self.decode_latents(latents)
+            video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+        else:
+            video = latents
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            video = torch.from_numpy(video)
+
+        return CogVideoX_Fun_PipelineOutput(videos=video)
diff --git a/cogvideox/pipeline/pipeline_cogvideox_control.py b/cogvideox/pipeline/pipeline_cogvideox_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..6461aae7858f4cafc399f608a106110a8360a7e3
--- /dev/null
+++ b/cogvideox/pipeline/pipeline_cogvideox_control.py
@@ -0,0 +1,843 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.image_processor import VaeImageProcessor
+from einops import rearrange
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```python
+        >>> import torch
+        >>> from diffusers import CogVideoX_Fun_Pipeline
+        >>> from diffusers.utils import export_to_video
+
+        >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+        >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+        >>> prompt = (
+        ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+        ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+        ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+        ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+        ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+        ...     "atmosphere of this unique musical performance."
+        ... )
+        >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+        >>> export_to_video(video, "output.mp4", fps=8)
+        ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+    tw = tgt_width
+    th = tgt_height
+    h, w = src
+    r = h / w
+    if r > (th / tw):
+        resize_height = th
+        resize_width = int(round(th / h * w))
+    else:
+        resize_width = tw
+        resize_height = int(round(tw / w * h))
+
+    crop_top = int(round((th - resize_height) / 2.0))
+    crop_left = int(round((tw - resize_width) / 2.0))
+
+    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+@dataclass
+class CogVideoX_Fun_PipelineOutput(BaseOutput):
+    r"""
+    Output class for CogVideo pipelines.
+
+    Args:
+        video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+            List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+            denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+            `(batch_size, num_frames, channels, height, width)`.
+    """
+
+    videos: torch.Tensor
+
+
+class CogVideoX_Fun_Pipeline_Control(DiffusionPipeline):
+    r"""
+    Pipeline for text-to-video generation using CogVideoX.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+        text_encoder ([`T5EncoderModel`]):
+            Frozen text-encoder. CogVideoX_Fun uses
+            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+        tokenizer (`T5Tokenizer`):
+            Tokenizer of class
+            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+        transformer ([`CogVideoXTransformer3DModel`]):
+            A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+    """
+
+    _optional_components = []
+    model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+    ]
+
+    def __init__(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        vae: AutoencoderKLCogVideoX,
+        transformer: CogVideoXTransformer3DModel,
+        scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+    ):
+        super().__init__()
+
+        self.register_modules(
+            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+        )
+        self.vae_scale_factor_spatial = (
+            2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+        )
+        self.vae_scale_factor_temporal = (
+            self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+        )
+
+        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.mask_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+        )
+
+    def _get_t5_prompt_embeds(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_videos_per_prompt: int = 1,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        device = device or self._execution_device
+        dtype = dtype or self.text_encoder.dtype
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        batch_size = len(prompt)
+
+        text_inputs = self.tokenizer(
+            prompt,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            add_special_tokens=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+            logger.warning(
+                "The following part of your input was truncated because `max_sequence_length` is set to "
+                f" {max_sequence_length} tokens: {removed_text}"
+            )
+
+        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        _, seq_len, _ = prompt_embeds.shape
+        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        do_classifier_free_guidance: bool = True,
+        num_videos_per_prompt: int = 1,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+                Whether to use classifier free guidance or not.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            device: (`torch.device`, *optional*):
+                torch device
+            dtype: (`torch.dtype`, *optional*):
+                torch dtype
+        """
+        device = device or self._execution_device
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+
+            negative_prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=negative_prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_latents(
+        self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+    ):
+        shape = (
+            batch_size,
+            (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+            num_channels_latents,
+            height // self.vae_scale_factor_spatial,
+            width // self.vae_scale_factor_spatial,
+        )
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if latents is None:
+            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+        else:
+            latents = latents.to(device)
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        latents = latents * self.scheduler.init_noise_sigma
+        return latents
+
+    def prepare_control_latents(
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+    ):
+        # resize the mask to latents shape as we concatenate the mask to the latents
+        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+        # and half precision
+
+        if mask is not None:
+            mask = mask.to(device=device, dtype=self.vae.dtype)
+            bs = 1
+            new_mask = []
+            for i in range(0, mask.shape[0], bs):
+                mask_bs = mask[i : i + bs]
+                mask_bs = self.vae.encode(mask_bs)[0]
+                mask_bs = mask_bs.mode()
+                new_mask.append(mask_bs)
+            mask = torch.cat(new_mask, dim = 0)
+            mask = mask * self.vae.config.scaling_factor
+
+        if masked_image is not None:
+            masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
+            bs = 1
+            new_mask_pixel_values = []
+            for i in range(0, masked_image.shape[0], bs):
+                mask_pixel_values_bs = masked_image[i : i + bs]
+                mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+                mask_pixel_values_bs = mask_pixel_values_bs.mode()
+                new_mask_pixel_values.append(mask_pixel_values_bs)
+            masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+            masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+        else:
+            masked_image_latents = None
+
+        return mask, masked_image_latents
+
+    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
+        latents = 1 / self.vae.config.scaling_factor * latents
+
+        frames = self.vae.decode(latents).sample
+        frames = (frames / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+        frames = frames.cpu().float().numpy()
+        return frames
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        negative_prompt,
+        callback_on_step_end_tensor_inputs,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+    def fuse_qkv_projections(self) -> None:
+        r"""Enables fused QKV projections."""
+        self.fusing_transformer = True
+        self.transformer.fuse_qkv_projections()
+
+    def unfuse_qkv_projections(self) -> None:
+        r"""Disable QKV projection fusion if enabled."""
+        if not self.fusing_transformer:
+            logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+        else:
+            self.transformer.unfuse_qkv_projections()
+            self.fusing_transformer = False
+
+    def _prepare_rotary_positional_embeddings(
+        self,
+        height: int,
+        width: int,
+        num_frames: int,
+        device: torch.device,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+        grid_crops_coords = get_resize_crop_region_for_grid(
+            (grid_height, grid_width), base_size_width, base_size_height
+        )
+        freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+            embed_dim=self.transformer.config.attention_head_dim,
+            crops_coords=grid_crops_coords,
+            grid_size=(grid_height, grid_width),
+            temporal_size=num_frames,
+            use_real=True,
+        )
+
+        freqs_cos = freqs_cos.to(device=device)
+        freqs_sin = freqs_sin.to(device=device)
+        return freqs_cos, freqs_sin
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device):
+        # get the original timestep using init_timestep
+        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+        t_start = max(num_inference_steps - init_timestep, 0)
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+        return timesteps, num_inference_steps - t_start
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: int = 480,
+        width: int = 720,
+        video: Union[torch.FloatTensor] = None,
+        control_video: Union[torch.FloatTensor] = None,
+        num_frames: int = 49,
+        num_inference_steps: int = 50,
+        timesteps: Optional[List[int]] = None,
+        guidance_scale: float = 6,
+        use_dynamic_cfg: bool = False,
+        num_videos_per_prompt: int = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: str = "numpy",
+        return_dict: bool = False,
+        callback_on_step_end: Optional[
+            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+        ] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        max_sequence_length: int = 226,
+        comfyui_progressbar: bool = False,
+    ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
+        """
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+            num_frames (`int`, defaults to `48`):
+                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+                contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
+                num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+                needs to be satisfied is that of divisibility mentioned above.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            guidance_scale (`float`, *optional*, defaults to 7.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                The number of videos to generate per prompt.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+            max_sequence_length (`int`, defaults to `226`):
+                Maximum sequence length in encoded prompt. Must be consistent with
+                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        if num_frames > 49:
+            raise ValueError(
+                "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+            )
+
+        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+        height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+        width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            callback_on_step_end_tensor_inputs,
+            prompt_embeds,
+            negative_prompt_embeds,
+        )
+        self._guidance_scale = guidance_scale
+        self._interrupt = False
+
+        # 2. Default call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            negative_prompt,
+            do_classifier_free_guidance,
+            num_videos_per_prompt=num_videos_per_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            max_sequence_length=max_sequence_length,
+            device=device,
+        )
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+        # 4. Prepare timesteps
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        self._num_timesteps = len(timesteps)
+        if comfyui_progressbar:
+            from comfy.utils import ProgressBar
+            pbar = ProgressBar(num_inference_steps + 2)
+
+        # 5. Prepare latents.
+        latent_channels = self.vae.config.latent_channels
+        latents = self.prepare_latents(
+            batch_size * num_videos_per_prompt,
+            latent_channels,
+            num_frames,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+        if comfyui_progressbar:
+            pbar.update(1)
+
+        if control_video is not None:
+            video_length = control_video.shape[2]
+            control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
+            control_video = control_video.to(dtype=torch.float32)
+            control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
+        else:
+            control_video = None
+        control_video_latents = self.prepare_control_latents(
+            None,
+            control_video,
+            batch_size,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            do_classifier_free_guidance
+        )[1]
+        control_video_latents_input = (
+            torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
+        )
+        control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
+
+        if comfyui_progressbar:
+            pbar.update(1)
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Create rotary embeds if required
+        image_rotary_emb = (
+            self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+            if self.transformer.config.use_rotary_positional_embeddings
+            else None
+        )
+
+        # 8. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            # for DPM-solver++
+            old_pred_original_sample = None
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+                timestep = t.expand(latent_model_input.shape[0])
+
+                # predict noise model_output
+                noise_pred = self.transformer(
+                    hidden_states=latent_model_input,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timestep,
+                    image_rotary_emb=image_rotary_emb,
+                    return_dict=False,
+                    control_latents=control_latents,
+                )[0]
+                noise_pred = noise_pred.float()
+
+                # perform guidance
+                if use_dynamic_cfg:
+                    self._guidance_scale = 1 + guidance_scale * (
+                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+                    )
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                else:
+                    latents, old_pred_original_sample = self.scheduler.step(
+                        noise_pred,
+                        old_pred_original_sample,
+                        t,
+                        timesteps[i - 1] if i > 0 else None,
+                        latents,
+                        **extra_step_kwargs,
+                        return_dict=False,
+                    )
+                latents = latents.to(prompt_embeds.dtype)
+
+                # call the callback, if provided
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                if comfyui_progressbar:
+                    pbar.update(1)
+
+        if output_type == "numpy":
+            video = self.decode_latents(latents)
+        elif not output_type == "latent":
+            video = self.decode_latents(latents)
+            video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+        else:
+            video = latents
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            video = torch.from_numpy(video)
+
+        return CogVideoX_Fun_PipelineOutput(videos=video)
diff --git a/cogvideox/pipeline/pipeline_cogvideox_inpaint.py b/cogvideox/pipeline/pipeline_cogvideox_inpaint.py
new file mode 100644
index 0000000000000000000000000000000000000000..75deae90a2d23bd4e821239b04c10a86f9b55365
--- /dev/null
+++ b/cogvideox/pipeline/pipeline_cogvideox_inpaint.py
@@ -0,0 +1,1020 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.image_processor import VaeImageProcessor
+from einops import rearrange
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```python
+        >>> import torch
+        >>> from diffusers import CogVideoX_Fun_Pipeline
+        >>> from diffusers.utils import export_to_video
+
+        >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+        >>> pipe = CogVideoX_Fun_Pipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+        >>> prompt = (
+        ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+        ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+        ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+        ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+        ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+        ...     "atmosphere of this unique musical performance."
+        ... )
+        >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+        >>> export_to_video(video, "output.mp4", fps=8)
+        ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+    tw = tgt_width
+    th = tgt_height
+    h, w = src
+    r = h / w
+    if r > (th / tw):
+        resize_height = th
+        resize_width = int(round(th / h * w))
+    else:
+        resize_width = tw
+        resize_height = int(round(tw / w * h))
+
+    crop_top = int(round((th - resize_height) / 2.0))
+    crop_left = int(round((tw - resize_width) / 2.0))
+
+    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+    latent_size = latent.size()
+    batch_size, channels, num_frames, height, width = mask.shape
+
+    if process_first_frame_only:
+        target_size = list(latent_size[2:])
+        target_size[0] = 1
+        first_frame_resized = F.interpolate(
+            mask[:, :, 0:1, :, :],
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+        
+        target_size = list(latent_size[2:])
+        target_size[0] = target_size[0] - 1
+        if target_size[0] != 0:
+            remaining_frames_resized = F.interpolate(
+                mask[:, :, 1:, :, :],
+                size=target_size,
+                mode='trilinear',
+                align_corners=False
+            )
+            resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+        else:
+            resized_mask = first_frame_resized
+    else:
+        target_size = list(latent_size[2:])
+        resized_mask = F.interpolate(
+            mask,
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+    return resized_mask
+
+
+def add_noise_to_reference_video(image, ratio=None):
+    if ratio is None:
+        sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
+        sigma = torch.exp(sigma).to(image.dtype)
+    else:
+        sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
+    
+    image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
+    image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
+    image = image + image_noise
+    return image
+
+
+@dataclass
+class CogVideoX_Fun_PipelineOutput(BaseOutput):
+    r"""
+    Output class for CogVideo pipelines.
+
+    Args:
+        video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+            List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+            denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+            `(batch_size, num_frames, channels, height, width)`.
+    """
+
+    videos: torch.Tensor
+
+
+class CogVideoX_Fun_Pipeline_Inpaint(DiffusionPipeline):
+    r"""
+    Pipeline for text-to-video generation using CogVideoX.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+        text_encoder ([`T5EncoderModel`]):
+            Frozen text-encoder. CogVideoX_Fun uses
+            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+        tokenizer (`T5Tokenizer`):
+            Tokenizer of class
+            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+        transformer ([`CogVideoXTransformer3DModel`]):
+            A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+    """
+
+    _optional_components = []
+    model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+    ]
+
+    def __init__(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        vae: AutoencoderKLCogVideoX,
+        transformer: CogVideoXTransformer3DModel,
+        scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+    ):
+        super().__init__()
+
+        self.register_modules(
+            tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+        )
+        self.vae_scale_factor_spatial = (
+            2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+        )
+        self.vae_scale_factor_temporal = (
+            self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+        )
+
+        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.mask_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+        )
+
+    def _get_t5_prompt_embeds(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_videos_per_prompt: int = 1,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        device = device or self._execution_device
+        dtype = dtype or self.text_encoder.dtype
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        batch_size = len(prompt)
+
+        text_inputs = self.tokenizer(
+            prompt,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            add_special_tokens=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+            logger.warning(
+                "The following part of your input was truncated because `max_sequence_length` is set to "
+                f" {max_sequence_length} tokens: {removed_text}"
+            )
+
+        prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        _, seq_len, _ = prompt_embeds.shape
+        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        do_classifier_free_guidance: bool = True,
+        num_videos_per_prompt: int = 1,
+        prompt_embeds: Optional[torch.Tensor] = None,
+        negative_prompt_embeds: Optional[torch.Tensor] = None,
+        max_sequence_length: int = 226,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+                Whether to use classifier free guidance or not.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+            prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.Tensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            device: (`torch.device`, *optional*):
+                torch device
+            dtype: (`torch.dtype`, *optional*):
+                torch dtype
+        """
+        device = device or self._execution_device
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        if do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+
+            negative_prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=negative_prompt,
+                num_videos_per_prompt=num_videos_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+                dtype=dtype,
+            )
+
+        return prompt_embeds, negative_prompt_embeds
+
+    def prepare_latents(
+        self, 
+        batch_size,
+        num_channels_latents,
+        height,
+        width,
+        video_length,
+        dtype,
+        device,
+        generator,
+        latents=None,
+        video=None,
+        timestep=None,
+        is_strength_max=True,
+        return_noise=False,
+        return_video_latents=False,
+    ):
+        shape = (
+            batch_size,
+            (video_length - 1) // self.vae_scale_factor_temporal + 1,
+            num_channels_latents,
+            height // self.vae_scale_factor_spatial,
+            width // self.vae_scale_factor_spatial,
+        )
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if return_video_latents or (latents is None and not is_strength_max):
+            video = video.to(device=device, dtype=self.vae.dtype)
+            
+            bs = 1
+            new_video = []
+            for i in range(0, video.shape[0], bs):
+                video_bs = video[i : i + bs]
+                video_bs = self.vae.encode(video_bs)[0]
+                video_bs = video_bs.sample()
+                new_video.append(video_bs)
+            video = torch.cat(new_video, dim = 0)
+            video = video * self.vae.config.scaling_factor
+
+            video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
+            video_latents = video_latents.to(device=device, dtype=dtype)
+            video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
+
+        if latents is None:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else initial to image + noise
+            latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
+            # if pure noise then scale the initial latents by the  Scheduler's init sigma
+            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+        else:
+            noise = latents.to(device)
+            latents = noise * self.scheduler.init_noise_sigma
+
+        # scale the initial noise by the standard deviation required by the scheduler
+        outputs = (latents,)
+
+        if return_noise:
+            outputs += (noise,)
+
+        if return_video_latents:
+            outputs += (video_latents,)
+
+        return outputs
+
+    def prepare_mask_latents(
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
+    ):
+        # resize the mask to latents shape as we concatenate the mask to the latents
+        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+        # and half precision
+
+        if mask is not None:
+            mask = mask.to(device=device, dtype=self.vae.dtype)
+            bs = 1
+            new_mask = []
+            for i in range(0, mask.shape[0], bs):
+                mask_bs = mask[i : i + bs]
+                mask_bs = self.vae.encode(mask_bs)[0]
+                mask_bs = mask_bs.mode()
+                new_mask.append(mask_bs)
+            mask = torch.cat(new_mask, dim = 0)
+            mask = mask * self.vae.config.scaling_factor
+
+        if masked_image is not None:
+            if self.transformer.config.add_noise_in_inpaint_model:
+                masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
+            masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
+            bs = 1
+            new_mask_pixel_values = []
+            for i in range(0, masked_image.shape[0], bs):
+                mask_pixel_values_bs = masked_image[i : i + bs]
+                mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
+                mask_pixel_values_bs = mask_pixel_values_bs.mode()
+                new_mask_pixel_values.append(mask_pixel_values_bs)
+            masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
+            masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+        else:
+            masked_image_latents = None
+
+        return mask, masked_image_latents
+
+    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
+        latents = 1 / self.vae.config.scaling_factor * latents
+
+        frames = self.vae.decode(latents).sample
+        frames = (frames / 2 + 0.5).clamp(0, 1)
+        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+        frames = frames.cpu().float().numpy()
+        return frames
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        negative_prompt,
+        callback_on_step_end_tensor_inputs,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+    ):
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+
+    def fuse_qkv_projections(self) -> None:
+        r"""Enables fused QKV projections."""
+        self.fusing_transformer = True
+        self.transformer.fuse_qkv_projections()
+
+    def unfuse_qkv_projections(self) -> None:
+        r"""Disable QKV projection fusion if enabled."""
+        if not self.fusing_transformer:
+            logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+        else:
+            self.transformer.unfuse_qkv_projections()
+            self.fusing_transformer = False
+
+    def _prepare_rotary_positional_embeddings(
+        self,
+        height: int,
+        width: int,
+        num_frames: int,
+        device: torch.device,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+        base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+        grid_crops_coords = get_resize_crop_region_for_grid(
+            (grid_height, grid_width), base_size_width, base_size_height
+        )
+        freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+            embed_dim=self.transformer.config.attention_head_dim,
+            crops_coords=grid_crops_coords,
+            grid_size=(grid_height, grid_width),
+            temporal_size=num_frames,
+            use_real=True,
+        )
+
+        freqs_cos = freqs_cos.to(device=device)
+        freqs_sin = freqs_sin.to(device=device)
+        return freqs_cos, freqs_sin
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device):
+        # get the original timestep using init_timestep
+        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+        t_start = max(num_inference_steps - init_timestep, 0)
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+        return timesteps, num_inference_steps - t_start
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        height: int = 480,
+        width: int = 720,
+        video: Union[torch.FloatTensor] = None,
+        mask_video: Union[torch.FloatTensor] = None,
+        masked_video_latents: Union[torch.FloatTensor] = None,
+        num_frames: int = 49,
+        num_inference_steps: int = 50,
+        timesteps: Optional[List[int]] = None,
+        guidance_scale: float = 6,
+        use_dynamic_cfg: bool = False,
+        num_videos_per_prompt: int = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: str = "numpy",
+        return_dict: bool = False,
+        callback_on_step_end: Optional[
+            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+        ] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        max_sequence_length: int = 226,
+        strength: float = 1,
+        noise_aug_strength: float = 0.0563,
+        comfyui_progressbar: bool = False,
+    ) -> Union[CogVideoX_Fun_PipelineOutput, Tuple]:
+        """
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+            num_frames (`int`, defaults to `48`):
+                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+                contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
+                num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+                needs to be satisfied is that of divisibility mentioned above.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            guidance_scale (`float`, *optional*, defaults to 7.0):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_videos_per_prompt (`int`, *optional*, defaults to 1):
+                The number of videos to generate per prompt.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+                of a plain tuple.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+            max_sequence_length (`int`, defaults to `226`):
+                Maximum sequence length in encoded prompt. Must be consistent with
+                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] or `tuple`:
+            [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoX_Fun_PipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple`. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        if num_frames > 49:
+            raise ValueError(
+                "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+            )
+
+        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+        height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+        width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+        num_videos_per_prompt = 1
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt,
+            callback_on_step_end_tensor_inputs,
+            prompt_embeds,
+            negative_prompt_embeds,
+        )
+        self._guidance_scale = guidance_scale
+        self._interrupt = False
+
+        # 2. Default call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+        # corresponds to doing no classifier free guidance.
+        do_classifier_free_guidance = guidance_scale > 1.0
+
+        # 3. Encode input prompt
+        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+            prompt,
+            negative_prompt,
+            do_classifier_free_guidance,
+            num_videos_per_prompt=num_videos_per_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            max_sequence_length=max_sequence_length,
+            device=device,
+        )
+        if do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+        # 4. set timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps, num_inference_steps = self.get_timesteps(
+            num_inference_steps=num_inference_steps, strength=strength, device=device
+        )
+        self._num_timesteps = len(timesteps)
+        if comfyui_progressbar:
+            from comfy.utils import ProgressBar
+            pbar = ProgressBar(num_inference_steps + 2)
+        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        is_strength_max = strength == 1.0
+
+        # 5. Prepare latents.
+        if video is not None:
+            video_length = video.shape[2]
+            init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) 
+            init_video = init_video.to(dtype=torch.float32)
+            init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
+        else:
+            init_video = None
+
+        num_channels_latents = self.vae.config.latent_channels
+        num_channels_transformer = self.transformer.config.in_channels
+        return_image_latents = num_channels_transformer == num_channels_latents
+
+        latents_outputs = self.prepare_latents(
+            batch_size * num_videos_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            video_length,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+            video=init_video,
+            timestep=latent_timestep,
+            is_strength_max=is_strength_max,
+            return_noise=True,
+            return_video_latents=return_image_latents,
+        )
+        if return_image_latents:
+            latents, noise, image_latents = latents_outputs
+        else:
+            latents, noise = latents_outputs
+        if comfyui_progressbar:
+            pbar.update(1)
+        
+        if mask_video is not None:
+            if (mask_video == 255).all():
+                mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
+                masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
+
+                mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+                masked_video_latents_input = (
+                    torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+                )
+                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
+            else:
+                # Prepare mask latent variables
+                video_length = video.shape[2]
+                mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
+                mask_condition = mask_condition.to(dtype=torch.float32)
+                mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
+
+                if num_channels_transformer != num_channels_latents:
+                    mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
+                    if masked_video_latents is None:
+                        masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
+                    else:
+                        masked_video = masked_video_latents
+
+                    _, masked_video_latents = self.prepare_mask_latents(
+                        None,
+                        masked_video,
+                        batch_size,
+                        height,
+                        width,
+                        prompt_embeds.dtype,
+                        device,
+                        generator,
+                        do_classifier_free_guidance,
+                        noise_aug_strength=noise_aug_strength,
+                    )
+                    mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
+                    mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
+
+                    mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
+                    mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+                    
+                    mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
+                    masked_video_latents_input = (
+                        torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+                    )
+
+                    mask = rearrange(mask, "b c f h w -> b f c h w")
+                    mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
+                    masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
+
+                    inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
+                else:
+                    mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
+                    mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+                    mask = rearrange(mask, "b c f h w -> b f c h w")
+                    
+                    inpaint_latents = None
+        else:
+            if num_channels_transformer != num_channels_latents:
+                mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
+                masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
+
+                mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+                masked_video_latents_input = (
+                    torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
+                )
+                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
+            else:
+                mask = torch.zeros_like(init_video[:, :1])
+                mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
+                mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
+                mask = rearrange(mask, "b c f h w -> b f c h w")
+
+                inpaint_latents = None
+        if comfyui_progressbar:
+            pbar.update(1)
+
+        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7. Create rotary embeds if required
+        image_rotary_emb = (
+            self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+            if self.transformer.config.use_rotary_positional_embeddings
+            else None
+        )
+
+        # 8. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            # for DPM-solver++
+            old_pred_original_sample = None
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+                timestep = t.expand(latent_model_input.shape[0])
+
+                # predict noise model_output
+                noise_pred = self.transformer(
+                    hidden_states=latent_model_input,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timestep,
+                    image_rotary_emb=image_rotary_emb,
+                    return_dict=False,
+                    inpaint_latents=inpaint_latents,
+                )[0]
+                noise_pred = noise_pred.float()
+
+                # perform guidance
+                if use_dynamic_cfg:
+                    self._guidance_scale = 1 + guidance_scale * (
+                        (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+                    )
+                if do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                else:
+                    latents, old_pred_original_sample = self.scheduler.step(
+                        noise_pred,
+                        old_pred_original_sample,
+                        t,
+                        timesteps[i - 1] if i > 0 else None,
+                        latents,
+                        **extra_step_kwargs,
+                        return_dict=False,
+                    )
+                latents = latents.to(prompt_embeds.dtype)
+
+                # call the callback, if provided
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                if comfyui_progressbar:
+                    pbar.update(1)
+
+        if output_type == "numpy":
+            video = self.decode_latents(latents)
+        elif not output_type == "latent":
+            video = self.decode_latents(latents)
+            video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+        else:
+            video = latents
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            video = torch.from_numpy(video)
+
+        return CogVideoX_Fun_PipelineOutput(videos=video)
diff --git a/cogvideox/ui/ui.py b/cogvideox/ui/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..04aef8075ca5daf1f9cc936a6c87d3b816e02d2e
--- /dev/null
+++ b/cogvideox/ui/ui.py
@@ -0,0 +1,1614 @@
+"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
+"""
+import base64
+import gc
+import json
+import os
+import random
+from datetime import datetime
+from glob import glob
+
+import cv2
+import gradio as gr
+import numpy as np
+import pkg_resources
+import requests
+import torch
+from diffusers import (AutoencoderKL, AutoencoderKLCogVideoX,
+                       CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from diffusers.utils.import_utils import is_xformers_available
+from omegaconf import OmegaConf
+from PIL import Image
+from safetensors import safe_open
+from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
+                          T5EncoderModel, T5Tokenizer)
+
+from cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_control import \
+    CogVideoX_Fun_Pipeline_Control
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
+    CogVideoX_Fun_Pipeline_Inpaint
+from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from cogvideox.utils.utils import (
+    get_image_to_video_latent, get_video_to_video_latent,
+    get_width_and_height_from_image_and_base_resolution, save_videos_grid)
+
+scheduler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "Euler A": EulerAncestralDiscreteScheduler,
+    "DPM++": DPMSolverMultistepScheduler, 
+    "PNDM": PNDMScheduler,
+    "DDIM_Cog": CogVideoXDDIMScheduler,
+    "DDIM_Origin": DDIMScheduler,
+}
+
+gradio_version = pkg_resources.get_distribution("gradio").version
+gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
+
+css = """
+.toolbutton {
+    margin-buttom: 0em 0em 0em 0em;
+    max-width: 2.5em;
+    min-width: 2.5em !important;
+    height: 2.5em;
+}
+"""
+
+class CogVideoX_Fun_Controller:
+    def __init__(self, low_gpu_memory_mode, weight_dtype):
+        # config dirs
+        self.basedir                    = os.getcwd()
+        self.config_dir                 = os.path.join(self.basedir, "config")
+        self.diffusion_transformer_dir  = os.path.join(self.basedir, "models", "Diffusion_Transformer")
+        self.motion_module_dir          = os.path.join(self.basedir, "models", "Motion_Module")
+        self.personalized_model_dir     = os.path.join(self.basedir, "models", "Personalized_Model")
+        self.savedir                    = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
+        self.savedir_sample             = os.path.join(self.savedir, "sample")
+        self.model_type                 = "Inpaint"
+        os.makedirs(self.savedir, exist_ok=True)
+
+        self.diffusion_transformer_list = []
+        self.motion_module_list      = []
+        self.personalized_model_list = []
+        
+        self.refresh_diffusion_transformer()
+        self.refresh_motion_module()
+        self.refresh_personalized_model()
+
+        # config models
+        self.tokenizer             = None
+        self.text_encoder          = None
+        self.vae                   = None
+        self.transformer           = None
+        self.pipeline              = None
+        self.motion_module_path    = "none"
+        self.base_model_path       = "none"
+        self.lora_model_path       = "none"
+        self.low_gpu_memory_mode   = low_gpu_memory_mode
+        
+        self.weight_dtype = weight_dtype
+
+    def refresh_diffusion_transformer(self):
+        self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
+
+    def refresh_motion_module(self):
+        motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
+        self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
+
+    def refresh_personalized_model(self):
+        personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
+        self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
+
+    def update_model_type(self, model_type):
+        self.model_type = model_type
+
+    def update_diffusion_transformer(self, diffusion_transformer_dropdown):
+        print("Update diffusion transformer")
+        if diffusion_transformer_dropdown == "none":
+            return gr.update()
+        self.vae = AutoencoderKLCogVideoX.from_pretrained(
+            diffusion_transformer_dropdown, 
+            subfolder="vae", 
+        ).to(self.weight_dtype)
+
+        # Get Transformer
+        self.transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+            diffusion_transformer_dropdown, 
+            subfolder="transformer", 
+        ).to(self.weight_dtype)
+        
+        # Get pipeline
+        if self.model_type == "Inpaint":
+            if self.transformer.config.in_channels != self.vae.config.latent_channels:
+                self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+                    diffusion_transformer_dropdown,
+                    vae=self.vae, 
+                    transformer=self.transformer,
+                    scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
+                    torch_dtype=self.weight_dtype
+                )
+            else:
+                self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+                    diffusion_transformer_dropdown,
+                    vae=self.vae, 
+                    transformer=self.transformer,
+                    scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
+                    torch_dtype=self.weight_dtype
+                )
+        else:
+            self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
+                diffusion_transformer_dropdown,
+                vae=self.vae, 
+                transformer=self.transformer,
+                scheduler=scheduler_dict["Euler"].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
+                torch_dtype=self.weight_dtype
+            )
+
+        if self.low_gpu_memory_mode:
+            self.pipeline.enable_sequential_cpu_offload()
+        else:
+            self.pipeline.enable_model_cpu_offload()
+        print("Update diffusion transformer done")
+        return gr.update()
+
+    def update_base_model(self, base_model_dropdown):
+        self.base_model_path = base_model_dropdown
+        print("Update base model")
+        if base_model_dropdown == "none":
+            return gr.update()
+        if self.transformer is None:
+            gr.Info(f"Please select a pretrained model path.")
+            return gr.update(value=None)
+        else:
+            base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
+            base_model_state_dict = {}
+            with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
+                for key in f.keys():
+                    base_model_state_dict[key] = f.get_tensor(key)
+            self.transformer.load_state_dict(base_model_state_dict, strict=False)
+            print("Update base done")
+            return gr.update()
+
+    def update_lora_model(self, lora_model_dropdown):
+        print("Update lora model")
+        if lora_model_dropdown == "none":
+            self.lora_model_path = "none"
+            return gr.update()
+        lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
+        self.lora_model_path = lora_model_dropdown
+        return gr.update()
+
+    def generate(
+        self,
+        diffusion_transformer_dropdown,
+        base_model_dropdown,
+        lora_model_dropdown, 
+        lora_alpha_slider,
+        prompt_textbox, 
+        negative_prompt_textbox, 
+        sampler_dropdown, 
+        sample_step_slider, 
+        resize_method,
+        width_slider, 
+        height_slider, 
+        base_resolution, 
+        generation_method, 
+        length_slider, 
+        overlap_video_length, 
+        partial_video_length, 
+        cfg_scale_slider, 
+        start_image, 
+        end_image, 
+        validation_video,
+        validation_video_mask,
+        control_video,
+        denoise_strength,
+        seed_textbox,
+        is_api = False,
+    ):
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+
+        if self.transformer is None:
+            raise gr.Error(f"Please select a pretrained model path.")
+
+        if self.base_model_path != base_model_dropdown:
+            self.update_base_model(base_model_dropdown)
+
+        if self.lora_model_path != lora_model_dropdown:
+            print("Update lora model")
+            self.update_lora_model(lora_model_dropdown)
+
+        if control_video is not None and self.model_type == "Inpaint":
+            if is_api:
+                return "", f"If specifying the control video, please set the model_type == \"Control\". "
+            else:
+                raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
+
+        if control_video is None and self.model_type == "Control":
+            if is_api:
+                return "", f"If set the model_type == \"Control\", please specifying the control video. "
+            else:
+                raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
+
+        if resize_method == "Resize according to Reference":
+            if start_image is None and validation_video is None and control_video is None:
+                if is_api:
+                    return "", f"Please upload an image when using \"Resize according to Reference\"."
+                else:
+                    raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
+
+            aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+            if self.model_type == "Inpaint":
+                if validation_video is not None:
+                    original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
+                else:
+                    original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
+            else:
+                original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
+            closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
+            height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
+
+        if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
+            if is_api:
+                return "", f"Please select an image to video pretrained model while using image to video."
+            else:
+                raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
+
+        if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation":
+            if is_api:
+                return "", f"Please select an image to video pretrained model while using long video generation."
+            else:
+                raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
+        
+        if start_image is None and end_image is not None:
+            if is_api:
+                return "", f"If specifying the ending image of the video, please specify a starting image of the video."
+            else:
+                raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
+
+        is_image = True if generation_method == "Image Generation" else False
+
+        self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
+        if self.lora_model_path != "none":
+            # lora part
+            self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+
+        if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
+        else: seed_textbox = np.random.randint(0, 1e10)
+        generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
+        
+        try:
+            if self.model_type == "Inpaint":
+                if self.transformer.config.in_channels != self.vae.config.latent_channels:
+                    if generation_method == "Long Video Generation":
+                        if validation_video is not None:
+                            raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
+                        init_frames = 0
+                        last_frames = init_frames + partial_video_length
+                        while init_frames < length_slider:
+                            if last_frames >= length_slider:
+                                _partial_video_length = length_slider - init_frames
+                                _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
+                                
+                                if _partial_video_length <= 0:
+                                    break
+                            else:
+                                _partial_video_length = partial_video_length
+
+                            if last_frames >= length_slider:
+                                input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
+                            else:
+                                input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
+
+                            with torch.no_grad():
+                                sample = self.pipeline(
+                                    prompt_textbox, 
+                                    negative_prompt     = negative_prompt_textbox,
+                                    num_inference_steps = sample_step_slider,
+                                    guidance_scale      = cfg_scale_slider,
+                                    width               = width_slider,
+                                    height              = height_slider,
+                                    num_frames          = _partial_video_length,
+                                    generator           = generator,
+
+                                    video        = input_video,
+                                    mask_video   = input_video_mask,
+                                    strength     = 1,
+                                ).videos
+                            
+                            if init_frames != 0:
+                                mix_ratio = torch.from_numpy(
+                                    np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
+                                ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+                                
+                                new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
+                                    sample[:, :, :overlap_video_length] * mix_ratio
+                                new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
+
+                                sample = new_sample
+                            else:
+                                new_sample = sample
+
+                            if last_frames >= length_slider:
+                                break
+
+                            start_image = [
+                                Image.fromarray(
+                                    (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
+                                ) for _index in range(-overlap_video_length, 0)
+                            ]
+
+                            init_frames = init_frames + _partial_video_length - overlap_video_length
+                            last_frames = init_frames + _partial_video_length
+                    else:
+                        if validation_video is not None:
+                            input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
+                            strength = denoise_strength
+                        else:
+                            input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
+                            strength = 1
+
+                        sample = self.pipeline(
+                            prompt_textbox,
+                            negative_prompt     = negative_prompt_textbox,
+                            num_inference_steps = sample_step_slider,
+                            guidance_scale      = cfg_scale_slider,
+                            width               = width_slider,
+                            height              = height_slider,
+                            num_frames          = length_slider if not is_image else 1,
+                            generator           = generator,
+
+                            video        = input_video,
+                            mask_video   = input_video_mask,
+                            strength     = strength,
+                        ).videos
+                else:
+                    sample = self.pipeline(
+                        prompt_textbox,
+                        negative_prompt     = negative_prompt_textbox,
+                        num_inference_steps = sample_step_slider,
+                        guidance_scale      = cfg_scale_slider,
+                        width               = width_slider,
+                        height              = height_slider,
+                        num_frames          = length_slider if not is_image else 1,
+                        generator           = generator
+                    ).videos
+            else:
+                input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
+
+                sample = self.pipeline(
+                    prompt_textbox,
+                    negative_prompt     = negative_prompt_textbox,
+                    num_inference_steps = sample_step_slider,
+                    guidance_scale      = cfg_scale_slider,
+                    width               = width_slider,
+                    height              = height_slider,
+                    num_frames          = length_slider if not is_image else 1,
+                    generator           = generator,
+
+                    control_video = input_video,
+                ).videos
+        except Exception as e:
+            gc.collect()
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+            if self.lora_model_path != "none":
+                self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+            if is_api:
+                return "", f"Error. error information is {str(e)}"
+            else:
+                return gr.update(), gr.update(), f"Error. error information is {str(e)}"
+
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+
+        # lora part
+        if self.lora_model_path != "none":
+            self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+
+        sample_config = {
+            "prompt": prompt_textbox,
+            "n_prompt": negative_prompt_textbox,
+            "sampler": sampler_dropdown,
+            "num_inference_steps": sample_step_slider,
+            "guidance_scale": cfg_scale_slider,
+            "width": width_slider,
+            "height": height_slider,
+            "video_length": length_slider,
+            "seed_textbox": seed_textbox
+        }
+        json_str = json.dumps(sample_config, indent=4)
+        with open(os.path.join(self.savedir, "logs.json"), "a") as f:
+            f.write(json_str)
+            f.write("\n\n")
+            
+        if not os.path.exists(self.savedir_sample):
+            os.makedirs(self.savedir_sample, exist_ok=True)
+        index = len([path for path in os.listdir(self.savedir_sample)]) + 1
+        prefix = str(index).zfill(3)
+
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+        if is_image or length_slider == 1:
+            save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
+
+            image = sample[0, :, 0]
+            image = image.transpose(0, 1).transpose(1, 2)
+            image = (image * 255).numpy().astype(np.uint8)
+            image = Image.fromarray(image)
+            image.save(save_sample_path)
+
+            if is_api:
+                return save_sample_path, "Success"
+            else:
+                if gradio_version_is_above_4:
+                    return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
+                else:
+                    return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
+        else:
+            save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
+            save_videos_grid(sample, save_sample_path, fps=8)
+
+            if is_api:
+                return save_sample_path, "Success"
+            else:
+                if gradio_version_is_above_4:
+                    return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
+                else:
+                    return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
+
+
+def ui(low_gpu_memory_mode, weight_dtype):
+    controller = CogVideoX_Fun_Controller(low_gpu_memory_mode, weight_dtype)
+
+    with gr.Blocks(css=css) as demo:
+        gr.Markdown(
+            """
+            # CogVideoX-Fun:
+
+            A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. 
+
+            [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
+            """
+        )
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
+                """
+            )
+            with gr.Row():
+                model_type = gr.Dropdown(
+                    label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
+                    choices=["Inpaint", "Control"],
+                    value="Inpaint",
+                    interactive=True,
+                )
+
+            gr.Markdown(
+                """
+                ### 2. Model checkpoints (模型路径).
+                """
+            )
+            with gr.Row():
+                diffusion_transformer_dropdown = gr.Dropdown(
+                    label="Pretrained Model Path (预训练模型路径)",
+                    choices=controller.diffusion_transformer_list,
+                    value="none",
+                    interactive=True,
+                )
+                diffusion_transformer_dropdown.change(
+                    fn=controller.update_diffusion_transformer, 
+                    inputs=[diffusion_transformer_dropdown], 
+                    outputs=[diffusion_transformer_dropdown]
+                )
+                
+                diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+                def refresh_diffusion_transformer():
+                    controller.refresh_diffusion_transformer()
+                    return gr.update(choices=controller.diffusion_transformer_list)
+                diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
+
+            with gr.Row():
+                base_model_dropdown = gr.Dropdown(
+                    label="Select base Dreambooth model (选择基模型[非必需])",
+                    choices=controller.personalized_model_list,
+                    value="none",
+                    interactive=True,
+                )
+                
+                lora_model_dropdown = gr.Dropdown(
+                    label="Select LoRA model (选择LoRA模型[非必需])",
+                    choices=["none"] + controller.personalized_model_list,
+                    value="none",
+                    interactive=True,
+                )
+
+                lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
+                
+                personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+                def update_personalized_model():
+                    controller.refresh_personalized_model()
+                    return [
+                        gr.update(choices=controller.personalized_model_list),
+                        gr.update(choices=["none"] + controller.personalized_model_list)
+                    ]
+                personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
+
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 3. Configs for Generation (生成参数配置).
+                """
+            )
+            
+            prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
+            negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
+                
+            with gr.Row():
+                with gr.Column():
+                    with gr.Row():
+                        sampler_dropdown   = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+                        sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
+                        
+                    resize_method = gr.Radio(
+                        ["Generate by", "Resize according to Reference"],
+                        value="Generate by",
+                        show_label=False,
+                    )
+                    width_slider     = gr.Slider(label="Width (视频宽度)",            value=672, minimum=128, maximum=1344, step=16)
+                    height_slider    = gr.Slider(label="Height (视频高度)",           value=384, minimum=128, maximum=1344, step=16)
+                    base_resolution  = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
+
+                    with gr.Group():
+                        generation_method = gr.Radio(
+                            ["Video Generation", "Image Generation", "Long Video Generation"],
+                            value="Video Generation",
+                            show_label=False,
+                        )
+                        with gr.Row():
+                            length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=1,   maximum=49,  step=4)
+                            overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1,   maximum=4,  step=1, visible=False)
+                            partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5,   maximum=49,  step=4, visible=False)
+                    
+                    source_method = gr.Radio(
+                        ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
+                        value="Text to Video (文本到视频)",
+                        show_label=False,
+                    )
+                    with gr.Column(visible = False) as image_to_video_col:
+                        start_image = gr.Image(
+                            label="The image at the beginning of the video (图片到视频的开始图片)",  show_label=True, 
+                            elem_id="i2v_start", sources="upload", type="filepath", 
+                        )
+                        
+                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
+                        def select_template(evt: gr.SelectData):
+                            text = {
+                                "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                            }[template_gallery_path[evt.index]]
+                            return template_gallery_path[evt.index], text
+
+                        template_gallery = gr.Gallery(
+                            template_gallery_path,
+                            columns=5, rows=1,
+                            height=140,
+                            allow_preview=False,
+                            container=False,
+                            label="Template Examples",
+                        )
+                        template_gallery.select(select_template, None, [start_image, prompt_textbox])
+                        
+                        with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
+                            end_image   = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
+
+                    with gr.Column(visible = False) as video_to_video_col:
+                        with gr.Row():
+                            validation_video = gr.Video(
+                                label="The video to convert (视频转视频的参考视频)",  show_label=True, 
+                                elem_id="v2v", sources="upload", 
+                            )
+                        with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
+                            gr.Markdown(
+                                """
+                                - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70  
+                                - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
+                                """
+                            )
+                            validation_video_mask = gr.Image(
+                                label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
+                                show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
+                            )
+                        denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
+
+                    with gr.Column(visible = False) as control_video_col:
+                        gr.Markdown(
+                            """
+                            Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
+                            """
+                        )
+                        control_video = gr.Video(
+                            label="The control video (用于提供控制信号的video)",  show_label=True, 
+                            elem_id="v2v_control", sources="upload", 
+                        )
+
+                    cfg_scale_slider  = gr.Slider(label="CFG Scale (引导系数)",        value=6.0, minimum=0,   maximum=20)
+                    
+                    with gr.Row():
+                        seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
+                        seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+                        seed_button.click(
+                            fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), 
+                            inputs=[], 
+                            outputs=[seed_textbox]
+                        )
+
+                    generate_button = gr.Button(value="Generate (生成)", variant='primary')
+                    
+                with gr.Column():
+                    result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
+                    result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
+                    infer_progress = gr.Textbox(
+                        label="Generation Info (生成信息)",
+                        value="No task currently",
+                        interactive=False
+                    )
+
+            model_type.change(
+                fn=controller.update_model_type, 
+                inputs=[model_type], 
+                outputs=[]
+            )
+
+            def upload_generation_method(generation_method):
+                if generation_method == "Video Generation":
+                    return [gr.update(visible=True, maximum=49, value=49), gr.update(visible=False), gr.update(visible=False)]
+                elif generation_method == "Image Generation":
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
+                else:
+                    return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
+            generation_method.change(
+                upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
+            )
+
+            def upload_source_method(source_method):
+                if source_method == "Text to Video (文本到视频)":
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
+                elif source_method == "Image to Video (图片到视频)":
+                    return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
+                elif source_method == "Video to Video (视频到视频)":
+                    return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
+                else:
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
+            source_method.change(
+                upload_source_method, source_method, [
+                    image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, 
+                    validation_video, validation_video_mask, control_video
+                ]
+            )
+
+            def upload_resize_method(resize_method):
+                if resize_method == "Generate by":
+                    return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
+                else:
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
+            resize_method.change(
+                upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
+            )
+
+            generate_button.click(
+                fn=controller.generate,
+                inputs=[
+                    diffusion_transformer_dropdown,
+                    base_model_dropdown,
+                    lora_model_dropdown,
+                    lora_alpha_slider,
+                    prompt_textbox, 
+                    negative_prompt_textbox, 
+                    sampler_dropdown, 
+                    sample_step_slider, 
+                    resize_method,
+                    width_slider, 
+                    height_slider, 
+                    base_resolution, 
+                    generation_method, 
+                    length_slider, 
+                    overlap_video_length, 
+                    partial_video_length, 
+                    cfg_scale_slider, 
+                    start_image, 
+                    end_image, 
+                    validation_video,
+                    validation_video_mask,
+                    control_video,
+                    denoise_strength, 
+                    seed_textbox,
+                ],
+                outputs=[result_image, result_video, infer_progress]
+            )
+    return demo, controller
+
+
+class CogVideoX_Fun_Controller_Modelscope:
+    def __init__(self, model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
+        # Basic dir
+        self.basedir                    = os.getcwd()
+        self.personalized_model_dir     = os.path.join(self.basedir, "models", "Personalized_Model")
+        self.lora_model_path            = "none"
+        self.savedir_sample             = savedir_sample
+        self.refresh_personalized_model()
+        os.makedirs(self.savedir_sample, exist_ok=True)
+
+        # model path
+        self.model_type = model_type
+        self.weight_dtype = weight_dtype
+        
+        self.vae = AutoencoderKLCogVideoX.from_pretrained(
+            model_name, 
+            subfolder="vae", 
+        ).to(self.weight_dtype)
+
+        # Get Transformer
+        self.transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+            model_name, 
+            subfolder="transformer", 
+        ).to(self.weight_dtype)
+        
+        # Get pipeline
+        if model_type == "Inpaint":
+            if self.transformer.config.in_channels != self.vae.config.latent_channels:
+                self.pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+                    model_name,
+                    vae=self.vae, 
+                    transformer=self.transformer,
+                    scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
+                    torch_dtype=self.weight_dtype
+                )
+            else:
+                self.pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+                    model_name,
+                    vae=self.vae, 
+                    transformer=self.transformer,
+                    scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
+                    torch_dtype=self.weight_dtype
+                )
+        else:
+            self.pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
+                model_name,
+                vae=self.vae, 
+                transformer=self.transformer,
+                scheduler=scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
+                torch_dtype=self.weight_dtype
+            )
+
+        if low_gpu_memory_mode:
+            self.pipeline.enable_sequential_cpu_offload()
+        else:
+            self.pipeline.enable_model_cpu_offload()
+        print("Update diffusion transformer done")
+
+
+    def refresh_personalized_model(self):
+        personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
+        self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
+
+
+    def update_lora_model(self, lora_model_dropdown):
+        print("Update lora model")
+        if lora_model_dropdown == "none":
+            self.lora_model_path = "none"
+            return gr.update()
+        lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
+        self.lora_model_path = lora_model_dropdown
+        return gr.update()
+
+    
+    def generate(
+        self,
+        diffusion_transformer_dropdown,
+        base_model_dropdown,
+        lora_model_dropdown, 
+        lora_alpha_slider,
+        prompt_textbox, 
+        negative_prompt_textbox, 
+        sampler_dropdown, 
+        sample_step_slider, 
+        resize_method,
+        width_slider, 
+        height_slider, 
+        base_resolution, 
+        generation_method, 
+        length_slider, 
+        overlap_video_length, 
+        partial_video_length, 
+        cfg_scale_slider, 
+        start_image, 
+        end_image, 
+        validation_video,
+        validation_video_mask,
+        control_video,
+        denoise_strength,
+        seed_textbox,
+        is_api = False,
+    ):    
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+
+        if self.transformer is None:
+            raise gr.Error(f"Please select a pretrained model path.")
+
+        if self.lora_model_path != lora_model_dropdown:
+            print("Update lora model")
+            self.update_lora_model(lora_model_dropdown)
+        
+        if control_video is not None and self.model_type == "Inpaint":
+            if is_api:
+                return "", f"If specifying the control video, please set the model_type == \"Control\". "
+            else:
+                raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
+
+        if control_video is None and self.model_type == "Control":
+            if is_api:
+                return "", f"If set the model_type == \"Control\", please specifying the control video. "
+            else:
+                raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
+
+        if resize_method == "Resize according to Reference":
+            if start_image is None and validation_video is None and control_video is None:
+                if is_api:
+                    return "", f"Please upload an image when using \"Resize according to Reference\"."
+                else:
+                    raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
+        
+            aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+            if self.model_type == "Inpaint":
+                if validation_video is not None:
+                    original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
+                else:
+                    original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
+            else:
+                original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
+            closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
+            height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
+
+        if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
+            if is_api:
+                return "", f"Please select an image to video pretrained model while using image to video."
+            else:
+                raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
+
+        if start_image is None and end_image is not None:
+            if is_api:
+                return "", f"If specifying the ending image of the video, please specify a starting image of the video."
+            else:
+                raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
+
+        is_image = True if generation_method == "Image Generation" else False
+
+        self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
+        if self.lora_model_path != "none":
+            # lora part
+            self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+
+        if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
+        else: seed_textbox = np.random.randint(0, 1e10)
+        generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
+        
+        try:
+            if self.model_type == "Inpaint":
+                if self.transformer.config.in_channels != self.vae.config.latent_channels:
+                    if validation_video is not None:
+                        input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
+                        strength = denoise_strength
+                    else:
+                        input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
+                        strength = 1
+
+                    sample = self.pipeline(
+                        prompt_textbox,
+                        negative_prompt     = negative_prompt_textbox,
+                        num_inference_steps = sample_step_slider,
+                        guidance_scale      = cfg_scale_slider,
+                        width               = width_slider,
+                        height              = height_slider,
+                        num_frames          = length_slider if not is_image else 1,
+                        generator           = generator,
+
+                        video        = input_video,
+                        mask_video   = input_video_mask,
+                        strength     = strength,
+                    ).videos
+                else:
+                    sample = self.pipeline(
+                        prompt_textbox,
+                        negative_prompt     = negative_prompt_textbox,
+                        num_inference_steps = sample_step_slider,
+                        guidance_scale      = cfg_scale_slider,
+                        width               = width_slider,
+                        height              = height_slider,
+                        num_frames          = length_slider if not is_image else 1,
+                        generator           = generator
+                    ).videos
+            else:
+                input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
+
+                sample = self.pipeline(
+                    prompt_textbox,
+                    negative_prompt     = negative_prompt_textbox,
+                    num_inference_steps = sample_step_slider,
+                    guidance_scale      = cfg_scale_slider,
+                    width               = width_slider,
+                    height              = height_slider,
+                    num_frames          = length_slider if not is_image else 1,
+                    generator           = generator,
+
+                    control_video = input_video,
+                ).videos
+        except Exception as e:
+            gc.collect()
+            torch.cuda.empty_cache()
+            torch.cuda.ipc_collect()
+            if self.lora_model_path != "none":
+                self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+            if is_api:
+                return "", f"Error. error information is {str(e)}"
+            else:
+                return gr.update(), gr.update(), f"Error. error information is {str(e)}"
+
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+        
+        # lora part
+        if self.lora_model_path != "none":
+            self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
+
+        if not os.path.exists(self.savedir_sample):
+            os.makedirs(self.savedir_sample, exist_ok=True)
+        index = len([path for path in os.listdir(self.savedir_sample)]) + 1
+        prefix = str(index).zfill(3)
+        
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+        if is_image or length_slider == 1:
+            save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
+
+            image = sample[0, :, 0]
+            image = image.transpose(0, 1).transpose(1, 2)
+            image = (image * 255).numpy().astype(np.uint8)
+            image = Image.fromarray(image)
+            image.save(save_sample_path)
+            if is_api:
+                return save_sample_path, "Success"
+            else:
+                if gradio_version_is_above_4:
+                    return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
+                else:
+                    return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
+        else:
+            save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
+            save_videos_grid(sample, save_sample_path, fps=8)
+            if is_api:
+                return save_sample_path, "Success"
+            else:
+                if gradio_version_is_above_4:
+                    return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
+                else:
+                    return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
+
+
+def ui_modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype):
+    controller = CogVideoX_Fun_Controller_Modelscope(model_name, model_type, savedir_sample, low_gpu_memory_mode, weight_dtype)
+
+    with gr.Blocks(css=css) as demo:
+        gr.Markdown(
+            """
+            # CogVideoX-Fun
+
+            A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. 
+
+            [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
+            """
+        )
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 1. CogVideoX-Fun Model Type (CogVideoX-Fun模型的种类,正常模型还是控制模型).
+                """
+            )
+            with gr.Row():
+                model_type = gr.Dropdown(
+                    label="The model type of CogVideoX-Fun (CogVideoX-Fun模型的种类,正常模型还是控制模型)",
+                    choices=[model_type],
+                    value=model_type,
+                    interactive=False,
+                )
+
+            gr.Markdown(
+                """
+                ### 2. Model checkpoints (模型路径).
+                """
+            )
+            with gr.Row():
+                diffusion_transformer_dropdown = gr.Dropdown(
+                    label="Pretrained Model Path (预训练模型路径)",
+                    choices=[model_name],
+                    value=model_name,
+                    interactive=False,
+                )
+            with gr.Row():
+                base_model_dropdown = gr.Dropdown(
+                    label="Select base Dreambooth model (选择基模型[非必需])",
+                    choices=["none"],
+                    value="none",
+                    interactive=False,
+                    visible=False
+                )
+                with gr.Column(visible=False):
+                    gr.Markdown(
+                        """
+                        ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
+                        """
+                    )
+                    with gr.Row():
+                        lora_model_dropdown = gr.Dropdown(
+                            label="Select LoRA model",
+                            choices=["none"],
+                            value="none",
+                            interactive=True,
+                        )
+
+                        lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
+                
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 3. Configs for Generation (生成参数配置).
+                """
+            )
+
+            prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
+            negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
+                
+            with gr.Row():
+                with gr.Column():
+                    with gr.Row():
+                        sampler_dropdown   = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+                        sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
+                    
+                    resize_method = gr.Radio(
+                        ["Generate by", "Resize according to Reference"],
+                        value="Generate by",
+                        show_label=False,
+                    )                        
+                    width_slider     = gr.Slider(label="Width (视频宽度)",            value=672, minimum=128, maximum=1280, step=16, interactive=False)
+                    height_slider    = gr.Slider(label="Height (视频高度)",           value=384, minimum=128, maximum=1280, step=16, interactive=False)
+                    base_resolution  = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
+
+                    with gr.Group():
+                        generation_method = gr.Radio(
+                            ["Video Generation", "Image Generation"],
+                            value="Video Generation",
+                            show_label=False,
+                            visible=True,
+                        )
+                        length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5,   maximum=49,  step=4)
+                        overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1,   maximum=4,  step=1, visible=False)
+                        partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5,   maximum=49,  step=4, visible=False)
+                        
+                    source_method = gr.Radio(
+                        ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"],
+                        value="Text to Video (文本到视频)",
+                        show_label=False,
+                    )
+                    with gr.Column(visible = False) as image_to_video_col:
+                        with gr.Row():
+                            start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
+                        
+                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
+                        def select_template(evt: gr.SelectData):
+                            text = {
+                                "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                            }[template_gallery_path[evt.index]]
+                            return template_gallery_path[evt.index], text
+
+                        template_gallery = gr.Gallery(
+                            template_gallery_path,
+                            columns=5, rows=1,
+                            height=140,
+                            allow_preview=False,
+                            container=False,
+                            label="Template Examples",
+                        )
+                        template_gallery.select(select_template, None, [start_image, prompt_textbox])
+
+                        with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
+                            end_image   = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
+
+                    with gr.Column(visible = False) as video_to_video_col:
+                        with gr.Row():
+                            validation_video = gr.Video(
+                                label="The video to convert (视频转视频的参考视频)",  show_label=True, 
+                                elem_id="v2v", sources="upload", 
+                            ) 
+                        with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
+                            gr.Markdown(
+                                """
+                                - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70  
+                                - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
+                                """
+                            )
+                            validation_video_mask = gr.Image(
+                                label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
+                                show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
+                            )
+                        denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
+
+                    with gr.Column(visible = False) as control_video_col:
+                        gr.Markdown(
+                            """
+                            Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
+                            """
+                        )
+                        control_video = gr.Video(
+                            label="The control video (用于提供控制信号的video)",  show_label=True, 
+                            elem_id="v2v_control", sources="upload", 
+                        )
+
+                    cfg_scale_slider  = gr.Slider(label="CFG Scale (引导系数)",        value=6.0, minimum=0,   maximum=20)
+                    
+                    with gr.Row():
+                        seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
+                        seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+                        seed_button.click(
+                            fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), 
+                            inputs=[], 
+                            outputs=[seed_textbox]
+                        )
+
+                    generate_button = gr.Button(value="Generate (生成)", variant='primary')
+                    
+                with gr.Column():
+                    result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
+                    result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
+                    infer_progress = gr.Textbox(
+                        label="Generation Info (生成信息)",
+                        value="No task currently",
+                        interactive=False
+                    )
+
+            def upload_generation_method(generation_method):
+                if generation_method == "Video Generation":
+                    return gr.update(visible=True, minimum=8, maximum=49, value=49, interactive=True)
+                elif generation_method == "Image Generation":
+                    return gr.update(minimum=1, maximum=1, value=1, interactive=False)
+            generation_method.change(
+                upload_generation_method, generation_method, [length_slider]
+            )
+
+            def upload_source_method(source_method):
+                if source_method == "Text to Video (文本到视频)":
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
+                elif source_method == "Image to Video (图片到视频)":
+                    return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
+                elif source_method == "Video to Video (视频到视频)":
+                    return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
+                else:
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
+            source_method.change(
+                upload_source_method, source_method, [
+                    image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, 
+                    validation_video, validation_video_mask, control_video
+                ]
+            )
+
+            def upload_resize_method(resize_method):
+                if resize_method == "Generate by":
+                    return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
+                else:
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
+            resize_method.change(
+                upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
+            )
+
+            generate_button.click(
+                fn=controller.generate,
+                inputs=[
+                    diffusion_transformer_dropdown,
+                    base_model_dropdown,
+                    lora_model_dropdown, 
+                    lora_alpha_slider,
+                    prompt_textbox, 
+                    negative_prompt_textbox, 
+                    sampler_dropdown, 
+                    sample_step_slider, 
+                    resize_method,
+                    width_slider, 
+                    height_slider, 
+                    base_resolution, 
+                    generation_method, 
+                    length_slider, 
+                    overlap_video_length, 
+                    partial_video_length, 
+                    cfg_scale_slider, 
+                    start_image, 
+                    end_image, 
+                    validation_video,
+                    validation_video_mask,
+                    control_video,
+                    denoise_strength, 
+                    seed_textbox,
+                ],
+                outputs=[result_image, result_video, infer_progress]
+            )
+    return demo, controller
+
+
+def post_eas(
+    diffusion_transformer_dropdown,
+    base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
+    prompt_textbox, negative_prompt_textbox, 
+    sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
+    base_resolution, generation_method, length_slider, cfg_scale_slider, 
+    start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
+):
+    if start_image is not None:
+        with open(start_image, 'rb') as file:
+            file_content = file.read()
+            start_image_encoded_content = base64.b64encode(file_content)
+            start_image = start_image_encoded_content.decode('utf-8')
+
+    if end_image is not None:
+        with open(end_image, 'rb') as file:
+            file_content = file.read()
+            end_image_encoded_content = base64.b64encode(file_content)
+            end_image = end_image_encoded_content.decode('utf-8')
+
+    if validation_video is not None:
+        with open(validation_video, 'rb') as file:
+            file_content = file.read()
+            validation_video_encoded_content = base64.b64encode(file_content)
+            validation_video = validation_video_encoded_content.decode('utf-8')
+
+    if validation_video_mask is not None:
+        with open(validation_video_mask, 'rb') as file:
+            file_content = file.read()
+            validation_video_mask_encoded_content = base64.b64encode(file_content)
+            validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
+
+    datas = {
+        "base_model_path": base_model_dropdown,
+        "lora_model_path": lora_model_dropdown, 
+        "lora_alpha_slider": lora_alpha_slider, 
+        "prompt_textbox": prompt_textbox, 
+        "negative_prompt_textbox": negative_prompt_textbox, 
+        "sampler_dropdown": sampler_dropdown, 
+        "sample_step_slider": sample_step_slider, 
+        "resize_method": resize_method,
+        "width_slider": width_slider, 
+        "height_slider": height_slider, 
+        "base_resolution": base_resolution,
+        "generation_method": generation_method,
+        "length_slider": length_slider,
+        "cfg_scale_slider": cfg_scale_slider,
+        "start_image": start_image,
+        "end_image": end_image,
+        "validation_video": validation_video,
+        "validation_video_mask": validation_video_mask,
+        "denoise_strength": denoise_strength,
+        "seed_textbox": seed_textbox,
+    }
+
+    session = requests.session()
+    session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
+
+    response = session.post(url=f'{os.environ.get("EAS_URL")}/cogvideox_fun/infer_forward', json=datas, timeout=300)
+
+    outputs = response.json()
+    return outputs
+
+
+class CogVideoX_Fun_Controller_EAS:
+    def __init__(self, model_name, savedir_sample):
+        self.savedir_sample = savedir_sample
+        os.makedirs(self.savedir_sample, exist_ok=True)
+
+    def generate(
+        self,
+        diffusion_transformer_dropdown,
+        base_model_dropdown,
+        lora_model_dropdown, 
+        lora_alpha_slider,
+        prompt_textbox, 
+        negative_prompt_textbox, 
+        sampler_dropdown, 
+        sample_step_slider, 
+        resize_method,
+        width_slider, 
+        height_slider, 
+        base_resolution, 
+        generation_method, 
+        length_slider, 
+        cfg_scale_slider, 
+        start_image, 
+        end_image, 
+        validation_video, 
+        validation_video_mask, 
+        denoise_strength,
+        seed_textbox
+    ):
+        is_image = True if generation_method == "Image Generation" else False
+
+        outputs = post_eas(
+            diffusion_transformer_dropdown,
+            base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
+            prompt_textbox, negative_prompt_textbox, 
+            sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
+            base_resolution, generation_method, length_slider, cfg_scale_slider, 
+            start_image, end_image, validation_video, validation_video_mask, denoise_strength, 
+            seed_textbox
+        )
+        try:
+            base64_encoding = outputs["base64_encoding"]
+        except:
+            return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
+            
+        decoded_data = base64.b64decode(base64_encoding)
+
+        if not os.path.exists(self.savedir_sample):
+            os.makedirs(self.savedir_sample, exist_ok=True)
+        index = len([path for path in os.listdir(self.savedir_sample)]) + 1
+        prefix = str(index).zfill(3)
+        
+        if is_image or length_slider == 1:
+            save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
+            with open(save_sample_path, "wb") as file:
+                file.write(decoded_data)
+            if gradio_version_is_above_4:
+                return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
+            else:
+                return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
+        else:
+            save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
+            with open(save_sample_path, "wb") as file:
+                file.write(decoded_data)
+            if gradio_version_is_above_4:
+                return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
+            else:
+                return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
+
+
+def ui_eas(model_name, savedir_sample):
+    controller = CogVideoX_Fun_Controller_EAS(model_name, savedir_sample)
+
+    with gr.Blocks(css=css) as demo:
+        gr.Markdown(
+            """
+            # CogVideoX-Fun
+
+            A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. 
+
+            [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
+            """
+        )
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 1. Model checkpoints (模型路径).
+                """
+            )
+            with gr.Row():
+                diffusion_transformer_dropdown = gr.Dropdown(
+                    label="Pretrained Model Path",
+                    choices=[model_name],
+                    value=model_name,
+                    interactive=False,
+                )
+            with gr.Row():
+                base_model_dropdown = gr.Dropdown(
+                    label="Select base Dreambooth model",
+                    choices=["none"],
+                    value="none",
+                    interactive=False,
+                    visible=False
+                )
+                with gr.Column(visible=False):
+                    gr.Markdown(
+                        """
+                        ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
+                        """
+                    )
+                    with gr.Row():
+                        lora_model_dropdown = gr.Dropdown(
+                            label="Select LoRA model",
+                            choices=["none"],
+                            value="none",
+                            interactive=True,
+                        )
+
+                        lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
+                
+        with gr.Column(variant="panel"):
+            gr.Markdown(
+                """
+                ### 2. Configs for Generation.
+                """
+            )
+            
+            prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
+            negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " )
+                
+            with gr.Row():
+                with gr.Column():
+                    with gr.Row():
+                        sampler_dropdown   = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+                        sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=50, step=1, interactive=False)
+                    
+                    resize_method = gr.Radio(
+                        ["Generate by", "Resize according to Reference"],
+                        value="Generate by",
+                        show_label=False,
+                    )
+                    width_slider     = gr.Slider(label="Width (视频宽度)",            value=672, minimum=128, maximum=1280, step=16, interactive=False)
+                    height_slider    = gr.Slider(label="Height (视频高度)",           value=384, minimum=128, maximum=1280, step=16, interactive=False)
+                    base_resolution  = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
+
+                    with gr.Group():
+                        generation_method = gr.Radio(
+                            ["Video Generation", "Image Generation"],
+                            value="Video Generation",
+                            show_label=False,
+                            visible=True,
+                        )
+                        length_slider = gr.Slider(label="Animation length (视频帧数)", value=49, minimum=5,   maximum=49,  step=4)
+                    
+                    source_method = gr.Radio(
+                        ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"],
+                        value="Text to Video (文本到视频)",
+                        show_label=False,
+                    )
+                    with gr.Column(visible = False) as image_to_video_col:
+                        start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
+                        
+                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
+                        def select_template(evt: gr.SelectData):
+                            text = {
+                                "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                                "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
+                            }[template_gallery_path[evt.index]]
+                            return template_gallery_path[evt.index], text
+
+                        template_gallery = gr.Gallery(
+                            template_gallery_path,
+                            columns=5, rows=1,
+                            height=140,
+                            allow_preview=False,
+                            container=False,
+                            label="Template Examples",
+                        )
+                        template_gallery.select(select_template, None, [start_image, prompt_textbox])
+
+                        with gr.Accordion("The image at the ending of the video (Optional)", open=False):
+                            end_image   = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
+                    
+                    with gr.Column(visible = False) as video_to_video_col:
+                        with gr.Row():
+                            validation_video = gr.Video(
+                                label="The video to convert (视频转视频的参考视频)",  show_label=True, 
+                                elem_id="v2v", sources="upload", 
+                            )
+                        with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
+                            gr.Markdown(
+                                """
+                                - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70  
+                                - (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
+                                """
+                            )
+                            validation_video_mask = gr.Image(
+                                label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
+                                show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
+                            )
+                        denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
+
+                    cfg_scale_slider  = gr.Slider(label="CFG Scale (引导系数)",        value=6.0, minimum=0,   maximum=20)
+                    
+                    with gr.Row():
+                        seed_textbox = gr.Textbox(label="Seed", value=43)
+                        seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+                        seed_button.click(
+                            fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), 
+                            inputs=[], 
+                            outputs=[seed_textbox]
+                        )
+
+                    generate_button = gr.Button(value="Generate", variant='primary')
+                    
+                with gr.Column():
+                    result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
+                    result_video = gr.Video(label="Generated Animation", interactive=False)
+                    infer_progress = gr.Textbox(
+                        label="Generation Info",
+                        value="No task currently",
+                        interactive=False
+                    )
+
+            def upload_generation_method(generation_method):
+                if generation_method == "Video Generation":
+                    return gr.update(visible=True, minimum=5, maximum=49, value=49, interactive=True)
+                elif generation_method == "Image Generation":
+                    return gr.update(minimum=1, maximum=1, value=1, interactive=False)
+            generation_method.change(
+                upload_generation_method, generation_method, [length_slider]
+            )
+
+            def upload_source_method(source_method):
+                if source_method == "Text to Video (文本到视频)":
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
+                elif source_method == "Image to Video (图片到视频)":
+                    return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
+                else:
+                    return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
+            source_method.change(
+                upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
+            )
+
+            def upload_resize_method(resize_method):
+                if resize_method == "Generate by":
+                    return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
+                else:
+                    return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
+            resize_method.change(
+                upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
+            )
+
+            generate_button.click(
+                fn=controller.generate,
+                inputs=[
+                    diffusion_transformer_dropdown,
+                    base_model_dropdown,
+                    lora_model_dropdown, 
+                    lora_alpha_slider,
+                    prompt_textbox, 
+                    negative_prompt_textbox, 
+                    sampler_dropdown, 
+                    sample_step_slider, 
+                    resize_method,
+                    width_slider, 
+                    height_slider, 
+                    base_resolution, 
+                    generation_method, 
+                    length_slider, 
+                    cfg_scale_slider, 
+                    start_image, 
+                    end_image, 
+                    validation_video,
+                    validation_video_mask,
+                    denoise_strength, 
+                    seed_textbox,
+                ],
+                outputs=[result_image, result_video, infer_progress]
+            )
+    return demo, controller
\ No newline at end of file
diff --git a/cogvideox/utils/__init__.py b/cogvideox/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/cogvideox/utils/lora_utils.py b/cogvideox/utils/lora_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b51fc6986e595ba7d609b827a8fd5667d02108
--- /dev/null
+++ b/cogvideox/utils/lora_utils.py
@@ -0,0 +1,477 @@
+# LoRA network module
+# reference:
+# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
+# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
+# https://github.com/bmaltais/kohya_ss
+
+import hashlib
+import math
+import os
+from collections import defaultdict
+from io import BytesIO
+from typing import List, Optional, Type, Union
+
+import safetensors.torch
+import torch
+import torch.utils.checkpoint
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from safetensors.torch import load_file
+from transformers import T5EncoderModel
+
+
+class LoRAModule(torch.nn.Module):
+    """
+    replaces forward method of the original Linear, instead of replacing the original Linear module.
+    """
+
+    def __init__(
+        self,
+        lora_name,
+        org_module: torch.nn.Module,
+        multiplier=1.0,
+        lora_dim=4,
+        alpha=1,
+        dropout=None,
+        rank_dropout=None,
+        module_dropout=None,
+    ):
+        """if alpha == 0 or None, alpha is rank (no scaling)."""
+        super().__init__()
+        self.lora_name = lora_name
+
+        if org_module.__class__.__name__ == "Conv2d":
+            in_dim = org_module.in_channels
+            out_dim = org_module.out_channels
+        else:
+            in_dim = org_module.in_features
+            out_dim = org_module.out_features
+
+        self.lora_dim = lora_dim
+        if org_module.__class__.__name__ == "Conv2d":
+            kernel_size = org_module.kernel_size
+            stride = org_module.stride
+            padding = org_module.padding
+            self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
+            self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
+        else:
+            self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
+            self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
+
+        if type(alpha) == torch.Tensor:
+            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
+        alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
+        self.scale = alpha / self.lora_dim
+        self.register_buffer("alpha", torch.tensor(alpha))
+
+        # same as microsoft's
+        torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
+        torch.nn.init.zeros_(self.lora_up.weight)
+
+        self.multiplier = multiplier
+        self.org_module = org_module  # remove in applying
+        self.dropout = dropout
+        self.rank_dropout = rank_dropout
+        self.module_dropout = module_dropout
+
+    def apply_to(self):
+        self.org_forward = self.org_module.forward
+        self.org_module.forward = self.forward
+        del self.org_module
+
+    def forward(self, x, *args, **kwargs):
+        weight_dtype = x.dtype
+        org_forwarded = self.org_forward(x)
+
+        # module dropout
+        if self.module_dropout is not None and self.training:
+            if torch.rand(1) < self.module_dropout:
+                return org_forwarded
+
+        lx = self.lora_down(x.to(self.lora_down.weight.dtype))
+
+        # normal dropout
+        if self.dropout is not None and self.training:
+            lx = torch.nn.functional.dropout(lx, p=self.dropout)
+
+        # rank dropout
+        if self.rank_dropout is not None and self.training:
+            mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
+            if len(lx.size()) == 3:
+                mask = mask.unsqueeze(1)  # for Text Encoder
+            elif len(lx.size()) == 4:
+                mask = mask.unsqueeze(-1).unsqueeze(-1)  # for Conv2d
+            lx = lx * mask
+
+            # scaling for rank dropout: treat as if the rank is changed
+            scale = self.scale * (1.0 / (1.0 - self.rank_dropout))  # redundant for readability
+        else:
+            scale = self.scale
+
+        lx = self.lora_up(lx)
+
+        return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
+
+
+def addnet_hash_legacy(b):
+    """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
+    m = hashlib.sha256()
+
+    b.seek(0x100000)
+    m.update(b.read(0x10000))
+    return m.hexdigest()[0:8]
+
+
+def addnet_hash_safetensors(b):
+    """New model hash used by sd-webui-additional-networks for .safetensors format files"""
+    hash_sha256 = hashlib.sha256()
+    blksize = 1024 * 1024
+
+    b.seek(0)
+    header = b.read(8)
+    n = int.from_bytes(header, "little")
+
+    offset = n + 8
+    b.seek(offset)
+    for chunk in iter(lambda: b.read(blksize), b""):
+        hash_sha256.update(chunk)
+
+    return hash_sha256.hexdigest()
+
+
+def precalculate_safetensors_hashes(tensors, metadata):
+    """Precalculate the model hashes needed by sd-webui-additional-networks to
+    save time on indexing the model later."""
+
+    # Because writing user metadata to the file can change the result of
+    # sd_models.model_hash(), only retain the training metadata for purposes of
+    # calculating the hash, as they are meant to be immutable
+    metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
+
+    bytes = safetensors.torch.save(tensors, metadata)
+    b = BytesIO(bytes)
+
+    model_hash = addnet_hash_safetensors(b)
+    legacy_hash = addnet_hash_legacy(b)
+    return model_hash, legacy_hash
+
+
+class LoRANetwork(torch.nn.Module):
+    TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel"]
+    TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"]
+    LORA_PREFIX_TRANSFORMER = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+    def __init__(
+        self,
+        text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
+        unet,
+        multiplier: float = 1.0,
+        lora_dim: int = 4,
+        alpha: float = 1,
+        dropout: Optional[float] = None,
+        module_class: Type[object] = LoRAModule,
+        add_lora_in_attn_temporal: bool = False,
+        varbose: Optional[bool] = False,
+    ) -> None:
+        super().__init__()
+        self.multiplier = multiplier
+
+        self.lora_dim = lora_dim
+        self.alpha = alpha
+        self.dropout = dropout
+
+        print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
+        print(f"neuron dropout: p={self.dropout}")
+
+        # create module instances
+        def create_modules(
+            is_unet: bool,
+            root_module: torch.nn.Module,
+            target_replace_modules: List[torch.nn.Module],
+        ) -> List[LoRAModule]:
+            prefix = (
+                self.LORA_PREFIX_TRANSFORMER
+                if is_unet
+                else self.LORA_PREFIX_TEXT_ENCODER
+            )
+            loras = []
+            skipped = []
+            for name, module in root_module.named_modules():
+                if module.__class__.__name__ in target_replace_modules:
+                    for child_name, child_module in module.named_modules():
+                        is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
+                        is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
+                        is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
+                        
+                        if not add_lora_in_attn_temporal:
+                            if "attn_temporal" in child_name:
+                                continue
+
+                        if is_linear or is_conv2d:
+                            lora_name = prefix + "." + name + "." + child_name
+                            lora_name = lora_name.replace(".", "_")
+
+                            dim = None
+                            alpha = None
+
+                            if is_linear or is_conv2d_1x1:
+                                dim = self.lora_dim
+                                alpha = self.alpha
+
+                            if dim is None or dim == 0:
+                                if is_linear or is_conv2d_1x1:
+                                    skipped.append(lora_name)
+                                continue
+
+                            lora = module_class(
+                                lora_name,
+                                child_module,
+                                self.multiplier,
+                                dim,
+                                alpha,
+                                dropout=dropout,
+                            )
+                            loras.append(lora)
+            return loras, skipped
+
+        text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
+
+        self.text_encoder_loras = []
+        skipped_te = []
+        for i, text_encoder in enumerate(text_encoders):
+            if text_encoder is not None:
+                text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
+                self.text_encoder_loras.extend(text_encoder_loras)
+                skipped_te += skipped
+        print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
+
+        self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
+        print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
+
+        # assertion
+        names = set()
+        for lora in self.text_encoder_loras + self.unet_loras:
+            assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
+            names.add(lora.lora_name)
+
+    def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
+        if apply_text_encoder:
+            print("enable LoRA for text encoder")
+        else:
+            self.text_encoder_loras = []
+
+        if apply_unet:
+            print("enable LoRA for U-Net")
+        else:
+            self.unet_loras = []
+
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.apply_to()
+            self.add_module(lora.lora_name, lora)
+
+    def set_multiplier(self, multiplier):
+        self.multiplier = multiplier
+        for lora in self.text_encoder_loras + self.unet_loras:
+            lora.multiplier = self.multiplier
+
+    def load_weights(self, file):
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import load_file
+
+            weights_sd = load_file(file)
+        else:
+            weights_sd = torch.load(file, map_location="cpu")
+        info = self.load_state_dict(weights_sd, False)
+        return info
+
+    def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
+        self.requires_grad_(True)
+        all_params = []
+
+        def enumerate_params(loras):
+            params = []
+            for lora in loras:
+                params.extend(lora.parameters())
+            return params
+
+        if self.text_encoder_loras:
+            param_data = {"params": enumerate_params(self.text_encoder_loras)}
+            if text_encoder_lr is not None:
+                param_data["lr"] = text_encoder_lr
+            all_params.append(param_data)
+
+        if self.unet_loras:
+            param_data = {"params": enumerate_params(self.unet_loras)}
+            if unet_lr is not None:
+                param_data["lr"] = unet_lr
+            all_params.append(param_data)
+
+        return all_params
+
+    def enable_gradient_checkpointing(self):
+        pass
+
+    def get_trainable_params(self):
+        return self.parameters()
+
+    def save_weights(self, file, dtype, metadata):
+        if metadata is not None and len(metadata) == 0:
+            metadata = None
+
+        state_dict = self.state_dict()
+
+        if dtype is not None:
+            for key in list(state_dict.keys()):
+                v = state_dict[key]
+                v = v.detach().clone().to("cpu").to(dtype)
+                state_dict[key] = v
+
+        if os.path.splitext(file)[1] == ".safetensors":
+            from safetensors.torch import save_file
+
+            # Precalculate model hashes to save time on indexing
+            if metadata is None:
+                metadata = {}
+            model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
+            metadata["sshs_model_hash"] = model_hash
+            metadata["sshs_legacy_hash"] = legacy_hash
+
+            save_file(state_dict, file, metadata)
+        else:
+            torch.save(state_dict, file)
+
+def create_network(
+    multiplier: float,
+    network_dim: Optional[int],
+    network_alpha: Optional[float],
+    text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
+    transformer,
+    neuron_dropout: Optional[float] = None,
+    add_lora_in_attn_temporal: bool = False,
+    **kwargs,
+):
+    if network_dim is None:
+        network_dim = 4  # default
+    if network_alpha is None:
+        network_alpha = 1.0
+
+    network = LoRANetwork(
+        text_encoder,
+        transformer,
+        multiplier=multiplier,
+        lora_dim=network_dim,
+        alpha=network_alpha,
+        dropout=neuron_dropout,
+        add_lora_in_attn_temporal=add_lora_in_attn_temporal,
+        varbose=True,
+    )
+    return network
+
+def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
+    LORA_PREFIX_TRANSFORMER = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+    if state_dict is None:
+        state_dict = load_file(lora_path, device=device)
+    else:
+        state_dict = state_dict
+    updates = defaultdict(dict)
+    for key, value in state_dict.items():
+        layer, elem = key.split('.', 1)
+        updates[layer][elem] = value
+
+    for layer, elems in updates.items():
+
+        if "lora_te" in layer:
+            if transformer_only:
+                continue
+            else:
+                layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+                curr_layer = pipeline.text_encoder
+        else:
+            layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
+            curr_layer = pipeline.transformer
+
+        temp_name = layer_infos.pop(0)
+        while len(layer_infos) > -1:
+            try:
+                curr_layer = curr_layer.__getattr__(temp_name)
+                if len(layer_infos) > 0:
+                    temp_name = layer_infos.pop(0)
+                elif len(layer_infos) == 0:
+                    break
+            except Exception:
+                if len(layer_infos) == 0:
+                    print('Error loading layer')
+                if len(temp_name) > 0:
+                    temp_name += "_" + layer_infos.pop(0)
+                else:
+                    temp_name = layer_infos.pop(0)
+
+        weight_up = elems['lora_up.weight'].to(dtype)
+        weight_down = elems['lora_down.weight'].to(dtype)
+        if 'alpha' in elems.keys():
+            alpha = elems['alpha'].item() / weight_up.shape[1]
+        else:
+            alpha = 1.0
+
+        curr_layer.weight.data = curr_layer.weight.data.to(device)
+        if len(weight_up.shape) == 4:
+            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
+                                                                    weight_down.squeeze(3).squeeze(2)).unsqueeze(
+                2).unsqueeze(3)
+        else:
+            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
+
+    return pipeline
+
+# TODO: Refactor with merge_lora.
+def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
+    """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
+    LORA_PREFIX_UNET = "lora_unet"
+    LORA_PREFIX_TEXT_ENCODER = "lora_te"
+    state_dict = load_file(lora_path, device=device)
+
+    updates = defaultdict(dict)
+    for key, value in state_dict.items():
+        layer, elem = key.split('.', 1)
+        updates[layer][elem] = value
+
+    for layer, elems in updates.items():
+
+        if "lora_te" in layer:
+            layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
+            curr_layer = pipeline.text_encoder
+        else:
+            layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
+            curr_layer = pipeline.transformer
+
+        temp_name = layer_infos.pop(0)
+        while len(layer_infos) > -1:
+            try:
+                curr_layer = curr_layer.__getattr__(temp_name)
+                if len(layer_infos) > 0:
+                    temp_name = layer_infos.pop(0)
+                elif len(layer_infos) == 0:
+                    break
+            except Exception:
+                if len(layer_infos) == 0:
+                    print('Error loading layer')
+                if len(temp_name) > 0:
+                    temp_name += "_" + layer_infos.pop(0)
+                else:
+                    temp_name = layer_infos.pop(0)
+
+        weight_up = elems['lora_up.weight'].to(dtype)
+        weight_down = elems['lora_down.weight'].to(dtype)
+        if 'alpha' in elems.keys():
+            alpha = elems['alpha'].item() / weight_up.shape[1]
+        else:
+            alpha = 1.0
+
+        curr_layer.weight.data = curr_layer.weight.data.to(device)
+        if len(weight_up.shape) == 4:
+            curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
+                                                                    weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
+        else:
+            curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
+
+    return pipeline
\ No newline at end of file
diff --git a/cogvideox/utils/utils.py b/cogvideox/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9298a3b65097a8503a1baac7c7570603e78a87a
--- /dev/null
+++ b/cogvideox/utils/utils.py
@@ -0,0 +1,208 @@
+import os
+import gc
+import imageio
+import numpy as np
+import torch
+import torchvision
+import cv2
+from einops import rearrange
+from PIL import Image
+
+def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
+    target_pixels = int(base_resolution) * int(base_resolution)
+    original_width, original_height = Image.open(image).size
+    ratio = (target_pixels / (original_width * original_height)) ** 0.5
+    width_slider = round(original_width * ratio)
+    height_slider = round(original_height * ratio)
+    return height_slider, width_slider
+
+def color_transfer(sc, dc):
+    """
+    Transfer color distribution from of sc, referred to dc.
+
+    Args:
+        sc (numpy.ndarray): input image to be transfered.
+        dc (numpy.ndarray): reference image
+
+    Returns:
+        numpy.ndarray: Transferred color distribution on the sc.
+    """
+
+    def get_mean_and_std(img):
+        x_mean, x_std = cv2.meanStdDev(img)
+        x_mean = np.hstack(np.around(x_mean, 2))
+        x_std = np.hstack(np.around(x_std, 2))
+        return x_mean, x_std
+
+    sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
+    s_mean, s_std = get_mean_and_std(sc)
+    dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
+    t_mean, t_std = get_mean_and_std(dc)
+    img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
+    np.putmask(img_n, img_n > 255, 255)
+    np.putmask(img_n, img_n < 0, 0)
+    dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
+    return dst
+
+def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
+    videos = rearrange(videos, "b c t h w -> t b c h w")
+    outputs = []
+    for x in videos:
+        x = torchvision.utils.make_grid(x, nrow=n_rows)
+        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
+        if rescale:
+            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
+        x = (x * 255).numpy().astype(np.uint8)
+        outputs.append(Image.fromarray(x))
+
+    if color_transfer_post_process:
+        for i in range(1, len(outputs)):
+            outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
+
+    os.makedirs(os.path.dirname(path), exist_ok=True)
+    if imageio_backend:
+        if path.endswith("mp4"):
+            imageio.mimsave(path, outputs, fps=fps)
+        else:
+            imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
+    else:
+        if path.endswith("mp4"):
+            path = path.replace('.mp4', '.gif')
+        outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
+
+def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
+    if validation_image_start is not None and validation_image_end is not None:
+        if type(validation_image_start) is str and os.path.isfile(validation_image_start):
+            image_start = clip_image = Image.open(validation_image_start).convert("RGB")
+            image_start = image_start.resize([sample_size[1], sample_size[0]])
+            clip_image = clip_image.resize([sample_size[1], sample_size[0]])
+        else:
+            image_start = clip_image = validation_image_start
+            image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
+            clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
+
+        if type(validation_image_end) is str and os.path.isfile(validation_image_end):
+            image_end = Image.open(validation_image_end).convert("RGB")
+            image_end = image_end.resize([sample_size[1], sample_size[0]])
+        else:
+            image_end = validation_image_end
+            image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
+
+        if type(image_start) is list:
+            clip_image = clip_image[0]
+            start_video = torch.cat(
+                [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], 
+                dim=2
+            )
+            input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
+            input_video[:, :, :len(image_start)] = start_video
+            
+            input_video_mask = torch.zeros_like(input_video[:, :1])
+            input_video_mask[:, :, len(image_start):] = 255
+        else:
+            input_video = torch.tile(
+                torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), 
+                [1, 1, video_length, 1, 1]
+            )
+            input_video_mask = torch.zeros_like(input_video[:, :1])
+            input_video_mask[:, :, 1:] = 255
+
+        if type(image_end) is list:
+            image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
+            end_video = torch.cat(
+                [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], 
+                dim=2
+            )
+            input_video[:, :, -len(end_video):] = end_video
+            
+            input_video_mask[:, :, -len(image_end):] = 0
+        else:
+            image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
+            input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
+            input_video_mask[:, :, -1:] = 0
+
+        input_video = input_video / 255
+
+    elif validation_image_start is not None:
+        if type(validation_image_start) is str and os.path.isfile(validation_image_start):
+            image_start = clip_image = Image.open(validation_image_start).convert("RGB")
+            image_start = image_start.resize([sample_size[1], sample_size[0]])
+            clip_image = clip_image.resize([sample_size[1], sample_size[0]])
+        else:
+            image_start = clip_image = validation_image_start
+            image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
+            clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
+        image_end = None
+        
+        if type(image_start) is list:
+            clip_image = clip_image[0]
+            start_video = torch.cat(
+                [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], 
+                dim=2
+            )
+            input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
+            input_video[:, :, :len(image_start)] = start_video
+            input_video = input_video / 255
+            
+            input_video_mask = torch.zeros_like(input_video[:, :1])
+            input_video_mask[:, :, len(image_start):] = 255
+        else:
+            input_video = torch.tile(
+                torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), 
+                [1, 1, video_length, 1, 1]
+            ) / 255
+            input_video_mask = torch.zeros_like(input_video[:, :1])
+            input_video_mask[:, :, 1:, ] = 255
+    else:
+        image_start = None
+        image_end = None
+        input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
+        input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
+        clip_image = None
+
+    del image_start
+    del image_end
+    gc.collect()
+
+    return  input_video, input_video_mask, clip_image
+
+def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None):
+    if isinstance(input_video_path, str):
+        cap = cv2.VideoCapture(input_video_path)
+        input_video = []
+
+        original_fps = cap.get(cv2.CAP_PROP_FPS)
+        frame_skip = 1 if fps is None else int(original_fps // fps)
+
+        frame_count = 0
+
+        while True:
+            ret, frame = cap.read()
+            if not ret:
+                break
+
+            if frame_count % frame_skip == 0:
+                frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
+                input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+
+            frame_count += 1
+
+        cap.release()
+    else:
+        input_video = input_video_path
+
+    input_video = torch.from_numpy(np.array(input_video))[:video_length]
+    input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
+
+    if validation_video_mask is not None:
+        validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
+        input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
+        
+        input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
+        input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
+        input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
+    else:
+        input_video_mask = torch.zeros_like(input_video[:, :1])
+        input_video_mask[:, :, :] = 255
+
+    return  input_video, input_video_mask, None
\ No newline at end of file
diff --git a/cogvideox/video_caption/README.md b/cogvideox/video_caption/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6bad4199d05c49e54bd1aaa439cf77601c0e0846
--- /dev/null
+++ b/cogvideox/video_caption/README.md
@@ -0,0 +1,174 @@
+# Video Caption
+English | [简体中文](./README_zh-CN.md)
+
+The folder contains codes for dataset preprocessing (i.e., video splitting, filtering, and recaptioning), and beautiful prompt used by CogVideoX-Fun.
+The entire process supports distributed parallel processing, capable of handling large-scale datasets.
+
+Meanwhile, we are collaborating with [Data-Juicer](https://github.com/modelscope/data-juicer/blob/main/docs/DJ_SORA.md),
+allowing you to easily perform video data processing on [Aliyun PAI-DLC](https://help.aliyun.com/zh/pai/user-guide/video-preprocessing/).
+
+# Table of Content
+- [Video Caption](#video-caption)
+- [Table of Content](#table-of-content)
+  - [Quick Start](#quick-start)
+    - [Setup](#setup)
+    - [Data Preprocessing](#data-preprocessing)
+      - [Data Preparation](#data-preparation)
+      - [Video Splitting](#video-splitting)
+      - [Video Filtering](#video-filtering)
+      - [Video Recaptioning](#video-recaptioning)
+    - [Beautiful Prompt (For CogVideoX-Fun Inference)](#beautiful-prompt-for-cogvideox-inference)
+      - [Batched Inference](#batched-inference)
+      - [OpenAI Server](#openai-server)
+
+## Quick Start
+
+### Setup
+AliyunDSW or Docker is recommended to setup the environment, please refer to [Quick Start](../../README.md#quick-start).
+You can also refer to the image build process in the [Dockerfile](../../Dockerfile.ds) to configure the conda environment and other dependencies locally.
+
+Since the video recaptioning depends on [llm-awq](https://github.com/mit-han-lab/llm-awq) for faster and memory efficient inference,
+the minimum GPU requirment should be RTX 3060 or A2 (CUDA Compute Capability >= 8.0).
+
+```shell
+# pull image
+docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
+
+# enter image
+docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
+
+# clone code
+git clone https://github.com/aigc-apps/CogVideoX-Fun.git
+
+# enter video_caption
+cd CogVideoX-Fun/cogvideox/video_caption
+```
+
+### Data Preprocessing
+#### Data Preparation
+Place the downloaded videos into a folder under [datasets](./datasets/) (preferably without nested structures, as the video names are used as unique IDs in subsequent processes).
+Taking Panda-70M as an example, the entire dataset directory structure is shown as follows:
+```
+📦 datasets/
+├── 📂 panda_70m/
+│   ├── 📂 videos/
+│   │   ├── 📂 data/
+│   │   │   └── 📄 --C66yU3LjM_2.mp4
+│   │   │   └── 📄 ...
+```
+
+#### Video Splitting
+CogVideoX-Fun utilizes [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) to identify scene changes within the video
+and performs video splitting via FFmpeg based on certain threshold values to ensure consistency of the video clip.
+Video clips shorter than 3 seconds will be discarded, and those longer than 10 seconds will be splitted recursively.
+
+The entire workflow of video splitting is in the [stage_1_video_splitting.sh](./scripts/stage_1_video_splitting.sh).
+After running
+```shell
+sh scripts/stage_1_video_splitting.sh
+```
+the video clips are obtained in `cogvideox/video_caption/datasets/panda_70m/videos_clips/data/`.
+
+#### Video Filtering
+Based on the videos obtained in the previous step, CogVideoX-Fun provides a simple yet effective pipeline to filter out high-quality videos for recaptioning.
+The overall process is as follows:
+
+- Aesthetic filtering: Filter out videos with poor content (blurry, dim, etc.) by calculating the average aesthetic score of uniformly sampled 4 frames via [aesthetic-predictor-v2-5](https://github.com/discus0434/aesthetic-predictor-v2-5).
+- Text filtering: Use [EasyOCR](https://github.com/JaidedAI/EasyOCR) to calculate the text area proportion of the middle frame to filter out videos with a large area of text.
+- Motion filtering: Calculate interframe optical flow differences to filter out videos that move too slowly or too quickly.
+
+The entire workflow of video filtering is in the [stage_2_video_filtering.sh](./scripts/stage_2_video_filtering.sh).
+After running
+```shell
+sh scripts/stage_2_video_filtering.sh
+```
+the aesthetic score, text score, and motion score of videos will be saved in the corresponding meta files in the folder `cogvideox/video_caption/datasets/panda_70m/videos_clips/`.
+
+> [!NOTE]
+> The computation of the aesthetic score depends on the [google/siglip-so400m-patch14-384 model](https://huggingface.co/google/siglip-so400m-patch14-384).
+Please run `HF_ENDPOINT=https://hf-mirror.com sh scripts/stage_2_video_filtering.sh` if you cannot access to huggingface.com.
+
+
+#### Video Recaptioning
+After obtaining the aboved high-quality filtered videos, CogVideoX-Fun utilizes [VILA1.5](https://github.com/NVlabs/VILA) to perform video recaptioning. 
+Subsequently, the recaptioning results are rewritten by LLMs to better meet with the requirements of video generation tasks. 
+Finally, an advanced VideoCLIPXL model is developed to filter out video-caption pairs with poor alignment, resulting in the final training dataset.
+
+Please download the video caption model from [VILA1.5](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e) of the appropriate size based on the GPU memory of your machine.
+For A100 with 40G VRAM, you can download [VILA1.5-40b-AWQ](https://huggingface.co/Efficient-Large-Model/VILA1.5-40b-AWQ) by running
+```shell
+# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
+huggingface-cli download Efficient-Large-Model/VILA1.5-40b-AWQ --local-dir-use-symlinks False --local-dir /PATH/TO/VILA_MODEL
+```
+
+Optionally, you can prepare local LLMs to rewrite the recaption results.
+For example, you can download [Meta-Llama-3-8B-Instruct](https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct) by running
+```shell
+# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
+huggingface-cli download NousResearch/Meta-Llama-3-8B-Instruct --local-dir-use-symlinks False --local-dir /PATH/TO/REWRITE_MODEL
+```
+
+The entire workflow of video recaption is in the [stage_3_video_recaptioning.sh](./scripts/stage_3_video_recaptioning.sh).
+After running
+```shell
+VILA_MODEL_PATH=/PATH/TO/VILA_MODEL REWRITE_MODEL_PATH=/PATH/TO/REWRITE_MODEL sh scripts/stage_3_video_recaptioning.sh
+``` 
+the final train file is obtained in `cogvideox/video_caption/datasets/panda_70m/videos_clips/meta_train_info.json`.
+
+
+### Beautiful Prompt (For CogVideoX-Fun Inference)
+Beautiful Prompt aims to rewrite and beautify the user-uploaded prompt via LLMs, mapping it to the style of CogVideoX-Fun's training captions,
+making it more suitable as the inference prompt and thus improving the quality of the generated videos.
+We support batched inference with local LLMs or OpenAI compatible server based on [vLLM](https://github.com/vllm-project/vllm) for beautiful prompt.
+
+#### Batched Inference
+1. Prepare original prompts in a jsonl file `cogvideox/video_caption/datasets/original_prompt.jsonl` with the following format:
+    ```json
+    {"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."}
+    {"prompt": "An underwater world with realistic fish and other creatures of the sea."}
+    {"prompt": "a monarch butterfly perched on a tree trunk in the forest."}
+    {"prompt": "a child in a room with a bottle of wine and a lamp."}
+    {"prompt": "two men in suits walking down a hallway."}
+    ```
+
+2. Then you can perform beautiful prompt by running
+    ```shell
+    # Meta-Llama-3-8B-Instruct is sufficient for this task.
+    # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
+
+    python caption_rewrite.py \
+        --video_metadata_path datasets/original_prompt.jsonl \
+        --caption_column "prompt" \
+        --batch_size 1 \
+        --model_name /path/to/your_llm \
+        --prompt prompt/beautiful_prompt.txt \
+        --prefix '"detailed description": ' \
+        --saved_path datasets/beautiful_prompt.jsonl \
+        --saved_freq 1
+    ```
+
+#### OpenAI Server
++ You can request OpenAI compatible server to perform beautiful prompt by running
+    ```shell
+    OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
+        --model "your_model_name" \
+        --prompt "your_prompt"
+    ```
+
++ You can also deploy the OpenAI Compatible Server locally using vLLM. For example:
+    ```shell
+    # Meta-Llama-3-8B-Instruct is sufficient for this task.
+    # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
+
+    # deploy the OpenAI compatible server
+    python -m vllm.entrypoints.openai.api_server serve /path/to/your_llm --dtype auto --api-key "your_api_key"
+    ```
+
+    Then you can perform beautiful prompt by running
+    ```shell
+    python -m beautiful_prompt.py \
+        --model /path/to/your_llm \
+        --prompt "your_prompt" \
+        --base_url "http://localhost:8000/v1" \
+        --api_key "your_api_key"
+    ```
diff --git a/cogvideox/video_caption/README_zh-CN.md b/cogvideox/video_caption/README_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc3266adc80a2bf315b2bed4d35831516f5d27c5
--- /dev/null
+++ b/cogvideox/video_caption/README_zh-CN.md
@@ -0,0 +1,159 @@
+# 数据预处理
+[English](./README.md) | 简体中文
+
+该文件夹包含 CogVideoX-Fun 使用的数据集预处理(即视频切分、过滤和生成描述)和提示词美化的代码。整个过程支持分布式并行处理,能够处理大规模数据集。
+
+此外,我们和 [Data-Juicer](https://github.com/modelscope/data-juicer/blob/main/docs/DJ_SORA.md) 合作,能让你在 [Aliyun PAI-DLC](https://help.aliyun.com/zh/pai/user-guide/video-preprocessing/) 轻松进行视频数据的处理。
+
+# 目录
+- [数据预处理](#数据预处理)
+- [目录](#目录)
+  - [快速开始](#快速开始)
+    - [安装](#安装)
+    - [数据集预处理](#数据集预处理)
+      - [数据准备](#数据准备)
+      - [视频切分](#视频切分)
+      - [视频过滤](#视频过滤)
+      - [视频描述](#视频描述)
+    - [提示词美化](#提示词美化)
+      - [批量推理](#批量推理)
+      - [OpenAI 服务器](#openai-服务器)
+
+
+## 快速开始
+### 安装
+推荐使用阿里云 DSW 和 Docker 来安装环境,请参考 [快速开始](../../README_zh-CN.md#1-云使用-aliyundswdocker). 你也可以参考 [Dockerfile](../../Dockerfile.ds) 中的镜像构建流程在本地安装对应的 conda 环境和其余依赖。
+
+为了提高推理速度和节省推理的显存,生成视频描述依赖于 [llm-awq](https://github.com/mit-han-lab/llm-awq)。因此,需要 RTX 3060 或者 A2 及以上的显卡 (CUDA Compute Capability >= 8.0)。
+
+```shell
+# pull image
+docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
+
+# enter image
+docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun
+
+# clone code
+git clone https://github.com/aigc-apps/CogVideoX-Fun.git
+
+# enter video_caption
+cd CogVideoX-Fun/cogvideox/video_caption
+```
+
+### 数据集预处理
+#### 数据准备
+将下载的视频准备到文件夹 [datasets](./datasets/)(最好不使用嵌套结构,因为视频名称在后续处理中用作唯一 ID)。以 Panda-70M 为例,完整的数据集目录结构如下所示:
+```
+📦 datasets/
+├── 📂 panda_70m/
+│   ├── 📂 videos/
+│   │   ├── 📂 data/
+│   │   │   └── 📄 --C66yU3LjM_2.mp4
+│   │   │   └── 📄 ...
+```
+
+#### 视频切分
+CogVideoX-Fun 使用 [PySceneDetect](https://github.com/Breakthrough/PySceneDetect) 来识别视频中的场景变化
+并根据某些阈值通过 FFmpeg 执行视频分割,以确保视频片段的一致性。
+短于 3 秒的视频片段将被丢弃,长于 10 秒的视频片段将被递归切分。
+
+视频切分的完整流程在 [stage_1_video_splitting.sh](./scripts/stage_1_video_splitting.sh)。执行
+```shell
+sh scripts/stage_1_video_splitting.sh
+```
+后,切分后的视频位于 `cogvideox/video_caption/datasets/panda_70m/videos_clips/data/`。
+
+#### 视频过滤
+基于上一步获得的视频,CogVideoX-Fun 提供了一个简单而有效的流程来过滤出高质量的视频。总体流程如下:
+
+- 美学过滤:通过 [aesthetic-predictor-v2-5](https://github.com/discus0434/aesthetic-predictor-v2-5) 计算均匀采样的 4 帧视频的平均美学分数,从而筛选出内容不佳(模糊、昏暗等)的视频。
+- 文本过滤:使用 [EasyOCR](https://github.com/JaidedAI/EasyOCR) 计算中间帧的文本区域比例,过滤掉含有大面积文本的视频。
+- 运动过滤:计算帧间光流差,过滤掉移动太慢或太快的视频。
+
+视频过滤的完整流程在 [stage_2_video_filtering.sh](./scripts/stage_2_video_filtering.sh)。执行
+```shell
+sh scripts/stage_2_video_filtering.sh
+```
+后,视频的美学得分、文本得分和运动得分对应的元文件保存在 `cogvideox/video_caption/datasets/panda_70m/videos_clips/`。
+
+> [!NOTE]
+> 美学得分的计算依赖于 [google/siglip-so400m-patch14-384 model](https://huggingface.co/google/siglip-so400m-patch14-384).
+请执行 `HF_ENDPOINT=https://hf-mirror.com sh scripts/stage_2_video_filtering.sh` 如果你无法访问 huggingface.com.
+
+#### 视频描述
+在获得上述高质量的过滤视频后,CogVideoX-Fun 利用 [VILA1.5](https://github.com/NVlabs/VILA) 来生成视频描述。随后,使用 LLMs 对生成的视频描述进行重写,以更好地满足视频生成任务的要求。最后,使用自研的 VideoCLIPXL 模型来过滤掉描述和视频内容不一致的数据,从而得到最终的训练数据集。
+
+请根据机器的显存从 [VILA1.5](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e) 下载合适大小的模型。对于 A100 40G,你可以执行下面的命令来下载 [VILA1.5-40b-AWQ](https://huggingface.co/Efficient-Large-Model/VILA1.5-40b-AWQ)
+```shell
+# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
+huggingface-cli download Efficient-Large-Model/VILA1.5-40b-AWQ --local-dir-use-symlinks False --local-dir /PATH/TO/VILA_MODEL
+```
+
+你可以选择性地准备 LLMs 来改写上述视频描述的结果。例如,你执行下面的命令来下载 [Meta-Llama-3-8B-Instruct](https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct)
+```shell
+# Add HF_ENDPOINT=https://hf-mirror.com before the command if you cannot access to huggingface.com
+huggingface-cli download NousResearch/Meta-Llama-3-8B-Instruct --local-dir-use-symlinks False --local-dir /PATH/TO/REWRITE_MODEL
+```
+
+视频描述的完整流程在 [stage_3_video_recaptioning.sh](./scripts/stage_3_video_recaptioning.sh).
+执行
+```shell
+VILA_MODEL_PATH=/PATH/TO/VILA_MODEL REWRITE_MODEL_PATH=/PATH/TO/REWRITE_MODEL sh scripts/stage_3_video_recaptioning.sh
+```
+后,最后的训练文件会保存在 `cogvideox/video_caption/datasets/panda_70m/videos_clips/meta_train_info.json`。
+
+### 提示词美化
+提示词美化旨在通过 LLMs 重写和美化用户上传的提示,将其映射为 CogVideoX-Fun 训练所使用的视频描述风格、
+使其更适合用作推理提示词,从而提高生成视频的质量。
+
+基于 [vLLM](https://github.com/vllm-project/vllm),我们支持使用本地 LLM 进行批量推理或请求 OpenAI 服务器的方式,以进行提示词美化。
+
+#### 批量推理
+1. 将原始的提示词以下面的格式准备在文件 `cogvideox/video_caption/datasets/original_prompt.jsonl` 中:
+    ```json
+    {"prompt": "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."}
+    {"prompt": "An underwater world with realistic fish and other creatures of the sea."}
+    {"prompt": "a monarch butterfly perched on a tree trunk in the forest."}
+    {"prompt": "a child in a room with a bottle of wine and a lamp."}
+    {"prompt": "two men in suits walking down a hallway."}
+    ```
+
+2. 随后你可以通过执行以下的命令进行提示词美化
+    ```shell
+    # Meta-Llama-3-8B-Instruct is sufficient for this task.
+    # Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
+
+    python caption_rewrite.py \
+        --video_metadata_path datasets/original_prompt.jsonl \
+        --caption_column "prompt" \
+        --batch_size 1 \
+        --model_name /path/to/your_llm \
+        --prompt prompt/beautiful_prompt.txt \
+        --prefix '"detailed description": ' \
+        --saved_path datasets/beautiful_prompt.jsonl \
+        --saved_freq 1
+    ```
+
+#### OpenAI 服务器
++ 你可以通过请求 OpenAI 服务器的方式来进行提示词美化
+    ```shell
+    OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
+        --model "your_model_name" \
+        --prompt "your_prompt"
+    ```
+
++ 你也可以执行以下命令,通过 vLLM 将本地 LLMs 部署成兼容 OpenAI 的服务器
+    ```shell
+    OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
+        --model "your_model_name" \
+        --prompt "your_prompt"
+    ```
+
+    然后再执行下面的命令来进行提示词美化
+    ```shell
+    python -m beautiful_prompt.py \
+        --model /path/to/your_llm \
+        --prompt "your_prompt" \
+        --base_url "http://localhost:8000/v1" \
+        --api_key "your_api_key"
+    ```
\ No newline at end of file
diff --git a/cogvideox/video_caption/beautiful_prompt.py b/cogvideox/video_caption/beautiful_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..339d7f8e33ffc58789572e90a7ad0c829b1443db
--- /dev/null
+++ b/cogvideox/video_caption/beautiful_prompt.py
@@ -0,0 +1,103 @@
+"""
+This script (optional) can rewrite and beautify the user-uploaded prompt via LLMs, mapping it to the style of cogvideox's training captions,
+making it more suitable as the inference prompt and thus improving the quality of the generated videos.
+
+Usage:
++ You can request OpenAI compatible server to perform beautiful prompt by running
+```shell
+export OPENAI_API_KEY="your_openai_api_key" OPENAI_BASE_URL="your_openai_base_url" python beautiful_prompt.py \
+    --model "your_model_name" \
+    --prompt "your_prompt"
+```
++ You can also deploy the OpenAI Compatible Server locally using vLLM. For example:
+```shell
+# Meta-Llama-3-8B-Instruct is sufficient for this task.
+# Download it from https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct or https://www.modelscope.cn/models/LLM-Research/Meta-Llama-3-8B-Instruct to /path/to/your_llm
+
+# deploy the OpenAI compatible server
+python -m vllm.entrypoints.openai.api_server serve /path/to/your_llm --dtype auto --api-key "your_api_key"
+```
+
+Then you can perform beautiful prompt by running
+```shell
+python -m beautiful_prompt.py \
+    --model /path/to/your_llm \
+    --prompt "your_prompt" \
+    --base_url "http://localhost:8000/v1" \
+    --api_key "your_api_key"
+```
+"""
+import argparse
+import os
+
+from openai import OpenAI
+
+from cogvideox.video_caption.caption_rewrite import extract_output
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Beautiful prompt.")
+    parser.add_argument("--model", type=str, required=True, help="The OpenAI model or the path to your local LLM.")
+    parser.add_argument("--prompt", type=str, required=True, help="The user-uploaded prompt.")
+    parser.add_argument(
+        "--template",
+        type=str,
+        default="cogvideox/video_caption/prompt/beautiful_prompt.txt",
+        help="A string or a txt file contains the template for beautiful prompt."
+    )
+    parser.add_argument(
+        "--max_retry_nums",
+        type=int,
+        default=5,
+        help="Maximum number of retries to obtain an output that meets the JSON format."
+    )
+    parser.add_argument(
+        "--base_url",
+        type=str,
+        default=None,
+        help="OpenAI API server url. If it is None, the OPENAI_BASE_URL from the environment variables will be used.",
+    )
+    parser.add_argument(
+        "--api_key",
+        type=str,
+        default=None,
+        help="OpenAI API key. If it is None, the OPENAI_API_KEY from the environment variables will be used.",
+    )
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    client = OpenAI(
+        base_url=os.getenv("OPENAI_BASE_URL", args.base_url),
+        api_key=os.environ.get("OPENAI_API_KEY", args.api_key),
+    )
+    if args.template.endswith(".txt") and os.path.exists(args.template):
+        with open(args.template, "r") as f:
+            args.template = "".join(f.readlines())
+    # print(f"Beautiful prompt template: {args.template}")
+
+    for _ in range(args.max_retry_nums):
+        completion = client.chat.completions.create(
+            model=args.model,
+            messages=[
+                # {"role": "system", "content": "You are a helpful assistant."},
+                {"role": "user", "content": args.template + "\n" + str(args.prompt)}
+            ],
+            temperature=0.7,
+            top_p=1,
+            max_tokens=1024,
+        )
+
+        output = completion.choices[0].message.content
+        output = extract_output(output, prefix='"detailed description": ')
+        if output is not None:
+            break
+    print(f"Beautiful prompt: {output}")
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/cogvideox/video_caption/caption_rewrite.py b/cogvideox/video_caption/caption_rewrite.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbdf1f12099ba773ab8738c489ab94b202ebcc1a
--- /dev/null
+++ b/cogvideox/video_caption/caption_rewrite.py
@@ -0,0 +1,224 @@
+import argparse
+import re
+import os
+from tqdm import tqdm
+
+import pandas as pd
+import torch
+from natsort import index_natsorted
+from vllm import LLM, SamplingParams
+from transformers import AutoTokenizer
+
+from utils.logger import logger
+
+
+def extract_output(s, prefix='"rewritten description": '):
+    """Customize the function according to the prompt."""
+    # Since some LLMs struggles to output strictly formatted JSON strings as specified by the prompt,
+    # thus manually parse the output string `{"rewritten description": "your rewritten description here"}`.
+    match = re.search(r"{(.+?)}", s, re.DOTALL)
+    if not match:
+        logger.warning(f"{s} is not in the json format. Return None.")
+        return None
+    output = match.group(1).strip()
+    if output.startswith(prefix):
+        output = output[len(prefix) :]
+        if output[0] == '"' and output[-1] == '"':
+            return output[1:-1]
+        else:
+            logger.warning(f"{output} does not start and end with the double quote. Return None.")
+            return None
+    else:
+        logger.warning(f"{output} does not start with {prefix}. Return None.")
+        return None
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Rewrite the video caption by LLMs.")
+    parser.add_argument(
+        "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default=None,
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument(
+        "--caption_column",
+        type=str,
+        default="caption",
+        help="The column contains the video caption.",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=128,
+        required=False,
+        help="The batch size for vllm inference. Adjust according to the number of GPUs to maximize inference throughput.",
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        default="NousResearch/Meta-Llama-3-8B-Instruct",
+    )
+    parser.add_argument(
+        "--prompt",
+        type=str,
+        required=True,
+        help="A string or a txt file contains the prompt.",
+    )
+    parser.add_argument(
+        "--prefix",
+        type=str,
+        required=True,
+        help="The prefix to extract the output from LLMs.",
+    )
+    parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+    parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.")
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.video_metadata_path.endswith(".csv"):
+        video_metadata_df = pd.read_csv(args.video_metadata_path)
+    elif args.video_metadata_path.endswith(".jsonl"):
+        video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    elif args.video_metadata_path.endswith(".json"):
+        video_metadata_df = pd.read_json(args.video_metadata_path)
+    else:
+        raise ValueError(f"The {args.video_metadata_path} must end with .csv, .jsonl or .json.")
+
+    saved_suffix = os.path.splitext(args.saved_path)[1]
+    if saved_suffix not in set([".csv", ".jsonl", ".json"]):
+        raise ValueError(f"The saved_path must end with .csv, .jsonl or .json.")
+
+    if os.path.exists(args.saved_path) and args.video_path_column is not None:
+        if args.saved_path.endswith(".csv"):
+            saved_metadata_df = pd.read_csv(args.saved_path)
+        elif args.saved_path.endswith(".jsonl"):
+            saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+
+        # Filter out the unprocessed video-caption pairs by setting the indicator=True.
+        merged_df = video_metadata_df.merge(saved_metadata_df, on=args.video_path_column, how="outer", indicator=True)
+        video_metadata_df = merged_df[merged_df["_merge"] == "left_only"]
+        # Sorting to guarantee the same result for each process.
+        video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df[args.video_path_column])].reset_index(
+            drop=True
+        )
+        logger.info(
+            f"Resume from {args.saved_path}: {len(saved_metadata_df)} processed and {len(video_metadata_df)} to be processed."
+        )
+
+    if args.prompt.endswith(".txt") and os.path.exists(args.prompt):
+        with open(args.prompt, "r") as f:
+            args.prompt = "".join(f.readlines())
+    logger.info(f"Prompt: {args.prompt}")
+
+    if args.video_path_column is not None:
+        video_path_list = video_metadata_df[args.video_path_column].tolist()
+    if args.caption_column in video_metadata_df.columns:
+        sampled_frame_caption_list = video_metadata_df[args.caption_column].tolist()
+    else:
+        # When two columns with the same name, the dataframe merge operation on will distinguish them by adding 'x' and 'y'.
+        sampled_frame_caption_list = video_metadata_df[args.caption_column + "_x"].tolist()
+
+    CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES", None)
+    tensor_parallel_size = torch.cuda.device_count() if CUDA_VISIBLE_DEVICES is None else len(CUDA_VISIBLE_DEVICES.split(","))
+    logger.info(f"Automatically set tensor_parallel_size={tensor_parallel_size} based on the available devices.")
+
+    llm = LLM(model=args.model_name, trust_remote_code=True, tensor_parallel_size=tensor_parallel_size)
+    if "Meta-Llama-3" in args.model_name:
+        if "Meta-Llama-3-70B" in args.model_name:
+            # Llama-3-70B should use the tokenizer from Llama-3-8B
+            # https://github.com/vllm-project/vllm/issues/4180#issuecomment-2068292942
+            tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
+        else:
+            tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+        stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
+        sampling_params = SamplingParams(temperature=0.7, top_p=1, max_tokens=1024, stop_token_ids=stop_token_ids)
+    else:
+        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+        sampling_params = SamplingParams(temperature=0.7, top_p=1, max_tokens=1024)
+
+    result_dict = {args.caption_column: []}
+    if args.video_path_column is not None:
+        result_dict = {args.video_path_column: [], args.caption_column: []}
+
+    for i in tqdm(range(0, len(sampled_frame_caption_list), args.batch_size)):
+        if args.video_path_column is not None:
+            batch_video_path = video_path_list[i : i + args.batch_size]
+        batch_caption = sampled_frame_caption_list[i : i + args.batch_size]
+        batch_prompt = []
+        for caption in batch_caption:
+            # batch_prompt.append("user:" + args.prompt + str(caption) + "\n assistant:")
+            messages = [
+                {"role": "system", "content": "You are a helpful assistant."},
+                {"role": "user", "content": args.prompt + "\n" + str(caption)},
+            ]
+            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+            batch_prompt.append(text)
+
+        batch_output = llm.generate(batch_prompt, sampling_params)
+        batch_output = [output.outputs[0].text.rstrip() for output in batch_output]
+        batch_output = [extract_output(output, prefix=args.prefix) for output in batch_output]
+
+        # Filter out data that does not meet the output format.
+        batch_result = []
+        if args.video_path_column is not None:
+            for video_path, output in zip(batch_video_path, batch_output):
+                if output is not None:
+                    batch_result.append((video_path, output))
+            batch_video_path, batch_output = zip(*batch_result)
+
+            result_dict[args.video_path_column].extend(batch_video_path)
+        else:
+            for output in batch_output:
+                if output is not None:
+                    batch_result.append(output)
+
+            result_dict[args.caption_column].extend(batch_result)
+
+        # Save the metadata every args.saved_freq.
+        if i != 0 and ((i // args.batch_size) % args.saved_freq) == 0:
+            if len(result_dict[args.caption_column]) > 0:
+                result_df = pd.DataFrame(result_dict)
+                if args.saved_path.endswith(".csv"):
+                    header = True if not os.path.exists(args.saved_path) else False
+                    result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+                elif args.saved_path.endswith(".jsonl"):
+                    result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+                elif args.saved_path.endswith(".json"):
+                    # Append is not supported.
+                    if os.path.exists(args.saved_path):
+                        saved_df = pd.read_json(args.saved_path, orient="records")
+                        result_df = pd.concat([saved_df, result_df], ignore_index=True)
+                    result_df.to_json(args.saved_path, orient="records", indent=4, force_ascii=False)
+                logger.info(f"Save result to {args.saved_path}.")
+
+            result_dict = {args.caption_column: []}
+            if args.video_path_column is not None:
+                result_dict = {args.video_path_column: [], args.caption_column: []}
+
+    if len(result_dict[args.caption_column]) > 0:
+        result_df = pd.DataFrame(result_dict)
+        if args.saved_path.endswith(".csv"):
+            header = True if not os.path.exists(args.saved_path) else False
+            result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+        elif args.saved_path.endswith(".jsonl"):
+            result_df.to_json(args.saved_path, orient="records", lines=True, mode="a")
+        elif args.saved_path.endswith(".json"):
+            # Append is not supported.
+            if os.path.exists(args.saved_path):
+                saved_df = pd.read_json(args.saved_path, orient="records")
+                result_df = pd.concat([saved_df, result_df], ignore_index=True)
+            result_df.to_json(args.saved_path, orient="records", indent=4, force_ascii=False)
+        logger.info(f"Save the final result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/cogvideox/video_caption/compute_motion_score.py b/cogvideox/video_caption/compute_motion_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1a8ee837bd689eb9c65f6580619653240e2d3d1
--- /dev/null
+++ b/cogvideox/video_caption/compute_motion_score.py
@@ -0,0 +1,186 @@
+import ast
+import argparse
+import gc
+import os
+from contextlib import contextmanager
+from pathlib import Path
+
+import cv2
+import numpy as np
+import pandas as pd
+from joblib import Parallel, delayed
+from natsort import natsorted
+from tqdm import tqdm
+
+from utils.logger import logger
+from utils.filter import filter
+
+
+@contextmanager
+def VideoCapture(video_path):
+    cap = cv2.VideoCapture(video_path)
+    try:
+        yield cap
+    finally:
+        cap.release()
+        del cap
+        gc.collect()
+
+
+def compute_motion_score(video_path):
+    video_motion_scores = []
+    sampling_fps = 2
+
+    try:
+        with VideoCapture(video_path) as cap:
+            fps = cap.get(cv2.CAP_PROP_FPS)
+            valid_fps = min(max(sampling_fps, 1), fps)
+            frame_interval = int(fps / valid_fps)
+            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+
+            # if cannot get the second frame, use the last one
+            frame_interval = min(frame_interval, total_frames - 1)
+
+            prev_frame = None
+            frame_count = -1
+            while cap.isOpened():
+                ret, frame = cap.read()
+                frame_count += 1
+
+                if not ret:
+                    break
+
+                # skip middle frames
+                if frame_count % frame_interval != 0:
+                    continue
+
+                gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+                if prev_frame is None:
+                    prev_frame = gray_frame
+                    continue
+
+                flow = cv2.calcOpticalFlowFarneback(
+                    prev_frame,
+                    gray_frame,
+                    None,
+                    pyr_scale=0.5,
+                    levels=3,
+                    winsize=15,
+                    iterations=3,
+                    poly_n=5,
+                    poly_sigma=1.2,
+                    flags=0,
+                )
+                mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
+                frame_motion_score = np.mean(mag)
+                video_motion_scores.append(frame_motion_score)
+                prev_frame = gray_frame
+
+            video_meta_info = {
+                "video_path": Path(video_path).name,
+                "motion_score": round(float(np.mean(video_motion_scores)), 5),
+            }
+            return video_meta_info
+
+    except Exception as e:
+        print(f"Compute motion score for video {video_path} with error: {e}.")
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compute the motion score of the videos.")
+    parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+    parser.add_argument(
+        "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+    parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.")
+    parser.add_argument("--n_jobs", type=int, default=1, help="The number of concurrent processes.")
+
+    parser.add_argument(
+        "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.")
+    parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
+    parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
+    parser.add_argument(
+        "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
+    parser.add_argument(
+        "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
+    parser.add_argument(
+        "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.")
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.video_metadata_path.endswith(".csv"):
+        video_metadata_df = pd.read_csv(args.video_metadata_path)
+    elif args.video_metadata_path.endswith(".jsonl"):
+        video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    else:
+        raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
+    video_path_list = video_metadata_df[args.video_path_column].tolist()
+
+    if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+        raise ValueError("The saved_path must end with .csv or .jsonl.")
+    
+    if os.path.exists(args.saved_path):
+        if args.saved_path.endswith(".csv"):
+            saved_metadata_df = pd.read_csv(args.saved_path)
+        elif args.saved_path.endswith(".jsonl"):
+            saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+        saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+        video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
+        logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+    
+    video_path_list = filter(
+        video_path_list,
+        basic_metadata_path=args.basic_metadata_path,
+        min_resolution=args.min_resolution,
+        min_duration=args.min_duration,
+        max_duration=args.max_duration,
+        asethetic_score_metadata_path=args.asethetic_score_metadata_path,
+        min_asethetic_score=args.min_asethetic_score,
+        asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
+        min_asethetic_score_siglip=args.min_asethetic_score_siglip,
+        text_score_metadata_path=args.text_score_metadata_path,
+        min_text_score=args.min_text_score,
+    )
+    video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
+    # Sorting to guarantee the same result for each process.
+    video_path_list = natsorted(video_path_list)
+
+    for i in tqdm(range(0, len(video_path_list), args.saved_freq)):
+        result_list = Parallel(n_jobs=args.n_jobs)(
+            delayed(compute_motion_score)(video_path) for video_path in tqdm(video_path_list[i: i + args.saved_freq])
+        )
+        result_list = [result for result in result_list if result is not None]
+        if len(result_list) == 0:
+            continue
+
+        result_df = pd.DataFrame(result_list)
+        if args.saved_path.endswith(".csv"):
+            header = False if os.path.exists(args.saved_path) else True
+            result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+        elif args.saved_path.endswith(".jsonl"):
+            result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+        logger.info(f"Save result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/cogvideox/video_caption/compute_text_score.py b/cogvideox/video_caption/compute_text_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..201ea0e53312bcfec0d2d584ff2f6eb1504ac5f4
--- /dev/null
+++ b/cogvideox/video_caption/compute_text_score.py
@@ -0,0 +1,214 @@
+import argparse
+import os
+from pathlib import Path
+
+import easyocr
+import numpy as np
+import pandas as pd
+from accelerate import PartialState
+from accelerate.utils import gather_object
+from natsort import natsorted
+from tqdm import tqdm
+from torchvision.datasets.utils import download_url
+
+from utils.logger import logger
+from utils.video_utils import extract_frames
+from utils.filter import filter
+
+
+def init_ocr_reader(root: str = "~/.cache/easyocr", device: str = "gpu"):
+    root = os.path.expanduser(root)
+    if not os.path.exists(root):
+        os.makedirs(root)
+    download_url(
+        "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/easyocr/craft_mlt_25k.pth",
+        root,
+        filename="craft_mlt_25k.pth",
+        md5="2f8227d2def4037cdb3b34389dcf9ec1",
+    )
+    ocr_reader = easyocr.Reader(
+        lang_list=["en", "ch_sim"],
+        gpu=device,
+        recognizer=False,
+        verbose=False,
+        model_storage_directory=root,
+    )
+
+    return ocr_reader
+
+
+def triangle_area(p1, p2, p3):
+    """Compute the triangle area according to its coordinates.
+    """
+    x1, y1 = p1
+    x2, y2 = p2
+    x3, y3 = p3
+    tri_area = 0.5 * np.abs(x1 * y2 + x2 * y3 + x3 * y1 - x2 * y1 - x3 * y2 - x1 * y3)
+    return tri_area
+
+
+def compute_text_score(video_path, ocr_reader):
+    _, images = extract_frames(video_path, sample_method="mid")
+    images = [np.array(image) for image in images]
+
+    frame_ocr_area_ratios = []
+    for image in images:
+        # horizontal detected results and free-form detected
+        horizontal_list, free_list = ocr_reader.detect(np.asarray(image))
+        width, height = image.shape[0], image.shape[1]
+
+        total_area = width * height
+        # rectangles
+        rect_area = 0
+        for xmin, xmax, ymin, ymax in horizontal_list[0]:
+            if xmax < xmin or ymax < ymin:
+                continue
+            rect_area += (xmax - xmin) * (ymax - ymin)
+        # free-form
+        quad_area = 0
+        try:
+            for points in free_list[0]:
+                triangle1 = points[:3]
+                quad_area += triangle_area(*triangle1)
+                triangle2 = points[3:] + [points[0]]
+                quad_area += triangle_area(*triangle2)
+        except:
+            quad_area = 0
+        text_area = rect_area + quad_area
+
+        frame_ocr_area_ratios.append(text_area / total_area)
+
+    video_meta_info = {
+        "video_path": Path(video_path).name,
+        "text_score": round(np.mean(frame_ocr_area_ratios), 5),
+    }
+
+    return video_meta_info
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compute the text score of the middle frame in the videos.")
+    parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+    parser.add_argument(
+        "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+    parser.add_argument("--saved_freq", type=int, default=100, help="The frequency to save the output results.")
+
+    parser.add_argument(
+        "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.")
+    parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
+    parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
+    parser.add_argument(
+        "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
+    parser.add_argument(
+        "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
+    parser.add_argument(
+        "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.")
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.video_metadata_path.endswith(".csv"):
+        video_metadata_df = pd.read_csv(args.video_metadata_path)
+    elif args.video_metadata_path.endswith(".jsonl"):
+        video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    else:
+        raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
+    video_path_list = video_metadata_df[args.video_path_column].tolist()
+
+    if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+        raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+    if os.path.exists(args.saved_path):
+        if args.saved_path.endswith(".csv"):
+            saved_metadata_df = pd.read_csv(args.saved_path)
+        elif args.saved_path.endswith(".jsonl"):
+            saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+        saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+        video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
+        logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.")
+    
+    video_path_list = filter(
+        video_path_list,
+        basic_metadata_path=args.basic_metadata_path,
+        min_resolution=args.min_resolution,
+        min_duration=args.min_duration,
+        max_duration=args.max_duration,
+        asethetic_score_metadata_path=args.asethetic_score_metadata_path,
+        min_asethetic_score=args.min_asethetic_score,
+        asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
+        min_asethetic_score_siglip=args.min_asethetic_score_siglip,
+        motion_score_metadata_path=args.motion_score_metadata_path,
+        min_motion_score=args.min_motion_score,
+    )
+    video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
+    # Sorting to guarantee the same result for each process.
+    video_path_list = natsorted(video_path_list)
+
+    state = PartialState()
+    if state.is_main_process:
+        # Check if the model is downloaded in the main process.
+        ocr_reader = init_ocr_reader(device="cpu")
+    state.wait_for_everyone()
+    ocr_reader = init_ocr_reader(device=state.device)
+
+    index = len(video_path_list) - len(video_path_list) % state.num_processes
+    # Avoid the NCCL timeout in the final gather operation.
+    logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to ensure each process handles the same number of videos.")
+    video_path_list = video_path_list[:index]
+    logger.info(f"{len(video_path_list)} videos are to be processed.")
+
+    result_list = []
+    with state.split_between_processes(video_path_list) as splitted_video_path_list:
+        for i, video_path in enumerate(tqdm(splitted_video_path_list)):
+            try:
+                video_meta_info = compute_text_score(video_path, ocr_reader)
+                result_list.append(video_meta_info)
+            except Exception as e:
+                logger.warning(f"Compute text score for video {video_path} with error: {e}.")
+            if i != 0 and i % args.saved_freq == 0:
+                state.wait_for_everyone()
+                gathered_result_list = gather_object(result_list)
+                if state.is_main_process and len(gathered_result_list) != 0:
+                    result_df = pd.DataFrame(gathered_result_list)
+                    if args.saved_path.endswith(".csv"):
+                        header = False if os.path.exists(args.saved_path) else True
+                        result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+                    elif args.saved_path.endswith(".jsonl"):
+                        result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+                    logger.info(f"Save result to {args.saved_path}.")
+                result_list = []
+
+    state.wait_for_everyone()
+    gathered_result_list = gather_object(result_list)
+    if state.is_main_process and len(gathered_result_list) != 0:
+        result_df = pd.DataFrame(gathered_result_list)
+        if args.saved_path.endswith(".csv"):
+            header = False if os.path.exists(args.saved_path) else True
+            result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+        elif args.saved_path.endswith(".jsonl"):
+            result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+        logger.info(f"Save the final result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/cogvideox/video_caption/compute_video_quality.py b/cogvideox/video_caption/compute_video_quality.py
new file mode 100644
index 0000000000000000000000000000000000000000..8591a879801f72210833c06321889c67204cfde1
--- /dev/null
+++ b/cogvideox/video_caption/compute_video_quality.py
@@ -0,0 +1,201 @@
+import argparse
+import os
+
+import pandas as pd
+from accelerate import PartialState
+from accelerate.utils import gather_object
+from natsort import index_natsorted
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+
+import utils.image_evaluator as image_evaluator
+import utils.video_evaluator as video_evaluator
+from utils.logger import logger
+from utils.video_dataset import VideoDataset, collate_fn
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compute scores of uniform sampled frames from videos.")
+    parser.add_argument(
+        "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+    parser.add_argument(
+        "--caption_column",
+        type=str,
+        default=None,
+        help="The column contains the caption.",
+    )
+    parser.add_argument(
+        "--frame_sample_method",
+        type=str,
+        choices=["mid", "uniform", "image"],
+        default="uniform",
+    )
+    parser.add_argument(
+        "--num_sampled_frames",
+        type=int,
+        default=8,
+        help="num_sampled_frames",
+    )
+    parser.add_argument("--metrics", nargs="+", type=str, required=True, help="The evaluation metric(s) for generated images.")
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=10,
+        required=False,
+        help="The batch size for the video dataset.",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=4,
+        required=False,
+        help="The number of workers for the video dataset.",
+    )
+    parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+    parser.add_argument("--saved_freq", type=int, default=1000, help="The frequency to save the output results.")
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+    
+    if args.video_metadata_path.endswith(".csv"):
+        video_metadata_df = pd.read_csv(args.video_metadata_path)
+    elif args.video_metadata_path.endswith(".jsonl"):
+        video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    else:
+        raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
+
+    if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+        raise ValueError("The saved_path must end with .csv or .jsonl.")
+    
+    if os.path.exists(args.saved_path):
+        if args.saved_path.endswith(".csv"):
+            saved_metadata_df = pd.read_csv(args.saved_path)
+        elif args.saved_path.endswith(".jsonl"):
+            saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+
+        # Filter out the unprocessed video-caption pairs by setting the indicator=True.
+        merged_df = video_metadata_df.merge(saved_metadata_df, on="video_path", how="outer", indicator=True)
+        video_metadata_df = merged_df[merged_df["_merge"] == "left_only"]
+        # Sorting to guarantee the same result for each process.
+        video_metadata_df = video_metadata_df.iloc[index_natsorted(video_metadata_df["video_path"])].reset_index(drop=True)
+        if args.caption_column is None:
+            video_metadata_df = video_metadata_df[[args.video_path_column]]
+        else:
+            video_metadata_df = video_metadata_df[[args.video_path_column, args.caption_column + "_x"]]
+            video_metadata_df.rename(columns={args.caption_column + "_x": args.caption_column}, inplace=True)
+        logger.info(f"Resume from {args.saved_path}: {len(saved_metadata_df)} processed and {len(video_metadata_df)} to be processed.")
+
+    state = PartialState()
+    metric_fns = []
+    for metric in args.metrics:
+        if hasattr(image_evaluator, metric):  # frame-wise
+            if state.is_main_process:
+                logger.info("Initializing frame-wise evaluator metrics...")
+                # Check if the model is downloaded in the main process.
+                getattr(image_evaluator, metric)(device="cpu")
+            state.wait_for_everyone()
+            metric_fns.append(getattr(image_evaluator, metric)(device=state.device))
+        else:  # video-wise
+            if state.is_main_process:
+                logger.info("Initializing video-wise evaluator metrics...")
+                # Check if the model is downloaded in the main process.
+                getattr(video_evaluator, metric)(device="cpu")
+            state.wait_for_everyone()
+            metric_fns.append(getattr(video_evaluator, metric)(device=state.device))
+
+    result_dict = {args.video_path_column: [], "sample_frame_idx": []}
+    for metric in metric_fns:
+        result_dict[str(metric)] = []
+    if args.caption_column is not None:
+        result_dict[args.caption_column] = []
+
+    if args.frame_sample_method == "image":
+        logger.warning("Set args.num_sampled_frames to 1 since args.frame_sample_method is image.")
+        args.num_sampled_frames = 1
+    
+    index = len(video_metadata_df) - len(video_metadata_df) % state.num_processes
+    # Avoid the NCCL timeout in the final gather operation.
+    logger.info(f"Drop {len(video_metadata_df) % state.num_processes} videos to ensure each process handles the same number of videos.")
+    video_metadata_df = video_metadata_df.iloc[:index]
+    logger.info(f"{len(video_metadata_df)} videos are to be processed.")
+
+    video_metadata_list = video_metadata_df.to_dict(orient='list')
+    with state.split_between_processes(video_metadata_list) as splitted_video_metadata:
+        video_dataset = VideoDataset(
+            dataset_inputs=splitted_video_metadata,
+            video_folder=args.video_folder,
+            text_column=args.caption_column,
+            sample_method=args.frame_sample_method,
+            num_sampled_frames=args.num_sampled_frames
+        )
+        video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
+
+        for idx, batch in enumerate(tqdm(video_loader)):
+            if len(batch) > 0:
+                batch_video_path = batch["path"]
+                result_dict["sample_frame_idx"].extend(batch["sampled_frame_idx"])
+                batch_frame = batch["sampled_frame"]  # [batch_size, num_sampled_frames, H, W, C]
+                batch_caption = None
+                if args.caption_column is not None:
+                    batch_caption = batch["text"]
+                    result_dict["caption"].extend(batch_caption)
+                # Compute the quality.
+                for i, metric in enumerate(args.metrics):
+                    quality_scores = metric_fns[i](batch_frame, batch_caption)
+                    if isinstance(quality_scores[0], list):  # frame-wise
+                        quality_scores = [
+                            [round(score, 5) for score in inner_list]
+                            for inner_list in quality_scores
+                        ]
+                    else:  # video-wise
+                        quality_scores = [round(score, 5) for score in quality_scores]
+                    result_dict[str(metric_fns[i])].extend(quality_scores)
+        
+                if args.video_folder == "":
+                    saved_video_path_list = batch_video_path
+                else:
+                    saved_video_path_list = [os.path.relpath(video_path, args.video_folder) for video_path in batch_video_path]
+                result_dict[args.video_path_column].extend(saved_video_path_list)
+
+            # Save the metadata in the main process every saved_freq.
+            if (idx != 0) and (idx % args.saved_freq == 0):
+                state.wait_for_everyone()
+                gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
+                if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0:
+                    result_df = pd.DataFrame(gathered_result_dict)
+                    if args.saved_path.endswith(".csv"):
+                        header = False if os.path.exists(args.saved_path) else True
+                        result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+                    elif args.saved_path.endswith(".jsonl"):
+                        result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+                    logger.info(f"Save result to {args.saved_path}.")
+                for k in result_dict.keys():
+                    result_dict[k] = []
+    
+    # Wait for all processes to finish and gather the final result.
+    state.wait_for_everyone()
+    gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
+    # Save the metadata in the main process.
+    if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0:
+        result_df = pd.DataFrame(gathered_result_dict)
+        if args.saved_path.endswith(".csv"):
+            header = False if os.path.exists(args.saved_path) else True
+            result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+        elif args.saved_path.endswith(".jsonl"):
+            result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+        logger.info(f"Save the final result to {args.saved_path}.")
+
+if __name__ == "__main__":
+    main()
diff --git a/cogvideox/video_caption/cutscene_detect.py b/cogvideox/video_caption/cutscene_detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f9a3230bf3662f2e5be83948ec818569721985d
--- /dev/null
+++ b/cogvideox/video_caption/cutscene_detect.py
@@ -0,0 +1,97 @@
+import argparse
+import os
+from copy import deepcopy
+from pathlib import Path
+from multiprocessing import Pool
+
+import pandas as pd
+from scenedetect import open_video, SceneManager
+from scenedetect.detectors import ContentDetector
+from tqdm import tqdm
+
+from utils.logger import logger
+
+
+def cutscene_detection_star(args):
+    return cutscene_detection(*args)
+
+
+def cutscene_detection(video_path, saved_path, cutscene_threshold=27, min_scene_len=15):
+    try:
+        if os.path.exists(saved_path):
+            logger.info(f"{video_path} has been processed.")
+            return
+        # Use PyAV as the backend to avoid (to some exent) containing the last frame of the previous scene.
+        # https://github.com/Breakthrough/PySceneDetect/issues/279#issuecomment-2152596761.
+        video = open_video(video_path, backend="pyav")
+        frame_rate, frame_size = video.frame_rate, video.frame_size
+        duration = deepcopy(video.duration)
+
+        frame_points, frame_timecode = [], {}
+        scene_manager = SceneManager()
+        scene_manager.add_detector(
+            # [ContentDetector, ThresholdDetector, AdaptiveDetector]
+            ContentDetector(threshold=cutscene_threshold, min_scene_len=min_scene_len)
+        )
+        scene_manager.detect_scenes(video, show_progress=False)
+        scene_list = scene_manager.get_scene_list()
+        for scene in scene_list:
+            for frame_time_code in scene:
+                frame_index = frame_time_code.get_frames()
+                if frame_index not in frame_points:
+                    frame_points.append(frame_index)
+                frame_timecode[frame_index] = frame_time_code
+        
+        del video, scene_manager
+        
+        frame_points = sorted(frame_points)
+        output_scene_list = []
+        for idx in range(len(frame_points) - 1):
+            output_scene_list.append((frame_timecode[frame_points[idx]], frame_timecode[frame_points[idx+1]]))
+        
+        timecode_list = [(frame_timecode_tuple[0].get_timecode(), frame_timecode_tuple[1].get_timecode()) for frame_timecode_tuple in output_scene_list]
+        meta_scene = [{
+            "video_path": Path(video_path).name,
+            "timecode_list": timecode_list,
+            "fram_rate": frame_rate,
+            "frame_size": frame_size,
+            "duration": str(duration)  # __repr__
+        }]
+        pd.DataFrame(meta_scene).to_json(saved_path, orient="records", lines=True)
+    except Exception as e:
+        logger.warning(f"Cutscene detection with {video_path} failed. Error is: {e}.")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Cutscene Detection")
+    parser.add_argument(
+        "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+    parser.add_argument("--saved_folder", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+    parser.add_argument("--n_jobs", type=int, default=1, help="The number of processes.")
+
+    args = parser.parse_args()
+
+    metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    video_path_list = metadata_df[args.video_path_column].tolist()
+    video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
+
+    if not os.path.exists(args.saved_folder):
+        os.makedirs(args.saved_folder, exist_ok=True)
+    # The glob can be slow when there are many small jsonl files.
+    saved_path_list = [os.path.join(args.saved_folder, Path(video_path).stem + ".jsonl") for video_path in video_path_list]
+    args_list = [
+        (video_path, saved_path)
+        for video_path, saved_path in zip(video_path_list, saved_path_list)
+    ]
+    # Since the length of the video is not uniform, the gather operation is not performed.
+    # We need to run easyanimate/video_caption/utils/gather_jsonl.py after the program finised.
+    with Pool(args.n_jobs) as pool:
+        results = list(tqdm(pool.imap(cutscene_detection_star, args_list), total=len(video_path_list)))
diff --git a/cogvideox/video_caption/filter_meta_train.py b/cogvideox/video_caption/filter_meta_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fff7707884046ff5ef7d85978b12d897b2ae22f
--- /dev/null
+++ b/cogvideox/video_caption/filter_meta_train.py
@@ -0,0 +1,88 @@
+import argparse
+import os
+
+import pandas as pd
+from natsort import natsorted
+
+from utils.logger import logger
+from utils.filter import filter
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--caption_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+    parser.add_argument(
+        "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_resolution", type=float, default=720*1280, help="The resolution threshold.")
+    parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
+    parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
+    parser.add_argument(
+        "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
+    parser.add_argument(
+        "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality (SigLIP) metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
+    parser.add_argument(
+        "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.")
+    parser.add_argument(
+        "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.")
+    parser.add_argument(
+        "--videoclipxl_score_metadata_path", type=str, default=None, help="The path to the video-caption VideoCLIPXL score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_videoclipxl_score", type=float, default=0.20, help="The VideoCLIPXL score threshold.")
+    parser.add_argument("--saved_path", type=str, required=True)
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    raw_caption_df = pd.read_json(args.caption_metadata_path, lines=True)
+    video_path_list = raw_caption_df[args.video_path_column].to_list()
+    filtered_video_path_list = filter(
+        video_path_list,
+        basic_metadata_path=args.basic_metadata_path,
+        min_resolution=args.min_resolution,
+        min_duration=args.min_duration,
+        max_duration=args.max_duration,
+        asethetic_score_metadata_path=args.asethetic_score_metadata_path,
+        min_asethetic_score=args.min_asethetic_score,
+        asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
+        min_asethetic_score_siglip=args.min_asethetic_score_siglip,
+        text_score_metadata_path=args.text_score_metadata_path,
+        min_text_score=args.min_text_score,
+        motion_score_metadata_path=args.motion_score_metadata_path,
+        min_motion_score=args.min_motion_score,
+        videoclipxl_score_metadata_path=args.videoclipxl_score_metadata_path,
+        min_videoclipxl_score=args.min_videoclipxl_score,
+        video_path_column=args.video_path_column
+    )
+    filtered_video_path_list = natsorted(filtered_video_path_list)
+    filtered_caption_df = raw_caption_df[raw_caption_df[args.video_path_column].isin(filtered_video_path_list)]
+    train_df = filtered_caption_df.rename(columns={"video_path": "file_path", "caption": "text"})
+    train_df["file_path"] = train_df["file_path"].map(lambda x: os.path.join(args.video_folder, x))
+    train_df["type"] = "video"
+    train_df.to_json(args.saved_path, orient="records", force_ascii=False, indent=2)
+    logger.info(f"The final train file with {len(train_df)} videos are saved to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/cogvideox/video_caption/package_patches/easyocr_detection_patched.py b/cogvideox/video_caption/package_patches/easyocr_detection_patched.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2cffa2b00c7c90aafcde307ce27307ed6e71dbf
--- /dev/null
+++ b/cogvideox/video_caption/package_patches/easyocr_detection_patched.py
@@ -0,0 +1,114 @@
+"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py.
+1. Disable DataParallel.
+"""
+import torch
+import torch.backends.cudnn as cudnn
+from torch.autograd import Variable
+from PIL import Image
+from collections import OrderedDict
+
+import cv2
+import numpy as np
+from .craft_utils import getDetBoxes, adjustResultCoordinates
+from .imgproc import resize_aspect_ratio, normalizeMeanVariance
+from .craft import CRAFT
+
+def copyStateDict(state_dict):
+    if list(state_dict.keys())[0].startswith("module"):
+        start_idx = 1
+    else:
+        start_idx = 0
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        name = ".".join(k.split(".")[start_idx:])
+        new_state_dict[name] = v
+    return new_state_dict
+
+def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False):
+    if isinstance(image, np.ndarray) and len(image.shape) == 4:  # image is batch of np arrays
+        image_arrs = image
+    else:                                                        # image is single numpy array
+        image_arrs = [image]
+
+    img_resized_list = []
+    # resize
+    for img in image_arrs:
+        img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size,
+                                                                      interpolation=cv2.INTER_LINEAR,
+                                                                      mag_ratio=mag_ratio)
+        img_resized_list.append(img_resized)
+    ratio_h = ratio_w = 1 / target_ratio
+    # preprocessing
+    x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1))
+         for n_img in img_resized_list]
+    x = torch.from_numpy(np.array(x))
+    x = x.to(device)
+
+    # forward pass
+    with torch.no_grad():
+        y, feature = net(x)
+
+    boxes_list, polys_list = [], []
+    for out in y:
+        # make score and link map
+        score_text = out[:, :, 0].cpu().data.numpy()
+        score_link = out[:, :, 1].cpu().data.numpy()
+
+        # Post-processing
+        boxes, polys, mapper = getDetBoxes(
+            score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars)
+
+        # coordinate adjustment
+        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
+        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
+        if estimate_num_chars:
+            boxes = list(boxes)
+            polys = list(polys)
+        for k in range(len(polys)):
+            if estimate_num_chars:
+                boxes[k] = (boxes[k], mapper[k])
+            if polys[k] is None:
+                polys[k] = boxes[k]
+        boxes_list.append(boxes)
+        polys_list.append(polys)
+
+    return boxes_list, polys_list
+
+def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False):
+    net = CRAFT()
+
+    if device == 'cpu':
+        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
+        if quantize:
+            try:
+                torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True)
+            except:
+                pass
+    else:
+        net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
+        # net = torch.nn.DataParallel(net).to(device)
+        net = net.to(device)
+        cudnn.benchmark = cudnn_benchmark
+
+    net.eval()
+    return net
+
+def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs):
+    result = []
+    estimate_num_chars = optimal_num_chars is not None
+    bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector,
+                                       image, text_threshold,
+                                       link_threshold, low_text, poly,
+                                       device, estimate_num_chars)
+    if estimate_num_chars:
+        polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))]
+                      for polys in polys_list]
+
+    for polys in polys_list:
+        single_img_result = []
+        for i, box in enumerate(polys):
+            poly = np.array(box).astype(np.int32).reshape((-1))
+            single_img_result.append(poly)
+        result.append(single_img_result)
+
+    return result
diff --git a/cogvideox/video_caption/package_patches/vila_siglip_encoder_patched.py b/cogvideox/video_caption/package_patches/vila_siglip_encoder_patched.py
new file mode 100644
index 0000000000000000000000000000000000000000..50698389df8b4d9aca422ea9dbf798071604acb2
--- /dev/null
+++ b/cogvideox/video_caption/package_patches/vila_siglip_encoder_patched.py
@@ -0,0 +1,42 @@
+# Modified from https://github.com/NVlabs/VILA/blob/1c88211/llava/model/multimodal_encoder/siglip_encoder.py
+# 1. Support transformers >= 4.36.2.
+import torch
+import transformers
+from packaging import version
+from transformers import AutoConfig, AutoModel, PretrainedConfig
+
+from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2
+
+if version.parse(transformers.__version__) > version.parse("4.36.2"):
+    from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
+else:
+    from .siglip import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
+
+
+class SiglipVisionTower(VisionTower):
+    def __init__(self, model_name_or_path: str, config: PretrainedConfig, state_dict=None):
+        super().__init__(model_name_or_path, config)
+        self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
+        self.vision_tower = SiglipVisionModel.from_pretrained(
+            # TODO(ligeng): why pass config here leading to errors?
+            model_name_or_path, torch_dtype=eval(config.model_dtype), state_dict=state_dict
+        )
+        self.is_loaded = True
+
+
+class SiglipVisionTowerS2(VisionTowerS2):
+    def __init__(self, model_name_or_path: str, config: PretrainedConfig):
+        super().__init__(model_name_or_path, config)
+        self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
+        self.vision_tower = SiglipVisionModel.from_pretrained(
+            model_name_or_path, torch_dtype=eval(config.model_dtype)
+        )
+
+        # Make sure it crops/resizes the image to the largest scale in self.scales to maintain high-res information
+        self.image_processor.size['height'] = self.image_processor.size['width'] = self.scales[-1]
+
+        self.is_loaded = True
+
+if version.parse(transformers.__version__) <= version.parse("4.36.2"):
+    AutoConfig.register("siglip_vision_model", SiglipVisionConfig)
+    AutoModel.register(SiglipVisionConfig, SiglipVisionModel)
diff --git a/cogvideox/video_caption/prompt/beautiful_prompt.txt b/cogvideox/video_caption/prompt/beautiful_prompt.txt
new file mode 100644
index 0000000000000000000000000000000000000000..88bd6e3ff7bc440fc2e882652b6fae3379f9a00e
--- /dev/null
+++ b/cogvideox/video_caption/prompt/beautiful_prompt.txt
@@ -0,0 +1,9 @@
+I will upload some brief prompt words to be used for AI-generated videos. Please expand these brief prompt words into a more detailed description to enhance the quality of the generated videos. The detailed description should include the main subject (person, object, animal, or none) actions and their attributes or status sequence, the background (the objects, location, weather, and time), the view shot and camera movement.
+The final detailed description must not exceed 200 words. Output with the following json format:
+{"detailed description": "your detailed description here"}
+
+Here is an example:
+brief prompt words: "A stylish woman in a black leather jacket, red dress, and boots walks confidently down a damp Tokyo street."
+{"detailed description": "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."}
+
+Here are the brief prompt words:
\ No newline at end of file
diff --git a/cogvideox/video_caption/prompt/rewrite.txt b/cogvideox/video_caption/prompt/rewrite.txt
new file mode 100644
index 0000000000000000000000000000000000000000..41212f8563c682867b00f3e40f04cc3756e9f779
--- /dev/null
+++ b/cogvideox/video_caption/prompt/rewrite.txt
@@ -0,0 +1,9 @@
+Please rewrite the video description to be useful for AI to re-generate the video, according to the following requirements
+1. Do not start with something similar to 'The video/scene/frame shows' or "In this video/scene/frame".
+2. Remove the subjective content deviates from describing the visual content of the video. For instance, a sentence like "It gives a feeling of ease and tranquility and makes people feel comfortable" is considered subjective.
+3. Remove the non-existent description that does not in the visual content of the video, For instance, a sentence like "There is no visible detail that could be used to identify the individual beyond what is shown." is considered as the non-existent description.
+4. Here are some examples of good descriptions: 1) A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about. 2) A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.
+5. Output with the following json format:
+{"rewritten description": "your rewritten description here"}
+
+Here is the video description:
\ No newline at end of file
diff --git a/cogvideox/video_caption/requirements.txt b/cogvideox/video_caption/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ddc0b66c4b8518c7ba57d846b5882189a5f5275b
--- /dev/null
+++ b/cogvideox/video_caption/requirements.txt
@@ -0,0 +1,9 @@
+pandas>=2.0.0
+easyocr==1.7.1
+git+https://github.com/openai/CLIP.git
+natsort
+joblib
+scenedetect
+av
+# https://github.com/NVlabs/VILA/issues/78#issuecomment-2195568292
+numpy<2.0.0
\ No newline at end of file
diff --git a/cogvideox/video_caption/scripts/stage_1_video_splitting.sh b/cogvideox/video_caption/scripts/stage_1_video_splitting.sh
new file mode 100644
index 0000000000000000000000000000000000000000..bf57dfbeaae8c35947fdfecb429aab586c6016cd
--- /dev/null
+++ b/cogvideox/video_caption/scripts/stage_1_video_splitting.sh
@@ -0,0 +1,39 @@
+VIDEO_FOLDER="datasets/panda_70m/videos/data/"
+META_FILE_PATH="datasets/panda_70m/videos/meta_file_info.jsonl"
+SCENE_FOLDER="datasets/panda_70m/videos/meta_scene_info/"
+SCENE_SAVED_PATH="datasets/panda_70m/videos/meta_scene_info.jsonl"
+OUTPUT_FOLDER="datasets/panda_70m/videos_clips/data/"
+RESOLUTION_THRESHOLD=$((512*512))
+
+# Set the duration range of video clips.
+export MIN_SECONDS=3
+export MAX_SECONDS=10
+
+# Save all video names in a video folder as a meta file.
+python -m utils.get_meta_file \
+    --video_folder $VIDEO_FOLDER \
+    --saved_path $META_FILE_PATH
+
+# Perform scene detection on the video dataset.
+# Adjust the n_jobs parameter based on the actual number of CPU cores in the machine.
+python cutscene_detect.py \
+    --video_metadata_path $META_FILE_PATH \
+    --video_folder $VIDEO_FOLDER \
+    --saved_folder $SCENE_FOLDER \
+    --n_jobs 32
+
+# Gather all scene jsonl files to a single scene jsonl file.
+# Adjust the n_jobs parameter based on the actual I/O speed in the machine.
+python -m utils.gather_jsonl \
+    --meta_folder $SCENE_FOLDER \
+    --meta_file_path $SCENE_SAVED_PATH \
+    --n_jobs 64
+
+# Perform video splitting filtered by the RESOLUTION_THRESHOLD.
+# It consumes more CPU computing resources compared to the above operations.
+python video_splitting.py \
+    --video_metadata_path $SCENE_SAVED_PATH \
+    --video_folder $VIDEO_FOLDER \
+    --output_folder $OUTPUT_FOLDER \
+    --n_jobs 16 \
+    --resolution_threshold $RESOLUTION_THRESHOLD
\ No newline at end of file
diff --git a/cogvideox/video_caption/scripts/stage_2_video_filtering.sh b/cogvideox/video_caption/scripts/stage_2_video_filtering.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0dd41a45ceb87b2d8f62bc8192c5c7befb5796bc
--- /dev/null
+++ b/cogvideox/video_caption/scripts/stage_2_video_filtering.sh
@@ -0,0 +1,41 @@
+META_FILE_PATH="datasets/panda_70m/videos_clips/data/meta_file_info.jsonl"
+VIDEO_FOLDER="datasets/panda_70m/videos_clips/data/"
+VIDEO_QUALITY_SAVED_PATH="datasets/panda_70m/videos_clips/meta_quality_info_siglip.jsonl"
+MIN_ASETHETIC_SCORE_SIGLIP=4.0
+TEXT_SAVED_PATH="datasets/panda_70m/videos_clips/meta_text_info.jsonl"
+MIN_TEXT_SCORE=0.02
+MOTION_SAVED_PATH="datasets/panda_70m/videos_clips/meta_motion_info.jsonl"
+
+python -m utils.get_meta_file \
+    --video_folder $VIDEO_FOLDER \
+    --saved_path $META_FILE_PATH
+
+# Get the asethetic score (SigLIP) of all videos
+accelerate launch compute_video_quality.py \
+    --video_metadata_path $META_FILE_PATH \
+    --video_folder $VIDEO_FOLDER \
+    --metrics "AestheticScoreSigLIP" \
+    --frame_sample_method uniform \
+    --num_sampled_frames 4 \
+    --saved_freq 10 \
+    --saved_path $VIDEO_QUALITY_SAVED_PATH \
+    --batch_size 4
+
+# Get the text score of all videos filtered by the video quality score.
+accelerate launch compute_text_score.py \
+    --video_metadata_path $META_FILE_PATH \
+    --video_folder $VIDEO_FOLDER  \
+    --saved_freq 10 \
+    --saved_path $TEXT_SAVED_PATH \
+    --asethetic_score_siglip_metadata_path $VIDEO_QUALITY_SAVED_PATH \
+    --min_asethetic_score_siglip $MIN_ASETHETIC_SCORE_SIGLIP
+
+# Get the motion score of all videos filtered by the video quality score and text score.
+python compute_motion_score.py \
+    --video_metadata_path $META_FILE_PATH \
+    --video_folder $VIDEO_FOLDER \
+    --saved_freq 10 \
+    --saved_path $MOTION_SAVED_PATH \
+    --n_jobs 8 \
+    --text_score_metadata_path $TEXT_SAVED_PATH \
+    --min_text_score $MIN_TEXT_SCORE
diff --git a/cogvideox/video_caption/scripts/stage_3_video_recaptioning.sh b/cogvideox/video_caption/scripts/stage_3_video_recaptioning.sh
new file mode 100644
index 0000000000000000000000000000000000000000..18cb2123ac7d140aaa141c541db338a2e4d0dc6f
--- /dev/null
+++ b/cogvideox/video_caption/scripts/stage_3_video_recaptioning.sh
@@ -0,0 +1,52 @@
+META_FILE_PATH="datasets/panda_70m/videos_clips/data/meta_file_info.jsonl"
+VIDEO_FOLDER="datasets/panda_70m/videos_clips/data/"
+MOTION_SAVED_PATH="datasets/panda_70m/videos_clips/meta_motion_info.jsonl"
+MIN_MOTION_SCORE=2
+VIDEO_CAPTION_SAVED_PATH="datasets/panda_70m/meta_caption_info_vila_8b.jsonl"
+REWRITTEN_VIDEO_CAPTION_SAVED_PATH="datasets/panda_70m/meta_caption_info_vila_8b_rewritten.jsonl"
+VIDEOCLIPXL_SCORE_SAVED_PATH="datasets/panda_70m/meta_caption_info_vila_8b_rewritten_videoclipxl.jsonl"
+MIN_VIDEOCLIPXL_SCORE=0.20
+TRAIN_SAVED_PATH="datasets/panda_70m/train_panda_70m.json"
+# Manually download Efficient-Large-Model/Llama-3-VILA1.5-8b-AWQ to VILA_MODEL_PATH.
+# Manually download meta-llama/Meta-Llama-3-8B-Instruct to REWRITE_MODEL_PATH.
+
+# Use VILA1.5-AWQ to perform recaptioning.
+accelerate launch vila_video_recaptioning.py \
+    --video_metadata_path ${META_FILE_PATH} \
+    --video_folder ${VIDEO_FOLDER} \
+    --model_path ${VILA_MODEL_PATH} \
+    --precision "W4A16" \
+    --saved_path $VIDEO_CAPTION_SAVED_PATH \
+    --saved_freq 1 \
+    --motion_score_metadata_path $MOTION_SAVED_PATH \
+    --min_motion_score $MIN_MOTION_SCORE
+
+# Rewrite video captions (optional).
+python caption_rewrite.py \
+    --video_metadata_path $VIDEO_CAPTION_SAVED_PATH \
+    --batch_size 4096 \
+    --model_name $REWRITE_MODEL_PATH \
+    --prompt prompt/rewrite.txt \
+    --prefix '"rewritten description": ' \
+    --saved_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \
+    --saved_freq 1
+
+# Compute caption-video alignment (optional).
+accelerate launch compute_video_quality.py \
+    --video_metadata_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \
+    --caption_column caption \
+    --video_folder $VIDEO_FOLDER \
+    --frame_sample_method uniform \
+    --num_sampled_frames 8 \
+    --metrics VideoCLIPXLScore \
+    --batch_size 4 \
+    --saved_path $VIDEOCLIPXL_SCORE_SAVED_PATH \
+    --saved_freq 10
+
+# Get the final train file.
+python filter_meta_train.py \
+    --caption_metadata_path $REWRITTEN_VIDEO_CAPTION_SAVED_PATH \
+    --video_folder=$VIDEO_FOLDER \
+    --videoclipxl_score_metadata_path $VIDEOCLIPXL_SCORE_SAVED_PATH \
+    --min_videoclipxl_score $MIN_VIDEOCLIPXL_SCORE \
+    --saved_path=$TRAIN_SAVED_PATH
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/filter.py b/cogvideox/video_caption/utils/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9466623388c5517d289fe455321767e95c2f91
--- /dev/null
+++ b/cogvideox/video_caption/utils/filter.py
@@ -0,0 +1,162 @@
+import ast
+import os
+
+import pandas as pd
+
+from .logger import logger
+
+
+def filter(
+    video_path_list,
+    basic_metadata_path=None,
+    min_resolution=0,
+    min_duration=-1,
+    max_duration=-1,
+    asethetic_score_metadata_path=None,
+    min_asethetic_score=4,
+    asethetic_score_siglip_metadata_path=None,
+    min_asethetic_score_siglip=4,
+    text_score_metadata_path=None,
+    min_text_score=0.02,
+    motion_score_metadata_path=None,
+    min_motion_score=2,
+    videoclipxl_score_metadata_path=None,
+    min_videoclipxl_score=0.20,
+    video_path_column="video_path",
+):
+    video_path_list = [os.path.basename(video_path) for video_path in video_path_list]
+
+    if basic_metadata_path is not None:
+        if basic_metadata_path.endswith(".csv"):
+            basic_df = pd.read_csv(basic_metadata_path)
+        elif basic_metadata_path.endswith(".jsonl"):
+            basic_df = pd.read_json(basic_metadata_path, lines=True)
+
+        basic_df["resolution"] = basic_df["frame_size"].apply(lambda x: x[0] * x[1])
+        filtered_basic_df = basic_df[basic_df["resolution"] < min_resolution]
+        filtered_video_path_list = filtered_basic_df[video_path_column].tolist()
+        filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+        video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+        logger.info(
+            f"Load {basic_metadata_path} ({len(basic_df)}) and filter {len(filtered_video_path_list)} videos "
+            f"with resolution less than {min_resolution}."
+        )
+
+        if min_duration != -1:
+            filtered_basic_df = basic_df[basic_df["duration"] < min_duration]
+            filtered_video_path_list = filtered_basic_df[video_path_column].tolist()
+            filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+            video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+            logger.info(
+                f"Load {basic_metadata_path} and filter {len(filtered_video_path_list)} videos "
+                f"with duration less than {min_duration}."
+            )
+
+        if max_duration != -1:
+            filtered_basic_df = basic_df[basic_df["duration"] > max_duration]
+            filtered_video_path_list = filtered_basic_df[video_path_column].tolist()
+            filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+            video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+            logger.info(
+                f"Load {basic_metadata_path} and filter {len(filtered_video_path_list)} videos "
+                f"with duration greater than {max_duration}."
+            )
+
+    if asethetic_score_metadata_path is not None:
+        if asethetic_score_metadata_path.endswith(".csv"):
+            asethetic_score_df = pd.read_csv(asethetic_score_metadata_path)
+        elif asethetic_score_metadata_path.endswith(".jsonl"):
+            asethetic_score_df = pd.read_json(asethetic_score_metadata_path, lines=True)
+
+        # In pandas, csv will save lists as strings, whereas jsonl will not.
+        asethetic_score_df["aesthetic_score"] = asethetic_score_df["aesthetic_score"].apply(
+            lambda x: ast.literal_eval(x) if isinstance(x, str) else x
+        )
+        asethetic_score_df["aesthetic_score_mean"] = asethetic_score_df["aesthetic_score"].apply(lambda x: sum(x) / len(x))
+        filtered_asethetic_score_df = asethetic_score_df[asethetic_score_df["aesthetic_score_mean"] < min_asethetic_score]
+        filtered_video_path_list = filtered_asethetic_score_df[video_path_column].tolist()
+        filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+        video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+        logger.info(
+            f"Load {asethetic_score_metadata_path} ({len(asethetic_score_df)}) and filter {len(filtered_video_path_list)} videos "
+            f"with aesthetic score less than {min_asethetic_score}."
+        )
+
+    if asethetic_score_siglip_metadata_path is not None:
+        if asethetic_score_siglip_metadata_path.endswith(".csv"):
+            asethetic_score_siglip_df = pd.read_csv(asethetic_score_siglip_metadata_path)
+        elif asethetic_score_siglip_metadata_path.endswith(".jsonl"):
+            asethetic_score_siglip_df = pd.read_json(asethetic_score_siglip_metadata_path, lines=True)
+
+        # In pandas, csv will save lists as strings, whereas jsonl will not.
+        asethetic_score_siglip_df["aesthetic_score_siglip"] = asethetic_score_siglip_df["aesthetic_score_siglip"].apply(
+            lambda x: ast.literal_eval(x) if isinstance(x, str) else x
+        )
+        asethetic_score_siglip_df["aesthetic_score_siglip_mean"] = asethetic_score_siglip_df["aesthetic_score_siglip"].apply(
+            lambda x: sum(x) / len(x)
+        )
+        filtered_asethetic_score_siglip_df = asethetic_score_siglip_df[
+            asethetic_score_siglip_df["aesthetic_score_siglip_mean"] < min_asethetic_score_siglip
+        ]
+        filtered_video_path_list = filtered_asethetic_score_siglip_df[video_path_column].tolist()
+        filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+        video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+        logger.info(
+            f"Load {asethetic_score_siglip_metadata_path} ({len(asethetic_score_siglip_df)}) and filter {len(filtered_video_path_list)} videos "
+            f"with aesthetic score (SigLIP) less than {min_asethetic_score_siglip}."
+        )
+
+    if text_score_metadata_path is not None:
+        if text_score_metadata_path.endswith(".csv"):
+            text_score_df = pd.read_csv(text_score_metadata_path)
+        elif text_score_metadata_path.endswith(".jsonl"):
+            text_score_df = pd.read_json(text_score_metadata_path, lines=True)
+
+        filtered_text_score_df = text_score_df[text_score_df["text_score"] > min_text_score]
+        filtered_video_path_list = filtered_text_score_df[video_path_column].tolist()
+        filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+        video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+        logger.info(
+            f"Load {text_score_metadata_path} ({len(text_score_df)}) and filter {len(filtered_video_path_list)} videos "
+            f"with text score greater than {min_text_score}."
+        )
+
+    if motion_score_metadata_path is not None:
+        if motion_score_metadata_path.endswith(".csv"):
+            motion_score_df = pd.read_csv(motion_score_metadata_path)
+        elif motion_score_metadata_path.endswith(".jsonl"):
+            motion_score_df = pd.read_json(motion_score_metadata_path, lines=True)
+
+        filtered_motion_score_df = motion_score_df[motion_score_df["motion_score"] < min_motion_score]
+        filtered_video_path_list = filtered_motion_score_df[video_path_column].tolist()
+        filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+        video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+        logger.info(
+            f"Load {motion_score_metadata_path} ({len(motion_score_df)}) and filter {len(filtered_video_path_list)} videos "
+            f"with motion score smaller than {min_motion_score}."
+        )
+    
+    if videoclipxl_score_metadata_path is not None:
+        if videoclipxl_score_metadata_path.endswith(".csv"):
+            videoclipxl_score_df = pd.read_csv(videoclipxl_score_metadata_path)
+        elif videoclipxl_score_metadata_path.endswith(".jsonl"):
+            videoclipxl_score_df = pd.read_json(videoclipxl_score_metadata_path, lines=True)
+        
+        filtered_videoclipxl_score_df = videoclipxl_score_df[videoclipxl_score_df["videoclipxl_score"] < min_videoclipxl_score]
+        filtered_video_path_list = filtered_videoclipxl_score_df[video_path_column].tolist()
+        filtered_video_path_list = [os.path.basename(video_path) for video_path in filtered_video_path_list]
+
+        video_path_list = list(set(video_path_list).difference(set(filtered_video_path_list)))
+        logger.info(
+            f"Load {videoclipxl_score_metadata_path} ({len(videoclipxl_score_df)}) and "
+            f"filter {len(filtered_video_path_list)} videos with mixclip score smaller than {min_videoclipxl_score}."
+        )
+
+    return video_path_list
diff --git a/cogvideox/video_caption/utils/gather_jsonl.py b/cogvideox/video_caption/utils/gather_jsonl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f722decb63a95ac2ccd9f5d19cf3e1a218f3cc53
--- /dev/null
+++ b/cogvideox/video_caption/utils/gather_jsonl.py
@@ -0,0 +1,55 @@
+import argparse
+import os
+import glob
+import json
+from multiprocessing import Pool, Manager
+
+import pandas as pd
+from natsort import index_natsorted
+
+from .logger import logger
+
+
+def process_file(file_path, shared_list):
+    with open(file_path, "r") as f:
+        for line in f:
+            data = json.loads(line)
+            shared_list.append(data)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Gather all jsonl files in a folder (meta_folder) to a single jsonl file (meta_file_path).")
+    parser.add_argument("--meta_folder", type=str, required=True)
+    parser.add_argument("--meta_file_path", type=str, required=True)
+    parser.add_argument("--video_path_column", type=str, default="video_path")
+    parser.add_argument("--n_jobs", type=int, default=1)
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    jsonl_files = glob.glob(os.path.join(args.meta_folder, "*.jsonl"))
+
+    with Manager() as manager:
+        shared_list = manager.list()
+        with Pool(processes=args.n_jobs) as pool:
+            for file_path in jsonl_files:
+                pool.apply_async(process_file, args=(file_path, shared_list))
+            pool.close()
+            pool.join()
+
+        with open(args.meta_file_path, "w") as f:
+            for item in shared_list:
+                f.write(json.dumps(item) + '\n')
+    
+    df = pd.read_json(args.meta_file_path, lines=True)
+    df = df.iloc[index_natsorted(df[args.video_path_column])].reset_index(drop=True)
+    logger.info(f"Save the gathered single jsonl file to {args.meta_file_path}.")
+    df.to_json(args.meta_file_path, orient="records", lines=True, force_ascii=False)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/cogvideox/video_caption/utils/get_meta_file.py b/cogvideox/video_caption/utils/get_meta_file.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e28cd0e679df34d7d350260ff72cf27f7386491
--- /dev/null
+++ b/cogvideox/video_caption/utils/get_meta_file.py
@@ -0,0 +1,74 @@
+import argparse
+from pathlib import Path
+
+import pandas as pd
+from natsort import natsorted
+from tqdm import tqdm
+
+from .logger import logger
+
+
+ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"])
+ALL_IMGAE_EXT = set(["png", "webp", "jpg", "jpeg", "bmp", "gif"])
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Compute scores of uniform sampled frames from videos.")
+    parser.add_argument(
+        "--image_path_column",
+        type=str,
+        default="image_path",
+        help="The column contains the image path (an absolute path or a relative path w.r.t the image_folder).",
+    )
+    parser.add_argument("--image_folder", type=str, default=None, help="The video folder.")
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--video_folder", type=str, default=None, help="The video folder.")
+    parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).")
+    parser.add_argument("--recursive", action="store_true", help="Whether to search sub-folders recursively.")
+
+    args = parser.parse_args()
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.video_folder is None and args.image_folder is None:
+        raise ValueError("Either video_folder or image_folder should be specified in the arguments.")
+    if args.video_folder is not None and args.image_folder is not None:
+        raise ValueError("Both video_folder and image_folder can not be specified in the arguments at the same time.")
+
+    # Use the path name instead of the file name as video_path/image_path (unique ID).
+    if args.video_folder is not None:
+        video_path_list = []
+        video_folder = Path(args.video_folder)
+        for ext in tqdm(list(ALL_VIDEO_EXT)):
+            if args.recursive:
+                video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.rglob(f"*.{ext}")]
+            else:
+                video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.glob(f"*.{ext}")]
+        video_path_list = natsorted(video_path_list)
+        meta_file_df = pd.DataFrame({args.video_path_column: video_path_list})
+    
+    if args.image_folder is not None:
+        image_path_list = []
+        image_folder = Path(args.image_folder)
+        for ext in tqdm(list(ALL_IMGAE_EXT)):
+            if args.recursive:
+                image_path_list += [str(file.relative_to(image_folder)) for file in image_folder.rglob(f"*.{ext}")]
+            else:
+                image_path_list += [str(file.relative_to(image_folder)) for file in image_folder.glob(f"*.{ext}")]
+        image_path_list = natsorted(image_path_list)
+        meta_file_df = pd.DataFrame({args.image_path_column: image_path_list})
+
+    logger.info(f"{len(meta_file_df)} files in total. Save the result to {args.saved_path}.")
+    meta_file_df.to_json(args.saved_path, orient="records", lines=True)
+
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/image_evaluator.py b/cogvideox/video_caption/utils/image_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a17c4234d86b53e75cc157476959f83bccf10ef
--- /dev/null
+++ b/cogvideox/video_caption/utils/image_evaluator.py
@@ -0,0 +1,248 @@
+import os
+from typing import Union
+
+import clip
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.datasets.utils import download_url
+from transformers import AutoModel, AutoProcessor
+
+from .siglip_v2_5 import convert_v2_5_from_siglip
+
+# All metrics.
+__all__ = ["AestheticScore", "AestheticScoreSigLIP", "CLIPScore"]
+
+_MODELS = {
+    "CLIP_ViT-L/14": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViT-L-14.pt",
+    "Aesthetics_V2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/sac%2Blogos%2Bava1-l14-linearMSE.pth",
+    "aesthetic_predictor_v2_5": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/aesthetic_predictor_v2_5.pth",
+}
+_MD5 = {
+    "CLIP_ViT-L/14": "096db1af569b284eb76b3881534822d9",
+    "Aesthetics_V2": "b1047fd767a00134b8fd6529bf19521a",
+    "aesthetic_predictor_v2_5": "c46eb8c29f714c9231dc630b8226842a",
+}
+
+
+def get_list_depth(lst):
+    if isinstance(lst, list):
+        return 1 + max(get_list_depth(item) for item in lst)
+    else:
+        return 0
+
+
+def reshape_images(images: Union[list[list[Image.Image]], list[Image.Image]]):
+    # Check the input sanity.
+    depth = get_list_depth(images)
+    if depth == 1:  # batch image input
+        if not isinstance(images[0], Image.Image):
+            raise ValueError("The item in 1D images should be Image.Image.")
+        num_sampled_frames = None
+    elif depth == 2:  # batch video input
+        if not isinstance(images[0][0], Image.Image):
+            raise ValueError("The item in 2D images (videos) should be Image.Image.")
+        num_sampled_frames = len(images[0])
+        if not all(len(video_frames) == num_sampled_frames for video_frames in images):
+            raise ValueError("All item in 2D images should be with the same length.")
+        # [batch_size, num_sampled_frames, H, W, C] => [batch_size * num_sampled_frames, H, W, C].
+        reshaped_images = []
+        for video_frames in images:
+            reshaped_images.extend([frame for frame in video_frames])
+        images = reshaped_images
+    else:
+        raise ValueError("The input images should be in 1/2D list.")
+    
+    return images, num_sampled_frames
+
+
+def reshape_scores(scores: list[float], num_sampled_frames: int) -> list[float]:
+    if isinstance(scores, list):
+        if num_sampled_frames is not None: # Batch video input
+            batch_size = len(scores) // num_sampled_frames
+            scores = [
+                scores[i * num_sampled_frames:(i + 1) * num_sampled_frames]
+                for i in range(batch_size)
+            ]
+        return scores
+    else:
+        return [scores]
+
+
+# if you changed the MLP architecture during training, change it also here:
+class _MLP(nn.Module):
+    def __init__(self, input_size):
+        super().__init__()
+        self.input_size = input_size
+        self.layers = nn.Sequential(
+            nn.Linear(self.input_size, 1024),
+            # nn.ReLU(),
+            nn.Dropout(0.2),
+            nn.Linear(1024, 128),
+            # nn.ReLU(),
+            nn.Dropout(0.2),
+            nn.Linear(128, 64),
+            # nn.ReLU(),
+            nn.Dropout(0.1),
+            nn.Linear(64, 16),
+            # nn.ReLU(),
+            nn.Linear(16, 1),
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class AestheticScore:
+    """Compute LAION Aesthetics Score V2 based on openai/clip. Note that the default
+    inference dtype with GPUs is fp16 in openai/clip.
+
+    Ref:
+    1. https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py.
+    2. https://github.com/openai/CLIP/issues/30.
+    """
+
+    def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
+        # The CLIP model is loaded in the evaluation mode.
+        self.root = os.path.expanduser(root)
+        if not os.path.exists(self.root):
+            os.makedirs(self.root)
+        filename = "ViT-L-14.pt"
+        download_url(_MODELS["CLIP_ViT-L/14"], self.root, filename=filename, md5=_MD5["CLIP_ViT-L/14"])
+        self.clip_model, self.preprocess = clip.load(os.path.join(self.root, filename), device=device)
+        self.device = device
+        self._load_mlp()
+
+    def _load_mlp(self):
+        filename = "sac+logos+ava1-l14-linearMSE.pth"
+        download_url(_MODELS["Aesthetics_V2"], self.root, filename=filename, md5=_MD5["Aesthetics_V2"])
+        state_dict = torch.load(os.path.join(self.root, filename))
+        self.mlp = _MLP(768)
+        self.mlp.load_state_dict(state_dict)
+        self.mlp.to(self.device)
+        self.mlp.eval()
+
+    def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts=None) -> list[float]:
+        images, num_sampled_frames = reshape_images(images)
+
+        with torch.no_grad():
+            images = torch.stack([self.preprocess(image) for image in images]).to(self.device)
+            image_embs = F.normalize(self.clip_model.encode_image(images))
+            scores = self.mlp(image_embs.float())  # torch.float16 -> torch.float32, [N, 1]
+        
+        scores = scores.squeeze().tolist()  # scalar or list
+        return reshape_scores(scores, num_sampled_frames)
+    
+    def __repr__(self) -> str:
+        return "aesthetic_score"
+
+
+class AestheticScoreSigLIP:
+    """Compute Aesthetics Score V2.5 based on google/siglip-so400m-patch14-384.
+
+    Ref:
+    1. https://github.com/discus0434/aesthetic-predictor-v2-5.
+    2. https://github.com/discus0434/aesthetic-predictor-v2-5/issues/2.
+    """
+
+    def __init__(
+        self,
+        root: str = "~/.cache/clip",
+        device: str = "cpu",
+        torch_dtype=torch.float16
+    ):
+        self.root = os.path.expanduser(root)
+        if not os.path.exists(self.root):
+            os.makedirs(self.root)
+        filename = "aesthetic_predictor_v2_5.pth"
+        download_url(_MODELS["aesthetic_predictor_v2_5"], self.root, filename=filename, md5=_MD5["aesthetic_predictor_v2_5"])
+        self.model, self.preprocessor = convert_v2_5_from_siglip(
+            predictor_name_or_path=os.path.join(self.root, filename),
+            low_cpu_mem_usage=True,
+            trust_remote_code=True,
+        )
+        self.model = self.model.to(device=device, dtype=torch_dtype)
+        self.device = device
+        self.torch_dtype = torch_dtype
+
+    def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts=None) -> list[float]:
+        images, num_sampled_frames = reshape_images(images)
+
+        pixel_values = self.preprocessor(images, return_tensors="pt").pixel_values
+        pixel_values = pixel_values.to(self.device, self.torch_dtype)
+        with torch.no_grad():
+            scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy()
+        
+        scores = scores.squeeze().tolist()  # scalar or list
+        return reshape_scores(scores, num_sampled_frames)
+    
+    def __repr__(self) -> str:
+        return "aesthetic_score_siglip"
+
+
+class CLIPScore:
+    """Compute CLIP scores for image-text pairs based on huggingface/transformers."""
+
+    def __init__(
+        self,
+        model_name_or_path: str = "openai/clip-vit-large-patch14",
+        torch_dtype=torch.float16,
+        device: str = "cpu",
+    ):
+        self.model = AutoModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).eval().to(device)
+        self.processor = AutoProcessor.from_pretrained(model_name_or_path)
+        self.torch_dtype = torch_dtype
+        self.device = device
+
+    def __call__(self, images: Union[list[list[Image.Image]], list[Image.Image]], texts: list[str]) -> list[float]:
+        assert len(images) == len(texts)
+        images, num_sampled_frames = reshape_images(images)
+        # Expand texts in the batch video input case.
+        if num_sampled_frames is not None:
+            texts = [[text] * num_sampled_frames for text in texts]
+            texts = [item for sublist in texts for item in sublist]
+
+        image_inputs = self.processor(images=images, return_tensors="pt")  # {"pixel_values": }
+        if self.torch_dtype == torch.float16:
+            image_inputs["pixel_values"] = image_inputs["pixel_values"].half()
+        text_inputs = self.processor(text=texts, return_tensors="pt", padding=True, truncation=True)  # {"inputs_id": }
+        image_inputs, text_inputs = image_inputs.to(self.device), text_inputs.to(self.device)
+        with torch.no_grad():
+            image_embs = F.normalize(self.model.get_image_features(**image_inputs))
+            text_embs = F.normalize(self.model.get_text_features(**text_inputs))
+            scores = text_embs @ image_embs.T  # [N, N]
+
+        scores = scores.squeeze().tolist()  # scalar or list
+        return reshape_scores(scores, num_sampled_frames)
+    
+    def __repr__(self) -> str:
+        return "clip_score"
+
+
+if __name__ == "__main__":
+    from torch.utils.data import DataLoader
+    from tqdm import tqdm
+    from .video_dataset import VideoDataset, collate_fn
+
+    aesthetic_score = AestheticScore(device="cuda")
+    aesthetic_score_siglip = AestheticScoreSigLIP(device="cuda")
+    # clip_score = CLIPScore(device="cuda")
+
+    paths = ["your_image_path"] * 3
+    # texts = ["a joker", "a woman", "a man"]
+    images = [Image.open(p).convert("RGB") for p in paths]
+
+    print(aesthetic_score(images))
+    # print(clip_score(images, texts))
+
+    test_dataset = VideoDataset(
+        dataset_inputs={"video_path": ["your_video_path"] * 3},
+        sample_method="mid",
+        num_sampled_frames=2
+    )
+    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, collate_fn=collate_fn)
+
+    for idx, batch in enumerate(tqdm(test_loader)):
+        batch_frame = batch["sampled_frame"]
+        print(aesthetic_score_siglip(batch_frame))
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/logger.py b/cogvideox/video_caption/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..754eaf6b379aa39e8b9469c95e17c8ec8128e30d
--- /dev/null
+++ b/cogvideox/video_caption/utils/logger.py
@@ -0,0 +1,36 @@
+# Borrowed from sd-webui-controlnet/scripts/logging.py
+import copy
+import logging
+import sys
+
+
+class ColoredFormatter(logging.Formatter):
+    COLORS = {
+        "DEBUG": "\033[0;36m",  # CYAN
+        "INFO": "\033[0;32m",  # GREEN
+        "WARNING": "\033[0;33m",  # YELLOW
+        "ERROR": "\033[0;31m",  # RED
+        "CRITICAL": "\033[0;37;41m",  # WHITE ON RED
+        "RESET": "\033[0m",  # RESET COLOR
+    }
+
+    def format(self, record):
+        colored_record = copy.copy(record)
+        levelname = colored_record.levelname
+        seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+        colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+        return super().format(colored_record)
+
+
+# Create a new logger
+logger = logging.getLogger("VideoCaption")
+logger.propagate = False
+
+# Add handler if we don't have one.
+if not logger.handlers:
+    handler = logging.StreamHandler(sys.stdout)
+    handler.setFormatter(ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
+    logger.addHandler(handler)
+
+# Configure logger
+logger.setLevel("INFO")
diff --git a/cogvideox/video_caption/utils/longclip/README.md b/cogvideox/video_caption/utils/longclip/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9e99054e2097d5133812daf6af2b20e3673b2b88
--- /dev/null
+++ b/cogvideox/video_caption/utils/longclip/README.md
@@ -0,0 +1,19 @@
+# Long-CLIP
+Codes in this directory are borrowed from https://github.com/beichenzbc/Long-CLIP/tree/4e6f5da/model.
+
+We only modify the following code in [model_longclip.py](model_longclip.py) from
+```python
+@property
+def dtype(self):
+    return self.visual.conv1.weight.dtype
+```
+to
+```python
+@property
+def dtype(self):
+    # Fix: the VideoCLIP-XL inference.
+    if hasattr(self, "visual"):
+        return self.visual.conv1.weight.dtype
+    else:
+        return self.token_embedding.weight.dtype
+```
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/longclip/__init__.py b/cogvideox/video_caption/utils/longclip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a8902d8a7ce19623f8537834b4fde73978b97d8
--- /dev/null
+++ b/cogvideox/video_caption/utils/longclip/__init__.py
@@ -0,0 +1 @@
+from .longclip import *
diff --git a/cogvideox/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz b/cogvideox/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/cogvideox/video_caption/utils/longclip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/cogvideox/video_caption/utils/longclip/longclip.py b/cogvideox/video_caption/utils/longclip/longclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a427f055560e3741da4db0515f4b184503a5269
--- /dev/null
+++ b/cogvideox/video_caption/utils/longclip/longclip.py
@@ -0,0 +1,353 @@
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Any, Union, List
+from pkg_resources import packaging
+from torch import nn
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+from .model_longclip import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+try:
+    from torchvision.transforms import InterpolationMode
+    BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+    BICUBIC = Image.BICUBIC
+
+
+if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
+    warnings.warn("PyTorch version 1.7.1 or higher is recommended")
+
+
+__all__ = ["load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+
+def _convert_image_to_rgb(image):
+    return image.convert("RGB")
+
+
+def _transform(n_px):
+    return Compose([
+        Resize(n_px, interpolation=BICUBIC),
+        CenterCrop(n_px),
+        _convert_image_to_rgb,
+        ToTensor(),
+        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+    ])
+
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", download_root: str = None):
+    """Load a long CLIP model
+
+    Parameters
+    ----------
+    name : str
+        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+    device : Union[str, torch.device]
+        The device to put the loaded model
+
+    Returns
+    -------
+    model : torch.nn.Module
+        The CLIP model
+
+    preprocess : Callable[[PIL.Image], torch.Tensor]
+        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+    """
+    
+    model_path = name
+
+    state_dict = torch.load(model_path, map_location="cpu")
+    
+    model = build_model(state_dict or model.state_dict(), load_from_clip = False).to(device)
+
+    if str(device) == "cpu":
+        model.float()
+
+    return model, _transform(model.visual.input_resolution)
+        
+    
+
+    def _node_get(node: torch._C.Node, key: str):
+        """Gets attributes of a node which is polymorphic over return type.
+        
+        From https://github.com/pytorch/pytorch/pull/82628
+        """
+        sel = node.kindOf(key)
+        return getattr(node, sel)(key)
+
+    def patch_device(module):
+        try:
+            graphs = [module.graph] if hasattr(module, "graph") else []
+        except RuntimeError:
+            graphs = []
+
+        if hasattr(module, "forward1"):
+            graphs.append(module.forward1.graph)
+
+        for graph in graphs:
+            for node in graph.findAllNodes("prim::Constant"):
+                if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
+                    node.copyAttributes(device_node)
+
+    model.apply(patch_device)
+    patch_device(model.encode_image)
+    patch_device(model.encode_text)
+
+    # patch dtype to float32 on CPU
+    if str(device) == "cpu":
+        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+        float_node = float_input.node()
+
+        def patch_float(module):
+            try:
+                graphs = [module.graph] if hasattr(module, "graph") else []
+            except RuntimeError:
+                graphs = []
+
+            if hasattr(module, "forward1"):
+                graphs.append(module.forward1.graph)
+
+            for graph in graphs:
+                for node in graph.findAllNodes("aten::to"):
+                    inputs = list(node.inputs())
+                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
+                        if _node_get(inputs[i].node(), "value") == 5:
+                            inputs[i].node().copyAttributes(float_node)
+
+        model.apply(patch_float)
+        patch_float(model.encode_image)
+        patch_float(model.encode_text)
+
+        model.float()
+
+    return model, _transform(model.input_resolution.item())
+
+
+def load_from_clip(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
+    """Load from CLIP model for fine-tuning 
+
+    Parameters
+    ----------
+    name : str
+        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+    device : Union[str, torch.device]
+        The device to put the loaded model
+
+    jit : bool
+        Whether to load the optimized JIT model or more hackable non-JIT model (default).
+
+    download_root: str
+        path to download the model files; by default, it uses "~/.cache/clip"
+
+    Returns
+    -------
+    model : torch.nn.Module
+        The CLIP model
+
+    preprocess : Callable[[PIL.Image], torch.Tensor]
+        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+    """
+
+    _MODELS = {
+    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+    "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+    "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
+    }
+
+    def available_models() -> List[str]:
+        """Returns the names of available CLIP models"""
+        return list(_MODELS.keys())
+
+    def _download(url: str, root: str):
+        os.makedirs(root, exist_ok=True)
+        filename = os.path.basename(url)
+
+        expected_sha256 = url.split("/")[-2]
+        download_target = os.path.join(root, filename)
+
+        if os.path.exists(download_target) and not os.path.isfile(download_target):
+            raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+        if os.path.isfile(download_target):
+            if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+                return download_target
+            else:
+                warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+        with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+            with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
+                while True:
+                    buffer = source.read(8192)
+                    if not buffer:
+                        break
+
+                    output.write(buffer)
+                    loop.update(len(buffer))
+
+        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+            raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+
+        return download_target
+
+    if name in _MODELS:
+        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
+    elif os.path.isfile(name):
+        model_path = name
+    else:
+        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+    with open(model_path, 'rb') as opened_file:
+        try:
+            # loading JIT archive
+            model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
+            state_dict = None
+        except RuntimeError:
+            # loading saved state dict
+            if jit:
+                warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+                jit = False
+            state_dict = torch.load(opened_file, map_location="cpu")
+
+    model = build_model(state_dict or model.state_dict(), load_from_clip = True).to(device)
+        
+    positional_embedding_pre = model.positional_embedding.type(model.dtype)
+            
+    length, dim = positional_embedding_pre.shape
+    keep_len = 20
+    posisitonal_embedding_new = torch.zeros([4*length-3*keep_len, dim], dtype=model.dtype)
+    for i in range(keep_len):
+        posisitonal_embedding_new[i] = positional_embedding_pre[i]
+    for i in range(length-1-keep_len):
+        posisitonal_embedding_new[4*i + keep_len] = positional_embedding_pre[i + keep_len]
+        posisitonal_embedding_new[4*i + 1 + keep_len] = 3*positional_embedding_pre[i + keep_len]/4 + 1*positional_embedding_pre[i+1+keep_len]/4
+        posisitonal_embedding_new[4*i + 2+keep_len] = 2*positional_embedding_pre[i+keep_len]/4 + 2*positional_embedding_pre[i+1+keep_len]/4
+        posisitonal_embedding_new[4*i + 3+keep_len] = 1*positional_embedding_pre[i+keep_len]/4 + 3*positional_embedding_pre[i+1+keep_len]/4
+
+    posisitonal_embedding_new[4*length -3*keep_len - 4] = positional_embedding_pre[length-1] + 0*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
+    posisitonal_embedding_new[4*length -3*keep_len - 3] = positional_embedding_pre[length-1] + 1*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
+    posisitonal_embedding_new[4*length -3*keep_len - 2] = positional_embedding_pre[length-1] + 2*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
+    posisitonal_embedding_new[4*length -3*keep_len - 1] = positional_embedding_pre[length-1] + 3*(positional_embedding_pre[length-1] - positional_embedding_pre[length-2])/4
+            
+    positional_embedding_res = posisitonal_embedding_new.clone()
+            
+    model.positional_embedding = nn.Parameter(posisitonal_embedding_new, requires_grad=False)
+    model.positional_embedding_res = nn.Parameter(positional_embedding_res, requires_grad=True)
+
+    if str(device) == "cpu":
+        model.float()
+    return model, _transform(model.visual.input_resolution)
+        
+    def _node_get(node: torch._C.Node, key: str):
+        """Gets attributes of a node which is polymorphic over return type.
+        
+        From https://github.com/pytorch/pytorch/pull/82628
+        """
+        sel = node.kindOf(key)
+        return getattr(node, sel)(key)
+
+    def patch_device(module):
+        try:
+            graphs = [module.graph] if hasattr(module, "graph") else []
+        except RuntimeError:
+            graphs = []
+
+        if hasattr(module, "forward1"):
+            graphs.append(module.forward1.graph)
+
+        for graph in graphs:
+            for node in graph.findAllNodes("prim::Constant"):
+                if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
+                    node.copyAttributes(device_node)
+
+    model.apply(patch_device)
+    patch_device(model.encode_image)
+    patch_device(model.encode_text)
+
+    # patch dtype to float32 on CPU
+    if str(device) == "cpu":
+        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+        float_node = float_input.node()
+
+        def patch_float(module):
+            try:
+                graphs = [module.graph] if hasattr(module, "graph") else []
+            except RuntimeError:
+                graphs = []
+
+            if hasattr(module, "forward1"):
+                graphs.append(module.forward1.graph)
+
+            for graph in graphs:
+                for node in graph.findAllNodes("aten::to"):
+                    inputs = list(node.inputs())
+                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
+                        if _node_get(inputs[i].node(), "value") == 5:
+                            inputs[i].node().copyAttributes(float_node)
+
+        model.apply(patch_float)
+        patch_float(model.encode_image)
+        patch_float(model.encode_text)
+
+        model.float()
+
+    return model, _transform(model.input_resolution.item())
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
+    """
+    Returns the tokenized representation of given input string(s)
+
+    Parameters
+    ----------
+    texts : Union[str, List[str]]
+        An input string or a list of input strings to tokenize
+
+    context_length : int
+        The context length to use; all CLIP models use 77 as the context length
+
+    truncate: bool
+        Whether to truncate the text in case its encoding is longer than the context length
+
+    Returns
+    -------
+    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
+    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
+    """
+    if isinstance(texts, str):
+        texts = [texts]
+
+    sot_token = _tokenizer.encoder["<|startoftext|>"]
+    eot_token = _tokenizer.encoder["<|endoftext|>"]
+    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+    if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
+        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+    else:
+        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
+
+    for i, tokens in enumerate(all_tokens):
+        if len(tokens) > context_length:
+            if truncate:
+                tokens = tokens[:context_length]
+                tokens[-1] = eot_token
+            else:
+                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+        result[i, :len(tokens)] = torch.tensor(tokens)
+
+    return result
diff --git a/cogvideox/video_caption/utils/longclip/model_longclip.py b/cogvideox/video_caption/utils/longclip/model_longclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..aefd521b9a45d12982d8727f03d25ae1c4cb39c7
--- /dev/null
+++ b/cogvideox/video_caption/utils/longclip/model_longclip.py
@@ -0,0 +1,471 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu3 = nn.ReLU(inplace=True)
+
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(OrderedDict([
+                ("-1", nn.AvgPool2d(stride)),
+                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+                ("1", nn.BatchNorm2d(planes * self.expansion))
+            ]))
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu1(self.bn1(self.conv1(x)))
+        out = self.relu2(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu3(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+
+    def forward(self, x):
+        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x[:1], key=x, value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False
+        )
+        return x.squeeze(0)
+
+
+class ModifiedResNet(nn.Module):
+    """
+    A ResNet class that is similar to torchvision's but contains the following changes:
+    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+    - The final pooling layer is a QKV attention instead of an average pool
+    """
+
+    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+        super().__init__()
+        self.output_dim = output_dim
+        self.input_resolution = input_resolution
+
+        # the 3-layer stem
+        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(width // 2)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(width // 2)
+        self.relu2 = nn.ReLU(inplace=True)
+        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(width)
+        self.relu3 = nn.ReLU(inplace=True)
+        self.avgpool = nn.AvgPool2d(2)
+
+        # residual layers
+        self._inplanes = width  # this is a *mutable* variable used during construction
+        self.layer1 = self._make_layer(width, layers[0])
+        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+        embed_dim = width * 32  # the ResNet feature dimension
+        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+    def _make_layer(self, planes, blocks, stride=1):
+        layers = [Bottleneck(self._inplanes, planes, stride)]
+
+        self._inplanes = planes * Bottleneck.expansion
+        for _ in range(1, blocks):
+            layers.append(Bottleneck(self._inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        def stem(x):
+            x = self.relu1(self.bn1(self.conv1(x)))
+            x = self.relu2(self.bn2(self.conv2(x)))
+            x = self.relu3(self.bn3(self.conv3(x)))
+            x = self.avgpool(x)
+            return x
+
+        x = x.type(self.conv1.weight.dtype)
+        x = stem(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.attnpool(x)
+
+        return x
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+    def forward(self, x: torch.Tensor):
+        x = x + self.attention(self.ln_1(x))
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+    def forward(self, x: torch.Tensor):
+        return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.output_dim = output_dim
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+        self.ln_pre = LayerNorm(width)
+
+        self.transformer = Transformer(width, layers, heads)
+
+        self.ln_post = LayerNorm(width)
+        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+    def forward(self, x: torch.Tensor):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        x = self.ln_post(x[:, 0, :])
+
+        if self.proj is not None:
+            x = x @ self.proj
+
+        return x
+
+
+class CLIP(nn.Module):
+    def __init__(self,
+                 embed_dim: int,
+                 # vision
+                 image_resolution: int,
+                 vision_layers: Union[Tuple[int, int, int, int], int],
+                 vision_width: int,
+                 vision_patch_size: int,
+                 # text
+                 context_length: int,
+                 vocab_size: int,
+                 transformer_width: int,
+                 transformer_heads: int,
+                 transformer_layers: int, 
+                 load_from_clip: bool
+                 ):
+        super().__init__()
+
+        self.context_length = 248
+
+        if isinstance(vision_layers, (tuple, list)):
+            vision_heads = vision_width * 32 // 64
+            self.visual = ModifiedResNet(
+                layers=vision_layers,
+                output_dim=embed_dim,
+                heads=vision_heads,
+                input_resolution=image_resolution,
+                width=vision_width
+            )
+        else:
+            vision_heads = vision_width // 64
+            self.visual = VisionTransformer(
+                input_resolution=image_resolution,
+                patch_size=vision_patch_size,
+                width=vision_width,
+                layers=vision_layers,
+                heads=vision_heads,
+                output_dim=embed_dim
+            )
+
+        self.transformer = Transformer(
+            width=transformer_width,
+            layers=transformer_layers,
+            heads=transformer_heads,
+            attn_mask=self.build_attention_mask()
+        )
+
+        self.vocab_size = vocab_size
+        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+
+        if load_from_clip == False:
+            self.positional_embedding = nn.Parameter(torch.empty(248, transformer_width))
+            self.positional_embedding_res = nn.Parameter(torch.empty(248, transformer_width))
+
+        else:
+            self.positional_embedding = nn.Parameter(torch.empty(77, transformer_width))
+
+        self.ln_final = LayerNorm(transformer_width)
+
+        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+        self.initialize_parameters()
+        self.mask1 = torch.zeros([248, 1])
+        self.mask1[:20, :] = 1
+        self.mask2 = torch.zeros([248, 1])
+        self.mask2[20:, :] = 1
+
+    
+    def initialize_parameters(self):
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+        nn.init.normal_(self.positional_embedding, std=0.01)
+
+        if isinstance(self.visual, ModifiedResNet):
+            if self.visual.attnpool is not None:
+                std = self.visual.attnpool.c_proj.in_features ** -0.5
+                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+                for name, param in resnet_block.named_parameters():
+                    if name.endswith("bn3.weight"):
+                        nn.init.zeros_(param)
+
+        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+        attn_std = self.transformer.width ** -0.5
+        fc_std = (2 * self.transformer.width) ** -0.5
+        for block in self.transformer.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        if self.text_projection is not None:
+            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    @property
+    def dtype(self):
+        # Fix: the mixclip inference.
+        if hasattr(self, "visual"):
+            return self.visual.conv1.weight.dtype
+        else:
+            return self.token_embedding.weight.dtype
+
+    def encode_image(self, image):
+        return self.visual(image.type(self.dtype))
+
+    def encode_text(self, text): 
+        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
+        
+        x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) 
+        
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x).type(self.dtype)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+        return x
+
+    def encode_text_full(self, text): 
+        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
+        
+        x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) 
+        
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x).type(self.dtype)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        #x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+        return x
+
+
+    def forward(self, image, text):
+        image_features = self.encode_image(image)
+        text_features = self.encode_text(text)
+
+        # normalized features
+        image_features = image_features / image_features.norm(dim=1, keepdim=True)
+        text_features = text_features / text_features.norm(dim=1, keepdim=True)
+
+        # cosine similarity as logits
+        logit_scale = self.logit_scale.exp()
+        logits_per_image = logit_scale * image_features @ text_features.t()
+        logits_per_text = logits_per_image.t()
+
+        # shape = [global_batch_size, global_batch_size]
+        return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+        if isinstance(l, nn.MultiheadAttention):
+            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+                tensor = getattr(l, attr)
+                if tensor is not None:
+                    tensor.data = tensor.data.half()
+
+        for name in ["text_projection", "proj"]:
+            if hasattr(l, name):
+                attr = getattr(l, name)
+                if attr is not None:
+                    attr.data = attr.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict, load_from_clip: bool):
+    vit = "visual.proj" in state_dict
+
+    if vit:
+        vision_width = state_dict["visual.conv1.weight"].shape[0]
+        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+        image_resolution = vision_patch_size * grid_size
+    else:
+        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+        vision_layers = tuple(counts)
+        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+        vision_patch_size = None
+        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+        image_resolution = output_width * 32
+
+    embed_dim = state_dict["text_projection"].shape[1]
+    context_length = state_dict["positional_embedding"].shape[0]
+    vocab_size = state_dict["token_embedding.weight"].shape[0]
+    transformer_width = state_dict["ln_final.weight"].shape[0]
+    transformer_heads = transformer_width // 64
+    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
+
+    model = CLIP(
+        embed_dim,
+        image_resolution, vision_layers, vision_width, vision_patch_size,
+        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, load_from_clip
+    )
+
+    for key in ["input_resolution", "context_length", "vocab_size"]:
+        if key in state_dict:
+            del state_dict[key]
+
+    convert_weights(model)
+    model.load_state_dict(state_dict)
+    return model.eval()
diff --git a/cogvideox/video_caption/utils/longclip/simple_tokenizer.py b/cogvideox/video_caption/utils/longclip/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91
--- /dev/null
+++ b/cogvideox/video_caption/utils/longclip/simple_tokenizer.py
@@ -0,0 +1,132 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+        merges = merges[1:49152-256-2+1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v+'</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token+'</w>'
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
+        return text
diff --git a/cogvideox/video_caption/utils/siglip_v2_5.py b/cogvideox/video_caption/utils/siglip_v2_5.py
new file mode 100644
index 0000000000000000000000000000000000000000..14db82673170cd062f29c7cb276335b3ed4e91f5
--- /dev/null
+++ b/cogvideox/video_caption/utils/siglip_v2_5.py
@@ -0,0 +1,127 @@
+# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py.
+import os
+from collections import OrderedDict
+from os import PathLike
+from typing import Final
+
+import torch
+import torch.nn as nn
+from transformers import (
+    SiglipImageProcessor,
+    SiglipVisionConfig,
+    SiglipVisionModel,
+    logging,
+)
+from transformers.image_processing_utils import BatchFeature
+from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
+
+logging.set_verbosity_error()
+
+URL: Final[str] = (
+    "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth"
+)
+
+
+class AestheticPredictorV2_5Head(nn.Module):
+    def __init__(self, config: SiglipVisionConfig) -> None:
+        super().__init__()
+        self.scoring_head = nn.Sequential(
+            nn.Linear(config.hidden_size, 1024),
+            nn.Dropout(0.5),
+            nn.Linear(1024, 128),
+            nn.Dropout(0.5),
+            nn.Linear(128, 64),
+            nn.Dropout(0.5),
+            nn.Linear(64, 16),
+            nn.Dropout(0.2),
+            nn.Linear(16, 1),
+        )
+
+    def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
+        return self.scoring_head(image_embeds)
+
+
+class AestheticPredictorV2_5Model(SiglipVisionModel):
+    PATCH_SIZE = 14
+
+    def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None:
+        super().__init__(config, *args, **kwargs)
+        self.layers = AestheticPredictorV2_5Head(config)
+        self.post_init()
+
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor | None = None,
+        labels: torch.Tensor | None = None,
+        return_dict: bool | None = None,
+    ) -> tuple | ImageClassifierOutputWithNoAttention:
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        outputs = super().forward(
+            pixel_values=pixel_values,
+            return_dict=return_dict,
+        )
+        image_embeds = outputs.pooler_output
+        image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
+        prediction = self.layers(image_embeds_norm)
+
+        loss = None
+        if labels is not None:
+            loss_fct = nn.MSELoss()
+            loss = loss_fct()
+
+        if not return_dict:
+            return (loss, prediction, image_embeds)
+
+        return ImageClassifierOutputWithNoAttention(
+            loss=loss,
+            logits=prediction,
+            hidden_states=image_embeds,
+        )
+
+
+class AestheticPredictorV2_5Processor(SiglipImageProcessor):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+    def __call__(self, *args, **kwargs) -> BatchFeature:
+        return super().__call__(*args, **kwargs)
+
+    @classmethod
+    def from_pretrained(
+        self,
+        pretrained_model_name_or_path: str
+        | PathLike = "google/siglip-so400m-patch14-384",
+        *args,
+        **kwargs,
+    ) -> "AestheticPredictorV2_5Processor":
+        return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
+
+
+def convert_v2_5_from_siglip(
+    predictor_name_or_path: str | PathLike | None = None,
+    encoder_model_name: str = "google/siglip-so400m-patch14-384",
+    *args,
+    **kwargs,
+) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]:
+    model = AestheticPredictorV2_5Model.from_pretrained(
+        encoder_model_name, *args, **kwargs
+    )
+
+    processor = AestheticPredictorV2_5Processor.from_pretrained(
+        encoder_model_name, *args, **kwargs
+    )
+
+    if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path):
+        state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu")
+    else:
+        state_dict = torch.load(predictor_name_or_path, map_location="cpu")
+
+    assert isinstance(state_dict, OrderedDict)
+
+    model.layers.load_state_dict(state_dict)
+    model.eval()
+
+    return model, processor
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/viclip/README.md b/cogvideox/video_caption/utils/viclip/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f5c9e2f1dccc77369adfc87569d967e172831ec2
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/README.md
@@ -0,0 +1,2 @@
+# ViCLIP
+Codes in this directory are borrowed from https://github.com/OpenGVLab/InternVideo/tree/73271ba/Data/InternVid/viclip.
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/viclip/__init__.py b/cogvideox/video_caption/utils/viclip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d72c0be64f7140f7898da87fbf581e8885418a9
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/__init__.py
@@ -0,0 +1,72 @@
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+from .viclip import ViCLIP
+import torch
+import numpy as np
+import cv2
+import os
+
+
+def get_viclip(size='l', 
+               pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth")):
+    
+    tokenizer = _Tokenizer()
+    vclip = ViCLIP(tokenizer=tokenizer, size=size, pretrain=pretrain)
+    m = {'viclip':vclip, 'tokenizer':tokenizer}
+    
+    return m
+
+def get_text_feat_dict(texts, clip, tokenizer, text_feat_d={}):
+    for t in texts:
+        feat = clip.get_text_features(t, tokenizer, text_feat_d)
+        text_feat_d[t] = feat
+    return text_feat_d
+
+def get_vid_feat(frames, clip):
+    return clip.get_vid_features(frames)
+
+def _frame_from_video(video):
+    while video.isOpened():
+        success, frame = video.read()
+        if success:
+            yield frame
+        else:
+            break
+
+v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
+v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
+def normalize(data):
+    return (data/255.0-v_mean)/v_std
+
+def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
+    assert(len(vid_list) >= fnum)
+    step = len(vid_list) // fnum
+    vid_list = vid_list[::step][:fnum]
+    vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list]
+    vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
+    vid_tube = np.concatenate(vid_tube, axis=1)
+    vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
+    vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
+    return vid_tube
+
+def retrieve_text(frames, 
+                  texts, 
+                  models={'viclip':None, 
+                          'tokenizer':None},
+                  topk=5, 
+                  device=torch.device('cuda')):
+    # clip, tokenizer = get_clip(name, model_cfg['size'], model_cfg['pretrained'], model_cfg['reload'])
+    assert(type(models)==dict and models['viclip'] is not None and models['tokenizer'] is not None)
+    clip, tokenizer = models['viclip'], models['tokenizer']
+    clip = clip.to(device)
+    frames_tensor = frames2tensor(frames, device=device)
+    vid_feat = get_vid_feat(frames_tensor, clip)
+
+    text_feat_d = {}
+    text_feat_d = get_text_feat_dict(texts, clip, tokenizer, text_feat_d)
+    text_feats = [text_feat_d[t] for t in texts]
+    text_feats_tensor = torch.cat(text_feats, 0)
+    
+    probs, idxs = clip.get_predict_label(vid_feat, text_feats_tensor, top=topk)
+
+    ret_texts = [texts[i] for i in idxs.numpy()[0].tolist()]
+    return ret_texts, probs.numpy()[0]
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/viclip/bpe_simple_vocab_16e6.txt.gz b/cogvideox/video_caption/utils/viclip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/cogvideox/video_caption/utils/viclip/simple_tokenizer.py b/cogvideox/video_caption/utils/viclip/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..744cba818a8ffaf53d4b557490843a33f43777c2
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/simple_tokenizer.py
@@ -0,0 +1,135 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+# @lru_cache()
+# def default_bpe():
+#     return "bpe_simple_vocab_16e6.txt.gz"
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r'\s+', ' ', text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+        merges = merges[1:49152-256-2+1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v+'</w>' for v in vocab]
+        for merge in merges:
+            vocab.append(''.join(merge))
+        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token+'</w>'
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
+        return text
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/viclip/viclip.py b/cogvideox/video_caption/utils/viclip/viclip.py
new file mode 100644
index 0000000000000000000000000000000000000000..7238611dc3ec8697f22cad7a68eda529469c915b
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/viclip.py
@@ -0,0 +1,262 @@
+import os
+import logging
+
+import torch
+from einops import rearrange
+from torch import nn
+import math
+
+# from .criterions import VTC_VTM_Loss
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+from .viclip_vision import clip_joint_l14, clip_joint_b16
+from .viclip_text import clip_text_l14, clip_text_b16
+
+logger = logging.getLogger(__name__)
+
+
+class ViCLIP(nn.Module):
+    """docstring for ViCLIP"""
+
+    def __init__(self,  
+                 tokenizer=None, 
+                 size='l',
+                 pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"),
+                 freeze_text=True):
+        super(ViCLIP, self).__init__()
+        if tokenizer:
+            self.tokenizer = tokenizer
+        else:
+            self.tokenizer = _Tokenizer()
+        self.max_txt_l = 32
+
+        if size.lower() == 'l':
+            self.vision_encoder_name = 'vit_l14'
+        elif size.lower() == 'b':
+            self.vision_encoder_name = 'vit_b16'
+        else:
+            raise NotImplementedError(f"Size {size} not implemented")
+    
+        self.vision_encoder_pretrained = False
+        self.inputs_image_res = 224
+        self.vision_encoder_kernel_size = 1
+        self.vision_encoder_center = True
+        self.video_input_num_frames = 8
+        self.vision_encoder_drop_path_rate = 0.1
+        self.vision_encoder_checkpoint_num = 24
+        self.is_pretrain = pretrain
+        self.vision_width = 1024
+        self.text_width = 768 
+        self.embed_dim = 768 
+        self.masking_prob = 0.9
+        
+        if size.lower() == 'l':
+            self.text_encoder_name = 'vit_l14'
+        elif size.lower() == 'b':
+            self.text_encoder_name = 'vit_b16'
+        else:
+            raise NotImplementedError(f"Size {size} not implemented")
+        
+        self.text_encoder_pretrained = False#'bert-base-uncased'
+        self.text_encoder_d_model = 768
+
+        self.text_encoder_vocab_size = 49408
+        
+        # create modules.
+        self.vision_encoder = self.build_vision_encoder()
+        self.text_encoder = self.build_text_encoder()
+
+        self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0)
+        self.temp_min = 1 / 100.0
+
+        if pretrain:
+            logger.info(f"Load pretrained weights from {pretrain}")
+            state_dict = torch.load(pretrain, map_location='cpu')['model']
+            self.load_state_dict(state_dict)
+        
+        # Freeze weights
+        if freeze_text:
+            self.freeze_text()
+
+
+    def freeze_text(self):
+        """freeze text encoder"""
+        for p in self.text_encoder.parameters():
+            p.requires_grad = False
+
+    def no_weight_decay(self):
+        ret = {"temp"}
+        ret.update(
+            {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
+        )
+        ret.update(
+            {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
+        )
+
+        return ret
+
+    def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False):
+        """forward and calculate loss.
+
+        Args:
+            image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
+            text (dict): TODO
+            idx (torch.Tensor): TODO
+
+        Returns: TODO
+
+        """
+        self.clip_contrastive_temperature()
+
+        vision_embeds = self.encode_vision(image)
+        text_embeds = self.encode_text(raw_text)
+        if return_sims:
+            sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \
+                  torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1)
+            return sims
+
+        # calculate loss
+
+        ## VTC loss
+        loss_vtc = self.clip_loss.vtc_loss(
+            vision_embeds, text_embeds, idx, self.temp, all_gather=True
+        )
+
+        return dict(
+            loss_vtc=loss_vtc,
+        )
+
+    def encode_vision(self, image, test=False):
+        """encode image / videos as features.
+
+        Args:
+            image (torch.Tensor): The input images.
+            test (bool): Whether testing.
+
+        Returns: tuple.
+            - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C].
+            - pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C].
+
+        """
+        if image.ndim == 5:
+            image = image.permute(0, 2, 1, 3, 4).contiguous()
+        else:
+            image = image.unsqueeze(2)
+
+        if not test and self.masking_prob > 0.0:
+            return self.vision_encoder(
+                image, masking_prob=self.masking_prob
+            )
+
+        return self.vision_encoder(image)
+
+    def encode_text(self, text):
+        """encode text.
+        Args:
+            text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
+                - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
+                - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
+                - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
+        Returns: tuple.
+            - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
+            - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
+
+        """
+        device = next(self.text_encoder.parameters()).device
+        text = self.text_encoder.tokenize(
+            text, context_length=self.max_txt_l
+        ).to(device)
+        text_embeds = self.text_encoder(text)
+        return text_embeds
+
+    @torch.no_grad()
+    def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
+        """Seems only used during pre-training"""
+        self.temp.clamp_(min=self.temp_min)
+
+    def build_vision_encoder(self):
+        """build vision encoder
+        Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
+
+        """
+        encoder_name = self.vision_encoder_name
+        if encoder_name == "vit_l14":
+            vision_encoder = clip_joint_l14(
+                pretrained=self.vision_encoder_pretrained,
+                input_resolution=self.inputs_image_res,
+                kernel_size=self.vision_encoder_kernel_size,
+                center=self.vision_encoder_center,
+                num_frames=self.video_input_num_frames,
+                drop_path=self.vision_encoder_drop_path_rate,
+                checkpoint_num=self.vision_encoder_checkpoint_num,
+            )
+        elif encoder_name == "vit_b16":
+            vision_encoder = clip_joint_b16(
+                pretrained=self.vision_encoder_pretrained,
+                input_resolution=self.inputs_image_res,
+                kernel_size=self.vision_encoder_kernel_size,
+                center=self.vision_encoder_center,
+                num_frames=self.video_input_num_frames,
+                drop_path=self.vision_encoder_drop_path_rate,
+                checkpoint_num=self.vision_encoder_checkpoint_num,
+            )
+        else:
+            raise NotImplementedError(f"Not implemented: {encoder_name}")
+            
+        return vision_encoder
+
+    def build_text_encoder(self):
+        """build text_encoder and possiblly video-to-text multimodal fusion encoder.
+        Returns: nn.Module. The text encoder
+
+        """
+        encoder_name = self.text_encoder_name
+        
+        if encoder_name == "vit_l14":
+            text_encoder = clip_text_l14(
+                pretrained=self.text_encoder_pretrained,
+                context_length=self.max_txt_l,
+                vocab_size=self.text_encoder_vocab_size,
+                checkpoint_num=0,
+            )
+        elif encoder_name == "vit_b16":
+            text_encoder = clip_text_b16(
+                pretrained=self.text_encoder_pretrained,
+                context_length=self.max_txt_l,
+                vocab_size=self.text_encoder_vocab_size,
+                checkpoint_num=0,
+            )
+        else:
+            raise NotImplementedError(f"Not implemented: {encoder_name}")
+
+        return text_encoder
+
+    def get_text_encoder(self):
+        """get text encoder, used for text and cross-modal encoding"""
+        encoder = self.text_encoder
+        return encoder.bert if hasattr(encoder, "bert") else encoder
+    
+    def get_text_features(self, input_text, tokenizer, text_feature_dict={}):
+        if input_text in text_feature_dict:
+            return text_feature_dict[input_text]
+        text_template= f"{input_text}"
+        with torch.no_grad():
+            # text_token = tokenizer.encode(text_template).cuda()
+            text_features = self.encode_text(text_template).float()
+            text_features /= text_features.norm(dim=-1, keepdim=True)      
+            text_feature_dict[input_text] = text_features
+        return text_features
+
+    def get_vid_features(self, input_frames):
+        with torch.no_grad():
+            clip_feat = self.encode_vision(input_frames,test=True).float()
+            clip_feat /= clip_feat.norm(dim=-1, keepdim=True)    
+        return clip_feat
+
+    def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
+        label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
+        top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
+        return top_probs, top_labels
+
+    
+if __name__ =="__main__":
+    tokenizer = _Tokenizer()
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/viclip/viclip_text.py b/cogvideox/video_caption/utils/viclip/viclip_text.py
new file mode 100644
index 0000000000000000000000000000000000000000..42988221e9de3f7e0cfddf0d4e5490479359fb03
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/viclip_text.py
@@ -0,0 +1,297 @@
+import os
+import logging
+from collections import OrderedDict
+from pkg_resources import packaging
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+import torch.utils.checkpoint as checkpoint
+import functools
+
+logger = logging.getLogger(__name__)
+
+
+# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
+MODEL_PATH = 'https://huggingface.co/laion'
+_MODELS = {
+    "ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"),
+    "ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"),
+}
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+    def forward(self, x: torch.Tensor):
+        x = x + self.attention(self.ln_1(x))
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
+                 checkpoint_num: int = 0):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+        self.checkpoint_num = checkpoint_num
+
+    def forward(self, x: torch.Tensor):
+        if self.checkpoint_num > 0:
+            segments = min(self.checkpoint_num, len(self.resblocks))
+            return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
+        else:
+            return self.resblocks(x)
+
+
+class CLIP_TEXT(nn.Module):
+    def __init__(
+            self,
+            embed_dim: int,
+            context_length: int,
+            vocab_size: int,
+            transformer_width: int,
+            transformer_heads: int,
+            transformer_layers: int,
+            checkpoint_num: int,
+        ):
+        super().__init__()
+
+        self.context_length = context_length
+        self._tokenizer = _Tokenizer()
+
+        self.transformer = Transformer(
+            width=transformer_width,
+            layers=transformer_layers,
+            heads=transformer_heads,
+            attn_mask=self.build_attention_mask(),
+            checkpoint_num=checkpoint_num,
+        )
+
+        self.vocab_size = vocab_size
+        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+        self.ln_final = LayerNorm(transformer_width)
+
+        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+    
+    def no_weight_decay(self):
+        return {'token_embedding', 'positional_embedding'}
+
+    @functools.lru_cache(maxsize=None)
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    def tokenize(self, texts, context_length=77, truncate=True):
+        """
+        Returns the tokenized representation of given input string(s)
+        Parameters
+        ----------
+        texts : Union[str, List[str]]
+            An input string or a list of input strings to tokenize
+        context_length : int
+            The context length to use; all CLIP models use 77 as the context length
+        truncate: bool
+            Whether to truncate the text in case its encoding is longer than the context length
+        Returns
+        -------
+        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
+        We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
+        """
+        if isinstance(texts, str):
+            texts = [texts]
+
+        sot_token = self._tokenizer.encoder["<|startoftext|>"]
+        eot_token = self._tokenizer.encoder["<|endoftext|>"]
+        all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
+        if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
+            result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+        else:
+            result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
+
+        for i, tokens in enumerate(all_tokens):
+            if len(tokens) > context_length:
+                if truncate:
+                    tokens = tokens[:context_length]
+                    tokens[-1] = eot_token
+                else:
+                    raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+            result[i, :len(tokens)] = torch.tensor(tokens)
+
+        return result
+
+    def forward(self, text):
+        x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]
+
+        x = x + self.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+        return x
+
+
+def clip_text_b16(
+    embed_dim=512,
+    context_length=77,
+    vocab_size=49408,
+    transformer_width=512,
+    transformer_heads=8,
+    transformer_layers=12,
+    checkpoint_num=0,
+    pretrained=True,
+):
+    # raise NotImplementedError
+    model = CLIP_TEXT(
+        embed_dim,
+        context_length,
+        vocab_size,
+        transformer_width,
+        transformer_heads,
+        transformer_layers,
+        checkpoint_num,
+    )
+    # pretrained = _MODELS["ViT-B/16"]
+    # logger.info(f"Load pretrained weights from {pretrained}")
+    # state_dict = torch.load(pretrained, map_location='cpu')
+    # model.load_state_dict(state_dict, strict=False)
+    # return model.eval()
+    if pretrained:
+        if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
+            pretrained = _MODELS[pretrained]
+        else:
+            pretrained = _MODELS["ViT-B/16"]
+        logger.info(f"Load pretrained weights from {pretrained}")
+        state_dict = torch.load(pretrained, map_location='cpu')
+        if context_length != state_dict["positional_embedding"].size(0):
+            # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
+            print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
+            if context_length < state_dict["positional_embedding"].size(0):
+                state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
+            else:
+                state_dict["positional_embedding"] = F.pad(
+                    state_dict["positional_embedding"],
+                    (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
+                    value=0,
+                )
+
+        message = model.load_state_dict(state_dict, strict=False)
+        print(f"Load pretrained weights from {pretrained}: {message}")
+    return model.eval()
+
+
+def clip_text_l14(
+    embed_dim=768,
+    context_length=77,
+    vocab_size=49408,
+    transformer_width=768,
+    transformer_heads=12,
+    transformer_layers=12,
+    checkpoint_num=0,
+    pretrained=True,
+):
+    model = CLIP_TEXT(
+        embed_dim,
+        context_length,
+        vocab_size,
+        transformer_width,
+        transformer_heads,
+        transformer_layers,
+        checkpoint_num,
+    )
+    if pretrained:
+        if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
+            pretrained = _MODELS[pretrained]
+        else:
+            pretrained = _MODELS["ViT-L/14"]
+        logger.info(f"Load pretrained weights from {pretrained}")
+        state_dict = torch.load(pretrained, map_location='cpu')
+        if context_length != state_dict["positional_embedding"].size(0):
+            # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
+            print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
+            if context_length < state_dict["positional_embedding"].size(0):
+                state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
+            else:
+                state_dict["positional_embedding"] = F.pad(
+                    state_dict["positional_embedding"],
+                    (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
+                    value=0,
+                )
+
+        message = model.load_state_dict(state_dict, strict=False)
+        print(f"Load pretrained weights from {pretrained}: {message}")
+    return model.eval()
+
+
+def clip_text_l14_336(
+    embed_dim=768,
+    context_length=77,
+    vocab_size=49408,
+    transformer_width=768,
+    transformer_heads=12,
+    transformer_layers=12,
+):
+    raise NotImplementedError
+    model = CLIP_TEXT(
+        embed_dim,
+        context_length,
+        vocab_size,
+        transformer_width,
+        transformer_heads,
+        transformer_layers
+    )
+    pretrained = _MODELS["ViT-L/14_336"]
+    logger.info(f"Load pretrained weights from {pretrained}")
+    state_dict = torch.load(pretrained, map_location='cpu')
+    model.load_state_dict(state_dict, strict=False)
+    return model.eval()
+
+
+def build_clip(config):
+    model_cls = config.text_encoder.clip_teacher
+    model = eval(model_cls)()
+    return model
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/viclip/viclip_vision.py b/cogvideox/video_caption/utils/viclip/viclip_vision.py
new file mode 100644
index 0000000000000000000000000000000000000000..2271b9ecec9ffb1ec4b5b57c7c8e60a77af4bd7a
--- /dev/null
+++ b/cogvideox/video_caption/utils/viclip/viclip_vision.py
@@ -0,0 +1,362 @@
+#!/usr/bin/env python
+import os
+import logging
+from collections import OrderedDict
+
+import torch
+from torch import nn
+from einops import rearrange
+from timm.models.layers import DropPath
+from timm.models.registry import register_model
+
+import torch.utils.checkpoint as checkpoint
+
+# from models.utils import load_temp_embed_with_mismatch
+
+logger = logging.getLogger(__name__)
+
+def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
+    """
+    Add/Remove extra temporal_embeddings as needed.
+    https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
+
+    temp_embed_old: (1, num_frames_old, 1, d)
+    temp_embed_new: (1, num_frames_new, 1, d)
+    add_zero: bool, if True, add zero, else, interpolate trained embeddings.
+    """
+    # TODO zero pad
+    num_frms_new = temp_embed_new.shape[1]
+    num_frms_old = temp_embed_old.shape[1]
+    logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
+    if num_frms_new > num_frms_old:
+        if add_zero:
+            temp_embed_new[
+                :, :num_frms_old
+            ] = temp_embed_old  # untrained embeddings are zeros.
+        else:
+            temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
+    elif num_frms_new < num_frms_old:
+        temp_embed_new = temp_embed_old[:, :num_frms_new]
+    else:  # =
+        temp_embed_new = temp_embed_old
+    return temp_embed_new
+
+
+# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
+MODEL_PATH = ''
+_MODELS = {
+    "ViT-L/14": os.path.join(MODEL_PATH, "ViCLIP-L_InternVid-FLT-10M.pth"),
+    "ViT-B/16": os.path.join(MODEL_PATH, "ViCLIP-B-InternVid-FLT-10M.pth"),
+}
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.):
+        super().__init__()
+
+        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        # logger.info(f'Droppath: {drop_path}')
+        self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
+        self.ln_1 = nn.LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("drop1", nn.Dropout(dropout)),
+            ("c_proj", nn.Linear(d_model * 4, d_model)),
+            ("drop2", nn.Dropout(dropout)),
+        ]))
+        self.ln_2 = nn.LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+    def forward(self, x):
+        x = x + self.drop_path1(self.attention(self.ln_1(x)))
+        x = x + self.drop_path2(self.mlp(self.ln_2(x)))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.):
+        super().__init__()
+        dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]
+        self.resblocks = nn.ModuleList()
+        for idx in range(layers):
+            self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout))
+        self.checkpoint_num = checkpoint_num
+
+    def forward(self, x):
+        for idx, blk in enumerate(self.resblocks):
+            if idx < self.checkpoint_num:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        return x
+
+
+class VisionTransformer(nn.Module):
+    def __init__(
+        self, input_resolution, patch_size, width, layers, heads, output_dim=None, 
+        kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0.,
+        temp_embed=True,
+    ):
+        super().__init__()
+        self.output_dim = output_dim
+        self.conv1 = nn.Conv3d(
+            3, width, 
+            (kernel_size, patch_size, patch_size), 
+            (kernel_size, patch_size, patch_size), 
+            (0, 0, 0), bias=False
+        )
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+        self.ln_pre = nn.LayerNorm(width)
+        if temp_embed:
+            self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width))
+        
+        self.transformer = Transformer(
+            width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num,
+            dropout=dropout)
+
+        self.ln_post = nn.LayerNorm(width)
+        if output_dim is not None:
+            self.proj = nn.Parameter(torch.empty(width, output_dim))
+        else:
+            self.proj = None
+        
+        self.dropout = nn.Dropout(dropout)
+
+    def get_num_layers(self):
+        return len(self.transformer.resblocks)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'}
+    
+    def mask_tokens(self, inputs, masking_prob=0.0):
+        B, L, _ = inputs.shape
+
+        # This is different from text as we are masking a fix number of tokens
+        Lm = int(masking_prob * L)
+        masked_indices = torch.zeros(B, L)
+        indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm]
+        batch_indices = (
+            torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices)
+        )
+        masked_indices[batch_indices, indices] = 1
+
+        masked_indices = masked_indices.bool()
+
+        return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1])
+
+    def forward(self, x, masking_prob=0.0):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        B, C, T, H, W = x.shape
+        x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C)
+
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+
+        # temporal pos
+        cls_tokens = x[:B, :1, :]
+        x = x[:, 1:]
+        x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
+        if hasattr(self, 'temporal_positional_embedding'):
+            if x.size(1) == 1:
+                # This is a workaround for unused parameter issue
+                x = x + self.temporal_positional_embedding.mean(1)
+            else:
+                x = x + self.temporal_positional_embedding
+        x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
+
+        if masking_prob > 0.0:
+            x = self.mask_tokens(x, masking_prob)
+
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  #BND -> NBD
+        x = self.transformer(x)
+
+        x = self.ln_post(x)
+
+        if self.proj is not None:
+            x = self.dropout(x[0]) @ self.proj
+        else:
+            x = x.permute(1, 0, 2)  #NBD -> BND
+
+        return x
+
+
+def inflate_weight(weight_2d, time_dim, center=True):
+    logger.info(f'Init center: {center}')
+    if center:
+        weight_3d = torch.zeros(*weight_2d.shape)
+        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+        middle_idx = time_dim // 2
+        weight_3d[:, :, middle_idx, :, :] = weight_2d
+    else:
+        weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
+        weight_3d = weight_3d / time_dim
+    return weight_3d
+
+
+def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True):
+    state_dict_3d = model.state_dict()
+    for k in state_dict.keys():
+        if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape:
+            if len(state_dict_3d[k].shape) <= 2:
+                logger.info(f'Ignore: {k}')
+                continue
+            logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
+            time_dim = state_dict_3d[k].shape[2]
+            state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center)
+
+    pos_embed_checkpoint = state_dict['positional_embedding']
+    embedding_size = pos_embed_checkpoint.shape[-1]
+    num_patches = (input_resolution // patch_size) ** 2
+    orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
+    new_size = int(num_patches ** 0.5)
+    if orig_size != new_size:
+        logger.info(f'Pos_emb from {orig_size} to {new_size}')
+        extra_tokens = pos_embed_checkpoint[:1]
+        pos_tokens = pos_embed_checkpoint[1:]
+        pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+        pos_tokens = torch.nn.functional.interpolate(
+            pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
+        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
+        state_dict['positional_embedding'] = new_pos_embed
+    
+    message = model.load_state_dict(state_dict, strict=False)
+    logger.info(f"Load pretrained weights: {message}")
+
+
+@register_model
+def clip_joint_b16(
+    pretrained=False, input_resolution=224, kernel_size=1,
+    center=True, num_frames=8, drop_path=0., checkpoint_num=0,
+    dropout=0.,
+):
+    model = VisionTransformer(
+        input_resolution=input_resolution, patch_size=16, 
+        width=768, layers=12, heads=12, output_dim=512,
+        kernel_size=kernel_size, num_frames=num_frames, 
+        drop_path=drop_path, checkpoint_num=checkpoint_num,
+        dropout=dropout,
+    )
+    # raise NotImplementedError
+    if pretrained:
+        if isinstance(pretrained, str):
+            model_name = pretrained
+        else:
+            model_name = "ViT-B/16"
+        
+        logger.info('load pretrained weights')
+        state_dict = torch.load(_MODELS[model_name], map_location='cpu')
+        load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center)
+    return model.eval()
+
+
+@register_model
+def clip_joint_l14(
+    pretrained=False, input_resolution=224, kernel_size=1,
+    center=True, num_frames=8, drop_path=0., checkpoint_num=0,
+    dropout=0.,
+):
+    model = VisionTransformer(
+        input_resolution=input_resolution, patch_size=14,
+        width=1024, layers=24, heads=16, output_dim=768,
+        kernel_size=kernel_size, num_frames=num_frames, 
+        drop_path=drop_path, checkpoint_num=checkpoint_num,
+        dropout=dropout,
+    )
+    
+    if pretrained:
+        if isinstance(pretrained, str):
+            model_name = pretrained
+        else:
+            model_name = "ViT-L/14"
+        logger.info('load pretrained weights')
+        state_dict = torch.load(_MODELS[model_name], map_location='cpu')
+        load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
+    return model.eval()
+
+
+@register_model
+def clip_joint_l14_336(
+    pretrained=True, input_resolution=336, kernel_size=1,
+    center=True, num_frames=8, drop_path=0.
+):
+    raise NotImplementedError
+    model = VisionTransformer(
+        input_resolution=input_resolution, patch_size=14, 
+        width=1024, layers=24, heads=16, output_dim=768,
+        kernel_size=kernel_size, num_frames=num_frames,
+        drop_path=drop_path,
+    )
+    if pretrained:
+        logger.info('load pretrained weights')
+        state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
+        load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
+    return model.eval()
+
+
+def interpolate_pos_embed_vit(state_dict, new_model):
+    key = "vision_encoder.temporal_positional_embedding"
+    if key in state_dict:
+        vision_temp_embed_new = new_model.state_dict()[key]
+        vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2)  # [1, n, d] -> [1, n, 1, d]
+        vision_temp_embed_old = state_dict[key]
+        vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2)
+
+        state_dict[key] = load_temp_embed_with_mismatch(
+            vision_temp_embed_old, vision_temp_embed_new, add_zero=False
+        ).squeeze(2)
+
+    key = "text_encoder.positional_embedding"
+    if key in state_dict:
+        text_temp_embed_new = new_model.state_dict()[key]
+        text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2)  # [n, d] -> [1, n, 1, d]
+        text_temp_embed_old = state_dict[key]
+        text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2)
+
+        state_dict[key] = load_temp_embed_with_mismatch(
+            text_temp_embed_old, text_temp_embed_new, add_zero=False
+        ).squeeze(2).squeeze(0)
+    return state_dict
+
+
+if __name__ == '__main__':
+    import time
+    from fvcore.nn import FlopCountAnalysis
+    from fvcore.nn import flop_count_table
+    import numpy as np
+
+    seed = 4217
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    num_frames = 8
+
+    # model = clip_joint_b16(pretrained=True, kernel_size=1, num_frames=8, num_classes=400, drop_path=0.1)
+    # logger.info(model)
+    model = clip_joint_l14(pretrained=False)
+
+    flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
+    s = time.time()
+    logger.info(flop_count_table(flops, max_depth=1))
+    logger.info(time.time()-s)
+    # logger.info(model(torch.rand(1, 3, num_frames, 224, 224)).shape)
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/video_dataset.py b/cogvideox/video_caption/utils/video_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a581ea0e843dab9443cf856c0d207a7f33564cc
--- /dev/null
+++ b/cogvideox/video_caption/utils/video_dataset.py
@@ -0,0 +1,101 @@
+import os
+from typing import Optional
+from pathlib import Path
+
+from func_timeout import func_timeout, FunctionTimedOut
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader
+
+from .logger import logger
+from .video_utils import extract_frames
+
+
+ALL_VIDEO_EXT = set(["mp4", "webm", "mkv", "avi", "flv", "mov"])
+VIDEO_READER_TIMEOUT = 300
+
+
+def collate_fn(batch):
+    batch = list(filter(lambda x: x is not None, batch))
+    if len(batch) != 0:
+        return {k: [item[k] for item in batch] for k in batch[0].keys()}
+    return {}
+
+
+class VideoDataset(Dataset):
+    def __init__(
+        self,
+        dataset_inputs: dict[str, list[str]],
+        video_folder: Optional[str] = None,
+        video_path_column: str = "video_path",
+        text_column: Optional[str] = None,
+        sample_method: str = "mid",
+        num_sampled_frames: int = 1,
+        num_sample_stride: Optional[int] = None
+    ):
+        length = len(dataset_inputs[list(dataset_inputs.keys())[0]])
+        if not all(len(v) == length for v in dataset_inputs.values()):
+            raise ValueError("All values in the dataset_inputs must have the same length.")
+        
+        self.video_path_column = video_path_column
+        self.video_folder = video_folder
+        self.video_path_list = dataset_inputs[video_path_column]
+        if self.video_folder is not None:
+            self.video_path_list = [os.path.join(self.video_folder, video_path) for video_path in self.video_path_list]
+        self.text_column = text_column
+        self.text_list = dataset_inputs[self.text_column] if self.text_column is not None else None
+
+        self.sample_method = sample_method
+        self.num_sampled_frames = num_sampled_frames
+        self.num_sample_stride = num_sample_stride
+
+    def __getitem__(self, index):
+        video_path = self.video_path_list[index]
+        if self.sample_method == "image":
+            try:
+                sampled_frame_idx_list = None
+                with open(video_path, "rb") as f:
+                    sampled_frame_list = [Image.open(f).convert("RGB")]
+            except Exception as e:
+                logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.")
+                return None
+        else:
+            # It is a trick to deal with decord hanging when reading some abnormal videos.
+            try:
+                sample_args = (video_path, self.sample_method, self.num_sampled_frames, self.num_sample_stride)
+                sampled_frame_idx_list, sampled_frame_list = func_timeout(
+                    VIDEO_READER_TIMEOUT, extract_frames, args=sample_args
+                )
+            except FunctionTimedOut:
+                logger.warning(f"Read {video_path} timeout.")
+                return None
+            except Exception as e:
+                logger.warning(f"Failed to extract frames from video {video_path}. Error is {e}.")
+                return None
+            
+        item = {
+            "path": video_path,
+            "sampled_frame_idx": sampled_frame_idx_list,
+            "sampled_frame": sampled_frame_list,
+        }
+        if self.text_list is not None:
+            item["text"] = self.text_list[index]
+
+        return item
+
+    def __len__(self):
+        return len(self.video_path_list)
+
+
+if __name__ == "__main__":
+    video_folder = Path("your_video_folder")
+    video_path_list = []
+    for ext in ALL_VIDEO_EXT:
+        video_path_list += [str(file.relative_to(video_folder)) for file in video_folder.glob(f"*.{ext}")]
+
+    video_dataset = VideoDataset(dataset_inputs={"video_path": video_path_list})
+    video_dataloader = DataLoader(
+        video_dataset, batch_size=16, num_workers=16, collate_fn=collate_fn
+    )
+    for idx, batch in enumerate(video_dataloader):
+        if len(batch) != 0:
+            print(batch["video_path"], batch["sampled_frame_idx"], len(batch["video_path"]))
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/video_evaluator.py b/cogvideox/video_caption/utils/video_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b644907c639cab7c3829db84fea2569ffa91fbf
--- /dev/null
+++ b/cogvideox/video_caption/utils/video_evaluator.py
@@ -0,0 +1,120 @@
+import os
+from typing import List
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+from torchvision.datasets.utils import download_url
+
+from .longclip import longclip
+from .viclip import get_viclip
+from .video_utils import extract_frames
+
+# All metrics.
+__all__ = ["VideoCLIPXLScore"]
+
+_MODELS = {
+    "ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth",
+    "LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt",
+    "VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin",
+}
+_MD5 = {
+    "ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0",
+    "LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3",
+    "VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35",
+}
+
+def normalize(
+    data: np.array,
+    mean: list[float] = [0.485, 0.456, 0.406],
+    std: list[float] = [0.229, 0.224, 0.225]
+):
+    v_mean = np.array(mean).reshape(1, 1, 3)
+    v_std = np.array(std).reshape(1, 1, 3)
+
+    return (data / 255.0 - v_mean) / v_std
+
+
+class VideoCLIPXL(nn.Module):
+    def __init__(self, root: str = "~/.cache/clip"):
+        super(VideoCLIPXL, self).__init__()
+
+        self.root = os.path.expanduser(root)
+        if not os.path.exists(self.root):
+            os.makedirs(self.root)
+        
+        k = "LongCLIP-L"
+        filename = os.path.basename(_MODELS[k])
+        download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
+        self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float()
+
+        k = "ViClip-InternVid-10M-FLT"
+        filename = os.path.basename(_MODELS[k])
+        download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
+        self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float()
+
+        # delete unused encoder
+        del self.model.visual
+        del self.viclip_model.text_encoder
+
+
+class VideoCLIPXLScore():
+    def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
+        self.root = os.path.expanduser(root)
+        if not os.path.exists(self.root):
+            os.makedirs(self.root)
+
+        k = "VideoCLIP-XL-v2"
+        filename = os.path.basename(_MODELS[k])
+        download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
+        self.model = VideoCLIPXL()
+        state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu")
+        self.model.load_state_dict(state_dict)
+        self.model.to(device)
+
+        self.device = device
+    
+    def __call__(self, videos: List[List[Image.Image]], texts: List[str]):
+        assert len(videos) == len(texts)
+
+        # Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3].
+        videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos]
+        resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos]
+        resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos]
+
+        video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos])
+        video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True)  # BTCHW
+
+        with torch.no_grad():
+            vid_features = torch.stack(
+                [self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs]
+            )
+            vid_features.squeeze_()
+            # vid_features = self.model.viclip_model.get_vid_features(video_inputs).float()
+            text_inputs = longclip.tokenize(texts, truncate=True).to(self.device)
+            text_features = self.model.model.encode_text(text_inputs)
+            text_features = text_features / text_features.norm(dim=1, keepdim=True)
+            scores = text_features @ vid_features.T
+        
+        return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist()
+    
+    def __repr__(self):
+        return "videoclipxl_score"
+
+
+if __name__ == "__main__":
+    videos = ["your_video_path"] * 3
+    texts = [
+        "a joker",
+        "glasses and flower",
+        "The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues."
+    ]
+
+    video_clip_xl_score = VideoCLIPXLScore(device="cuda")
+    batch_frames = []
+    for v in videos:
+        sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1]
+        batch_frames.append(sampled_frames)
+    print(video_clip_xl_score(batch_frames, texts))
\ No newline at end of file
diff --git a/cogvideox/video_caption/utils/video_utils.py b/cogvideox/video_caption/utils/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b23d39332dcb2c4bbf75bb1af37dffff174c8960
--- /dev/null
+++ b/cogvideox/video_caption/utils/video_utils.py
@@ -0,0 +1,44 @@
+import gc
+import random
+from contextlib import contextmanager
+from typing import List, Tuple, Optional
+
+import numpy as np
+from decord import VideoReader
+from PIL import Image
+
+
+@contextmanager
+def video_reader(*args, **kwargs):
+    """A context manager to solve the memory leak of decord.
+    """
+    vr = VideoReader(*args, **kwargs)
+    try:
+        yield vr
+    finally:
+        del vr
+        gc.collect()
+
+
+def extract_frames(
+    video_path: str,
+    sample_method: str = "mid",
+    num_sampled_frames: int = -1,
+    sample_stride: int = -1,
+    **kwargs
+) -> Optional[Tuple[List[int], List[Image.Image]]]:
+    with video_reader(video_path, num_threads=2, **kwargs) as vr:
+        if sample_method == "mid":
+            sampled_frame_idx_list = [len(vr) // 2]
+        elif sample_method == "uniform":
+            sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
+        elif sample_method == "random":
+            clip_length = min(len(vr), (num_sampled_frames - 1) * sample_stride + 1)
+            start_idx = random.randint(0, len(vr) - clip_length)
+            sampled_frame_idx_list = np.linspace(start_idx, start_idx + clip_length - 1, num_sampled_frames, dtype=int)
+        else:
+            raise ValueError(f"The sample_method {sample_method} must be mid, uniform or random.")
+        sampled_frame_list = vr.get_batch(sampled_frame_idx_list).asnumpy()
+        sampled_frame_list = [Image.fromarray(frame) for frame in sampled_frame_list]
+
+        return list(sampled_frame_idx_list), sampled_frame_list
diff --git a/cogvideox/video_caption/video_splitting.py b/cogvideox/video_caption/video_splitting.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bca41ced93b453bd3e27ef9df6399797174b54
--- /dev/null
+++ b/cogvideox/video_caption/video_splitting.py
@@ -0,0 +1,169 @@
+import argparse
+import os
+import subprocess
+from datetime import datetime, timedelta
+from pathlib import Path
+from multiprocessing import Pool
+
+import pandas as pd
+from tqdm import tqdm
+
+from utils.logger import logger
+
+
+MIN_SECONDS = int(os.getenv("MIN_SECONDS", 3))
+MAX_SECONDS = int(os.getenv("MAX_SECONDS", 10))
+
+
+def get_command(start_time, video_path, video_duration, output_path):
+    # Use FFmpeg to split the video. Re-encoding is needed to ensure the accuracy of the clip
+    # at the cost of consuming computational resources.
+    return [
+        'ffmpeg',
+        '-hide_banner',
+        '-loglevel', 'panic',
+        '-ss', str(start_time.time()),
+        '-i', video_path,
+        '-t', str(video_duration),
+        '-c:v', 'libx264',
+        '-preset', 'veryfast',
+        '-crf', '22',
+        '-c:a', 'aac',
+        '-sn',
+        output_path
+    ]
+
+
+def clip_video_star(args):
+    return clip_video(*args)
+
+
+def clip_video(video_path, timecode_list, output_folder, video_duration):
+    """Recursively clip the video within the range of [MIN_SECONDS, MAX_SECONDS], 
+    according to the timecode obtained from cogvideox/video_caption/cutscene_detect.py.
+    """
+    try:
+        video_name = Path(video_path).stem
+
+        if len(timecode_list) == 0:  # The video of a single scene.
+            splitted_timecode_list = []
+            start_time = datetime.strptime("00:00:00.000", "%H:%M:%S.%f")
+            end_time = datetime.strptime(video_duration, "%H:%M:%S.%f")
+            cur_start = start_time
+            splitted_index = 0
+            while cur_start < end_time:
+                cur_end = min(cur_start + timedelta(seconds=MAX_SECONDS), end_time)
+                cur_video_duration = (cur_end - cur_start).total_seconds()
+                if cur_video_duration < MIN_SECONDS:
+                    cur_start = cur_end
+                    splitted_index += 1
+                    continue
+                splitted_timecode_list.append([cur_start.strftime("%H:%M:%S.%f")[:-3], cur_end.strftime("%H:%M:%S.%f")[:-3]])
+                output_path = os.path.join(output_folder, video_name + f"_{splitted_index}.mp4")
+                if os.path.exists(output_path):
+                    logger.info(f"The clipped video {output_path} exists.")
+                    cur_start = cur_end
+                    splitted_index += 1
+                    continue
+                else:
+                    command = get_command(cur_start, video_path, cur_video_duration, output_path)
+                    try:
+                        subprocess.run(command, check=True)
+                    except Exception as e:
+                        logger.warning(f"Run {command} error: {e}.")
+                    finally:
+                        cur_start = cur_end
+                        splitted_index += 1
+
+        for i, timecode in enumerate(timecode_list):  # The video of multiple scenes.
+            start_time = datetime.strptime(timecode[0], "%H:%M:%S.%f")
+            end_time = datetime.strptime(timecode[1], "%H:%M:%S.%f")
+            video_duration = (end_time - start_time).total_seconds()
+            output_path = os.path.join(output_folder, video_name + f"_{i}.mp4")
+            if os.path.exists(output_path):
+                logger.info(f"The clipped video {output_path} exists.")
+                continue
+            if video_duration < MIN_SECONDS:
+                continue
+            if video_duration > MAX_SECONDS:
+                splitted_timecode_list = []
+                cur_start = start_time
+                splitted_index = 0
+                while cur_start < end_time:
+                    cur_end = min(cur_start + timedelta(seconds=MAX_SECONDS), end_time)
+                    cur_video_duration = (cur_end - cur_start).total_seconds()
+                    if cur_video_duration < MIN_SECONDS:
+                        break
+                    splitted_timecode_list.append([cur_start.strftime("%H:%M:%S.%f")[:-3], cur_end.strftime("%H:%M:%S.%f")[:-3]])
+                    splitted_output_path = os.path.join(output_folder, video_name + f"_{i}_{splitted_index}.mp4")
+                    if os.path.exists(splitted_output_path):
+                        logger.info(f"The clipped video {splitted_output_path} exists.")
+                        cur_start = cur_end
+                        splitted_index += 1
+                        continue
+                    else:
+                        command = get_command(cur_start, video_path, cur_video_duration, splitted_output_path)
+                        try:
+                            subprocess.run(command, check=True)
+                        except Exception as e:
+                            logger.warning(f"Run {command} error: {e}.")
+                        finally:
+                            cur_start = cur_end
+                            splitted_index += 1
+                
+                continue
+            
+            # We found that the current scene detected by PySceneDetect includes a few frames from
+            # the next scene occasionally. Directly discard the last few frames of the current scene.
+            video_duration = video_duration - 0.5
+            command = get_command(start_time, video_path, video_duration, output_path)
+            subprocess.run(command, check=True)
+    except Exception as e:
+        logger.warning(f"Clip video with {video_path}. Error is: {e}.")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Video Splitting")
+    parser.add_argument(
+        "--video_metadata_path", type=str, default=None, help="The path to the video dataset metadata (csv/jsonl)."
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument("--video_folder", type=str, default="", help="The video folder.")
+    parser.add_argument("--output_folder", type=str, default="outputs")
+    parser.add_argument("--n_jobs", type=int, default=16)
+
+    parser.add_argument("--resolution_threshold", type=float, default=0, help="The resolution threshold.")
+
+    args = parser.parse_args()
+
+    video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    num_videos = len(video_metadata_df)
+    video_metadata_df["resolution"] = video_metadata_df["frame_size"].apply(lambda x: x[0] * x[1])
+    video_metadata_df = video_metadata_df[video_metadata_df["resolution"] >= args.resolution_threshold]
+    logger.info(f"Filter {num_videos - len(video_metadata_df)} videos with resolution smaller than {args.resolution_threshold}.")
+    video_path_list = video_metadata_df[args.video_path_column].to_list()
+    video_id_list = [Path(video_path).stem for video_path in video_path_list]
+    if len(video_id_list) != len(list(set(video_id_list))):
+        logger.warning("Duplicate file names exist in the input video path list.")
+    video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
+    video_timecode_list = video_metadata_df["timecode_list"].to_list()
+    video_duration_list = video_metadata_df["duration"].to_list()
+
+    assert len(video_path_list) == len(video_timecode_list)
+    os.makedirs(args.output_folder, exist_ok=True)
+    args_list = [
+        (video_path, timecode_list, args.output_folder, video_duration)
+        for video_path, timecode_list, video_duration in zip(
+            video_path_list, video_timecode_list, video_duration_list
+        )
+    ]
+    with Pool(args.n_jobs) as pool:
+        # results = list(tqdm(pool.imap(clip_video_star, args_list), total=len(video_path_list)))
+        results = pool.imap(clip_video_star, args_list)
+        for result in tqdm(results, total=len(video_path_list)):
+            pass
\ No newline at end of file
diff --git a/cogvideox/video_caption/vila_video_recaptioning.py b/cogvideox/video_caption/vila_video_recaptioning.py
new file mode 100644
index 0000000000000000000000000000000000000000..a172e0c3b043d0a7fcefc548b81d8991f0ae12c8
--- /dev/null
+++ b/cogvideox/video_caption/vila_video_recaptioning.py
@@ -0,0 +1,354 @@
+# Modified from https://github.com/mit-han-lab/llm-awq/blob/main/tinychat/vlm_demo_new.py.
+import argparse
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+from accelerate import load_checkpoint_and_dispatch, PartialState
+from accelerate.utils import gather_object
+from decord import VideoReader
+from PIL import Image
+from natsort import natsorted
+from tqdm import tqdm
+from transformers import AutoConfig, AutoTokenizer
+
+import tinychat.utils.constants
+# from tinychat.models.llava_llama import LlavaLlamaForCausalLM
+from tinychat.models.vila_llama import VilaLlamaForCausalLM
+from tinychat.stream_generators.llava_stream_gen import LlavaStreamGenerator
+from tinychat.utils.conversation_utils import gen_params
+from tinychat.utils.llava_image_processing import process_images
+from tinychat.utils.prompt_templates import (
+    get_image_token,
+    get_prompter,
+    get_stop_token_ids,
+)
+from tinychat.utils.tune import (
+    device_warmup,
+    tune_llava_patch_embedding,
+)
+
+from utils.filter import filter
+from utils.logger import logger
+
+gen_params.seed = 1
+gen_params.temp = 1.0
+gen_params.top_p = 1.0
+
+
+def extract_uniform_frames(video_path: str, num_sampled_frames: int = 8):
+    vr = VideoReader(video_path)
+    sampled_frame_idx_list = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
+    sampled_frame_list = []
+    for idx in sampled_frame_idx_list:
+        sampled_frame = Image.fromarray(vr[idx].asnumpy())
+        sampled_frame_list.append(sampled_frame)
+
+    return sampled_frame_list
+
+
+def stream_output(output_stream):
+    for outputs in output_stream:
+        output_text = outputs["text"]
+        output_text = output_text.strip().split(" ")
+        # print(f"output_text: {output_text}.")
+    return " ".join(output_text)
+
+
+def skip(*args, **kwargs):
+    pass
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Recaption videos with VILA1.5.")
+    parser.add_argument(
+        "--video_metadata_path",
+        type=str,
+        default=None,
+        help="The path to the video dataset metadata (csv/jsonl).",
+    )
+    parser.add_argument(
+        "--video_path_column",
+        type=str,
+        default="video_path",
+        help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).",
+    )
+    parser.add_argument(
+        "--caption_column",
+        type=str,
+        default="caption",
+        help="The column contains the caption.",
+    )
+    parser.add_argument(
+        "--video_folder", type=str, default="", help="The video folder."
+    )
+    parser.add_argument("--input_prompt", type=str, default="<video>\\n Elaborate on the visual and narrative elements of the video in detail.")
+    parser.add_argument(
+        "--model_type", type=str, default="LLaMa", help="type of the model"
+    )
+    parser.add_argument(
+        "--model_path", type=str, default="Efficient-Large-Model/Llama-3-VILA1.5-8b-AWQ"
+    )
+    parser.add_argument(
+        "--quant_path",
+        type=str,
+        default=None,
+    )
+    parser.add_argument(
+        "--precision", type=str, default="W4A16", help="compute precision"
+    )
+    parser.add_argument("--num_sampled_frames", type=int, default=8)
+    parser.add_argument(
+        "--saved_path",
+        type=str,
+        required=True,
+        help="The save path to the output results (csv/jsonl).",
+    )
+    parser.add_argument(
+        "--saved_freq",
+        type=int,
+        default=100,
+        help="The frequency to save the output results.",
+    )
+
+    parser.add_argument(
+        "--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.")
+    parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.")
+    parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.")
+    parser.add_argument(
+        "--asethetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score", type=float, default=4.0, help="The asethetic score threshold.")
+    parser.add_argument(
+        "--asethetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_asethetic_score_siglip", type=float, default=4.0, help="The asethetic score (SigLIP) threshold.")
+    parser.add_argument(
+        "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.")
+    parser.add_argument(
+        "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)."
+    )
+    parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.")
+    
+    args = parser.parse_args()
+    return args
+
+
+def main(args):
+    if args.video_metadata_path.endswith(".csv"):
+        video_metadata_df = pd.read_csv(args.video_metadata_path)
+    elif args.video_metadata_path.endswith(".jsonl"):
+        video_metadata_df = pd.read_json(args.video_metadata_path, lines=True)
+    else:
+        raise ValueError("The video_metadata_path must end with .csv or .jsonl.")
+    video_path_list = video_metadata_df[args.video_path_column].tolist()
+    video_path_list = [os.path.basename(video_path) for video_path in video_path_list]
+
+    if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")):
+        raise ValueError("The saved_path must end with .csv or .jsonl.")
+
+    if os.path.exists(args.saved_path):
+        if args.saved_path.endswith(".csv"):
+            saved_metadata_df = pd.read_csv(args.saved_path)
+        elif args.saved_path.endswith(".jsonl"):
+            saved_metadata_df = pd.read_json(args.saved_path, lines=True)
+        saved_video_path_list = saved_metadata_df[args.video_path_column].tolist()
+        video_path_list = list(set(video_path_list).difference(set(saved_video_path_list)))
+        logger.info(
+            f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed."
+        )
+    
+    video_path_list = filter(
+        video_path_list,
+        basic_metadata_path=args.basic_metadata_path,
+        min_resolution=args.min_resolution,
+        min_duration=args.min_duration,
+        max_duration=args.max_duration,
+        asethetic_score_metadata_path=args.asethetic_score_metadata_path,
+        min_asethetic_score=args.min_asethetic_score,
+        asethetic_score_siglip_metadata_path=args.asethetic_score_siglip_metadata_path,
+        min_asethetic_score_siglip=args.min_asethetic_score_siglip,
+        text_score_metadata_path=args.text_score_metadata_path,
+        min_text_score=args.min_text_score,
+        motion_score_metadata_path=args.motion_score_metadata_path,
+        min_motion_score=args.min_motion_score,
+    )
+    video_path_list = [os.path.join(args.video_folder, video_path) for video_path in video_path_list]
+    # Sorting to guarantee the same result for each process.
+    video_path_list = natsorted(video_path_list)
+
+    state = PartialState()
+
+    # Accelerate model initialization
+    setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+    setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+    torch.nn.init.kaiming_uniform_ = skip
+    torch.nn.init.kaiming_normal_ = skip
+    torch.nn.init.uniform_ = skip
+    torch.nn.init.normal_ = skip
+
+    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.model_path, "llm"), use_fast=False)
+    tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = (
+        tokenizer.convert_tokens_to_ids(
+            [tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN]
+        )[0]
+    )
+    config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
+    model = VilaLlamaForCausalLM(config).half()
+    tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN_IDX = (
+        tokenizer.convert_tokens_to_ids(
+            [tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_PATCH_TOKEN]
+        )[0]
+    )
+    vision_tower = model.get_vision_tower()
+    # if not vision_tower.is_loaded:
+    #     vision_tower.load_model()
+    image_processor = vision_tower.image_processor
+    # vision_tower = vision_tower.half()
+
+    if args.precision == "W16A16":
+        pbar = tqdm(range(1))
+        pbar.set_description("Loading checkpoint shards")
+        for i in pbar:
+            model.llm = load_checkpoint_and_dispatch(
+                model.llm,
+                os.path.join(args.model_path, "llm"),
+                no_split_module_classes=[
+                    "OPTDecoderLayer",
+                    "LlamaDecoderLayer",
+                    "BloomBlock",
+                    "MPTBlock",
+                    "DecoderLayer",
+                    "CLIPEncoderLayer",
+                ],
+            ).to(state.device)
+        model = model.to(state.device)
+
+    elif args.precision == "W4A16":
+        from tinychat.utils.load_quant import load_awq_model
+        # Auto load quant_path from the 3b/8b/13b/40b model.
+        if args.quant_path is None:
+            if "VILA1.5-3b-s2-AWQ" in args.model_path:
+                args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-3b-s2-w4-g128-awq-v2.pt")
+            elif "VILA1.5-3b-AWQ" in args.model_path:
+                args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-3b-w4-g128-awq-v2.pt")
+            elif "Llama-3-VILA1.5-8b-AWQ" in args.model_path:
+                args.quant_path = os.path.join(args.model_path, "llm/llama-3-vila1.5-8b-w4-g128-awq-v2.pt")
+            elif "VILA1.5-13b-AWQ" in args.model_path:
+                args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-13b-w4-g128-awq-v2.pt")
+            elif "VILA1.5-40b-AWQ" in args.model_path:
+                args.quant_path = os.path.join(args.model_path, "llm/vila-1.5-40b-w4-g128-awq-v2.pt")
+        model.llm = load_awq_model(model.llm, args.quant_path, 4, 128, state.device)
+        from tinychat.modules import (
+            make_fused_mlp,
+            make_fused_vision_attn,
+            make_quant_attn,
+            make_quant_norm,
+        )
+
+        make_quant_attn(model.llm, state.device)
+        make_quant_norm(model.llm)
+        # make_fused_mlp(model)
+        # make_fused_vision_attn(model,state.device)
+        model = model.to(state.device)
+
+    else:
+        raise NotImplementedError(f"Precision {args.precision} is not supported.")
+    
+    device_warmup(state.device)
+    tune_llava_patch_embedding(vision_tower, device=state.device)
+
+    stream_generator = LlavaStreamGenerator
+
+    model_prompter = get_prompter(
+        args.model_type, args.model_path, False, False
+    )
+    stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
+
+    model.eval()
+
+    index = len(video_path_list) - len(video_path_list) % state.num_processes
+    # Avoid the NCCL timeout in the final gather operation.
+    logger.info(f"Drop {len(video_path_list) % state.num_processes} videos to ensure each process handles the same number of videos.")
+    video_path_list = video_path_list[:index]
+    logger.info(f"{len(video_path_list)} videos are to be processed.")
+    
+    result_dict = {args.video_path_column: [], args.caption_column: []}
+    with state.split_between_processes(video_path_list) as splitted_video_path_list:
+        # TODO: Use VideoDataset.
+        for i, video_path in enumerate(tqdm(splitted_video_path_list)):
+            try:
+                image_list = extract_uniform_frames(video_path, args.num_sampled_frames)
+                image_num = len(image_list)
+                # Similar operation in model_worker.py
+                image_tensor = process_images(image_list, image_processor, model.config)
+                if type(image_tensor) is list:
+                    image_tensor = [
+                        image.to(state.device, dtype=torch.float16) for image in image_tensor
+                    ]
+                else:
+                    image_tensor = image_tensor.to(state.device, dtype=torch.float16)
+
+                input_prompt = args.input_prompt
+                # Insert image here
+                image_token = get_image_token(model, args.model_path)
+                image_token_holder = tinychat.utils.constants.LLAVA_DEFAULT_IM_TOKEN_PLACE_HOLDER
+                im_token_count = input_prompt.count(image_token_holder)
+                if im_token_count == 0:
+                    model_prompter.insert_prompt(image_token * image_num + input_prompt)
+                else:
+                    assert im_token_count == image_num
+                    input_prompt = input_prompt.replace(image_token_holder, image_token)
+                    model_prompter.insert_prompt(input_prompt)
+                output_stream = stream_generator(
+                    model,
+                    tokenizer,
+                    model_prompter.model_input,
+                    gen_params,
+                    device=state.device,
+                    stop_token_ids=stop_token_ids,
+                    image_tensor=image_tensor,
+                )
+                outputs = stream_output(output_stream)
+                if len(outputs) != 0:
+                    result_dict[args.video_path_column].append(Path(video_path).name)
+                    result_dict[args.caption_column].append(outputs)
+            
+            except Exception as e:
+                logger.warning(f"VILA with {video_path} failed. Error is {e}.")
+
+            if i != 0 and i % args.saved_freq == 0:
+                state.wait_for_everyone()
+                gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
+                if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0:
+                    result_df = pd.DataFrame(gathered_result_dict)
+                    if args.saved_path.endswith(".csv"):
+                        header = False if os.path.exists(args.saved_path) else True
+                        result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+                    elif args.saved_path.endswith(".jsonl"):
+                        result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+                    logger.info(f"Save result to {args.saved_path}.")
+                for k in result_dict.keys():
+                    result_dict[k] = []
+    
+    state.wait_for_everyone()
+    gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()}
+    if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0:
+        result_df = pd.DataFrame(gathered_result_dict)
+        if args.saved_path.endswith(".csv"):
+            header = False if os.path.exists(args.saved_path) else True
+            result_df.to_csv(args.saved_path, header=header, index=False, mode="a")
+        elif args.saved_path.endswith(".jsonl"):
+            result_df.to_json(args.saved_path, orient="records", lines=True, mode="a", force_ascii=False)
+        logger.info(f"Save result to {args.saved_path}.")
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    main(args)
diff --git a/comfyui/README.md b/comfyui/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..305664cb99f12e25ff9e7fc7e1a698caed359e4d
--- /dev/null
+++ b/comfyui/README.md
@@ -0,0 +1,84 @@
+# ComfyUI CogVideoX-Fun
+Easily use CogVideoX-Fun inside ComfyUI!
+
+- [Installation](#1-installation)
+- [Node types](#node-types)
+- [Example workflows](#example-workflows)
+
+## 1. Installation
+
+### Option 1: Install via ComfyUI Manager
+TBD
+
+### Option 2: Install manually
+The CogVideoX-Fun repository needs to be placed at `ComfyUI/custom_nodes/CogVideoX-Fun/`.
+
+```
+cd ComfyUI/custom_nodes/
+
+# Git clone the cogvideox_fun itself
+git clone https://github.com/aigc-apps/CogVideoX-Fun.git
+
+# Git clone the video outout node
+git clone https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git
+
+cd CogVideoX-Fun/
+python install.py
+```
+
+### 2. Download models into `ComfyUI/models/CogVideoX_Fun/`
+
+V1.1:
+
+| 名称 | 存储空间 | Hugging Face | Model Scope | 描述 |
+|--|--|--|--|--|
+| CogVideoX-Fun-V1.1-2b-InP.tar.gz | Before extraction:9.7 GB \/ After extraction: 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-2b-InP) | Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. Noise has been added to the reference image, and the amplitude of motion is greater compared to V1.0. |
+| CogVideoX-Fun-V1.1-5b-InP.tar.gz | Before extraction:16.0 GB \/ After extraction: 20.0 GB  | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-5b-InP) | Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. Noise has been added to the reference image, and the amplitude of motion is greater compared to V1.0. |
+| CogVideoX-Fun-V1.1-2b-Pose.tar.gz | Before extraction:9.7 GB \/ After extraction: 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-2b-Pose) | Our official pose-control video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second.|
+| CogVideoX-Fun-V1.1-5b-Pose.tar.gz | Before extraction:16.0 GB \/ After extraction: 20.0 GB  | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-V1.1-5b-Pose) | Our official pose-control video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second.|
+
+V1.0:
+
+| Name | Storage Space | Hugging Face | Model Scope | Description |
+|--|--|--|--|--|
+| CogVideoX-Fun-2b-InP.tar.gz | Before extraction:9.7 GB \/ After extraction: 13.0 GB | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-2b-InP) | [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-2b-InP) | Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. |
+| CogVideoX-Fun-5b-InP.tar.gz | Before extraction:16.0 GB \/ After extraction: 20.0 GB  | [🤗Link](https://huggingface.co/alibaba-pai/CogVideoX-Fun-5b-InP)| [😄Link](https://modelscope.cn/models/PAI/CogVideoX-Fun-5b-InP)| Our official graph-generated video model is capable of predicting videos at multiple resolutions (512, 768, 1024, 1280) and has been trained on 49 frames at a rate of 8 frames per second. |
+
+## Node types
+- **LoadCogVideoX_Fun_Model**
+    - Loads the CogVideoX-Fun model
+- **CogVideoX_FUN_TextBox**
+    - Write the prompt for CogVideoX-Fun model
+- **CogVideoX_Fun_I2VSampler**
+    - CogVideoX-Fun Sampler for Image to Video 
+- **CogVideoX_Fun_T2VSampler**
+    - CogVideoX-Fun Sampler for Text to Video
+- **CogVideoX_Fun_V2VSampler**
+    - CogVideoX-Fun Sampler for Video to Video
+
+## Example workflows
+
+### Video to video generation
+Our ui is shown as follow, this is the [download link](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_v2v.json) of the json:
+![workflow graph](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_v2v.jpg)
+
+You can run the demo using following video:
+[demo video](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/play_guitar.mp4)
+
+### Control video generation
+Our ui is shown as follow, this is the [download link](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_v2v_control.json) of the json:
+![workflow graph](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_v2v_control.jpg)
+
+You can run the demo using following video:
+[demo video](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4)
+
+### Image to video generation
+Our ui is shown as follow, this is the [download link](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_i2v.json) of the json:
+![workflow graph](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_i2v.jpg)
+
+You can run the demo using following photo:
+![demo image](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/firework.png)
+
+### Text to video generation
+Our ui is shown as follow, this is the [download link](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_t2v.json) of the json:
+![workflow graph](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/cogvideoxfunv1.1_workflow_t2v.jpg)
\ No newline at end of file
diff --git a/comfyui/comfyui_nodes.py b/comfyui/comfyui_nodes.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e7eeb13f5fb43b0be8337394a5b6a24e23fc1e3
--- /dev/null
+++ b/comfyui/comfyui_nodes.py
@@ -0,0 +1,635 @@
+"""Modified from https://github.com/kijai/ComfyUI-EasyAnimateWrapper/blob/main/nodes.py
+"""
+import gc
+import json
+import os
+
+import comfy.model_management as mm
+import cv2
+import folder_paths
+import numpy as np
+import torch
+from comfy.utils import ProgressBar, load_torch_file
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from einops import rearrange
+from omegaconf import OmegaConf
+from PIL import Image
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ..cogvideox.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
+from ..cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from ..cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from ..cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from ..cogvideox.pipeline.pipeline_cogvideox_control import \
+    CogVideoX_Fun_Pipeline_Control
+from ..cogvideox.pipeline.pipeline_cogvideox_inpaint import (
+    CogVideoX_Fun_Pipeline_Inpaint)
+from ..cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from ..cogvideox.utils.utils import (get_image_to_video_latent,
+                                     get_video_to_video_latent,
+                                     save_videos_grid)
+
+# Compatible with Alibaba EAS for quick launch
+eas_cache_dir       = '/stable-diffusion-cache/models'
+# The directory of the cogvideoxfun
+script_directory    = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+def tensor2pil(image):
+    return Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8))
+
+def numpy2pil(image):
+    return Image.fromarray(np.clip(255. * image, 0, 255).astype(np.uint8))
+
+def to_pil(image):
+    if isinstance(image, Image.Image):
+        return image
+    if isinstance(image, torch.Tensor):
+        return tensor2pil(image)
+    if isinstance(image, np.ndarray):
+        return numpy2pil(image)
+    raise ValueError(f"Cannot convert {type(image)} to PIL.Image")
+
+class LoadCogVideoX_Fun_Model:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {
+            "required": {
+                "model": (
+                    [ 
+                        'CogVideoX-Fun-2b-InP',
+                        'CogVideoX-Fun-5b-InP',
+                        'CogVideoX-Fun-V1.1-2b-InP',
+                        'CogVideoX-Fun-V1.1-5b-InP',
+                        'CogVideoX-Fun-V1.1-2b-Pose',
+                        'CogVideoX-Fun-V1.1-5b-Pose',
+                    ],
+                    {
+                        "default": 'CogVideoX-Fun-V1.1-2b-InP',
+                    }
+                ),
+                "model_type": (
+                    ["Inpaint", "Control"],
+                    {
+                        "default": "Inpaint",
+                    }
+                ),
+                "low_gpu_memory_mode":(
+                    [False, True],
+                    {
+                        "default": False,
+                    }
+                ),
+                "precision": (
+                    ['fp16', 'bf16'],
+                    {
+                        "default": 'fp16'
+                    }
+                ),
+                
+            },
+        }
+
+    RETURN_TYPES = ("CogVideoXFUNSMODEL",)
+    RETURN_NAMES = ("cogvideoxfun_model",)
+    FUNCTION = "loadmodel"
+    CATEGORY = "CogVideoXFUNWrapper"
+
+    def loadmodel(self, low_gpu_memory_mode, model, model_type, precision):
+        # Init weight_dtype and device
+        device          = mm.get_torch_device()
+        offload_device  = mm.unet_offload_device()
+        weight_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
+
+        # Init processbar
+        pbar = ProgressBar(3)
+
+        # Detect model is existing or not 
+        model_path = os.path.join(folder_paths.models_dir, "CogVideoX_Fun", model)
+      
+        if not os.path.exists(model_path):
+            if os.path.exists(eas_cache_dir):
+                model_path = os.path.join(eas_cache_dir, 'CogVideoX_Fun', model)
+            else:
+                model_path = os.path.join(folder_paths.models_dir, "CogVideoX-Fun", model)
+                if not os.path.exists(model_path):
+                    if os.path.exists(eas_cache_dir):
+                        model_path = os.path.join(eas_cache_dir, 'CogVideoX_Fun', model)
+                    else:
+                        # Detect model is existing or not 
+                        print(f"Please download cogvideoxfun model to: {model_path}")
+
+        vae = AutoencoderKLCogVideoX.from_pretrained(
+            model_path, 
+            subfolder="vae", 
+        ).to(weight_dtype)
+        # Update pbar
+        pbar.update(1)
+
+        # Load Sampler
+        print("Load Sampler.")
+        scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder='scheduler')
+        # Update pbar
+        pbar.update(1)
+        
+        # Get Transformer
+        transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+            model_path, 
+            subfolder="transformer", 
+        ).to(weight_dtype)
+        # Update pbar
+        pbar.update(1) 
+
+        # Get pipeline
+        if model_type == "Inpaint":
+            if transformer.config.in_channels != vae.config.latent_channels:
+                pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+                    model_path,
+                    vae=vae, 
+                    transformer=transformer,
+                    scheduler=scheduler,
+                    torch_dtype=weight_dtype
+                )
+            else:
+                pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+                    model_path,
+                    vae=vae, 
+                    transformer=transformer,
+                    scheduler=scheduler,
+                    torch_dtype=weight_dtype
+                )
+        else:
+            pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
+                    model_path,
+                    vae=vae, 
+                    transformer=transformer,
+                    scheduler=scheduler,
+                    torch_dtype=weight_dtype
+                )
+        if low_gpu_memory_mode:
+            pipeline.enable_sequential_cpu_offload()
+        else:
+            pipeline.enable_model_cpu_offload()
+
+        cogvideoxfun_model = {
+            'pipeline': pipeline, 
+            'dtype': weight_dtype,
+            'model_path': model_path,
+            'model_type': model_type,
+            'loras': [],
+            'strength_model': [],
+        }
+        return (cogvideoxfun_model,)
+
+class LoadCogVideoX_Fun_Lora:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {
+            "required": {
+                "cogvideoxfun_model": ("CogVideoXFUNSMODEL",),
+                "lora_name": (folder_paths.get_filename_list("loras"), {"default": None,}),
+                "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}),
+            }
+        }
+    RETURN_TYPES = ("CogVideoXFUNSMODEL",)
+    RETURN_NAMES = ("cogvideoxfun_model",)
+    FUNCTION = "load_lora"
+    CATEGORY = "CogVideoXFUNWrapper"
+
+    def load_lora(self, cogvideoxfun_model, lora_name, strength_model):
+        if lora_name is not None:
+            return (
+                {
+                    'pipeline': cogvideoxfun_model["pipeline"], 
+                    'dtype': cogvideoxfun_model["dtype"],
+                    'model_path': cogvideoxfun_model["model_path"],
+                    'loras': cogvideoxfun_model.get("loras", []) + [folder_paths.get_full_path("loras", lora_name)],
+                    'strength_model': cogvideoxfun_model.get("strength_model", []) + [strength_model],
+                }, 
+            )
+        else:
+            return (cogvideoxfun_model,)
+
+class CogVideoX_FUN_TextBox:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {
+            "required": {
+                "prompt": ("STRING", {"multiline": True, "default": "",}),
+            }
+        }
+    
+    RETURN_TYPES = ("STRING_PROMPT",)
+    RETURN_NAMES =("prompt",)
+    FUNCTION = "process"
+    CATEGORY = "CogVideoXFUNWrapper"
+
+    def process(self, prompt):
+        return (prompt, )
+
+class CogVideoX_Fun_I2VSampler:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {
+            "required": {
+                "cogvideoxfun_model": (
+                    "CogVideoXFUNSMODEL", 
+                ),
+                "prompt": (
+                    "STRING_PROMPT",
+                ),
+                "negative_prompt": (
+                    "STRING_PROMPT",
+                ),
+                "video_length": (
+                    "INT", {"default": 49, "min": 5, "max": 49, "step": 4}
+                ),
+                "base_resolution": (
+                    [ 
+                        512,
+                        768,
+                        960,
+                        1024,
+                    ], {"default": 768}
+                ),
+                "seed": (
+                    "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}
+                ),
+                "steps": (
+                    "INT", {"default": 50, "min": 1, "max": 200, "step": 1}
+                ),
+                "cfg": (
+                    "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}
+                ),
+                "scheduler": (
+                    [ 
+                        "Euler",
+                        "Euler A",
+                        "DPM++",
+                        "PNDM",
+                        "DDIM",
+                    ],
+                    {
+                        "default": 'DDIM'
+                    }
+                )
+            },
+            "optional":{
+                "start_img": ("IMAGE",),
+                "end_img": ("IMAGE",),
+            },
+        }
+    
+    RETURN_TYPES = ("IMAGE",)
+    RETURN_NAMES =("images",)
+    FUNCTION = "process"
+    CATEGORY = "CogVideoXFUNWrapper"
+
+    def process(self, cogvideoxfun_model, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, scheduler, start_img=None, end_img=None):
+        device = mm.get_torch_device()
+        offload_device = mm.unet_offload_device()
+
+        mm.soft_empty_cache()
+        gc.collect()
+
+        start_img = [to_pil(_start_img) for _start_img in start_img] if start_img is not None else None
+        end_img = [to_pil(_end_img) for _end_img in end_img] if end_img is not None else None
+        # Count most suitable height and width
+        aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+        original_width, original_height = start_img[0].size if type(start_img) is list else Image.open(start_img).size
+        closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
+        height, width = [int(x / 16) * 16 for x in closest_size]
+        
+        # Get Pipeline
+        pipeline = cogvideoxfun_model['pipeline']
+        model_path = cogvideoxfun_model['model_path']
+
+        # Load Sampler
+        if scheduler == "DPM++":
+            noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "Euler":
+            noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "Euler A":
+            noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "PNDM":
+            noise_scheduler = PNDMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "DDIM":
+            noise_scheduler = DDIMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        pipeline.scheduler = noise_scheduler
+
+        generator= torch.Generator(device).manual_seed(seed)
+
+        with torch.no_grad():
+            video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+            input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width))
+
+            for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
+                pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
+
+            sample = pipeline(
+                prompt, 
+                num_frames = video_length,
+                negative_prompt = negative_prompt,
+                height      = height,
+                width       = width,
+                generator   = generator,
+                guidance_scale = cfg,
+                num_inference_steps = steps,
+
+                video        = input_video,
+                mask_video   = input_video_mask,
+                comfyui_progressbar = True,
+            ).videos
+            videos = rearrange(sample, "b c t h w -> (b t) h w c")
+
+            for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
+                pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
+        return (videos,)   
+
+
+class CogVideoX_Fun_T2VSampler:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {
+            "required": {
+                "cogvideoxfun_model": (
+                    "CogVideoXFUNSMODEL", 
+                ),
+                "prompt": (
+                    "STRING_PROMPT", 
+                ),
+                "negative_prompt": (
+                    "STRING_PROMPT", 
+                ),
+                "video_length": (
+                    "INT", {"default": 49, "min": 5, "max": 49, "step": 4}
+                ),
+                "width": (
+                    "INT", {"default": 1008, "min": 64, "max": 2048, "step": 16}
+                ),
+                "height": (
+                    "INT", {"default": 576, "min": 64, "max": 2048, "step": 16}
+                ),
+                "is_image":(
+                    [
+                        False,
+                        True
+                    ], 
+                    {
+                        "default": False,
+                    }
+                ),
+                "seed": (
+                    "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}
+                ),
+                "steps": (
+                    "INT", {"default": 50, "min": 1, "max": 200, "step": 1}
+                ),
+                "cfg": (
+                    "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}
+                ),
+                "scheduler": (
+                    [ 
+                        "Euler",
+                        "Euler A",
+                        "DPM++",
+                        "PNDM",
+                        "DDIM",
+                    ],
+                    {
+                        "default": 'DDIM'
+                    }
+                ),
+            },
+        }
+    
+    RETURN_TYPES = ("IMAGE",)
+    RETURN_NAMES =("images",)
+    FUNCTION = "process"
+    CATEGORY = "CogVideoXFUNWrapper"
+
+    def process(self, cogvideoxfun_model, prompt, negative_prompt, video_length, width, height, is_image, seed, steps, cfg, scheduler):
+        device = mm.get_torch_device()
+        offload_device = mm.unet_offload_device()
+
+        mm.soft_empty_cache()
+        gc.collect()
+
+        # Get Pipeline
+        pipeline = cogvideoxfun_model['pipeline']
+        model_path = cogvideoxfun_model['model_path']
+
+        # Load Sampler
+        if scheduler == "DPM++":
+            noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "Euler":
+            noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "Euler A":
+            noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "PNDM":
+            noise_scheduler = PNDMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "DDIM":
+            noise_scheduler = DDIMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        pipeline.scheduler = noise_scheduler
+
+        generator= torch.Generator(device).manual_seed(seed)
+        
+        video_length = 1 if is_image else video_length
+        with torch.no_grad():
+            video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+            input_video, input_video_mask, clip_image = get_image_to_video_latent(None, None, video_length=video_length, sample_size=(height, width))
+
+            for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
+                pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
+
+            sample = pipeline(
+                prompt, 
+                num_frames = video_length,
+                negative_prompt = negative_prompt,
+                height      = height,
+                width       = width,
+                generator   = generator,
+                guidance_scale = cfg,
+                num_inference_steps = steps,
+
+                video        = input_video,
+                mask_video   = input_video_mask,
+                comfyui_progressbar = True,
+            ).videos
+            videos = rearrange(sample, "b c t h w -> (b t) h w c")
+
+            for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
+                pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
+        return (videos,)   
+
+class CogVideoX_Fun_V2VSampler:
+    @classmethod
+    def INPUT_TYPES(s):
+        return {
+            "required": {
+                "cogvideoxfun_model": (
+                    "CogVideoXFUNSMODEL", 
+                ),
+                "prompt": (
+                    "STRING_PROMPT",
+                ),
+                "negative_prompt": (
+                    "STRING_PROMPT",
+                ),
+                "video_length": (
+                    "INT", {"default": 49, "min": 5, "max": 49, "step": 4}
+                ),
+                "base_resolution": (
+                    [ 
+                        512,
+                        768,
+                        960,
+                        1024,
+                    ], {"default": 768}
+                ),
+                "seed": (
+                    "INT", {"default": 43, "min": 0, "max": 0xffffffffffffffff}
+                ),
+                "steps": (
+                    "INT", {"default": 50, "min": 1, "max": 200, "step": 1}
+                ),
+                "cfg": (
+                    "FLOAT", {"default": 6.0, "min": 1.0, "max": 20.0, "step": 0.01}
+                ),
+                "denoise_strength": (
+                    "FLOAT", {"default": 0.70, "min": 0.05, "max": 1.00, "step": 0.01}
+                ),
+                "scheduler": (
+                    [ 
+                        "Euler",
+                        "Euler A",
+                        "DPM++",
+                        "PNDM",
+                        "DDIM",
+                    ],
+                    {
+                        "default": 'DDIM'
+                    }
+                ),
+            },
+            "optional":{
+                "validation_video": ("IMAGE",),
+                "control_video": ("IMAGE",),
+            },
+        }
+    
+    RETURN_TYPES = ("IMAGE",)
+    RETURN_NAMES =("images",)
+    FUNCTION = "process"
+    CATEGORY = "CogVideoXFUNWrapper"
+
+    def process(self, cogvideoxfun_model, prompt, negative_prompt, video_length, base_resolution, seed, steps, cfg, denoise_strength, scheduler, validation_video=None, control_video=None):
+        device = mm.get_torch_device()
+        offload_device = mm.unet_offload_device()
+
+        mm.soft_empty_cache()
+        gc.collect()
+        
+        # Get Pipeline
+        pipeline = cogvideoxfun_model['pipeline']
+        model_path = cogvideoxfun_model['model_path']
+        model_type = cogvideoxfun_model['model_type']
+
+        # Count most suitable height and width
+        aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+        if model_type == "Inpaint":
+            if type(validation_video) is str:
+                original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
+            else:
+                validation_video = np.array(validation_video.cpu().numpy() * 255, np.uint8)
+                original_width, original_height = Image.fromarray(validation_video[0]).size
+        else:
+            if type(control_video) is str:
+                original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
+            else:
+                control_video = np.array(control_video.cpu().numpy() * 255, np.uint8)
+                original_width, original_height = Image.fromarray(control_video[0]).size
+        closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
+        height, width = [int(x / 16) * 16 for x in closest_size]
+
+        # Load Sampler
+        if scheduler == "DPM++":
+            noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "Euler":
+            noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "Euler A":
+            noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "PNDM":
+            noise_scheduler = PNDMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        elif scheduler == "DDIM":
+            noise_scheduler = DDIMScheduler.from_pretrained(model_path, subfolder= 'scheduler')
+        pipeline.scheduler = noise_scheduler
+
+        generator= torch.Generator(device).manual_seed(seed)
+        
+        with torch.no_grad():
+            video_length = int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+            if model_type == "Inpaint":
+                input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=(height, width), fps=8)
+            else:
+                input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=(height, width), fps=8)
+
+            for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
+                pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
+            
+            if model_type == "Inpaint":
+                sample = pipeline(
+                    prompt, 
+                    num_frames = video_length,
+                    negative_prompt = negative_prompt,
+                    height      = height,
+                    width       = width,
+                    generator   = generator,
+                    guidance_scale = cfg,
+                    num_inference_steps = steps,
+
+                    video        = input_video,
+                    mask_video   = input_video_mask,
+                    strength = float(denoise_strength),
+                    comfyui_progressbar = True,
+                ).videos
+            else:
+                sample = pipeline(
+                    prompt, 
+                    num_frames = video_length,
+                    negative_prompt = negative_prompt,
+                    height      = height,
+                    width       = width,
+                    generator   = generator,
+                    guidance_scale = cfg,
+                    num_inference_steps = steps,
+
+                    control_video = input_video,
+                    comfyui_progressbar = True,
+                ).videos
+            videos = rearrange(sample, "b c t h w -> (b t) h w c")
+
+            for _lora_path, _lora_weight in zip(cogvideoxfun_model.get("loras", []), cogvideoxfun_model.get("strength_model", [])):
+                pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
+        return (videos,)   
+
+NODE_CLASS_MAPPINGS = {
+    "CogVideoX_FUN_TextBox": CogVideoX_FUN_TextBox,
+    "LoadCogVideoX_Fun_Model": LoadCogVideoX_Fun_Model,
+    "LoadCogVideoX_Fun_Lora": LoadCogVideoX_Fun_Lora,
+    "CogVideoX_Fun_I2VSampler": CogVideoX_Fun_I2VSampler,
+    "CogVideoX_Fun_T2VSampler": CogVideoX_Fun_T2VSampler,
+    "CogVideoX_Fun_V2VSampler": CogVideoX_Fun_V2VSampler,
+}
+
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+    "CogVideoX_FUN_TextBox": "CogVideoX_FUN_TextBox",
+    "LoadCogVideoX_Fun_Model": "Load CogVideoX-Fun Model",
+    "LoadCogVideoX_Fun_Lora": "Load CogVideoX-Fun Lora",
+    "CogVideoX_Fun_I2VSampler": "CogVideoX-Fun Sampler for Image to Video",
+    "CogVideoX_Fun_T2VSampler": "CogVideoX-Fun Sampler for Text to Video",
+    "CogVideoX_Fun_V2VSampler": "CogVideoX-Fun Sampler for Video to Video",
+}
\ No newline at end of file
diff --git a/comfyui/v1.1/cogvideoxfunv1.1_workflow_i2v.json b/comfyui/v1.1/cogvideoxfunv1.1_workflow_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..b046768c2a1347a51a1c451b887dd29b9fb9403e
--- /dev/null
+++ b/comfyui/v1.1/cogvideoxfunv1.1_workflow_i2v.json
@@ -0,0 +1,451 @@
+{
+  "last_node_id": 83,
+  "last_link_id": 46,
+  "nodes": [
+    {
+      "id": 7,
+      "type": "LoadImage",
+      "pos": [
+        258.76883544921907,
+        468.15773315429715
+      ],
+      "size": [
+        378.07147216796875,
+        314.0000114440918
+      ],
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "IMAGE",
+          "type": "IMAGE",
+          "links": [
+            45
+          ],
+          "shape": 3,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "MASK",
+          "type": "MASK",
+          "links": null,
+          "shape": 3,
+          "label": "遮罩"
+        }
+      ],
+      "title": "Start Image(图片到视频的开始图片)",
+      "properties": {
+        "Node name for S&R": "LoadImage"
+      },
+      "widgets_values": [
+        "firework.png",
+        "image"
+      ]
+    },
+    {
+      "id": 79,
+      "type": "Note",
+      "pos": [
+        16,
+        460
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can upload image here\n(在此上传开始图像)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            43
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
+      ]
+    },
+    {
+      "id": 82,
+      "type": "CogVideoX_Fun_I2VSampler",
+      "pos": [
+        758,
+        93
+      ],
+      "size": {
+        "0": 336,
+        "1": 282
+      },
+      "flags": {},
+      "order": 7,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 42
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 43
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 44
+        },
+        {
+          "name": "start_img",
+          "type": "IMAGE",
+          "link": 45,
+          "slot_index": 3
+        },
+        {
+          "name": "end_img",
+          "type": "IMAGE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            46
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_I2VSampler"
+      },
+      "widgets_values": [
+        49,
+        512,
+        43,
+        "fixed",
+        50,
+        6,
+        "DDIM"
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1134,
+        93
+      ],
+      "size": [
+        390.9534912109375,
+        535.9734235491071
+      ],
+      "flags": {},
+      "order": 8,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 46,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "CogVideoX-Fun",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "CogVideoX-Fun_00003.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    },
+    {
+      "id": 83,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        300,
+        -294
+      ],
+      "size": {
+        "0": 315,
+        "1": 130
+      },
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            42
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-V1.1-2b-InP",
+        "Inpaint",
+        false,
+        "bf16"
+      ]
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            44
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    }
+  ],
+  "links": [
+    [
+      42,
+      83,
+      0,
+      82,
+      0,
+      "CogVideoXFUNSMODEL"
+    ],
+    [
+      43,
+      75,
+      0,
+      82,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      44,
+      73,
+      0,
+      82,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      45,
+      7,
+      0,
+      82,
+      3,
+      "IMAGE"
+    ],
+    [
+      46,
+      82,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    },
+    {
+      "title": "Upload Your Start Image",
+      "bounding": [
+        218,
+        382,
+        452,
+        418
+      ],
+      "color": "#a1309b",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.7513148009015778,
+      "offset": [
+        268.77277812624413,
+        436.3236112390962
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/comfyui/v1.1/cogvideoxfunv1.1_workflow_t2v.json b/comfyui/v1.1/cogvideoxfunv1.1_workflow_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..1ac9c23ff19691f1e06d88d7eef56502f76b94e6
--- /dev/null
+++ b/comfyui/v1.1/cogvideoxfunv1.1_workflow_t2v.json
@@ -0,0 +1,359 @@
+{
+  "last_node_id": 88,
+  "last_link_id": 52,
+  "nodes": [
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            50
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
+      ]
+    },
+    {
+      "id": 88,
+      "type": "CogVideoX_Fun_T2VSampler",
+      "pos": [
+        728,
+        -68
+      ],
+      "size": {
+        "0": 327.6000061035156,
+        "1": 290
+      },
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 49
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 50
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 51,
+          "slot_index": 2
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            52
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_T2VSampler"
+      },
+      "widgets_values": [
+        49,
+        672,
+        384,
+        false,
+        43,
+        "fixed",
+        50,
+        6,
+        "DDIM"
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1110,
+        -67
+      ],
+      "size": [
+        390.9534912109375,
+        535.9734235491071
+      ],
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 52,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "CogVideoX-Fun",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "CogVideoX-Fun_00004.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            51
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    },
+    {
+      "id": 87,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        302,
+        -285
+      ],
+      "size": {
+        "0": 315,
+        "1": 130
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            49
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-V1.1-2b-InP",
+        "Inpaint",
+        false,
+        "bf16"
+      ]
+    }
+  ],
+  "links": [
+    [
+      49,
+      87,
+      0,
+      88,
+      0,
+      "CogVideoXFUNSMODEL"
+    ],
+    [
+      50,
+      75,
+      0,
+      88,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      51,
+      73,
+      0,
+      88,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      52,
+      88,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.8264462809917354,
+      "offset": [
+        181.0702206286297,
+        544.9672051634072
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/comfyui/v1.1/cogvideoxfunv1.1_workflow_v2v.json b/comfyui/v1.1/cogvideoxfunv1.1_workflow_v2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..66e25a9efe44f1618671ed021c109039c25b8f37
--- /dev/null
+++ b/comfyui/v1.1/cogvideoxfunv1.1_workflow_v2v.json
@@ -0,0 +1,492 @@
+{
+  "last_node_id": 90,
+  "last_link_id": 57,
+  "nodes": [
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 79,
+      "type": "Note",
+      "pos": [
+        15.739953613281248,
+        462.38664912015946
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can upload video here\n(在此上传视频)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 85,
+      "type": "VHS_LoadVideo",
+      "pos": [
+        336,
+        470
+      ],
+      "size": [
+        235.1999969482422,
+        398.971426827567
+      ],
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "IMAGE",
+          "type": "IMAGE",
+          "links": [
+            56
+          ],
+          "shape": 3,
+          "slot_index": 0
+        },
+        {
+          "name": "frame_count",
+          "type": "INT",
+          "links": null,
+          "shape": 3
+        },
+        {
+          "name": "audio",
+          "type": "AUDIO",
+          "links": null,
+          "shape": 3
+        },
+        {
+          "name": "video_info",
+          "type": "VHS_VIDEOINFO",
+          "links": null,
+          "shape": 3
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_LoadVideo"
+      },
+      "widgets_values": {
+        "video": "00000125.mp4",
+        "force_rate": 0,
+        "force_size": "Disabled",
+        "custom_width": 512,
+        "custom_height": 512,
+        "frame_load_cap": 0,
+        "skip_first_frames": 0,
+        "select_every_nth": 1,
+        "choose video to upload": "image",
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "frame_load_cap": 0,
+            "skip_first_frames": 0,
+            "force_rate": 0,
+            "filename": "00000125.mp4",
+            "type": "input",
+            "format": "video/mp4",
+            "select_every_nth": 1
+          }
+        }
+      }
+    },
+    {
+      "id": 88,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        309,
+        -286
+      ],
+      "size": {
+        "0": 315,
+        "1": 130
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            53
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-V1.1-2b-InP",
+        "Inpaint",
+        false,
+        "bf16"
+      ]
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            54
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "A cute cat is playing the guitar."
+      ]
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            55
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    },
+    {
+      "id": 90,
+      "type": "CogVideoX_Fun_V2VSampler",
+      "pos": [
+        754,
+        14
+      ],
+      "size": {
+        "0": 317.4000244140625,
+        "1": 306
+      },
+      "flags": {},
+      "order": 7,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 53
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 54
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 55
+        },
+        {
+          "name": "validation_video",
+          "type": "IMAGE",
+          "link": 56,
+          "slot_index": 3
+        },
+        {
+          "name": "control_video",
+          "type": "IMAGE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            57
+          ],
+          "shape": 3
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_V2VSampler"
+      },
+      "widgets_values": [
+        49,
+        768,
+        43,
+        "randomize",
+        50,
+        6,
+        0.7,
+        "DDIM"
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1125,
+        15
+      ],
+      "size": [
+        390.9534912109375,
+        535.9734235491071
+      ],
+      "flags": {},
+      "order": 8,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 57,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "EasyAnimate",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "EasyAnimate_00045.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    }
+  ],
+  "links": [
+    [
+      53,
+      88,
+      0,
+      90,
+      0,
+      "CogVideoXFUNSMODEL"
+    ],
+    [
+      54,
+      75,
+      0,
+      90,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      55,
+      73,
+      0,
+      90,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      56,
+      85,
+      0,
+      90,
+      3,
+      "IMAGE"
+    ],
+    [
+      57,
+      90,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    },
+    {
+      "title": "Upload Your Video",
+      "bounding": [
+        218,
+        385,
+        456,
+        498
+      ],
+      "color": "#a1309b",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.683013455365071,
+      "offset": [
+        314.4077746994681,
+        444.69453403364594
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/comfyui/v1.1/cogvideoxfunv1.1_workflow_v2v_control.json b/comfyui/v1.1/cogvideoxfunv1.1_workflow_v2v_control.json
new file mode 100644
index 0000000000000000000000000000000000000000..dc7de7aadee1117cff3809d9ef770412736be575
--- /dev/null
+++ b/comfyui/v1.1/cogvideoxfunv1.1_workflow_v2v_control.json
@@ -0,0 +1,492 @@
+{
+  "last_node_id": 90,
+  "last_link_id": 59,
+  "nodes": [
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 79,
+      "type": "Note",
+      "pos": [
+        15.739953613281248,
+        462.38664912015946
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can upload video here\n(在此上传视频)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            55
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1125,
+        15
+      ],
+      "size": [
+        390.9534912109375,
+        973.1686096191406
+      ],
+      "flags": {},
+      "order": 8,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 57,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "CogVideoX-Fun",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "CogVideoX-Fun_00007.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    },
+    {
+      "id": 88,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        309,
+        -286
+      ],
+      "size": {
+        "0": 315,
+        "1": 130
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            53
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-V1.1-2b-Pose",
+        "Control",
+        false,
+        "bf16"
+      ]
+    },
+    {
+      "id": 85,
+      "type": "VHS_LoadVideo",
+      "pos": [
+        336,
+        470
+      ],
+      "size": [
+        235.1999969482422,
+        658.5777723524305
+      ],
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "IMAGE",
+          "type": "IMAGE",
+          "links": [
+            59
+          ],
+          "shape": 3,
+          "slot_index": 0
+        },
+        {
+          "name": "frame_count",
+          "type": "INT",
+          "links": null,
+          "shape": 3
+        },
+        {
+          "name": "audio",
+          "type": "AUDIO",
+          "links": null,
+          "shape": 3
+        },
+        {
+          "name": "video_info",
+          "type": "VHS_VIDEOINFO",
+          "links": null,
+          "shape": 3
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_LoadVideo"
+      },
+      "widgets_values": {
+        "video": "pose.mp4",
+        "force_rate": 8,
+        "force_size": "Disabled",
+        "custom_width": 512,
+        "custom_height": 512,
+        "frame_load_cap": 0,
+        "skip_first_frames": 0,
+        "select_every_nth": 1,
+        "choose video to upload": "image",
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "frame_load_cap": 0,
+            "skip_first_frames": 0,
+            "force_rate": 8,
+            "filename": "pose.mp4",
+            "type": "input",
+            "format": "video/mp4",
+            "select_every_nth": 1
+          }
+        }
+      }
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            54
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "A person wearing a knee-length white sleeveless dress and white high-heeled sandals performs a dance in a well-lit room with wooden flooring. The room's background features a closed door, a shelf displaying clear glass bottles of alcoholic beverages, and a partially visible dark-colored sofa. "
+      ]
+    },
+    {
+      "id": 90,
+      "type": "CogVideoX_Fun_V2VSampler",
+      "pos": [
+        754,
+        14
+      ],
+      "size": {
+        "0": 336,
+        "1": 306
+      },
+      "flags": {},
+      "order": 7,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 53
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 54
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 55
+        },
+        {
+          "name": "validation_video",
+          "type": "IMAGE",
+          "link": null,
+          "slot_index": 3
+        },
+        {
+          "name": "control_video",
+          "type": "IMAGE",
+          "link": 59
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            57
+          ],
+          "shape": 3
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_V2VSampler"
+      },
+      "widgets_values": [
+        49,
+        512,
+        43,
+        "fixed",
+        50,
+        6,
+        1,
+        "DDIM"
+      ]
+    }
+  ],
+  "links": [
+    [
+      53,
+      88,
+      0,
+      90,
+      0,
+      "CogVideoXFUNSMODEL"
+    ],
+    [
+      54,
+      75,
+      0,
+      90,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      55,
+      73,
+      0,
+      90,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      57,
+      90,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ],
+    [
+      59,
+      85,
+      0,
+      90,
+      4,
+      "IMAGE"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    },
+    {
+      "title": "Upload Your Video",
+      "bounding": [
+        218,
+        385,
+        457,
+        776
+      ],
+      "color": "#a1309b",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.6830134553650712,
+      "offset": [
+        250.2298948633902,
+        399.72391778748613
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/comfyui/v1/cogvideoxfunv1_workflow_i2v.json b/comfyui/v1/cogvideoxfunv1_workflow_i2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..4fc21f2a29cb7fd68138cf7b9a0bdc4f25dfd3d9
--- /dev/null
+++ b/comfyui/v1/cogvideoxfunv1_workflow_i2v.json
@@ -0,0 +1,450 @@
+{
+  "last_node_id": 83,
+  "last_link_id": 46,
+  "nodes": [
+    {
+      "id": 7,
+      "type": "LoadImage",
+      "pos": [
+        258.76883544921907,
+        468.15773315429715
+      ],
+      "size": [
+        378.07147216796875,
+        314.0000114440918
+      ],
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "IMAGE",
+          "type": "IMAGE",
+          "links": [
+            45
+          ],
+          "shape": 3,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "MASK",
+          "type": "MASK",
+          "links": null,
+          "shape": 3,
+          "label": "遮罩"
+        }
+      ],
+      "title": "Start Image(图片到视频的开始图片)",
+      "properties": {
+        "Node name for S&R": "LoadImage"
+      },
+      "widgets_values": [
+        "firework.png",
+        "image"
+      ]
+    },
+    {
+      "id": 79,
+      "type": "Note",
+      "pos": [
+        16,
+        460
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can upload image here\n(在此上传开始图像)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            43
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
+      ]
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            44
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    },
+    {
+      "id": 82,
+      "type": "CogVideoX_Fun_I2VSampler",
+      "pos": [
+        758,
+        93
+      ],
+      "size": {
+        "0": 336,
+        "1": 282
+      },
+      "flags": {},
+      "order": 7,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 42
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 43
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 44
+        },
+        {
+          "name": "start_img",
+          "type": "IMAGE",
+          "link": 45,
+          "slot_index": 3
+        },
+        {
+          "name": "end_img",
+          "type": "IMAGE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            46
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_I2VSampler"
+      },
+      "widgets_values": [
+        49,
+        512,
+        43,
+        "fixed",
+        50,
+        6,
+        "DDIM"
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1134,
+        93
+      ],
+      "size": [
+        390.9534912109375,
+        535.9734235491071
+      ],
+      "flags": {},
+      "order": 8,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 46,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "CogVideoX-Fun",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "CogVideoX-Fun_00003.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    },
+    {
+      "id": 83,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        300,
+        -294
+      ],
+      "size": {
+        "0": 315,
+        "1": 106
+      },
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            42
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-2b-InP",
+        false,
+        "bf16"
+      ]
+    }
+  ],
+  "links": [
+    [
+      42,
+      83,
+      0,
+      82,
+      0,
+      "CogVideoXFUNSMODEL"
+    ],
+    [
+      43,
+      75,
+      0,
+      82,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      44,
+      73,
+      0,
+      82,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      45,
+      7,
+      0,
+      82,
+      3,
+      "IMAGE"
+    ],
+    [
+      46,
+      82,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    },
+    {
+      "title": "Upload Your Start Image",
+      "bounding": [
+        218,
+        382,
+        452,
+        418
+      ],
+      "color": "#a1309b",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.7513148009015778,
+      "offset": [
+        265.8612156262443,
+        436.6199667078462
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/comfyui/v1/cogvideoxfunv1_workflow_t2v.json b/comfyui/v1/cogvideoxfunv1_workflow_t2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..60b75bf2f7fefcd7b855017971adf9175dc3e679
--- /dev/null
+++ b/comfyui/v1/cogvideoxfunv1_workflow_t2v.json
@@ -0,0 +1,358 @@
+{
+  "last_node_id": 88,
+  "last_link_id": 52,
+  "nodes": [
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            51
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            50
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
+      ]
+    },
+    {
+      "id": 88,
+      "type": "CogVideoX_Fun_T2VSampler",
+      "pos": [
+        728,
+        -68
+      ],
+      "size": {
+        "0": 327.6000061035156,
+        "1": 290
+      },
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 49
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 50
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 51,
+          "slot_index": 2
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            52
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_T2VSampler"
+      },
+      "widgets_values": [
+        49,
+        672,
+        384,
+        false,
+        43,
+        "fixed",
+        50,
+        6,
+        "DDIM"
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1110,
+        -67
+      ],
+      "size": [
+        390.9534912109375,
+        535.9734235491071
+      ],
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 52,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "CogVideoX-Fun",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "CogVideoX-Fun_00004.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    },
+    {
+      "id": 87,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        302,
+        -285
+      ],
+      "size": {
+        "0": 315,
+        "1": 106
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            49
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-2b-InP",
+        false,
+        "bf16"
+      ]
+    }
+  ],
+  "links": [
+    [
+      49,
+      87,
+      0,
+      88,
+      0,
+      "CogVideoXFUNSMODEL"
+    ],
+    [
+      50,
+      75,
+      0,
+      88,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      51,
+      73,
+      0,
+      88,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      52,
+      88,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.8264462809917354,
+      "offset": [
+        171.40912687862968,
+        545.5627520384072
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/comfyui/v1/cogvideoxfunv1_workflow_v2v.json b/comfyui/v1/cogvideoxfunv1_workflow_v2v.json
new file mode 100644
index 0000000000000000000000000000000000000000..c37af567f7262d0af6e37c4b1a3bdb8bac6457a1
--- /dev/null
+++ b/comfyui/v1/cogvideoxfunv1_workflow_v2v.json
@@ -0,0 +1,488 @@
+{
+  "last_node_id": 88,
+  "last_link_id": 52,
+  "nodes": [
+    {
+      "id": 80,
+      "type": "Note",
+      "pos": [
+        20,
+        -300
+      ],
+      "size": {
+        "0": 210,
+        "1": 66.98204040527344
+      },
+      "flags": {},
+      "order": 0,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "Load model here\n(在此选择要使用的模型)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 78,
+      "type": "Note",
+      "pos": [
+        18,
+        -46
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 1,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can write prompt here\n(你可以在此填写提示词)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 79,
+      "type": "Note",
+      "pos": [
+        15.739953613281248,
+        462.38664912015946
+      ],
+      "size": {
+        "0": 210,
+        "1": 58
+      },
+      "flags": {},
+      "order": 2,
+      "mode": 0,
+      "properties": {
+        "text": ""
+      },
+      "widgets_values": [
+        "You can upload video here\n(在此上传视频)"
+      ],
+      "color": "#432",
+      "bgcolor": "#653"
+    },
+    {
+      "id": 85,
+      "type": "VHS_LoadVideo",
+      "pos": [
+        336,
+        470
+      ],
+      "size": [
+        235.1999969482422,
+        398.971426827567
+      ],
+      "flags": {},
+      "order": 3,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "IMAGE",
+          "type": "IMAGE",
+          "links": [
+            49
+          ],
+          "shape": 3,
+          "slot_index": 0
+        },
+        {
+          "name": "frame_count",
+          "type": "INT",
+          "links": null,
+          "shape": 3
+        },
+        {
+          "name": "audio",
+          "type": "AUDIO",
+          "links": null,
+          "shape": 3
+        },
+        {
+          "name": "video_info",
+          "type": "VHS_VIDEOINFO",
+          "links": null,
+          "shape": 3
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_LoadVideo"
+      },
+      "widgets_values": {
+        "video": "00000125.mp4",
+        "force_rate": 0,
+        "force_size": "Disabled",
+        "custom_width": 512,
+        "custom_height": 512,
+        "frame_load_cap": 0,
+        "skip_first_frames": 0,
+        "select_every_nth": 1,
+        "choose video to upload": "image",
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "frame_load_cap": 0,
+            "skip_first_frames": 0,
+            "force_rate": 0,
+            "filename": "00000125.mp4",
+            "type": "input",
+            "format": "video/mp4",
+            "select_every_nth": 1
+          }
+        }
+      }
+    },
+    {
+      "id": 73,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        160
+      ],
+      "size": {
+        "0": 383.7149963378906,
+        "1": 183.83506774902344
+      },
+      "flags": {},
+      "order": 4,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            50
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Negtive Prompt(反向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "The video is not of a high quality, it has a low resolution. Watermark present in each frame. Strange motion trajectory. "
+      ]
+    },
+    {
+      "id": 75,
+      "type": "CogVideoX_FUN_TextBox",
+      "pos": [
+        250,
+        -50
+      ],
+      "size": {
+        "0": 383.54010009765625,
+        "1": 156.71620178222656
+      },
+      "flags": {},
+      "order": 5,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "links": [
+            51
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "title": "Positive Prompt(正向提示词)",
+      "properties": {
+        "Node name for S&R": "CogVideoX_FUN_TextBox"
+      },
+      "widgets_values": [
+        "A cute cat is playing the guitar."
+      ]
+    },
+    {
+      "id": 87,
+      "type": "CogVideoX_Fun_V2VSampler",
+      "pos": [
+        778,
+        93
+      ],
+      "size": {
+        "0": 336,
+        "1": 286
+      },
+      "flags": {},
+      "order": 7,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "link": 52
+        },
+        {
+          "name": "prompt",
+          "type": "STRING_PROMPT",
+          "link": 51,
+          "slot_index": 1
+        },
+        {
+          "name": "negative_prompt",
+          "type": "STRING_PROMPT",
+          "link": 50
+        },
+        {
+          "name": "validation_video",
+          "type": "IMAGE",
+          "link": 49,
+          "slot_index": 3
+        }
+      ],
+      "outputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "links": [
+            48
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "CogVideoX_Fun_V2VSampler"
+      },
+      "widgets_values": [
+        49,
+        512,
+        43,
+        "fixed",
+        50,
+        6,
+        0.7,
+        "DDIM"
+      ]
+    },
+    {
+      "id": 17,
+      "type": "VHS_VideoCombine",
+      "pos": [
+        1134,
+        93
+      ],
+      "size": [
+        390.9534912109375,
+        535.9734235491071
+      ],
+      "flags": {},
+      "order": 8,
+      "mode": 0,
+      "inputs": [
+        {
+          "name": "images",
+          "type": "IMAGE",
+          "link": 48,
+          "label": "图像",
+          "slot_index": 0
+        },
+        {
+          "name": "audio",
+          "type": "VHS_AUDIO",
+          "link": null,
+          "label": "音频"
+        },
+        {
+          "name": "meta_batch",
+          "type": "VHS_BatchManager",
+          "link": null,
+          "label": "批次管理"
+        },
+        {
+          "name": "vae",
+          "type": "VAE",
+          "link": null
+        }
+      ],
+      "outputs": [
+        {
+          "name": "Filenames",
+          "type": "VHS_FILENAMES",
+          "links": null,
+          "shape": 3,
+          "label": "文件名",
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "VHS_VideoCombine"
+      },
+      "widgets_values": {
+        "frame_rate": 8,
+        "loop_count": 0,
+        "filename_prefix": "EasyAnimate",
+        "format": "video/h264-mp4",
+        "pix_fmt": "yuv420p",
+        "crf": 22,
+        "save_metadata": true,
+        "pingpong": false,
+        "save_output": true,
+        "videopreview": {
+          "hidden": false,
+          "paused": false,
+          "params": {
+            "filename": "EasyAnimate_00045.mp4",
+            "subfolder": "",
+            "type": "output",
+            "format": "video/h264-mp4",
+            "frame_rate": 8
+          }
+        }
+      }
+    },
+    {
+      "id": 88,
+      "type": "LoadCogVideoX_Fun_Model",
+      "pos": [
+        309,
+        -286
+      ],
+      "size": {
+        "0": 315,
+        "1": 106
+      },
+      "flags": {},
+      "order": 6,
+      "mode": 0,
+      "outputs": [
+        {
+          "name": "cogvideoxfun_model",
+          "type": "CogVideoXFUNSMODEL",
+          "links": [
+            52
+          ],
+          "shape": 3,
+          "slot_index": 0
+        }
+      ],
+      "properties": {
+        "Node name for S&R": "LoadCogVideoX_Fun_Model"
+      },
+      "widgets_values": [
+        "CogVideoX-Fun-2b-InP",
+        false,
+        "bf16"
+      ]
+    }
+  ],
+  "links": [
+    [
+      48,
+      87,
+      0,
+      17,
+      0,
+      "IMAGE"
+    ],
+    [
+      49,
+      85,
+      0,
+      87,
+      3,
+      "IMAGE"
+    ],
+    [
+      50,
+      73,
+      0,
+      87,
+      2,
+      "STRING_PROMPT"
+    ],
+    [
+      51,
+      75,
+      0,
+      87,
+      1,
+      "STRING_PROMPT"
+    ],
+    [
+      52,
+      88,
+      0,
+      87,
+      0,
+      "CogVideoXFUNSMODEL"
+    ]
+  ],
+  "groups": [
+    {
+      "title": "Prompts",
+      "bounding": [
+        218,
+        -127,
+        450,
+        483
+      ],
+      "color": "#3f789e",
+      "font_size": 24
+    },
+    {
+      "title": "Load CogVideoX-Fun",
+      "bounding": [
+        220,
+        -380,
+        472,
+        232
+      ],
+      "color": "#b06634",
+      "font_size": 24
+    },
+    {
+      "title": "Upload Your Video",
+      "bounding": [
+        218,
+        385,
+        456,
+        498
+      ],
+      "color": "#a1309b",
+      "font_size": 24
+    }
+  ],
+  "config": {},
+  "extra": {
+    "ds": {
+      "scale": 0.683013455365071,
+      "offset": [
+        322.5575500900931,
+        444.05399028364593
+      ]
+    },
+    "workspace_info": {
+      "id": "776b62b4-bd17-4ed3-9923-b7aad000b1ea"
+    }
+  },
+  "version": 0.4
+}
\ No newline at end of file
diff --git a/config/zero_stage2_config.json b/config/zero_stage2_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e60ea05a563827d4286955c1421e82d3b1bbe5cc
--- /dev/null
+++ b/config/zero_stage2_config.json
@@ -0,0 +1,16 @@
+{
+    "bf16": {
+        "enabled": true
+    },
+    "train_micro_batch_size_per_gpu": 1,
+    "train_batch_size": "auto",
+    "gradient_accumulation_steps": "auto",
+    "dump_state": true,
+    "zero_optimization": {
+        "stage": 2,
+        "overlap_comm": true,
+        "contiguous_gradients": true,
+        "sub_group_size": 1e9,
+        "reduce_bucket_size": 5e8
+    }
+}
\ No newline at end of file
diff --git a/datasets/put datasets here.txt b/datasets/put datasets here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/install.py b/install.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e349b4db1ccf44876c55e910b802875e2016fc4
--- /dev/null
+++ b/install.py
@@ -0,0 +1,45 @@
+import sys
+import subprocess
+import locale
+import threading
+import os
+
+def handle_stream(stream, prefix):
+    stream.reconfigure(encoding=locale.getpreferredencoding(), errors='replace')
+    for msg in stream:
+        if prefix == '[!]' and ('it/s]' in msg or 's/it]' in msg) and ('%|' in msg or 'it [' in msg):
+            if msg.startswith('100%'):
+                print('\r' + msg, end="", file=sys.stderr),
+            else:
+                print('\r' + msg[:-1], end="", file=sys.stderr),
+        else:
+            if prefix == '[!]':
+                print(prefix, msg, end="", file=sys.stderr)
+            else:
+                print(prefix, msg, end="")
+
+def process_wrap(cmd_str, cwd_path, handler=None):
+    process = subprocess.Popen(cmd_str, cwd=cwd_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
+
+    if handler is None:
+        handler = handle_stream
+
+    stdout_thread = threading.Thread(target=handler, args=(process.stdout, ""))
+    stderr_thread = threading.Thread(target=handler, args=(process.stderr, "[!]"))
+
+    stdout_thread.start()
+    stderr_thread.start()
+
+    stdout_thread.join()
+    stderr_thread.join()
+
+    return process.wait()
+
+assert process_wrap([sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd_path=os.path.dirname(os.path.realpath(__file__))) == 0, "ERROR: Failed to install requirements.txt. Please install them manually, and restart ComfyUI."
+
+nodep_packages = [
+    "kornia>=0.6.9",
+    "xformers>=0.0.20",
+]
+
+assert process_wrap([sys.executable, "-m", "pip", "install", "--no-deps", *nodep_packages], cwd_path=os.path.dirname(os.path.realpath(__file__))) == 0, "ERROR: Failed to install last set of packages. Please install them manually, and restart ComfyUI."
\ No newline at end of file
diff --git a/model.json b/model.json
new file mode 100644
index 0000000000000000000000000000000000000000..58e4344ccc4d5cb39f2b845bc14f3e669c7195b4
--- /dev/null
+++ b/model.json
@@ -0,0 +1,4 @@
+{
+    "input_image_check": "https://files.catbox.moe/z3fh7f.png",
+    "prompt": "cat eating a cake"
+}
diff --git a/models/put models here.txt b/models/put models here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/predict_i2v.py b/predict_i2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..fce07a5af05a01009fc8c28e81a2ab8aa39abf57
--- /dev/null
+++ b/predict_i2v.py
@@ -0,0 +1,238 @@
+
+
+import json
+import os
+
+import numpy as np
+import torch
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from transformers import T5EncoderModel, T5Tokenizer
+from omegaconf import OmegaConf
+from PIL import Image
+
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
+from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
+
+# Low gpu memory mode, this is used when the GPU memory is under 16GB
+low_gpu_memory_mode = False
+
+# Config and model path
+model_name          = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
+
+# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" "DDIM_Cog" and "DDIM_Origin"
+sampler_name        = "DDIM_Origin"
+
+# Load pretrained model if need
+transformer_path    = None 
+vae_path            = None
+lora_path           = None
+
+# Other params
+sample_size         = [384, 672]
+video_length        = 49
+fps                 = 8
+
+# If you want to generate ultra long videos, please set partial_video_length as the length of each sub video segment
+partial_video_length = None
+overlap_video_length = 4
+
+# Use torch.float16 if GPU does not support torch.bfloat16
+# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
+weight_dtype            = torch.bfloat16
+# If you want to generate from text, please set the validation_image_start = None and validation_image_end = None
+validation_image_start  = "asset/1.png"
+validation_image_end    = None
+
+# prompts
+prompt                  = "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
+negative_prompt         = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
+guidance_scale          = 6.0
+seed                    = 43
+num_inference_steps     = 50
+lora_weight             = 0.55
+save_path               = "samples/cogvideox-fun-videos_i2v"
+
+transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+    model_name, 
+    subfolder="transformer",
+).to(weight_dtype)
+
+if transformer_path is not None:
+    print(f"From checkpoint: {transformer_path}")
+    if transformer_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(transformer_path)
+    else:
+        state_dict = torch.load(transformer_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = transformer.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+# Get Vae
+vae = AutoencoderKLCogVideoX.from_pretrained(
+    model_name, 
+    subfolder="vae"
+).to(weight_dtype)
+
+if vae_path is not None:
+    print(f"From checkpoint: {vae_path}")
+    if vae_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(vae_path)
+    else:
+        state_dict = torch.load(vae_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = vae.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+text_encoder = T5EncoderModel.from_pretrained(
+    model_name, subfolder="text_encoder", torch_dtype=weight_dtype
+)
+# Get Scheduler
+Choosen_Scheduler = scheduler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "Euler A": EulerAncestralDiscreteScheduler,
+    "DPM++": DPMSolverMultistepScheduler, 
+    "PNDM": PNDMScheduler,
+    "DDIM_Cog": CogVideoXDDIMScheduler,
+    "DDIM_Origin": DDIMScheduler,
+}[sampler_name]
+scheduler = Choosen_Scheduler.from_pretrained(
+    model_name, 
+    subfolder="scheduler"
+)
+
+if transformer.config.in_channels != vae.config.latent_channels:
+    pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+        model_name,
+        vae=vae,
+        text_encoder=text_encoder,
+        transformer=transformer,
+        scheduler=scheduler,
+        torch_dtype=weight_dtype
+    )
+else:
+    pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+        model_name,
+        vae=vae,
+        text_encoder=text_encoder,
+        transformer=transformer,
+        scheduler=scheduler,
+        torch_dtype=weight_dtype
+    )
+if low_gpu_memory_mode:
+    pipeline.enable_sequential_cpu_offload()
+else:
+    pipeline.enable_model_cpu_offload()
+
+generator = torch.Generator(device="cuda").manual_seed(seed)
+
+if lora_path is not None:
+    pipeline = merge_lora(pipeline, lora_path, lora_weight)
+
+if partial_video_length is not None:
+    init_frames = 0
+    last_frames = init_frames + partial_video_length
+    while init_frames < video_length:
+        if last_frames >= video_length:
+            if pipeline.vae.quant_conv.weight.ndim==5:
+                mini_batch_encoder = 4
+                _partial_video_length = video_length - init_frames
+                _partial_video_length = int((_partial_video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1
+            else:
+                _partial_video_length = video_length - init_frames
+            
+            if _partial_video_length <= 0:
+                break
+        else:
+            _partial_video_length = partial_video_length
+
+        input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image, None, video_length=_partial_video_length, sample_size=sample_size)
+        
+        with torch.no_grad():
+            sample = pipeline(
+                prompt, 
+                num_frames = _partial_video_length,
+                negative_prompt = negative_prompt,
+                height      = sample_size[0],
+                width       = sample_size[1],
+                generator   = generator,
+                guidance_scale = guidance_scale,
+                num_inference_steps = num_inference_steps,
+
+                video        = input_video,
+                mask_video   = input_video_mask
+            ).videos
+        
+        if init_frames != 0:
+            mix_ratio = torch.from_numpy(
+                np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
+            ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+            
+            new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
+                sample[:, :, :overlap_video_length] * mix_ratio
+            new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
+
+            sample = new_sample
+        else:
+            new_sample = sample
+
+        if last_frames >= video_length:
+            break
+
+        validation_image = [
+            Image.fromarray(
+                (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
+            ) for _index in range(-overlap_video_length, 0)
+        ]
+
+        init_frames = init_frames + _partial_video_length - overlap_video_length
+        last_frames = init_frames + _partial_video_length
+else:
+    video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+    input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, validation_image_end, video_length=video_length, sample_size=sample_size)
+
+    with torch.no_grad():
+        sample = pipeline(
+            prompt, 
+            num_frames = video_length,
+            negative_prompt = negative_prompt,
+            height      = sample_size[0],
+            width       = sample_size[1],
+            generator   = generator,
+            guidance_scale = guidance_scale,
+            num_inference_steps = num_inference_steps,
+
+            video        = input_video,
+            mask_video   = input_video_mask
+        ).videos
+
+if lora_path is not None:
+    pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
+
+if not os.path.exists(save_path):
+    os.makedirs(save_path, exist_ok=True)
+
+index = len([path for path in os.listdir(save_path)]) + 1
+prefix = str(index).zfill(8)
+
+if video_length == 1:
+    video_path = os.path.join(save_path, prefix + ".png")
+
+    image = sample[0, :, 0]
+    image = image.transpose(0, 1).transpose(1, 2)
+    image = (image * 255).numpy().astype(np.uint8)
+    image = Image.fromarray(image)
+    image.save(video_path)
+else:
+    video_path = os.path.join(save_path, prefix + ".mp4")
+    save_videos_grid(sample, video_path, fps=fps)
diff --git a/predict_t2v.py b/predict_t2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..f94f6c2d95a5b532553e4ad97bf436e83a7f2c4b
--- /dev/null
+++ b/predict_t2v.py
@@ -0,0 +1,182 @@
+
+
+import json
+import os
+
+import numpy as np
+import torch
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from transformers import T5EncoderModel, T5Tokenizer
+from omegaconf import OmegaConf
+from PIL import Image
+
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
+from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
+
+# Low gpu memory mode, this is used when the GPU memory is under 16GB
+low_gpu_memory_mode = False
+
+# model path
+model_name          = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
+
+# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
+sampler_name        = "DDIM_Origin"
+
+# Load pretrained model if need
+transformer_path    = None
+vae_path            = None
+lora_path           = None
+
+# Other params
+sample_size         = [384, 672]
+video_length        = 49
+fps                 = 8
+
+# Use torch.float16 if GPU does not support torch.bfloat16
+# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
+weight_dtype        = torch.bfloat16
+prompt              = "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
+negative_prompt     = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
+guidance_scale      = 6.0
+seed                = 43
+num_inference_steps = 50
+lora_weight         = 0.55
+save_path           = "samples/cogvideox-fun-videos-t2v"
+
+transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+    model_name, 
+    subfolder="transformer",
+).to(weight_dtype)
+
+if transformer_path is not None:
+    print(f"From checkpoint: {transformer_path}")
+    if transformer_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(transformer_path)
+    else:
+        state_dict = torch.load(transformer_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = transformer.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+# Get Vae
+vae = AutoencoderKLCogVideoX.from_pretrained(
+    model_name, 
+    subfolder="vae"
+).to(weight_dtype)
+
+if vae_path is not None:
+    print(f"From checkpoint: {vae_path}")
+    if vae_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(vae_path)
+    else:
+        state_dict = torch.load(vae_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = vae.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+text_encoder = T5EncoderModel.from_pretrained(
+    model_name, subfolder="text_encoder", torch_dtype=weight_dtype
+)
+# Get Scheduler
+Choosen_Scheduler = scheduler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "Euler A": EulerAncestralDiscreteScheduler,
+    "DPM++": DPMSolverMultistepScheduler, 
+    "PNDM": PNDMScheduler,
+    "DDIM_Cog": CogVideoXDDIMScheduler,
+    "DDIM_Origin": DDIMScheduler,
+}[sampler_name]
+scheduler = Choosen_Scheduler.from_pretrained(
+    model_name, 
+    subfolder="scheduler"
+)
+
+if transformer.config.in_channels != vae.config.latent_channels:
+    pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+        model_name,
+        vae=vae,
+        text_encoder=text_encoder,
+        transformer=transformer,
+        scheduler=scheduler,
+        torch_dtype=weight_dtype
+    )
+else:
+    pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+        model_name,
+        vae=vae,
+        text_encoder=text_encoder,
+        transformer=transformer,
+        scheduler=scheduler,
+        torch_dtype=weight_dtype
+    )
+if low_gpu_memory_mode:
+    pipeline.enable_sequential_cpu_offload()
+else:
+    pipeline.enable_model_cpu_offload()
+
+generator = torch.Generator(device="cuda").manual_seed(seed)
+
+if lora_path is not None:
+    pipeline = merge_lora(pipeline, lora_path, lora_weight)
+
+with torch.no_grad():
+    video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+    if transformer.config.in_channels != vae.config.latent_channels:
+        input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=sample_size)
+
+        sample = pipeline(
+            prompt, 
+            num_frames = video_length,
+            negative_prompt = negative_prompt,
+            height      = sample_size[0],
+            width       = sample_size[1],
+            generator   = generator,
+            guidance_scale = guidance_scale,
+            num_inference_steps = num_inference_steps,
+
+            video        = input_video,
+            mask_video   = input_video_mask,
+        ).videos
+    else:
+        sample = pipeline(
+            prompt, 
+            num_frames = video_length,
+            negative_prompt = negative_prompt,
+            height      = sample_size[0],
+            width       = sample_size[1],
+            generator   = generator,
+            guidance_scale = guidance_scale,
+            num_inference_steps = num_inference_steps,
+        ).videos
+
+if lora_path is not None:
+    pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
+
+if not os.path.exists(save_path):
+    os.makedirs(save_path, exist_ok=True)
+
+index = len([path for path in os.listdir(save_path)]) + 1
+prefix = str(index).zfill(8)
+
+if video_length == 1:
+    video_path = os.path.join(save_path, prefix + ".png")
+
+    image = sample[0, :, 0]
+    image = image.transpose(0, 1).transpose(1, 2)
+    image = (image * 255).numpy().astype(np.uint8)
+    image = Image.fromarray(image)
+    image.save(video_path)
+else:
+    video_path = os.path.join(save_path, prefix + ".mp4")
+    save_videos_grid(sample, video_path, fps=fps)
\ No newline at end of file
diff --git a/predict_v2v.py b/predict_v2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..62ab21a91aa67346f598b173b959e283715c67d7
--- /dev/null
+++ b/predict_v2v.py
@@ -0,0 +1,181 @@
+import json
+import os
+
+import cv2
+import numpy as np
+import torch
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from omegaconf import OmegaConf
+from PIL import Image
+from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
+                          T5EncoderModel, T5Tokenizer)
+
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
+    CogVideoX_Fun_Pipeline_Inpaint
+from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from cogvideox.utils.utils import get_video_to_video_latent, save_videos_grid
+
+# Low gpu memory mode, this is used when the GPU memory is under 16GB
+low_gpu_memory_mode = False
+
+# model path
+model_name          = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
+
+# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
+sampler_name        = "DDIM_Origin"
+
+# Load pretrained model if need
+transformer_path    = None
+vae_path            = None
+lora_path           = None
+# Other params
+sample_size         = [384, 672]
+video_length        = 49
+fps                 = 8
+
+# Use torch.float16 if GPU does not support torch.bfloat16
+# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
+weight_dtype            = torch.bfloat16
+# If you are preparing to redraw the reference video, set validation_video and validation_video_mask. 
+# If you do not use validation_video_mask, the entire video will be redrawn; 
+# if you use validation_video_mask, only a portion of the video will be redrawn.
+# Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
+validation_video        = "asset/1.mp4"
+validation_video_mask   = None 
+denoise_strength        = 0.70
+
+# prompts
+prompt                  = "A cute cat is playing the guitar. "
+negative_prompt         = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
+guidance_scale          = 6.0
+seed                    = 43
+num_inference_steps     = 50
+lora_weight             = 0.55
+save_path               = "samples/cogvideox-fun-videos_v2v"
+
+transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+    model_name, 
+    subfolder="transformer",
+).to(weight_dtype)
+
+if transformer_path is not None:
+    print(f"From checkpoint: {transformer_path}")
+    if transformer_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(transformer_path)
+    else:
+        state_dict = torch.load(transformer_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = transformer.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+# Get Vae
+vae = AutoencoderKLCogVideoX.from_pretrained(
+    model_name, 
+    subfolder="vae"
+).to(weight_dtype)
+
+if vae_path is not None:
+    print(f"From checkpoint: {vae_path}")
+    if vae_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(vae_path)
+    else:
+        state_dict = torch.load(vae_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = vae.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+text_encoder = T5EncoderModel.from_pretrained(
+    model_name, subfolder="text_encoder", torch_dtype=weight_dtype
+)
+# Get Scheduler
+Choosen_Scheduler = scheduler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "Euler A": EulerAncestralDiscreteScheduler,
+    "DPM++": DPMSolverMultistepScheduler, 
+    "PNDM": PNDMScheduler,
+    "DDIM_Cog": CogVideoXDDIMScheduler,
+    "DDIM_Origin": DDIMScheduler,
+}[sampler_name]
+scheduler = Choosen_Scheduler.from_pretrained(
+    model_name, 
+    subfolder="scheduler"
+)
+
+if transformer.config.in_channels != vae.config.latent_channels:
+    pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+        model_name,
+        vae=vae,
+        text_encoder=text_encoder,
+        transformer=transformer,
+        scheduler=scheduler,
+        torch_dtype=weight_dtype
+    )
+else:
+    pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+        model_name,
+        vae=vae,
+        text_encoder=text_encoder,
+        transformer=transformer,
+        scheduler=scheduler,
+        torch_dtype=weight_dtype
+    )
+
+if low_gpu_memory_mode:
+    pipeline.enable_sequential_cpu_offload()
+else:
+    pipeline.enable_model_cpu_offload()
+
+generator = torch.Generator(device="cuda").manual_seed(seed)
+
+if lora_path is not None:
+    pipeline = merge_lora(pipeline, lora_path, lora_weight, "cuda")
+
+video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=sample_size, validation_video_mask=validation_video_mask, fps=fps)
+
+with torch.no_grad():
+    sample = pipeline(
+        prompt, 
+        num_frames = video_length,
+        negative_prompt = negative_prompt,
+        height      = sample_size[0],
+        width       = sample_size[1],
+        generator   = generator,
+        guidance_scale = guidance_scale,
+        num_inference_steps = num_inference_steps,
+
+        video       = input_video,
+        mask_video  = input_video_mask,
+        strength    = denoise_strength,
+    ).videos
+
+if lora_path is not None:
+    pipeline = unmerge_lora(pipeline, lora_path, lora_weight, "cuda")
+    
+if not os.path.exists(save_path):
+    os.makedirs(save_path, exist_ok=True)
+
+index = len([path for path in os.listdir(save_path)]) + 1
+prefix = str(index).zfill(8)
+    
+if video_length == 1:
+    save_sample_path = os.path.join(save_path, prefix + f".png")
+
+    image = sample[0, :, 0]
+    image = image.transpose(0, 1).transpose(1, 2)
+    image = (image * 255).numpy().astype(np.uint8)
+    image = Image.fromarray(image)
+    image.save(save_sample_path)
+else:
+    video_path = os.path.join(save_path, prefix + ".mp4")
+    save_videos_grid(sample, video_path, fps=fps)
\ No newline at end of file
diff --git a/predict_v2v_control.py b/predict_v2v_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..2628e0c52f5ee37ebfb0e5b3d087bc83aa0ce8fd
--- /dev/null
+++ b/predict_v2v_control.py
@@ -0,0 +1,163 @@
+import json
+import os
+
+import cv2
+import numpy as np
+import torch
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from omegaconf import OmegaConf
+from PIL import Image
+from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
+                          T5EncoderModel, T5Tokenizer)
+
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_control import \
+    CogVideoX_Fun_Pipeline_Control
+from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from cogvideox.utils.utils import get_video_to_video_latent, save_videos_grid
+
+# Low gpu memory mode, this is used when the GPU memory is under 16GB
+low_gpu_memory_mode = False
+
+# model path
+model_name          = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-Pose"
+
+# Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
+sampler_name        = "DDIM_Origin"
+
+# Load pretrained model if need
+transformer_path    = None
+vae_path            = None
+lora_path           = None
+# Other params
+sample_size         = [672, 384]
+video_length        = 49
+fps                 = 8
+
+# Use torch.float16 if GPU does not support torch.bfloat16
+# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
+weight_dtype            = torch.bfloat16
+control_video           = "asset/pose.mp4"
+
+# prompts
+prompt                  = "A person wearing a knee-length white sleeveless dress and white high-heeled sandals performs a dance in a well-lit room with wooden flooring. The room's background features a closed door, a shelf displaying clear glass bottles of alcoholic beverages, and a partially visible dark-colored sofa. "
+negative_prompt         = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
+guidance_scale          = 6.0
+seed                    = 43
+num_inference_steps     = 50
+lora_weight             = 0.55
+save_path               = "samples/cogvideox-fun-videos_control"
+
+transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+    model_name, 
+    subfolder="transformer",
+).to(weight_dtype)
+
+if transformer_path is not None:
+    print(f"From checkpoint: {transformer_path}")
+    if transformer_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(transformer_path)
+    else:
+        state_dict = torch.load(transformer_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = transformer.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+# Get Vae
+vae = AutoencoderKLCogVideoX.from_pretrained(
+    model_name, 
+    subfolder="vae"
+).to(weight_dtype)
+
+if vae_path is not None:
+    print(f"From checkpoint: {vae_path}")
+    if vae_path.endswith("safetensors"):
+        from safetensors.torch import load_file, safe_open
+        state_dict = load_file(vae_path)
+    else:
+        state_dict = torch.load(vae_path, map_location="cpu")
+    state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+    m, u = vae.load_state_dict(state_dict, strict=False)
+    print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+text_encoder = T5EncoderModel.from_pretrained(
+    model_name, subfolder="text_encoder", torch_dtype=weight_dtype
+)
+# Get Scheduler
+Choosen_Scheduler = scheduler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "Euler A": EulerAncestralDiscreteScheduler,
+    "DPM++": DPMSolverMultistepScheduler, 
+    "PNDM": PNDMScheduler,
+    "DDIM_Cog": CogVideoXDDIMScheduler,
+    "DDIM_Origin": DDIMScheduler,
+}[sampler_name]
+scheduler = Choosen_Scheduler.from_pretrained(
+    model_name, 
+    subfolder="scheduler"
+)
+
+pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
+    model_name,
+    vae=vae,
+    text_encoder=text_encoder,
+    transformer=transformer,
+    scheduler=scheduler,
+    torch_dtype=weight_dtype
+)
+
+if low_gpu_memory_mode:
+    pipeline.enable_sequential_cpu_offload()
+else:
+    pipeline.enable_model_cpu_offload()
+
+generator = torch.Generator(device="cuda").manual_seed(seed)
+
+if lora_path is not None:
+    pipeline = merge_lora(pipeline, lora_path, lora_weight, "cuda")
+
+video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps)
+
+with torch.no_grad():
+    sample = pipeline(
+        prompt, 
+        num_frames = video_length,
+        negative_prompt = negative_prompt,
+        height      = sample_size[0],
+        width       = sample_size[1],
+        generator   = generator,
+        guidance_scale = guidance_scale,
+        num_inference_steps = num_inference_steps,
+
+        control_video = input_video,
+    ).videos
+
+if lora_path is not None:
+    pipeline = unmerge_lora(pipeline, lora_path, lora_weight, "cuda")
+    
+if not os.path.exists(save_path):
+    os.makedirs(save_path, exist_ok=True)
+
+index = len([path for path in os.listdir(save_path)]) + 1
+prefix = str(index).zfill(8)
+    
+if video_length == 1:
+    save_sample_path = os.path.join(save_path, prefix + f".png")
+
+    image = sample[0, :, 0]
+    image = image.transpose(0, 1).transpose(1, 2)
+    image = (image * 255).numpy().astype(np.uint8)
+    image = Image.fromarray(image)
+    image.save(save_sample_path)
+else:
+    video_path = os.path.join(save_path, prefix + ".mp4")
+    save_videos_grid(sample, video_path, fps=fps)
\ No newline at end of file
diff --git a/reports/report_v1.md b/reports/report_v1.md
new file mode 100644
index 0000000000000000000000000000000000000000..802186ea7f854de912ebf43d497cdac5eb72134e
--- /dev/null
+++ b/reports/report_v1.md
@@ -0,0 +1,36 @@
+# CogVideoX FUN v1 Report
+In CogVideoX-FUN, we trained on approximately 1.2 million data points based on CogVideoX, supporting image and video predictions. It accommodates pixel values for video generation across different resolutions of 512x512x49, 768x768x49, and 1024x1024x49, as well as videos with different aspect ratios. Moreover, we support the generation of videos from images and the reconstruction of videos from other videos.
+
+Compared to CogVideoX, CogVideoX FUN also highlights the following features:
+- Introduction of the InPaint model, enabling the generation of videos from images with specified starting and ending images.
+- Training the model based on token lengths. This allows for the implementation of various sizes and resolutions within the same model.
+
+## InPaint Model
+We used [CogVideoX](https://github.com/THUDM/CogVideo/) as the foundational structure, referencing [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) for the model training to generate videos from images. 
+
+During video generation, the **reference video** is encoded using VAE, with the **black area in the above image representing the part to be reconstructed, and the white area representing the start image**. This is stacked with noise latents and input into the Transformer for video generation. We perform 3D resizing on the **masked area**, directly resizing it to fit the canvas size of the video that needs reconstruction. 
+
+Then, we concatenate the latent, the encoded reference video, and the masked area, inputting them into DiT for noise prediction to obtain the final video. 
+The pipeline structure of CogVideoX FUN is as follows:
+<img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/pipeline.jpg" alt="ui" style="zoom:50%;" />
+
+## Token Length-Based Model Training
+We collected approximately 1.2 million high-quality data for the training of CogVideoX-Fun. During the training, we resized the videos based on different token lengths. The entire training process is divided into three phases, with each phase corresponding to 13312 (for 512x512x49 videos), 29952 (for 768x768x49 videos), and 53248 (for 1024x1024x49 videos).
+
+Taking CogVideoX-Fun-2B as an example:
+- In the 13312 phase, the batch size is 128 with 7k training steps.
+- In the 29952 phase, the batch size is 256 with 6.5k training steps.
+- In the 53248 phase, the batch size is 128 with 5k training steps.
+
+During training, we combined high and low resolutions, enabling the model to support video generation from any resolution between 512 and 1280. For example, with a token length of 13312:
+- At a resolution of 512x512, the number of video frames is 49.
+- At a resolution of 768x768, the number of video frames is 21.
+- At a resolution of 1024x1024, the number of video frames is 9.
+
+These resolutions and corresponding lengths were mixed for training, allowing the model to generate videos at different resolutions.
+
+## Resize 3D Embedding
+In adapting CogVideoX-2B to the CogVideoX-Fun framework, it was found that the source code obtains 3D embeddings in a truncated manner. This approach only accommodates a single resolution; when the resolution changes, the embedding should also change.
+<img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/PE_Interpolation.jpg" alt="ui" style="zoom:50%;" />
+
+Referencing Pixart-Sigma, the above image is from the Pixart-Sigma paper. We used Positional Embeddings Interpolation (PE Interpolation) to resize 3D embeddings. PE Interpolation is more conducive to convergence than directly generating cosine and sine embeddings for different resolutions.
diff --git a/reports/report_v1_1.md b/reports/report_v1_1.md
new file mode 100644
index 0000000000000000000000000000000000000000..b716d8c72db3cc2f64bbe66b3095c6d6b8d16a34
--- /dev/null
+++ b/reports/report_v1_1.md
@@ -0,0 +1,32 @@
+# CogVideoX FUN v1.1 Report
+
+In CogVideoX-FUN v1.1, we performed additional filtering on the previous dataset, selecting videos with larger motion amplitudes rather than still images in motion, resulting in approximately 0.48 million videos. The model continues to support both image and video prediction, accommodating pixel values from 512x512x49, 768x768x49, 1024x1024x49, and videos with different aspect ratios. We support both image-to-video generation and video-to-video reconstruction.
+
+Additionally, we have released training and prediction code for adding control signals, along with the initial version of the Control model.
+
+Compared to version 1.0, CogVideoX-FUN V1.1 highlights the following features:
+- In the 5b model, Noise has been added to the reference images, increasing the motion amplitude of the videos.
+- Released training and prediction code for adding control signals, along with the initial version of the Control model.
+
+## Adding Noise to Reference Images
+
+Building on the original CogVideoX-FUN V1.0, we drew upon [CogVideoX](https://github.com/THUDM/CogVideo/) and [SVD](https://github.com/Stability-AI/generative-models) to add Noise upwards to the non-zero reference images to disrupt the original images, aiming for greater motion amplitude.
+
+In our 5b model, Noise has been added, while the 2b model only performed fine-tuning with new data. This is because, after attempting to add Noise in the 2b model, the generated videos exhibited excessive motion amplitude, leading to deformation and damaging the output. The 5b model, due to its stronger generative capabilities, maintains relatively stable outputs during motion.
+
+Furthermore, the prompt words significantly influence the generation results, so please describe the actions in detail to increase dynamism. If unsure how to write positive prompts, you can use phrases like "smooth motion" or "in the wind" to enhance dynamism. Additionally, it is advisable to avoid using dynamic terms like "motion" in negative prompts.
+
+## Adding Control Signals to CogVideoX-FUN
+
+On the basis of the original CogVideoX-FUN V1.0, we replaced the original mask signal with Pose control signals. The control signals are encoded using VAE and used as Guidance, along with latent data entering the patch processing flow.
+
+We filtered the 0.48 million dataset, selecting around 20,000 videos and images containing portraits for pose extraction, which served as condition control signals for training. 
+
+During the training process, the videos are scaled according to different Token lengths. The entire training process is divided into two phases, with each phase comprising 13,312 (corresponding to 512x512x49 videos) and 53,248 (corresponding to 1024x1024x49 videos).
+
+Taking CogVideoX-Fun-V1.1-5b-Pose as an example:
+- In the 13312 phase, the batch size is 128, with 2.4k training steps.
+- In the 53248 phase, the batch size is 128, with 1.2k training steps.
+
+The working principle diagram is shown below:
+<img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pipeline_control.jpg" alt="ui" style="zoom:50%;" />
diff --git a/reports/report_v1_1_zh-CN.md b/reports/report_v1_1_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..81dfb7a902a71421b1baec0b3054a09a7a9b2b81
--- /dev/null
+++ b/reports/report_v1_1_zh-CN.md
@@ -0,0 +1,31 @@
+# CogVideoX FUN v1.1 Report
+
+在CogVideoX-FUN v1.1中,我们在之前的数据集中再次做了筛选,选出其中动作幅度较大,而不是静止画面移动的视频,数量大约为0.48m。模型依然支持图片与视频预测,支持像素值从512x512x49、768x768x49、1024x1024x49与不同纵横比的视频生成。我们支持图像到视频的生成与视频到视频的重建。
+
+另外,我们还发布了添加控制信号的训练代码与预测代码,并发布了初版的Control模型。
+
+对比V1.0版本,CogVideoX-FUN V1.1突出了以下功能:
+
+- 在5b模型中,给参考图片添加了Noise,增加了视频的运动幅度。
+- 发布了添加控制信号的训练代码与预测代码,并发布了初版的Control模型。
+
+## 参考图片添加Noise
+在原本CogVideoX-FUN V1.0的基础上,我们参考[CogVideoX](https://github.com/THUDM/CogVideo/)和[SVD](https://github.com/Stability-AI/generative-models),在非0的参考图向上添加Noise以破环原图,追求更大的运动幅度。
+
+我们5b模型中添加了Noise,2b模型仅使用了新数据进行了finetune,因为我们在2b模型中尝试添加Noise之后,生成的视频运动幅度过大导致结果变形,破坏了生成结果,而5b模型因为更为的强大生成能力,在运动中也保持了较为稳定的输出。
+
+另外,提示词对生成结果影响较大,请尽量描写动作以增加动态性。如果不知道怎么写正向提示词,可以使用smooth motion or in the wind来增加动态性。并且尽量避免在负向提示词中出现motion等表示动态的词汇。
+
+## 添加控制信号的CogVideoX-Fun
+在原本CogVideoX-FUN V1.0的基础上,我们使用Pose控制信号替代了原本的mask信号,将控制信号使用VAE编码后作为Guidance与latent一起进入patch流程,
+
+我们在0.48m数据中进行了筛选,选择出大约20000包含人像的视频与图片进行pose提取,作为condition控制信号进行训练。
+
+在进行训练时,我们根据不同Token长度,对视频进行缩放后进行训练。整个训练过程分为两个阶段,每个阶段的13312(对应512x512x49的视频),53248(对应1024x1024x49的视频)。
+
+以CogVideoX-Fun-V1.1-5b-Pose为例子,其中:
+- 13312阶段,Batch size为128,训练步数为2.4k
+- 53248阶段,Batch size为128,训练步数为1.2k。
+
+工作原理图如下:
+<img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pipeline_control.jpg" alt="ui" style="zoom:50%;" />
diff --git a/reports/report_v1_zh-CN.md b/reports/report_v1_zh-CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..2d42f123eac46efcd8648fbee9bd4542e96924dd
--- /dev/null
+++ b/reports/report_v1_zh-CN.md
@@ -0,0 +1,43 @@
+# CogVideoX FUN v1 Report
+
+在CogVideoX-FUN中,我们基于CogVideoX在大约1.2m的数据上进行了训练,支持图片与视频预测,支持像素值从512x512x49、768x768x49、1024x1024x49与不同纵横比的视频生成。另外,我们支持图像到视频的生成与视频到视频的重建。
+
+对比与CogVideoX,CogVideoX FUN还突出了以下功能:
+
+- 引入InPaint模型,实现图生视频功能,可以通过首尾图指定视频生成。
+- 基于Token长度的模型训练。达成不同大小多分辨率在同一模型中的实现。
+
+## InPaint模型
+我们以[CogVideoX](https://github.com/THUDM/CogVideo/)作为基础结构,参考[EasyAnimate](https://github.com/aigc-apps/EasyAnimate)进行图生视频的模型训练。
+
+在进行视频生成的时候,将**参考视频**使用VAE进行encode,**上图黑色的部分代表需要重建的部分,白色的部分代表首图**,与噪声Latents一起堆叠后输入到Transformer中进行视频生成。
+
+我们对**被Mask的区域**进行3D Resize,直接Resize到需要重建的视频的画布大小。
+
+然后将Latent、Encode后的参考视频、被Mask的区域,concat后输入到DiT中进行噪声预测。获得最终的视频。
+
+CogVideoX FUN的Pipeline结构如下:
+<img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/pipeline.jpg" alt="ui" style="zoom:50%;" />
+
+## 基于Token长度的模型训练
+我们收集了大约高质量的1.2m数据进行CogVideoX-Fun的训练。
+
+在进行训练时,我们根据不同Token长度,对视频进行缩放后进行训练。整个训练过程分为三个阶段,每个阶段的13312(对应512x512x49的视频),29952(对应768x768x49的视频),53248(对应1024x1024x49的视频)。
+
+以CogVideoX-Fun-2B为例子,其中:
+- 13312阶段,Batch size为128,训练步数为7k
+- 29952阶段,Batch size为256,训练步数为6.5k。
+- 53248阶段,Batch size为128,训练步数为5k。
+
+训练时我们采用高低分辨率结合训练,因此模型支持从512到1280任意分辨率的视频生成,以13312 token长度为例:
+- 在512x512分辨率下,视频帧数为49;
+- 在768x768分辨率下,视频帧数为21;
+- 在1024x1024分辨率下,视频帧数为9;
+这些分辨率与对应长度混合训练,模型可以完成不同大小分辨率的视频生成。
+
+## Resize 3D Embedding
+在适配CogVideoX-2B到CogVideoX-Fun框架的途中,发现源码是以截断的方式去得到3D Embedding的,这样的方式只能适配单一分辨率,当分辨率发生变化时,Embedding也应当发生变化。
+
+<img src="https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1/PE_Interpolation.jpg" alt="ui" style="zoom:50%;" />
+
+参考Pixart-Sigma,上图来自于Pixart-Sigma论文,我们采用Positional Embeddings Interpolation(PE Interpolation)对3D embedding进行Resize,PE Interpolation相比于直接生成不同分辨率的Cos Sin Embedding更易收敛。
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7cbfa50bc81c5b91bcdc95d857e085ae903421b0
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+Pillow
+einops
+safetensors
+timm
+tomesd
+torch>=2.1.2
+torchdiffeq
+torchsde
+xformers
+decord
+datasets
+numpy
+scikit-image
+opencv-python
+omegaconf
+SentencePiece
+albumentations
+imageio[ffmpeg]
+imageio[pyav]
+tensorboard
+beautifulsoup4
+ftfy
+func_timeout
+deepspeed
+accelerate>=0.25.0
+gradio>=3.41.2,<=3.48.0
+diffusers>=0.30.1
+transformers>=4.37.2
diff --git a/scripts/README_DEMO.md b/scripts/README_DEMO.md
new file mode 100644
index 0000000000000000000000000000000000000000..0d0032cc84e71b0813e3415545960c3ea7d1f486
--- /dev/null
+++ b/scripts/README_DEMO.md
@@ -0,0 +1,37 @@
+## Demo
+
+Image generation video corresponding images and prompts.
+
+If you don't know how to write positive prompts, you can use "smooth motion" or "in the wind" to add dynamism.
+
+| Image | Prompt |
+|--|--|
+| ![1.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/1.png) | closeup face photo of man is smiling in black clothes, night city street, bokeh, fireworks in background |
+| ![2.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/2.png) | sunset, orange sky, warm lighting, fishing boats, ocean waves, seagulls, rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, golden hour, coastal landscape, seaside scenery |
+| ![3.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/3.png) | a man in an astronaut suit playing a guitar |
+| ![4.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/4.png) | time-lapse of a blooming flower with leaves and a stem, blossom |
+| ![5.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/5.png) | fireworks display over night city |
+| ![6.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/6.png) | a beautiful woman with long hair and a dress blowing in the wind |
+| ![7.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/7.png) | the dog is shaking head |
+| ![8.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/8.png) | a robot is walking through a destroyed city |
+| ![9.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/9.png) | a group of penguins walking on a beach |
+| ![10.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/10.png) | a bonfire is lit in the middle of a field |
+| ![11.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/11.png) | a boat traveling on the ocean |
+| ![12.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/12.png) | pouring honey onto some slices of bread |
+| ![13.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/13.png) | a sailboat sailing in rough seas with a dramatic sunset  |
+| ![14.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/14.png) | a boat traveling on the ocean |
+| ![15.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/15.png) | a scenic view of a lake with several seagulls flying above the water. In the foreground, there is a person wearing a red garment, possibly a jacket or a shawl, observing the scenery. The lake has clear blue water, and there's a structure that appears to be a wooden pavilion or boathouse on stilts situated in the water. In the background, hills or mountains can be seen under a clear blue sky, enhancing the tranquil and picturesque setting |
+| ![16.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/16.png) | A man's body shimmered with golden light in the wind |
+| ![17.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/17.png) | a buried broken emerald cross glazed by the sun emitting smoke, backlit, forgotten, atmospheric AF, detailed, 8k |
+| ![18.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/18.png) | A beautiful woman is smiling in the wind |
+| ![19.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/19.png) | A beautiful woman is smiling in the wind |
+| ![20.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/20.png) | A beautiful woman is smiling in the wind |
+| ![21.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/21.png) | A beautiful woman smiles in the heavy snow |
+| ![22.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/22.png) | cats smiling taking a selfie with a super wide angle lenses, opening mouth. |
+| ![23.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/23.png) | The sturdy sailboat in 'Temperamental Tides', masterfully navigating the restless, pulsating waves of the deep navy sea, maintaining balance on the surging storm grey crests |
+| ![24.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/24.png) | The sturdy sailboat in 'Temperamental Tides', masterfully navigating the restless, pulsating waves of the deep navy sea, maintaining balance on the surging storm grey crests |
+| ![25.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/25.png) | a beach with waves crashing against it and a sunset in the background a brigantine, a sailboat in the distance, 4k |
+| ![26.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/26.png) | Create an illustration that captures the essence of water. The scene should be a tranquil beach at sunrise, with the calm ocean stretching out to the horizon. The sky is painted in soft hues of pink and orange as the sun begins to rise. Gentle waves lap against the sandy shore, creating delicate ripples. The water is crystal clear, reflecting the colors of the sky, and small, glistening seashells are scattered along the shoreline. In the distance, a small sailboat with white sails drifts peacefully on the water. The overall mood of the illustration should be serene and calming, emphasizing the fluid and reflective nature of water.glowneon, glowing, sparks, lightning |
+| ![27.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/27.png) | a Lighthouse battered by high winds, huge crashing waves, realistic northern lights, behind lighthouse, realistic stormy seas, high quality image, photographic, mist, and sea spray, storm clouds, angry sky, dusk, peninsula, winter, almost dark, storm, gales, elevated view point, high up perspective, night time, lighthouse light beams, position lighthouse to left of image, view from on high |
+| ![28.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/28.png) | Two racing cars racing towards the camera, desert dune in the background, hyperrealistic, driver turning the wheel, more details, speed of light, a trail of intense light follows the cars, image evokes the sensation of speed, frozen movement, insane intricate detail, (masterpiece, best quality), high resolution, (ultra detailed), |
+| ![29.png](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/i2v_images/29.png) | one long eared dog, beagle, making goofy faces under water, lie the wind is blowing in his open mouth bubbles wide. an annoyed goldfish swims by |
\ No newline at end of file
diff --git a/scripts/README_TRAIN.md b/scripts/README_TRAIN.md
new file mode 100644
index 0000000000000000000000000000000000000000..bff1821ce4447c63ef7e69c04cd2e10217db37c5
--- /dev/null
+++ b/scripts/README_TRAIN.md
@@ -0,0 +1,109 @@
+## Training Code
+
+The default training commands for the different versions are as follows:
+
+We can choose whether to use deep speed in CogVideoX-Fun, which can save a lot of video memory. 
+
+Some parameters in the sh file can be confusing, and they are explained in this document:
+
+- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images and videos at the center, but instead, it trains the entire images and videos after grouping them into buckets based on resolution.
+- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts.
+- `random_hw_adapt` is used to enable automatic height and width scaling for images and videos. When random_hw_adapt is enabled, the training images will have their height and width set to image_sample_size as the maximum and video_sample_size as the minimum. For training videos, the height and width will be set to video_sample_size as the maximum and min(video_sample_size, 512) as the minimum.
+- `training_with_video_token_length` specifies training the model according to token length. The token length for a video with dimensions 512x512 and 49 frames is 13,312.
+  - At 512x512 resolution, the number of video frames is 49;
+  - At 768x768 resolution, the number of video frames is 21;
+  - At 1024x1024 resolution, the number of video frames is 9;
+  - These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes.
+- `train_mode` is used to specify the training mode, which can be either normal or inpaint. Since CogVideoX-Fun uses the Inpaint model to achieve image-to-video generation, the default is set to inpaint mode. If you only wish to achieve text-to-video generation, you can remove this line, and it will default to the text-to-video mode.
+
+CogVideoX-Fun without deepspeed:
+```sh
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --mixed_precision="bf16" scripts/train.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=4 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=2e-05 \
+  --lr_scheduler="constant_with_warmup" \
+  --lr_warmup_steps=100 \
+  --seed=42 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_came \
+  --train_mode="inpaint" \
+  --resume_from_checkpoint="latest" \
+  --trainable_modules "."
+```
+
+CogVideoX-Fun with deepspeed:
+```sh
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/train.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=4 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=2e-05 \
+  --lr_scheduler="constant_with_warmup" \
+  --lr_warmup_steps=100 \
+  --seed=42 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_came \
+  --use_deepspeed \
+  --train_mode="inpaint" \
+  --resume_from_checkpoint="latest" \
+  --trainable_modules "."
+```
diff --git a/scripts/README_TRAIN_CONTROL.md b/scripts/README_TRAIN_CONTROL.md
new file mode 100644
index 0000000000000000000000000000000000000000..14a1119811589d06529db1b423dffccd5e2d170c
--- /dev/null
+++ b/scripts/README_TRAIN_CONTROL.md
@@ -0,0 +1,126 @@
+## Training Code
+
+The default training commands for the different versions are as follows:
+
+We can choose whether to use deep speed in CogVideoX-Fun, which can save a lot of video memory. 
+
+The metadata_control.json is a little different from normal json in CogVideoX-Fun, you need to add a control_file_path, and [DWPose](https://github.com/IDEA-Research/DWPose) is suggested as tool to generate control file.
+
+```json
+[
+    {
+      "file_path": "train/00000001.mp4",
+      "control_file_path": "control/00000001.mp4",
+      "text": "A group of young men in suits and sunglasses are walking down a city street.",
+      "type": "video"
+    },
+    {
+      "file_path": "train/00000002.jpg",
+      "control_file_path": "control/00000002.jpg",
+      "text": "A group of young men in suits and sunglasses are walking down a city street.",
+      "type": "image"
+    },
+    .....
+]
+```
+
+Some parameters in the sh file can be confusing, and they are explained in this document:
+
+- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images and videos at the center, but instead, it trains the entire images and videos after grouping them into buckets based on resolution.
+- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts.
+- `random_hw_adapt` is used to enable automatic height and width scaling for images and videos. When random_hw_adapt is enabled, the training images will have their height and width set to image_sample_size as the maximum and video_sample_size as the minimum. For training videos, the height and width will be set to video_sample_size as the maximum and min(video_sample_size, 512) as the minimum.
+- `training_with_video_token_length` specifies training the model according to token length. The token length for a video with dimensions 512x512 and 49 frames is 13,312.
+  - At 512x512 resolution, the number of video frames is 49;
+  - At 768x768 resolution, the number of video frames is 21;
+  - At 1024x1024 resolution, the number of video frames is 9;
+  - These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes.
+
+CogVideoX-Fun without deepspeed:
+```sh
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-Pose"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --mixed_precision="bf16" scripts/train_control.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=4 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=2e-05 \
+  --lr_scheduler="constant_with_warmup" \
+  --lr_warmup_steps=50 \
+  --seed=43 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_came \
+  --resume_from_checkpoint="latest" \
+  --trainable_modules "."
+```
+
+CogVideoX-Fun with deepspeed:
+```sh
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-Pose"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/train.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=4 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=2e-05 \
+  --lr_scheduler="constant_with_warmup" \
+  --lr_warmup_steps=50 \
+  --seed=43 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_came \
+  --use_deepspeed \
+  --resume_from_checkpoint="latest" \
+  --trainable_modules "."
+```
diff --git a/scripts/README_TRAIN_LORA.md b/scripts/README_TRAIN_LORA.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ffd1bf80847c7cbabea50fe329342c33cd1d658
--- /dev/null
+++ b/scripts/README_TRAIN_LORA.md
@@ -0,0 +1,100 @@
+## Lora Training Code
+
+We can choose whether to use deep speed in CogVideoX-Fun, which can save a lot of video memory. 
+
+Some parameters in the sh file can be confusing, and they are explained in this document:
+
+- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images and videos at the center, but instead, it trains the entire images and videos after grouping them into buckets based on resolution.
+- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts.
+- `random_hw_adapt` is used to enable automatic height and width scaling for images and videos. When random_hw_adapt is enabled, the training images will have their height and width set to image_sample_size as the maximum and video_sample_size as the minimum. For training videos, the height and width will be set to video_sample_size as the maximum and min(video_sample_size, 512) as the minimum.
+- `training_with_video_token_length` specifies training the model according to token length. The token length for a video with dimensions 512x512 and 49 frames is 13,312.
+  - At 512x512 resolution, the number of video frames is 49;
+  - At 768x768 resolution, the number of video frames is 21;
+  - At 1024x1024 resolution, the number of video frames is 9;
+  - These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes.
+- `train_mode` is used to specify the training mode, which can be either normal or inpaint. Since CogVideoX-Fun uses the Inpaint model to achieve image-to-video generation, the default is set to inpaint mode. If you only wish to achieve text-to-video generation, you can remove this line, and it will default to the text-to-video mode.
+
+CogVideoX-Fun without deepspeed:
+
+```sh
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --mixed_precision="bf16" scripts/train_lora.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=1 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=1e-04 \
+  --seed=42 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --low_vram \
+  --train_mode="inpaint" 
+```
+
+```sh
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/train_lora.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=1 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=1e-04 \
+  --seed=42 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_deepspeed \
+  --low_vram \
+  --train_mode="inpaint" 
+```
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9eb64660dcab5e56a83908fe4fabde32aed1415
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,1711 @@
+"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+"""
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import logging
+import math
+import os
+import pickle
+import shutil
+import sys
+
+import accelerate
+import diffusers
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+from einops import rearrange
+from huggingface_hub import create_repo, upload_folder
+from omegaconf import OmegaConf
+from packaging import version
+from PIL import Image
+from torch.utils.data import RandomSampler
+from torch.utils.tensorboard import SummaryWriter
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
+                          CLIPVisionModelWithProjection, MT5Tokenizer,
+                          T5EncoderModel, T5Tokenizer)
+from transformers.utils import ContextManagers
+
+import datasets
+
+current_file_path = os.path.abspath(__file__)
+project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path))]
+for project_root in project_roots:
+    sys.path.insert(0, project_root) if project_root not in sys.path else None
+
+from cogvideox.data.bucket_sampler import (ASPECT_RATIO_512,
+                                           ASPECT_RATIO_RANDOM_CROP_512,
+                                           ASPECT_RATIO_RANDOM_CROP_PROB,
+                                           AspectRatioBatchImageVideoSampler,
+                                           RandomSampler, get_closest_ratio)
+from cogvideox.data.dataset_image_video import (ImageVideoDataset,
+                                                ImageVideoSampler,
+                                                get_random_mask)
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
+    CogVideoX_Fun_Pipeline_Inpaint, add_noise_to_reference_video
+from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
+
+if is_wandb_available():
+    import wandb
+
+
+def get_random_downsample_ratio(sample_size, image_ratio=[],
+                                all_choices=False, rng=None):
+    def _create_special_list(length):
+        if length == 1:
+            return [1.0]
+        if length >= 2:
+            first_element = 0.75
+            remaining_sum = 1.0 - first_element
+            other_elements_value = remaining_sum / (length - 1)
+            special_list = [first_element] + [other_elements_value] * (length - 1)
+            return special_list
+            
+    if sample_size >= 1536:
+        number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio 
+    elif sample_size >= 1024:
+        number_list = [1, 1.25, 1.5, 2] + image_ratio
+    elif sample_size >= 768:
+        number_list = [1, 1.25, 1.5] + image_ratio
+    elif sample_size >= 512:
+        number_list = [1] + image_ratio
+    else:
+        number_list = [1]
+
+    if all_choices:
+        return number_list
+
+    number_list_prob = np.array(_create_special_list(len(number_list)))
+    if rng is None:
+        return np.random.choice(number_list, p = number_list_prob)
+    else:
+        return rng.choice(number_list, p = number_list_prob)
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+    latent_size = latent.size()
+    batch_size, channels, num_frames, height, width = mask.shape
+
+    if process_first_frame_only:
+        target_size = list(latent_size[2:])
+        target_size[0] = 1
+        first_frame_resized = F.interpolate(
+            mask[:, :, 0:1, :, :],
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+        
+        target_size = list(latent_size[2:])
+        target_size[0] = target_size[0] - 1
+        if target_size[0] != 0:
+            remaining_frames_resized = F.interpolate(
+                mask[:, :, 1:, :, :],
+                size=target_size,
+                mode='trilinear',
+                align_corners=False
+            )
+            resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+        else:
+            resized_mask = first_frame_resized
+    else:
+        target_size = list(latent_size[2:])
+        resized_mask = F.interpolate(
+            mask,
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+    return resized_mask
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+def log_validation(vae, text_encoder, tokenizer, transformer3d, args, accelerator, weight_dtype, global_step):
+    try:
+        logger.info("Running validation... ")
+            
+        transformer3d_val = CogVideoXTransformer3DModel.from_pretrained_2d(
+            args.pretrained_model_name_or_path, subfolder="transformer"
+        ).to(weight_dtype)
+        transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict())
+        scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+        if args.train_mode != "normal":
+            pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+                args.pretrained_model_name_or_path, 
+                vae=accelerator.unwrap_model(vae).to(weight_dtype), 
+                text_encoder=accelerator.unwrap_model(text_encoder),
+                tokenizer=tokenizer,
+                transformer=transformer3d_val,
+                scheduler=scheduler,
+                torch_dtype=weight_dtype
+            )
+        else:
+            pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+                args.pretrained_model_name_or_path, 
+                vae=accelerator.unwrap_model(vae).to(weight_dtype), 
+                text_encoder=accelerator.unwrap_model(text_encoder),
+                tokenizer=tokenizer,
+                transformer=transformer3d_val,
+                scheduler=scheduler,
+                torch_dtype=weight_dtype
+            )
+        pipeline = pipeline.to(accelerator.device)
+
+        if args.seed is None:
+            generator = None
+        else:
+            generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+        images = []
+        for i in range(len(args.validation_prompts)):
+            with torch.no_grad():
+                if args.train_mode != "normal":
+                    with torch.autocast("cuda", dtype=weight_dtype):
+                        video_length = int((args.video_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1
+                        input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size])
+                        sample = pipeline(
+                            args.validation_prompts[i],
+                            num_frames = video_length,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            guidance_scale = 6.0,
+                            generator   = generator,
+
+                            video        = input_video,
+                            mask_video   = input_video_mask,
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif"))
+
+                        video_length = 1
+                        input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size])
+                        sample = pipeline(
+                            args.validation_prompts[i],
+                            num_frames = video_length,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            guidance_scale = 6.0,
+                            generator   = generator, 
+
+                            video        = input_video,
+                            mask_video   = input_video_mask,
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif"))
+                else:
+                    with torch.autocast("cuda", dtype=weight_dtype):
+                        sample = pipeline(
+                            args.validation_prompts[i],
+                            num_frames = args.video_sample_n_frames,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            generator   = generator
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif"))
+
+                        sample = pipeline(
+                            args.validation_prompts[i], 
+                            num_frames = args.video_sample_n_frames,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            generator   = generator
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif"))
+
+        del pipeline
+        del transformer3d_val
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+
+        return images
+    except Exception as e:
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+        print(f"Eval error with info {e}")
+        return None
+
+def linear_decay(initial_value, final_value, total_steps, current_step):
+    if current_step >= total_steps:
+        return final_value
+    current_step = max(0, current_step)
+    step_size = (final_value - initial_value) / total_steps
+    current_value = initial_value + step_size * current_step
+    return current_value
+
+def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None):
+    u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator)
+    t = 1 / (1 + torch.exp(-u)) * (high - low) + low
+    return torch.clip(t.to(torch.int32), low, high - 1)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Simple example of a training script.")
+    parser.add_argument(
+        "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+    )
+    parser.add_argument(
+        "--pretrained_model_name_or_path",
+        type=str,
+        default=None,
+        required=True,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--revision",
+        type=str,
+        default=None,
+        required=False,
+        help="Revision of pretrained model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--variant",
+        type=str,
+        default=None,
+        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+    )
+    parser.add_argument(
+        "--train_data_dir",
+        type=str,
+        default=None,
+        help=(
+            "A folder containing the training data. "
+        ),
+    )
+    parser.add_argument(
+        "--train_data_meta",
+        type=str,
+        default=None,
+        help=(
+            "A csv containing the training data. "
+        ),
+    )
+    parser.add_argument(
+        "--max_train_samples",
+        type=int,
+        default=None,
+        help=(
+            "For debugging purposes or quicker training, truncate the number of training examples to this "
+            "value if set."
+        ),
+    )
+    parser.add_argument(
+        "--validation_prompts",
+        type=str,
+        default=None,
+        nargs="+",
+        help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="sd-model-finetuned",
+        help="The output directory where the model predictions and checkpoints will be written.",
+    )
+    parser.add_argument(
+        "--cache_dir",
+        type=str,
+        default=None,
+        help="The directory where the downloaded models and datasets will be stored.",
+    )
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument(
+        "--random_flip",
+        action="store_true",
+        help="whether to randomly flip images horizontally",
+    )
+    parser.add_argument(
+        "--use_came",
+        action="store_true",
+        help="whether to use came",
+    )
+    parser.add_argument(
+        "--multi_stream",
+        action="store_true",
+        help="whether to use cuda multi-stream",
+    )
+    parser.add_argument(
+        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+    )
+    parser.add_argument(
+        "--vae_mini_batch", type=int, default=32, help="mini batch size for vae."
+    )
+    parser.add_argument("--num_train_epochs", type=int, default=100)
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--gradient_checkpointing",
+        action="store_true",
+        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=1e-4,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument(
+        "--scale_lr",
+        action="store_true",
+        default=False,
+        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+    )
+    parser.add_argument(
+        "--lr_scheduler",
+        type=str,
+        default="constant",
+        help=(
+            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+            ' "constant", "constant_with_warmup"]'
+        ),
+    )
+    parser.add_argument(
+        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument(
+        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+    )
+    parser.add_argument(
+        "--allow_tf32",
+        action="store_true",
+        help=(
+            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+        ),
+    )
+    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+    parser.add_argument(
+        "--non_ema_revision",
+        type=str,
+        default=None,
+        required=False,
+        help=(
+            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+            " remote repository specified with --pretrained_model_name_or_path."
+        ),
+    )
+    parser.add_argument(
+        "--dataloader_num_workers",
+        type=int,
+        default=0,
+        help=(
+            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+        ),
+    )
+    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+    parser.add_argument(
+        "--prediction_type",
+        type=str,
+        default=None,
+        help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+    )
+    parser.add_argument(
+        "--hub_model_id",
+        type=str,
+        default=None,
+        help="The name of the repository to keep in sync with the local `output_dir`.",
+    )
+    parser.add_argument(
+        "--logging_dir",
+        type=str,
+        default="logs",
+        help=(
+            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+        ),
+    )
+    parser.add_argument(
+        "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)."
+    )
+    parser.add_argument(
+        "--mixed_precision",
+        type=str,
+        default=None,
+        choices=["no", "fp16", "bf16"],
+        help=(
+            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
+            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+        ),
+    )
+    parser.add_argument(
+        "--report_to",
+        type=str,
+        default="tensorboard",
+        help=(
+            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+        ),
+    )
+    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+    parser.add_argument(
+        "--checkpointing_steps",
+        type=int,
+        default=500,
+        help=(
+            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+            " training using `--resume_from_checkpoint`."
+        ),
+    )
+    parser.add_argument(
+        "--checkpoints_total_limit",
+        type=int,
+        default=None,
+        help=("Max number of checkpoints to store."),
+    )
+    parser.add_argument(
+        "--resume_from_checkpoint",
+        type=str,
+        default=None,
+        help=(
+            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+        ),
+    )
+    parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+    parser.add_argument(
+        "--validation_epochs",
+        type=int,
+        default=5,
+        help="Run validation every X epochs.",
+    )
+    parser.add_argument(
+        "--validation_steps",
+        type=int,
+        default=2000,
+        help="Run validation every X steps.",
+    )
+    parser.add_argument(
+        "--tracker_project_name",
+        type=str,
+        default="text2image-fine-tune",
+        help=(
+            "The `project_name` argument passed to Accelerator.init_trackers for"
+            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+        ),
+    )
+    
+    parser.add_argument(
+        "--snr_loss", action="store_true", help="Whether or not to use snr_loss."
+    )
+    parser.add_argument(
+        "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader."
+    )
+    parser.add_argument(
+        "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets."
+    )
+    parser.add_argument(
+        "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets."
+    )
+    parser.add_argument(
+        "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets."
+    )
+    parser.add_argument(
+        "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets."
+    )
+    parser.add_argument(
+        "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.",
+    )
+    parser.add_argument(
+        "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss."
+    )
+    parser.add_argument(
+        "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss."
+    )
+    parser.add_argument(
+        "--train_sampling_steps",
+        type=int,
+        default=1000,
+        help="Run train_sampling_steps.",
+    )
+    parser.add_argument(
+        "--keep_all_node_same_token_length",
+        action="store_true", 
+        help="Reference of the length token.",
+    )
+    parser.add_argument(
+        "--token_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the token.",
+    )
+    parser.add_argument(
+        "--video_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the video.",
+    )
+    parser.add_argument(
+        "--image_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the video.",
+    )
+    parser.add_argument(
+        "--video_sample_stride",
+        type=int,
+        default=4,
+        help="Sample stride of the video.",
+    )
+    parser.add_argument(
+        "--video_sample_n_frames",
+        type=int,
+        default=17,
+        help="Num frame of video.",
+    )
+    parser.add_argument(
+        "--video_repeat",
+        type=int,
+        default=0,
+        help="Num of repeat video.",
+    )
+    parser.add_argument(
+        "--transformer_path",
+        type=str,
+        default=None,
+        help=("If you want to load the weight from other transformers, input its path."),
+    )
+    parser.add_argument(
+        "--vae_path",
+        type=str,
+        default=None,
+        help=("If you want to load the weight from other vaes, input its path."),
+    )
+
+    parser.add_argument(
+        '--trainable_modules', 
+        nargs='+', 
+        help='Enter a list of trainable modules'
+    )
+    parser.add_argument(
+        '--trainable_modules_low_learning_rate', 
+        nargs='+', 
+        default=[],
+        help='Enter a list of trainable modules with lower learning rate'
+    )
+    parser.add_argument(
+        '--tokenizer_max_length', 
+        type=int,
+        default=226,
+        help='Max length of tokenizer'
+    )
+    parser.add_argument(
+        "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed."
+    )
+    parser.add_argument(
+        "--low_vram", action="store_true", help="Whether enable low_vram mode."
+    )
+    parser.add_argument(
+        "--train_mode",
+        type=str,
+        default="normal",
+        help=(
+            'The format of training data. Support `"normal"`'
+            ' (default), `"inpaint"`.'
+        ),
+    )
+    parser.add_argument(
+        "--abnormal_norm_clip_start",
+        type=int,
+        default=1000,
+        help=(
+            'When do we start doing additional processing on abnormal gradients. '
+        ),
+    )
+    parser.add_argument(
+        "--initial_grad_norm_ratio",
+        type=int,
+        default=5,
+        help=(
+            'The initial gradient is relative to the multiple of the max_grad_norm. '
+        ),
+    )
+
+    args = parser.parse_args()
+    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+    if env_local_rank != -1 and env_local_rank != args.local_rank:
+        args.local_rank = env_local_rank
+
+    # default to using the same revision for the non-ema model if not specified
+    if args.non_ema_revision is None:
+        args.non_ema_revision = args.revision
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.report_to == "wandb" and args.hub_token is not None:
+        raise ValueError(
+            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+            " Please use `huggingface-cli login` to authenticate with the Hub."
+        )
+
+    if args.non_ema_revision is not None:
+        deprecate(
+            "non_ema_revision!=None",
+            "0.15.0",
+            message=(
+                "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+                " use `--variant=non_ema` instead."
+            ),
+        )
+    logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+    accelerator = Accelerator(
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision,
+        log_with=args.report_to,
+        project_config=accelerator_project_config,
+    )
+    if accelerator.is_main_process:
+        writer = SummaryWriter(log_dir=logging_dir)
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state, main_process_only=False)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_warning()
+        diffusers.utils.logging.set_verbosity_info()
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+        diffusers.utils.logging.set_verbosity_error()
+
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        set_seed(args.seed)
+        rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index))
+        torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index)
+    else:
+        rng = None
+        torch_rng = None
+    index_rng = np.random.default_rng(np.random.PCG64(43))
+    print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}")
+
+    # Handle the repository creation
+    if accelerator.is_main_process:
+        if args.output_dir is not None:
+            os.makedirs(args.output_dir, exist_ok=True)
+
+    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision
+    # as these weights are only used for inference, keeping weights in full precision is not required.
+    weight_dtype = torch.float32
+    if accelerator.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+        args.mixed_precision = accelerator.mixed_precision
+    elif accelerator.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+        args.mixed_precision = accelerator.mixed_precision
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+    tokenizer = T5Tokenizer.from_pretrained(
+        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+    )
+
+    def deepspeed_zero_init_disabled_context_manager():
+        """
+        returns either a context list that includes one that will disable zero.Init or an empty context list
+        """
+        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+        if deepspeed_plugin is None:
+            return []
+
+        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+    # will try to assign the same optimizer with the same weights to all models during
+    # `deepspeed.initialize`, which of course doesn't work.
+    #
+    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+    # frozen models from being partitioned during `zero.Init` which gets called during
+    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+        text_encoder = T5EncoderModel.from_pretrained(
+            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant,
+            torch_dtype=weight_dtype
+        )
+
+        vae = AutoencoderKLCogVideoX.from_pretrained(
+            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+        )
+
+    transformer3d = CogVideoXTransformer3DModel.from_pretrained_2d(
+        args.pretrained_model_name_or_path, subfolder="transformer"
+    )
+
+    # Freeze vae and text_encoder and set transformer3d to trainable
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    transformer3d.requires_grad_(False)
+
+    if args.transformer_path is not None:
+        print(f"From checkpoint: {args.transformer_path}")
+        if args.transformer_path.endswith("safetensors"):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(args.transformer_path)
+        else:
+            state_dict = torch.load(args.transformer_path, map_location="cpu")
+        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+        m, u = transformer3d.load_state_dict(state_dict, strict=False)
+        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+
+    if args.vae_path is not None:
+        print(f"From checkpoint: {args.vae_path}")
+        if args.vae_path.endswith("safetensors"):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(args.vae_path)
+        else:
+            state_dict = torch.load(args.vae_path, map_location="cpu")
+        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+        m, u = vae.load_state_dict(state_dict, strict=False)
+        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+    
+    # A good trainable modules is showed below now.
+    # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position']
+    # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position']
+    transformer3d.train()
+    if accelerator.is_main_process:
+        accelerator.print(
+            f"Trainable modules '{args.trainable_modules}'."
+        )
+    for name, param in transformer3d.named_parameters():
+        for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate:
+            if trainable_module_name in name:
+                param.requires_grad = True
+                break
+
+    # Create EMA for the transformer3d.
+    if args.use_ema:
+        ema_transformer3d = CogVideoXTransformer3DModel.from_pretrained_2d(
+            args.pretrained_model_name_or_path, subfolder="transformer"
+        )
+        ema_transformer3d = EMAModel(ema_transformer3d.parameters(), model_cls=CogVideoXTransformer3DModel, model_config=ema_transformer3d.config)
+
+    # `accelerate` 0.16.0 will have better support for customized saving
+    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+        def save_model_hook(models, weights, output_dir):
+            if accelerator.is_main_process:
+                if args.use_ema:
+                    ema_transformer3d.save_pretrained(os.path.join(output_dir, "transformer_ema"))
+
+                models[0].save_pretrained(os.path.join(output_dir, "transformer"))
+                if not args.use_deepspeed:
+                    weights.pop()
+
+                with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file:
+                    pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file)
+
+        def load_model_hook(models, input_dir):
+            if args.use_ema:
+                ema_path = os.path.join(input_dir, "transformer_ema")
+                _, ema_kwargs = CogVideoXTransformer3DModel.load_config(ema_path, return_unused_kwargs=True)
+                load_model = CogVideoXTransformer3DModel.from_pretrained_2d(
+                    input_dir, subfolder="transformer_ema"
+                )
+                load_model = EMAModel(load_model.parameters(), model_cls=CogVideoXTransformer3DModel, model_config=load_model.config)
+                load_model.load_state_dict(ema_kwargs)
+
+                ema_transformer3d.load_state_dict(load_model.state_dict())
+                ema_transformer3d.to(accelerator.device)
+                del load_model
+
+            for i in range(len(models)):
+                # pop models so that they are not loaded again
+                model = models.pop()
+
+                # load diffusers style into model
+                load_model = CogVideoXTransformer3DModel.from_pretrained_2d(
+                    input_dir, subfolder="transformer"
+                )
+                model.register_to_config(**load_model.config)
+
+                model.load_state_dict(load_model.state_dict())
+                del load_model
+
+            pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl")
+            if os.path.exists(pkl_path):
+                with open(pkl_path, 'rb') as file:
+                    loaded_number, _ = pickle.load(file)
+                    batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0)
+                print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.")
+
+        accelerator.register_save_state_pre_hook(save_model_hook)
+        accelerator.register_load_state_pre_hook(load_model_hook)
+
+    if args.gradient_checkpointing:
+        transformer3d.enable_gradient_checkpointing()
+
+    # Enable TF32 for faster training on Ampere GPUs,
+    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+    if args.allow_tf32:
+        torch.backends.cuda.matmul.allow_tf32 = True
+
+    if args.scale_lr:
+        args.learning_rate = (
+            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+        )
+
+    # Initialize the optimizer
+    if args.use_8bit_adam:
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError(
+                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+            )
+
+        optimizer_cls = bnb.optim.AdamW8bit
+    elif args.use_came:
+        try:
+            from came_pytorch import CAME
+        except:
+            raise ImportError(
+                "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`"
+            )
+
+        optimizer_cls = CAME
+    else:
+        optimizer_cls = torch.optim.AdamW
+
+    trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters()))
+    trainable_params_optim = [
+        {'params': [], 'lr': args.learning_rate},
+        {'params': [], 'lr': args.learning_rate / 2},
+    ]
+    in_already = []
+    for name, param in transformer3d.named_parameters():
+        high_lr_flag = False
+        if name in in_already:
+            continue
+        for trainable_module_name in args.trainable_modules:
+            if trainable_module_name in name:
+                in_already.append(name)
+                high_lr_flag = True
+                trainable_params_optim[0]['params'].append(param)
+                if accelerator.is_main_process:
+                    print(f"Set {name} to lr : {args.learning_rate}")
+                break
+        if high_lr_flag:
+            continue
+        for trainable_module_name in args.trainable_modules_low_learning_rate:
+            if trainable_module_name in name:
+                in_already.append(name)
+                trainable_params_optim[1]['params'].append(param)
+                if accelerator.is_main_process:
+                    print(f"Set {name} to lr : {args.learning_rate / 2}")
+                break
+
+    if args.use_came:
+        optimizer = optimizer_cls(
+            trainable_params_optim,
+            lr=args.learning_rate,
+            # weight_decay=args.adam_weight_decay,
+            betas=(0.9, 0.999, 0.9999), 
+            eps=(1e-30, 1e-16)
+        )
+    else:
+        optimizer = optimizer_cls(
+            trainable_params_optim,
+            lr=args.learning_rate,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            eps=args.adam_epsilon,
+        )
+
+    # Get the training dataset
+    sample_n_frames_bucket_interval = 4
+
+    train_dataset = ImageVideoDataset(
+        args.train_data_meta, args.train_data_dir,
+        video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, 
+        video_repeat=args.video_repeat, 
+        image_sample_size=args.image_sample_size,
+        enable_bucket=args.enable_bucket, enable_inpaint=True if args.train_mode != "normal" else False,
+    )
+    
+    if args.enable_bucket:
+        aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+        aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+
+        batch_sampler_generator = torch.Generator().manual_seed(args.seed)
+        batch_sampler = AspectRatioBatchImageVideoSampler(
+            sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, 
+            batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True,
+            aspect_ratios=aspect_ratio_sample_size,
+        )
+        if args.keep_all_node_same_token_length:
+            if args.token_sample_size > 256:
+                numbers_list = list(range(256, args.token_sample_size + 1, 128))
+
+                if numbers_list[-1] != args.token_sample_size:
+                    numbers_list.append(args.token_sample_size)
+            else:
+                numbers_list = [256]
+            numbers_list = [_number * _number * args.video_sample_n_frames for _number in  numbers_list]
+        else:
+            numbers_list = None
+
+        def get_length_to_frame_num(token_length):
+            if args.image_sample_size > args.video_sample_size:
+                sample_sizes = list(range(256, args.image_sample_size + 1, 128))
+
+                if sample_sizes[-1] != args.image_sample_size:
+                    sample_sizes.append(args.image_sample_size)
+            else:
+                sample_sizes = [256]
+            
+            length_to_frame_num = {
+                sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes
+            }
+
+            return length_to_frame_num
+
+        def collate_fn(examples):
+            target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size
+            length_to_frame_num = get_length_to_frame_num(
+                target_token_length, 
+            )
+
+            # Create new output
+            new_examples                 = {}
+            new_examples["target_token_length"] = target_token_length
+            new_examples["pixel_values"] = []
+            new_examples["text"]         = []
+            if args.train_mode != "normal":
+                new_examples["mask_pixel_values"] = []
+                new_examples["mask"] = []
+
+            # Get ratio
+            pixel_value     = examples[0]["pixel_values"]
+            data_type       = examples[0]["data_type"]
+            f, h, w, c      = np.shape(pixel_value)
+            if data_type == 'image':
+                random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size, image_ratio=[args.image_sample_size / args.video_sample_size], rng=rng)
+
+                aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+                aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+                
+                batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+            else:
+                if args.random_hw_adapt:
+                    if args.training_with_video_token_length:
+                        local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples]))
+                        choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25]
+                        if len(choice_list) == 0:
+                            choice_list = list(length_to_frame_num.keys())
+                        if rng is None:
+                            local_video_sample_size = np.random.choice(choice_list)
+                        else:
+                            local_video_sample_size = rng.choice(choice_list)
+                        batch_video_length = length_to_frame_num[local_video_sample_size]
+                        random_downsample_ratio = args.video_sample_size / local_video_sample_size
+                    else:
+                        random_downsample_ratio = get_random_downsample_ratio(
+                                args.video_sample_size, rng=rng)
+                        batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+                else:
+                    random_downsample_ratio = 1
+                    batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+
+                aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+                aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+
+            closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size)
+            closest_size = [int(x / 16) * 16 for x in closest_size]
+            if args.random_ratio_crop:
+                if rng is None:
+                    random_sample_size = aspect_ratio_random_crop_sample_size[
+                        np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB)
+                    ]
+                else:
+                    random_sample_size = aspect_ratio_random_crop_sample_size[
+                        rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB)
+                    ]
+                random_sample_size = [int(x / 16) * 16 for x in random_sample_size]
+
+            for example in examples:
+                if args.random_ratio_crop:
+                    # To 0~1
+                    pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+
+                    # Get adapt hw for resize
+                    b, c, h, w = pixel_values.size()
+                    th, tw = random_sample_size
+                    if th / tw > h / w:
+                        nh = int(th)
+                        nw = int(w / h * nh)
+                    else:
+                        nw = int(tw)
+                        nh = int(h / w * nw)
+                    
+                    transform = transforms.Compose([
+                        transforms.Resize([nh, nw]),
+                        transforms.CenterCrop([int(x) for x in random_sample_size]),
+                        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+                    ])
+                else:
+                    closest_size = list(map(lambda x: int(x), closest_size))
+                    if closest_size[0] / h > closest_size[1] / w:
+                        resize_size = closest_size[0], int(w * closest_size[0] / h)
+                    else:
+                        resize_size = int(h * closest_size[1] / w), closest_size[1]
+                    
+                    pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+                    transform = transforms.Compose([
+                        transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR),  # Image.BICUBIC
+                        transforms.CenterCrop(closest_size),
+                        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+                    ])
+                new_examples["pixel_values"].append(transform(pixel_values))
+                new_examples["text"].append(example["text"])
+                batch_video_length = int(
+                    min(
+                        batch_video_length,
+                        (len(pixel_values) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1, 
+                    )
+                )
+
+                if batch_video_length == 0:
+                    batch_video_length = 1
+
+                if args.train_mode != "normal":
+                    mask = get_random_mask(new_examples["pixel_values"][-1].size())
+                    mask_pixel_values = new_examples["pixel_values"][-1] * (1 - mask) + torch.ones_like(new_examples["pixel_values"][-1]) * -1 * mask
+                    new_examples["mask_pixel_values"].append(mask_pixel_values)
+                    new_examples["mask"].append(mask)
+
+            new_examples["pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["pixel_values"]])
+            if args.train_mode != "normal":
+                new_examples["mask_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["mask_pixel_values"]])
+                new_examples["mask"] = torch.stack([example[:batch_video_length] for example in new_examples["mask"]])
+
+            if args.enable_text_encoder_in_dataloader:
+                prompt_ids = tokenizer(
+                    new_examples['text'], 
+                    max_length=args.tokenizer_max_length, 
+                    padding="max_length", 
+                    add_special_tokens=True, 
+                    truncation=True, 
+                    return_tensors="pt"
+                )
+                encoder_hidden_states = text_encoder(
+                    prompt_ids.input_ids,
+                    return_dict=False
+                )[0]
+                new_examples['encoder_attention_mask'] = prompt_ids.attention_mask
+                new_examples['encoder_hidden_states'] = encoder_hidden_states
+
+            return new_examples
+        
+        # DataLoaders creation:
+        train_dataloader = torch.utils.data.DataLoader(
+            train_dataset,
+            batch_sampler=batch_sampler,
+            collate_fn=collate_fn,
+            persistent_workers=True if args.dataloader_num_workers != 0 else False,
+            num_workers=args.dataloader_num_workers,
+        )
+    else:
+        # DataLoaders creation:
+        batch_sampler_generator = torch.Generator().manual_seed(args.seed)
+        batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size)
+        train_dataloader = torch.utils.data.DataLoader(
+            train_dataset,
+            batch_sampler=batch_sampler, 
+            persistent_workers=True if args.dataloader_num_workers != 0 else False,
+            num_workers=args.dataloader_num_workers,
+        )
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+
+    lr_scheduler = get_scheduler(
+        args.lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+        num_training_steps=args.max_train_steps * accelerator.num_processes,
+    )
+
+    # Prepare everything with our `accelerator`.
+    transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+        transformer3d, optimizer, train_dataloader, lr_scheduler
+    )
+
+    if args.use_ema:
+        ema_transformer3d.to(accelerator.device)
+
+    # Move text_encode and vae to gpu and cast to weight_dtype
+    vae.to(accelerator.device, dtype=weight_dtype)
+    if not args.enable_text_encoder_in_dataloader:
+        text_encoder.to(accelerator.device)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    # Afterwards we recalculate our number of training epochs
+    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    # We need to initialize the trackers we use, and also store our configuration.
+    # The trackers initializes automatically on the main process.
+    if accelerator.is_main_process:
+        tracker_config = dict(vars(args))
+        tracker_config.pop("validation_prompts")
+        tracker_config.pop("trainable_modules")
+        tracker_config.pop("trainable_modules_low_learning_rate")
+        accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+    # Function for unwrapping if model was compiled with `torch.compile`.
+    def unwrap_model(model):
+        model = accelerator.unwrap_model(model)
+        model = model._orig_mod if is_compiled_module(model) else model
+        return model
+
+    # Train!
+    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Potentially load in the weights and states from a previous save
+    if args.resume_from_checkpoint:
+        if args.resume_from_checkpoint != "latest":
+            path = os.path.basename(args.resume_from_checkpoint)
+        else:
+            # Get the most recent checkpoint
+            dirs = os.listdir(args.output_dir)
+            dirs = [d for d in dirs if d.startswith("checkpoint")]
+            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+            path = dirs[-1] if len(dirs) > 0 else None
+
+        if path is None:
+            accelerator.print(
+                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+            )
+            args.resume_from_checkpoint = None
+            initial_global_step = 0
+        else:
+            global_step = int(path.split("-")[1])
+
+            initial_global_step = global_step
+
+            pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl")
+            if os.path.exists(pkl_path):
+                with open(pkl_path, 'rb') as file:
+                    _, first_epoch = pickle.load(file)
+            else:
+                first_epoch = global_step // num_update_steps_per_epoch
+            print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.")
+
+            accelerator.print(f"Resuming from checkpoint {path}")
+            accelerator.load_state(os.path.join(args.output_dir, path))
+    else:
+        initial_global_step = 0
+
+    progress_bar = tqdm(
+        range(0, args.max_train_steps),
+        initial=initial_global_step,
+        desc="Steps",
+        # Only show the progress bar once on each machine.
+        disable=not accelerator.is_local_main_process,
+    )
+
+    if args.multi_stream and args.train_mode != "normal":
+        # create extra cuda streams to speedup inpaint vae computation
+        vae_stream_1 = torch.cuda.Stream()
+        vae_stream_2 = torch.cuda.Stream()
+    else:
+        vae_stream_1 = None
+        vae_stream_2 = None
+
+    for epoch in range(first_epoch, args.num_train_epochs):
+        train_loss = 0.0
+        batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch)
+        for step, batch in enumerate(train_dataloader):
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True)
+                for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                    pixel_value = pixel_value[None, ...]
+                    gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}'
+                    save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True)
+                if args.train_mode != "normal":
+                    mask_pixel_values, texts = batch['mask_pixel_values'].cpu(), batch['text']
+                    mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w")
+                    for idx, (pixel_value, text) in enumerate(zip(mask_pixel_values, texts)):
+                        pixel_value = pixel_value[None, ...]
+                        save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/mask_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.gif", rescale=True)
+
+            with accelerator.accumulate(transformer3d):
+                # Convert images to latent space
+                pixel_values = batch["pixel_values"].to(weight_dtype)
+                if args.training_with_video_token_length:
+                    if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                        pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1))
+                        if args.enable_text_encoder_in_dataloader:
+                            batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1))
+                            batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1))
+                        else:
+                            batch['text'] = batch['text'] * 4
+                    elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                        pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1))
+                        if args.enable_text_encoder_in_dataloader:
+                            batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1))
+                            batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1))
+                        else:
+                            batch['text'] = batch['text'] * 2
+                
+                if args.train_mode != "normal":
+                    mask_pixel_values = batch["mask_pixel_values"].to(weight_dtype)
+                    mask = batch["mask"].to(weight_dtype)
+                    if args.training_with_video_token_length:
+                        if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                            mask_pixel_values = torch.tile(mask_pixel_values, (4, 1, 1, 1, 1))
+                            mask = torch.tile(mask, (4, 1, 1, 1, 1))
+                        elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                            mask_pixel_values = torch.tile(mask_pixel_values, (2, 1, 1, 1, 1))
+                            mask = torch.tile(mask, (2, 1, 1, 1, 1))
+
+                def create_special_list(length):
+                    if length == 1:
+                        return [1.0]
+                    if length >= 2:
+                        last_element = 0.90
+                        remaining_sum = 1.0 - last_element
+                        other_elements_value = remaining_sum / (length - 1)
+                        special_list = [other_elements_value] * (length - 1) + [last_element]
+                        return special_list
+                    
+                if args.keep_all_node_same_token_length:
+                    actual_token_length = index_rng.choice(numbers_list)
+
+                    actual_video_length = (min(
+                            actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames
+                    ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1
+                    actual_video_length = int(max(actual_video_length, 1))
+                else:
+                    actual_video_length = None
+
+                if args.random_frame_crop:
+                    select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))]
+                    select_frames_prob = np.array(create_special_list(len(select_frames)))
+                    
+                    if rng is None:
+                        temp_n_frames = np.random.choice(select_frames, p = select_frames_prob)
+                    else:
+                        temp_n_frames = rng.choice(select_frames, p = select_frames_prob)
+                    if args.keep_all_node_same_token_length:
+                        temp_n_frames = min(actual_video_length, temp_n_frames)
+
+                    pixel_values = pixel_values[:, :temp_n_frames, :, :]
+
+                    if args.train_mode != "normal":
+                        mask_pixel_values = mask_pixel_values[:, :temp_n_frames, :, :]
+                        mask = mask[:, :temp_n_frames, :, :]
+
+                if args.train_mode != "normal":
+                    t2v_flag = [(_mask == 1).all() for _mask in mask]
+                    new_t2v_flag = []
+                    for _mask in t2v_flag:
+                        if _mask and np.random.rand() < 0.90:
+                            new_t2v_flag.append(0)
+                        else:
+                            new_t2v_flag.append(1)
+                    t2v_flag = torch.from_numpy(np.array(new_t2v_flag)).to(accelerator.device, dtype=weight_dtype)
+
+                if args.low_vram:
+                    torch.cuda.empty_cache()
+                    vae.to(accelerator.device)
+                    if not args.enable_text_encoder_in_dataloader:
+                        text_encoder.to(accelerator.device)
+
+                with torch.no_grad():
+                    # This way is quicker when batch grows up
+                    def _slice_vae(pixel_values):
+                        pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                        bs = args.vae_mini_batch
+                        new_pixel_values = []
+                        for i in range(0, pixel_values.shape[0], bs):
+                            pixel_values_bs = pixel_values[i : i + bs]
+                            pixel_values_bs = vae.encode(pixel_values_bs)[0]
+                            pixel_values_bs = pixel_values_bs.sample()
+                            new_pixel_values.append(pixel_values_bs)
+                            vae._clear_fake_context_parallel_cache()
+                        return torch.cat(new_pixel_values, dim = 0)
+                    if vae_stream_1 is not None:
+                        vae_stream_1.wait_stream(torch.cuda.current_stream())
+                        with torch.cuda.stream(vae_stream_1):
+                            latents = _slice_vae(pixel_values)
+                    else:
+                        latents = _slice_vae(pixel_values)
+                    latents = latents * vae.config.scaling_factor
+
+                    if args.train_mode != "normal":
+                        mask = rearrange(mask, "b f c h w -> b c f h w")
+                        mask = 1 - mask
+                        mask = resize_mask(mask, latents)
+
+                        if unwrap_model(transformer3d).config.add_noise_in_inpaint_model:
+                            mask_pixel_values = add_noise_to_reference_video(mask_pixel_values)
+                        mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w")
+                        bs = args.vae_mini_batch
+                        new_mask_pixel_values = []
+                        for i in range(0, mask_pixel_values.shape[0], bs):
+                            mask_pixel_values_bs = mask_pixel_values[i : i + bs]
+                            mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0]
+                            mask_pixel_values_bs = mask_pixel_values_bs.sample()
+                            new_mask_pixel_values.append(mask_pixel_values_bs)
+                            vae._clear_fake_context_parallel_cache()
+                        mask_latents = torch.cat(new_mask_pixel_values, dim = 0)
+
+                        if vae_stream_2 is not None:
+                            torch.cuda.current_stream().wait_stream(vae_stream_2) 
+
+                        inpaint_latents = torch.concat([mask, mask_latents], dim=1)
+                        inpaint_latents = t2v_flag[:, None, None, None, None] * inpaint_latents
+                        inpaint_latents = inpaint_latents * vae.config.scaling_factor
+                        inpaint_latents = rearrange(inpaint_latents, "b c f h w -> b f c h w")
+
+                    latents = rearrange(latents, "b c f h w -> b f c h w")
+                        
+                # wait for latents = vae.encode(pixel_values) to complete
+                if vae_stream_1 is not None:
+                    torch.cuda.current_stream().wait_stream(vae_stream_1)
+
+                if args.low_vram:
+                    vae.to('cpu')
+                    torch.cuda.empty_cache()
+                    if not args.enable_text_encoder_in_dataloader:
+                        text_encoder.to(accelerator.device)
+
+                if args.enable_text_encoder_in_dataloader:
+                    prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device)
+                else:
+                    with torch.no_grad():
+                        prompt_ids = tokenizer(
+                            batch['text'], 
+                            max_length=args.tokenizer_max_length, 
+                            padding="max_length", 
+                            add_special_tokens=True, 
+                            truncation=True, 
+                            return_tensors="pt"
+                        )
+                        prompt_embeds = text_encoder(
+                            prompt_ids.input_ids.to(latents.device),
+                            return_dict=False
+                        )[0]
+
+                if args.low_vram and not args.enable_text_encoder_in_dataloader:
+                    text_encoder.to('cpu')
+                    torch.cuda.empty_cache()
+
+                bsz = latents.shape[0]
+                noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype)
+                # Sample a random timestep for each image
+                # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng)
+                timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng)
+                timesteps = timesteps.long()
+
+                # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+                def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+                    tw = tgt_width
+                    th = tgt_height
+                    h, w = src
+                    r = h / w
+                    if r > (th / tw):
+                        resize_height = th
+                        resize_width = int(round(th / h * w))
+                    else:
+                        resize_width = tw
+                        resize_height = int(round(tw / w * h))
+
+                    crop_top = int(round((th - resize_height) / 2.0))
+                    crop_left = int(round((tw - resize_width) / 2.0))
+
+                    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+                def _prepare_rotary_positional_embeddings(
+                    height: int,
+                    width: int,
+                    num_frames: int,
+                    device: torch.device
+                ):
+                    vae_scale_factor_spatial = (
+                        2 ** (len(vae.config.block_out_channels) - 1) if vae is not None else 8
+                    )
+                    grid_height = height // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    grid_width = width // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    base_size_width = 720 // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    base_size_height = 480 // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+
+                    grid_crops_coords = get_resize_crop_region_for_grid(
+                        (grid_height, grid_width), base_size_width, base_size_height
+                    )
+                    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                        embed_dim=unwrap_model(transformer3d).config.attention_head_dim,
+                        crops_coords=grid_crops_coords,
+                        grid_size=(grid_height, grid_width),
+                        temporal_size=num_frames,
+                        use_real=True,
+                    )
+                    freqs_cos = freqs_cos.to(device=device)
+                    freqs_sin = freqs_sin.to(device=device)
+                    return freqs_cos, freqs_sin
+
+                height, width = batch["pixel_values"].size()[-2], batch["pixel_values"].size()[-1]
+                # 7. Create rotary embeds if required
+                image_rotary_emb = (
+                    _prepare_rotary_positional_embeddings(height, width, latents.size(1), latents.device)
+                    if unwrap_model(transformer3d).config.use_rotary_positional_embeddings
+                    else None
+                )
+                prompt_embeds = prompt_embeds.to(device=latents.device)
+
+                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+                if noise_scheduler.config.prediction_type == "epsilon":
+                    target = noise
+                elif noise_scheduler.config.prediction_type == "v_prediction":
+                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                else:
+                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+                # predict the noise residual
+                noise_pred = transformer3d(
+                    hidden_states=noisy_latents,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timesteps,
+                    image_rotary_emb=image_rotary_emb,
+                    return_dict=False,
+                    inpaint_latents=inpaint_latents if args.train_mode != "normal" else None,
+                )[0]
+                print(noise_pred.size(), noisy_latents.size(), latents.size(), pixel_values.size())
+                loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+                if args.motion_sub_loss and noise_pred.size()[1] > 2:
+                    gt_sub_noise = noise_pred[:, 1:, :].float() - noise_pred[:, :-1, :].float()
+                    pre_sub_noise = target[:, 1:, :].float() - target[:, :-1, :].float()
+                    sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean")
+                    loss = loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio
+
+                # Gather the losses across all processes for logging (if we use distributed training).
+                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+                train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+                # Backpropagate
+                accelerator.backward(loss)
+                if accelerator.sync_gradients:
+                    if not args.use_deepspeed:
+                        trainable_params_grads = [p.grad for p in trainable_params if p.grad is not None]
+                        trainable_params_total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in trainable_params_grads]), 2)
+                        max_grad_norm = linear_decay(args.max_grad_norm * args.initial_grad_norm_ratio, args.max_grad_norm, args.abnormal_norm_clip_start, global_step)
+                        if trainable_params_total_norm / max_grad_norm > 5 and global_step > args.abnormal_norm_clip_start:
+                            actual_max_grad_norm = max_grad_norm / min((trainable_params_total_norm / max_grad_norm), 10)
+                        else:
+                            actual_max_grad_norm = max_grad_norm
+                    else:
+                        actual_max_grad_norm = args.max_grad_norm
+
+                    if not args.use_deepspeed and args.report_model_info and accelerator.is_main_process:
+                        if trainable_params_total_norm > 1 and global_step > args.abnormal_norm_clip_start:
+                            for name, param in transformer3d.named_parameters():
+                                if param.requires_grad:
+                                    writer.add_scalar(f'gradients/before_clip_norm/{name}', param.grad.norm(), global_step=global_step)
+
+                    norm_sum = accelerator.clip_grad_norm_(trainable_params, actual_max_grad_norm)
+                    if not args.use_deepspeed and args.report_model_info and accelerator.is_main_process:
+                        writer.add_scalar(f'gradients/norm_sum', norm_sum, global_step=global_step)
+                        writer.add_scalar(f'gradients/actual_max_grad_norm', actual_max_grad_norm, global_step=global_step)
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad()
+
+            # Checks if the accelerator has performed an optimization step behind the scenes
+            if accelerator.sync_gradients:
+
+                if args.use_ema:
+                    ema_transformer3d.step(transformer3d.parameters())
+                progress_bar.update(1)
+                global_step += 1
+                accelerator.log({"train_loss": train_loss}, step=global_step)
+                train_loss = 0.0
+
+                if global_step % args.checkpointing_steps == 0:
+                    if args.use_deepspeed or accelerator.is_main_process:
+                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+                        if args.checkpoints_total_limit is not None:
+                            checkpoints = os.listdir(args.output_dir)
+                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+                            if len(checkpoints) >= args.checkpoints_total_limit:
+                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+                                removing_checkpoints = checkpoints[0:num_to_remove]
+
+                                logger.info(
+                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+                                )
+                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+                                for removing_checkpoint in removing_checkpoints:
+                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+                                    shutil.rmtree(removing_checkpoint)
+
+                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+                        accelerator.save_state(save_path)
+                        logger.info(f"Saved state to {save_path}")
+
+                if accelerator.is_main_process:
+                    if args.validation_prompts is not None and global_step % args.validation_steps == 0:
+                        if args.use_ema:
+                            # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+                            ema_transformer3d.store(transformer3d.parameters())
+                            ema_transformer3d.copy_to(transformer3d.parameters())
+                        log_validation(
+                            vae,
+                            text_encoder,
+                            tokenizer,
+                            transformer3d,
+                            args,
+                            accelerator,
+                            weight_dtype,
+                            global_step,
+                        )
+                        if args.use_ema:
+                            # Switch back to the original transformer3d parameters.
+                            ema_transformer3d.restore(transformer3d.parameters())
+
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+
+            if global_step >= args.max_train_steps:
+                break
+
+        if accelerator.is_main_process:
+            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+                if args.use_ema:
+                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+                    ema_transformer3d.store(transformer3d.parameters())
+                    ema_transformer3d.copy_to(transformer3d.parameters())
+                log_validation(
+                    vae,
+                    text_encoder,
+                    tokenizer,
+                    transformer3d,
+                    args,
+                    accelerator,
+                    weight_dtype,
+                    global_step,
+                )
+                if args.use_ema:
+                    # Switch back to the original transformer3d parameters.
+                    ema_transformer3d.restore(transformer3d.parameters())
+
+    # Create the pipeline using the trained modules and save it.
+    accelerator.wait_for_everyone()
+    if accelerator.is_main_process:
+        transformer3d = unwrap_model(transformer3d)
+        if args.use_ema:
+            ema_transformer3d.copy_to(transformer3d.parameters())
+
+        if args.use_deepspeed or accelerator.is_main_process:
+            save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+            accelerator.save_state(save_path)
+            logger.info(f"Saved state to {save_path}")
+
+    accelerator.end_training()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train.sh b/scripts/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ce97428256b29db05642d43f3467a26cc9e58d4a
--- /dev/null
+++ b/scripts/train.sh
@@ -0,0 +1,42 @@
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --mixed_precision="bf16" scripts/train.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1280 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=1 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=2e-05 \
+  --lr_scheduler="constant_with_warmup" \
+  --lr_warmup_steps=100 \
+  --seed=42 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_ema \
+  --train_mode="inpaint" \
+  --resume_from_checkpoint="latest" \
+  --trainable_modules "."
\ No newline at end of file
diff --git a/scripts/train_control.py b/scripts/train_control.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8982531af80dbbf16c83612f8a9a5ad2c7c90
--- /dev/null
+++ b/scripts/train_control.py
@@ -0,0 +1,1607 @@
+"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+"""
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import gc
+import logging
+import math
+import os
+import pickle
+import shutil
+import sys
+
+import accelerate
+import diffusers
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from diffusers import AutoencoderKL, DDPMScheduler
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+from einops import rearrange
+from huggingface_hub import create_repo, upload_folder
+from omegaconf import OmegaConf
+from packaging import version
+from PIL import Image
+from torch.utils.data import RandomSampler
+from torch.utils.tensorboard import SummaryWriter
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
+                          CLIPVisionModelWithProjection, MT5Tokenizer,
+                          T5EncoderModel, T5Tokenizer)
+from transformers.utils import ContextManagers
+
+import datasets
+
+current_file_path = os.path.abspath(__file__)
+project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path))]
+for project_root in project_roots:
+    sys.path.insert(0, project_root) if project_root not in sys.path else None
+
+from cogvideox.data.bucket_sampler import (ASPECT_RATIO_512,
+                                           ASPECT_RATIO_RANDOM_CROP_512,
+                                           ASPECT_RATIO_RANDOM_CROP_PROB,
+                                           AspectRatioBatchImageVideoSampler,
+                                           RandomSampler, get_closest_ratio)
+from cogvideox.data.dataset_image_video import (ImageVideoDataset, ImageVideoControlDataset,
+                                                ImageVideoSampler,
+                                                get_random_mask)
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.pipeline.pipeline_cogvideox_control import \
+    CogVideoX_Fun_Pipeline_Control
+from cogvideox.utils.utils import get_video_to_video_latent, save_videos_grid
+
+if is_wandb_available():
+    import wandb
+
+
+def get_random_downsample_ratio(sample_size, image_ratio=[],
+                                all_choices=False, rng=None):
+    def _create_special_list(length):
+        if length == 1:
+            return [1.0]
+        if length >= 2:
+            first_element = 0.75
+            remaining_sum = 1.0 - first_element
+            other_elements_value = remaining_sum / (length - 1)
+            special_list = [first_element] + [other_elements_value] * (length - 1)
+            return special_list
+            
+    if sample_size >= 1536:
+        number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio 
+    elif sample_size >= 1024:
+        number_list = [1, 1.25, 1.5, 2] + image_ratio
+    elif sample_size >= 768:
+        number_list = [1, 1.25, 1.5] + image_ratio
+    elif sample_size >= 512:
+        number_list = [1] + image_ratio
+    else:
+        number_list = [1]
+
+    if all_choices:
+        return number_list
+
+    number_list_prob = np.array(_create_special_list(len(number_list)))
+    if rng is None:
+        return np.random.choice(number_list, p = number_list_prob)
+    else:
+        return rng.choice(number_list, p = number_list_prob)
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+    latent_size = latent.size()
+    batch_size, channels, num_frames, height, width = mask.shape
+
+    if process_first_frame_only:
+        target_size = list(latent_size[2:])
+        target_size[0] = 1
+        first_frame_resized = F.interpolate(
+            mask[:, :, 0:1, :, :],
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+        
+        target_size = list(latent_size[2:])
+        target_size[0] = target_size[0] - 1
+        if target_size[0] != 0:
+            remaining_frames_resized = F.interpolate(
+                mask[:, :, 1:, :, :],
+                size=target_size,
+                mode='trilinear',
+                align_corners=False
+            )
+            resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+        else:
+            resized_mask = first_frame_resized
+    else:
+        target_size = list(latent_size[2:])
+        resized_mask = F.interpolate(
+            mask,
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+    return resized_mask
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+def log_validation(vae, text_encoder, tokenizer, transformer3d, args, accelerator, weight_dtype, global_step):
+    try:
+        logger.info("Running validation... ")
+            
+        transformer3d_val = CogVideoXTransformer3DModel.from_pretrained_2d(
+            args.pretrained_model_name_or_path, subfolder="transformer"
+        ).to(weight_dtype)
+        transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict())
+
+        pipeline = CogVideoX_Fun_Pipeline_Control.from_pretrained(
+            args.pretrained_model_name_or_path, 
+            vae=accelerator.unwrap_model(vae).to(weight_dtype), 
+            text_encoder=accelerator.unwrap_model(text_encoder),
+            tokenizer=tokenizer,
+            transformer=transformer3d_val,
+            torch_dtype=weight_dtype,
+        )
+        pipeline = pipeline.to(accelerator.device)
+
+        if args.seed is None:
+            generator = None
+        else:
+            generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+        images = []
+        for i in range(len(args.validation_prompts)):
+            with torch.no_grad():
+                with torch.autocast("cuda", dtype=weight_dtype):
+                    video_length = int(args.video_sample_n_frames // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1
+                    input_video, input_video_mask, clip_image = get_video_to_video_latent(args.validation_paths[i], video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size])
+                    sample = pipeline(
+                        args.validation_prompts[i], 
+                        num_frames = video_length,
+                        negative_prompt = "bad detailed",
+                        height      = args.video_sample_size,
+                        width       = args.video_sample_size,
+                        generator   = generator, 
+
+                        control_video = input_video,
+                    ).videos
+                    os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                    save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif"))
+
+        del pipeline
+        del transformer3d_val
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+
+        return images
+    except Exception as e:
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+        print(f"Eval error with info {e}")
+        return None
+
+def linear_decay(initial_value, final_value, total_steps, current_step):
+    if current_step >= total_steps:
+        return final_value
+    current_step = max(0, current_step)
+    step_size = (final_value - initial_value) / total_steps
+    current_value = initial_value + step_size * current_step
+    return current_value
+
+def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None):
+    u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator)
+    t = 1 / (1 + torch.exp(-u)) * (high - low) + low
+    return torch.clip(t.to(torch.int32), low, high - 1)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Simple example of a training script.")
+    parser.add_argument(
+        "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+    )
+    parser.add_argument(
+        "--pretrained_model_name_or_path",
+        type=str,
+        default=None,
+        required=True,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--revision",
+        type=str,
+        default=None,
+        required=False,
+        help="Revision of pretrained model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--variant",
+        type=str,
+        default=None,
+        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+    )
+    parser.add_argument(
+        "--train_data_dir",
+        type=str,
+        default=None,
+        help=(
+            "A folder containing the training data. "
+        ),
+    )
+    parser.add_argument(
+        "--train_data_meta",
+        type=str,
+        default=None,
+        help=(
+            "A csv containing the training data. "
+        ),
+    )
+    parser.add_argument(
+        "--max_train_samples",
+        type=int,
+        default=None,
+        help=(
+            "For debugging purposes or quicker training, truncate the number of training examples to this "
+            "value if set."
+        ),
+    )
+    parser.add_argument(
+        "--validation_paths",
+        type=str,
+        default=None,
+        nargs="+",
+        help=("A set of control videos evaluated every `--validation_epochs` and logged to `--report_to`."),
+    )
+    parser.add_argument(
+        "--validation_prompts",
+        type=str,
+        default=None,
+        nargs="+",
+        help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="sd-model-finetuned",
+        help="The output directory where the model predictions and checkpoints will be written.",
+    )
+    parser.add_argument(
+        "--cache_dir",
+        type=str,
+        default=None,
+        help="The directory where the downloaded models and datasets will be stored.",
+    )
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument(
+        "--random_flip",
+        action="store_true",
+        help="whether to randomly flip images horizontally",
+    )
+    parser.add_argument(
+        "--use_came",
+        action="store_true",
+        help="whether to use came",
+    )
+    parser.add_argument(
+        "--multi_stream",
+        action="store_true",
+        help="whether to use cuda multi-stream",
+    )
+    parser.add_argument(
+        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+    )
+    parser.add_argument(
+        "--vae_mini_batch", type=int, default=32, help="mini batch size for vae."
+    )
+    parser.add_argument("--num_train_epochs", type=int, default=100)
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--gradient_checkpointing",
+        action="store_true",
+        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=1e-4,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument(
+        "--scale_lr",
+        action="store_true",
+        default=False,
+        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+    )
+    parser.add_argument(
+        "--lr_scheduler",
+        type=str,
+        default="constant",
+        help=(
+            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+            ' "constant", "constant_with_warmup"]'
+        ),
+    )
+    parser.add_argument(
+        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument(
+        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+    )
+    parser.add_argument(
+        "--allow_tf32",
+        action="store_true",
+        help=(
+            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+        ),
+    )
+    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+    parser.add_argument(
+        "--non_ema_revision",
+        type=str,
+        default=None,
+        required=False,
+        help=(
+            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+            " remote repository specified with --pretrained_model_name_or_path."
+        ),
+    )
+    parser.add_argument(
+        "--dataloader_num_workers",
+        type=int,
+        default=0,
+        help=(
+            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+        ),
+    )
+    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+    parser.add_argument(
+        "--prediction_type",
+        type=str,
+        default=None,
+        help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+    )
+    parser.add_argument(
+        "--hub_model_id",
+        type=str,
+        default=None,
+        help="The name of the repository to keep in sync with the local `output_dir`.",
+    )
+    parser.add_argument(
+        "--logging_dir",
+        type=str,
+        default="logs",
+        help=(
+            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+        ),
+    )
+    parser.add_argument(
+        "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)."
+    )
+    parser.add_argument(
+        "--mixed_precision",
+        type=str,
+        default=None,
+        choices=["no", "fp16", "bf16"],
+        help=(
+            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
+            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+        ),
+    )
+    parser.add_argument(
+        "--report_to",
+        type=str,
+        default="tensorboard",
+        help=(
+            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+        ),
+    )
+    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+    parser.add_argument(
+        "--checkpointing_steps",
+        type=int,
+        default=500,
+        help=(
+            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+            " training using `--resume_from_checkpoint`."
+        ),
+    )
+    parser.add_argument(
+        "--checkpoints_total_limit",
+        type=int,
+        default=None,
+        help=("Max number of checkpoints to store."),
+    )
+    parser.add_argument(
+        "--resume_from_checkpoint",
+        type=str,
+        default=None,
+        help=(
+            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+        ),
+    )
+    parser.add_argument(
+        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+    )
+    parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+    parser.add_argument(
+        "--validation_epochs",
+        type=int,
+        default=5,
+        help="Run validation every X epochs.",
+    )
+    parser.add_argument(
+        "--validation_steps",
+        type=int,
+        default=2000,
+        help="Run validation every X steps.",
+    )
+    parser.add_argument(
+        "--tracker_project_name",
+        type=str,
+        default="text2image-fine-tune",
+        help=(
+            "The `project_name` argument passed to Accelerator.init_trackers for"
+            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+        ),
+    )
+    
+    parser.add_argument(
+        "--snr_loss", action="store_true", help="Whether or not to use snr_loss."
+    )
+    parser.add_argument(
+        "--not_sigma_loss", action="store_true", help="Whether or not to not use sigma_loss."
+    )
+    parser.add_argument(
+        "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader."
+    )
+    parser.add_argument(
+        "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets."
+    )
+    parser.add_argument(
+        "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets."
+    )
+    parser.add_argument(
+        "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets."
+    )
+    parser.add_argument(
+        "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets."
+    )
+    parser.add_argument(
+        "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.",
+    )
+    parser.add_argument(
+        "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss."
+    )
+    parser.add_argument(
+        "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss."
+    )
+    parser.add_argument(
+        "--train_sampling_steps",
+        type=int,
+        default=1000,
+        help="Run train_sampling_steps.",
+    )
+    parser.add_argument(
+        "--keep_all_node_same_token_length",
+        action="store_true", 
+        help="Reference of the length token.",
+    )
+    parser.add_argument(
+        "--token_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the token.",
+    )
+    parser.add_argument(
+        "--video_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the video.",
+    )
+    parser.add_argument(
+        "--image_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the video.",
+    )
+    parser.add_argument(
+        "--video_sample_stride",
+        type=int,
+        default=4,
+        help="Sample stride of the video.",
+    )
+    parser.add_argument(
+        "--video_sample_n_frames",
+        type=int,
+        default=17,
+        help="Num frame of video.",
+    )
+    parser.add_argument(
+        "--video_repeat",
+        type=int,
+        default=0,
+        help="Num of repeat video.",
+    )
+    parser.add_argument(
+        "--transformer_path",
+        type=str,
+        default=None,
+        help=("If you want to load the weight from other transformers, input its path."),
+    )
+    parser.add_argument(
+        "--vae_path",
+        type=str,
+        default=None,
+        help=("If you want to load the weight from other vaes, input its path."),
+    )
+
+    parser.add_argument(
+        '--trainable_modules', 
+        nargs='+', 
+        help='Enter a list of trainable modules'
+    )
+    parser.add_argument(
+        '--trainable_modules_low_learning_rate', 
+        nargs='+', 
+        default=[],
+        help='Enter a list of trainable modules with lower learning rate'
+    )
+    parser.add_argument(
+        '--tokenizer_max_length', 
+        type=int,
+        default=226,
+        help='Max length of tokenizer'
+    )
+    parser.add_argument(
+        "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed."
+    )
+    parser.add_argument(
+        "--low_vram", action="store_true", help="Whether enable low_vram mode."
+    )
+    parser.add_argument(
+        "--abnormal_norm_clip_start",
+        type=int,
+        default=1000,
+        help=(
+            'When do we start doing additional processing on abnormal gradients. '
+        ),
+    )
+    parser.add_argument(
+        "--initial_grad_norm_ratio",
+        type=int,
+        default=5,
+        help=(
+            'The initial gradient is relative to the multiple of the max_grad_norm. '
+        ),
+    )
+
+    args = parser.parse_args()
+    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+    if env_local_rank != -1 and env_local_rank != args.local_rank:
+        args.local_rank = env_local_rank
+
+    # default to using the same revision for the non-ema model if not specified
+    if args.non_ema_revision is None:
+        args.non_ema_revision = args.revision
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.report_to == "wandb" and args.hub_token is not None:
+        raise ValueError(
+            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+            " Please use `huggingface-cli login` to authenticate with the Hub."
+        )
+
+    if args.non_ema_revision is not None:
+        deprecate(
+            "non_ema_revision!=None",
+            "0.15.0",
+            message=(
+                "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+                " use `--variant=non_ema` instead."
+            ),
+        )
+    logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+    accelerator = Accelerator(
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision,
+        log_with=args.report_to,
+        project_config=accelerator_project_config,
+    )
+    if accelerator.is_main_process:
+        writer = SummaryWriter(log_dir=logging_dir)
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state, main_process_only=False)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_warning()
+        diffusers.utils.logging.set_verbosity_info()
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+        diffusers.utils.logging.set_verbosity_error()
+
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        set_seed(args.seed)
+        rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index))
+        torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index)
+    else:
+        rng = None
+        torch_rng = None
+    index_rng = np.random.default_rng(np.random.PCG64(43))
+    print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}")
+
+    # Handle the repository creation
+    if accelerator.is_main_process:
+        if args.output_dir is not None:
+            os.makedirs(args.output_dir, exist_ok=True)
+
+    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision
+    # as these weights are only used for inference, keeping weights in full precision is not required.
+    weight_dtype = torch.float32
+    if accelerator.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+        args.mixed_precision = accelerator.mixed_precision
+    elif accelerator.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+        args.mixed_precision = accelerator.mixed_precision
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+    tokenizer = T5Tokenizer.from_pretrained(
+        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+    )
+
+    def deepspeed_zero_init_disabled_context_manager():
+        """
+        returns either a context list that includes one that will disable zero.Init or an empty context list
+        """
+        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+        if deepspeed_plugin is None:
+            return []
+
+        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+    # will try to assign the same optimizer with the same weights to all models during
+    # `deepspeed.initialize`, which of course doesn't work.
+    #
+    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+    # frozen models from being partitioned during `zero.Init` which gets called during
+    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+        text_encoder = T5EncoderModel.from_pretrained(
+            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant,
+            torch_dtype=weight_dtype
+        )
+
+        vae = AutoencoderKLCogVideoX.from_pretrained(
+            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+        )
+
+    transformer3d = CogVideoXTransformer3DModel.from_pretrained_2d(
+        args.pretrained_model_name_or_path, subfolder="transformer"
+    )
+
+    # Freeze vae and text_encoder and set transformer3d to trainable
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    transformer3d.requires_grad_(False)
+
+    if args.transformer_path is not None:
+        print(f"From checkpoint: {args.transformer_path}")
+        if args.transformer_path.endswith("safetensors"):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(args.transformer_path)
+        else:
+            state_dict = torch.load(args.transformer_path, map_location="cpu")
+        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+        m, u = transformer3d.load_state_dict(state_dict, strict=False)
+        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+
+    if args.vae_path is not None:
+        print(f"From checkpoint: {args.vae_path}")
+        if args.vae_path.endswith("safetensors"):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(args.vae_path)
+        else:
+            state_dict = torch.load(args.vae_path, map_location="cpu")
+        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+        m, u = vae.load_state_dict(state_dict, strict=False)
+        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+    
+    # A good trainable modules is showed below now.
+    # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position']
+    # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position']
+    transformer3d.train()
+    if accelerator.is_main_process:
+        accelerator.print(
+            f"Trainable modules '{args.trainable_modules}'."
+        )
+    for name, param in transformer3d.named_parameters():
+        for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate:
+            if trainable_module_name in name:
+                param.requires_grad = True
+                break
+
+    # Create EMA for the transformer3d.
+    if args.use_ema:
+        ema_transformer3d = CogVideoXTransformer3DModel.from_pretrained_2d(
+            args.pretrained_model_name_or_path, subfolder="transformer"
+        )
+        ema_transformer3d = EMAModel(ema_transformer3d.parameters(), model_cls=CogVideoXTransformer3DModel, model_config=ema_transformer3d.config)
+
+    # `accelerate` 0.16.0 will have better support for customized saving
+    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+        def save_model_hook(models, weights, output_dir):
+            if accelerator.is_main_process:
+                if args.use_ema:
+                    ema_transformer3d.save_pretrained(os.path.join(output_dir, "transformer_ema"))
+
+                models[0].save_pretrained(os.path.join(output_dir, "transformer"))
+                if not args.use_deepspeed:
+                    weights.pop()
+
+                with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file:
+                    pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file)
+
+        def load_model_hook(models, input_dir):
+            if args.use_ema:
+                ema_path = os.path.join(input_dir, "transformer_ema")
+                _, ema_kwargs = CogVideoXTransformer3DModel.load_config(ema_path, return_unused_kwargs=True)
+                load_model = CogVideoXTransformer3DModel.from_pretrained_2d(
+                    input_dir, subfolder="transformer_ema"
+                )
+                load_model = EMAModel(load_model.parameters(), model_cls=CogVideoXTransformer3DModel, model_config=load_model.config)
+                load_model.load_state_dict(ema_kwargs)
+
+                ema_transformer3d.load_state_dict(load_model.state_dict())
+                ema_transformer3d.to(accelerator.device)
+                del load_model
+
+            for i in range(len(models)):
+                # pop models so that they are not loaded again
+                model = models.pop()
+
+                # load diffusers style into model
+                load_model = CogVideoXTransformer3DModel.from_pretrained_2d(
+                    input_dir, subfolder="transformer"
+                )
+                model.register_to_config(**load_model.config)
+
+                model.load_state_dict(load_model.state_dict())
+                del load_model
+
+            pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl")
+            if os.path.exists(pkl_path):
+                with open(pkl_path, 'rb') as file:
+                    loaded_number, _ = pickle.load(file)
+                    batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0)
+                print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.")
+
+        accelerator.register_save_state_pre_hook(save_model_hook)
+        accelerator.register_load_state_pre_hook(load_model_hook)
+
+    if args.gradient_checkpointing:
+        transformer3d.enable_gradient_checkpointing()
+
+    # Enable TF32 for faster training on Ampere GPUs,
+    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+    if args.allow_tf32:
+        torch.backends.cuda.matmul.allow_tf32 = True
+
+    if args.scale_lr:
+        args.learning_rate = (
+            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+        )
+
+    # Initialize the optimizer
+    if args.use_8bit_adam:
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError(
+                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+            )
+
+        optimizer_cls = bnb.optim.AdamW8bit
+    elif args.use_came:
+        try:
+            from came_pytorch import CAME
+        except:
+            raise ImportError(
+                "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`"
+            )
+
+        optimizer_cls = CAME
+    else:
+        optimizer_cls = torch.optim.AdamW
+
+    trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters()))
+    trainable_params_optim = [
+        {'params': [], 'lr': args.learning_rate},
+        {'params': [], 'lr': args.learning_rate / 2},
+    ]
+    in_already = []
+    for name, param in transformer3d.named_parameters():
+        high_lr_flag = False
+        if name in in_already:
+            continue
+        for trainable_module_name in args.trainable_modules:
+            if trainable_module_name in name:
+                in_already.append(name)
+                high_lr_flag = True
+                trainable_params_optim[0]['params'].append(param)
+                if accelerator.is_main_process:
+                    print(f"Set {name} to lr : {args.learning_rate}")
+                break
+        if high_lr_flag:
+            continue
+        for trainable_module_name in args.trainable_modules_low_learning_rate:
+            if trainable_module_name in name:
+                in_already.append(name)
+                trainable_params_optim[1]['params'].append(param)
+                if accelerator.is_main_process:
+                    print(f"Set {name} to lr : {args.learning_rate / 2}")
+                break
+
+    if args.use_came:
+        optimizer = optimizer_cls(
+            trainable_params_optim,
+            lr=args.learning_rate,
+            # weight_decay=args.adam_weight_decay,
+            betas=(0.9, 0.999, 0.9999), 
+            eps=(1e-30, 1e-16)
+        )
+    else:
+        optimizer = optimizer_cls(
+            trainable_params_optim,
+            lr=args.learning_rate,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            eps=args.adam_epsilon,
+        )
+
+    # Get the training dataset
+    sample_n_frames_bucket_interval = 4
+
+    train_dataset = ImageVideoControlDataset(
+        args.train_data_meta, args.train_data_dir,
+        video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, 
+        video_repeat=args.video_repeat, 
+        image_sample_size=args.image_sample_size,
+        enable_bucket=args.enable_bucket, enable_inpaint=False,
+    )
+    
+    if args.enable_bucket:
+        aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+        aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+
+        batch_sampler_generator = torch.Generator().manual_seed(args.seed)
+        batch_sampler = AspectRatioBatchImageVideoSampler(
+            sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, 
+            batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True,
+            aspect_ratios=aspect_ratio_sample_size,
+        )
+        if args.keep_all_node_same_token_length:
+            if args.token_sample_size > 256:
+                numbers_list = list(range(256, args.token_sample_size + 1, 128))
+
+                if numbers_list[-1] != args.token_sample_size:
+                    numbers_list.append(args.token_sample_size)
+            else:
+                numbers_list = [256]
+            numbers_list = [_number * _number * args.video_sample_n_frames for _number in  numbers_list]
+        else:
+            numbers_list = None
+
+        def get_length_to_frame_num(token_length):
+            if args.image_sample_size > args.video_sample_size:
+                sample_sizes = list(range(256, args.image_sample_size + 1, 128))
+
+                if sample_sizes[-1] != args.image_sample_size:
+                    sample_sizes.append(args.image_sample_size)
+            else:
+                sample_sizes = [256]
+            
+            length_to_frame_num = {
+                sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes
+            }
+
+            return length_to_frame_num
+
+        def collate_fn(examples):
+            target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size
+            length_to_frame_num = get_length_to_frame_num(
+                target_token_length, 
+            )
+
+            # Create new output
+            new_examples                 = {}
+            new_examples["target_token_length"] = target_token_length
+            new_examples["pixel_values"] = []
+            new_examples["text"]         = []
+            new_examples["control_pixel_values"] = []
+
+            # Get ratio
+            pixel_value     = examples[0]["pixel_values"]
+            data_type       = examples[0]["data_type"]
+            f, h, w, c      = np.shape(pixel_value)
+            if data_type == 'image':
+                random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size, image_ratio=[args.image_sample_size / args.video_sample_size], rng=rng)
+
+                aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+                aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+                
+                batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+            else:
+                if args.random_hw_adapt:
+                    if args.training_with_video_token_length:
+                        local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples]))
+                        choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25]
+                        if len(choice_list) == 0:
+                            choice_list = list(length_to_frame_num.keys())
+                        if rng is None:
+                            local_video_sample_size = np.random.choice(choice_list)
+                        else:
+                            local_video_sample_size = rng.choice(choice_list)
+                        batch_video_length = length_to_frame_num[local_video_sample_size]
+                        random_downsample_ratio = args.video_sample_size / local_video_sample_size
+                    else:
+                        random_downsample_ratio = get_random_downsample_ratio(
+                                args.video_sample_size, rng=rng)
+                        batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+                else:
+                    random_downsample_ratio = 1
+                    batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+
+                aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+                aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+
+            closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size)
+            closest_size = [int(x / 16) * 16 for x in closest_size]
+            if args.random_ratio_crop:
+                if rng is None:
+                    random_sample_size = aspect_ratio_random_crop_sample_size[
+                        np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB)
+                    ]
+                else:
+                    random_sample_size = aspect_ratio_random_crop_sample_size[
+                        rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB)
+                    ]
+                random_sample_size = [int(x / 16) * 16 for x in random_sample_size]
+
+            for example in examples:
+                if args.random_ratio_crop:
+                    pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+
+                    control_pixel_values = torch.from_numpy(example["control_pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    control_pixel_values = control_pixel_values / 255.
+
+                    # Get adapt hw for resize
+                    b, c, h, w = pixel_values.size()
+                    th, tw = random_sample_size
+                    if th / tw > h / w:
+                        nh = int(th)
+                        nw = int(w / h * nh)
+                    else:
+                        nw = int(tw)
+                        nh = int(h / w * nw)
+                    
+                    transform = transforms.Compose([
+                        transforms.Resize([nh, nw]),
+                        transforms.CenterCrop([int(x) for x in random_sample_size]),
+                        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+                    ])
+                else:
+                    closest_size = list(map(lambda x: int(x), closest_size))
+                    if closest_size[0] / h > closest_size[1] / w:
+                        resize_size = closest_size[0], int(w * closest_size[0] / h)
+                    else:
+                        resize_size = int(h * closest_size[1] / w), closest_size[1]
+                    
+                    pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+
+                    control_pixel_values = torch.from_numpy(example["control_pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    control_pixel_values = control_pixel_values / 255.
+
+                    transform = transforms.Compose([
+                        transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR),  # Image.BICUBIC
+                        transforms.CenterCrop(closest_size),
+                        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+                    ])
+                new_examples["pixel_values"].append(transform(pixel_values))
+                new_examples["control_pixel_values"].append(transform(control_pixel_values))
+                new_examples["text"].append(example["text"])
+                batch_video_length = int(
+                    min(
+                        batch_video_length,
+                        (len(pixel_values) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1, 
+                    )
+                )
+
+                if batch_video_length == 0:
+                    batch_video_length = 1
+
+            new_examples["pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["pixel_values"]])
+            new_examples["control_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["control_pixel_values"]])
+
+            if args.enable_text_encoder_in_dataloader:
+                prompt_ids = tokenizer(
+                    new_examples['text'], 
+                    max_length=args.tokenizer_max_length, 
+                    padding="max_length", 
+                    add_special_tokens=True, 
+                    truncation=True, 
+                    return_tensors="pt"
+                )
+                encoder_hidden_states = text_encoder(
+                    prompt_ids.input_ids,
+                    return_dict=False
+                )[0]
+                new_examples['encoder_attention_mask'] = prompt_ids.attention_mask
+                new_examples['encoder_hidden_states'] = encoder_hidden_states
+
+            return new_examples
+        
+        # DataLoaders creation:
+        train_dataloader = torch.utils.data.DataLoader(
+            train_dataset,
+            batch_sampler=batch_sampler,
+            collate_fn=collate_fn,
+            persistent_workers=True if args.dataloader_num_workers != 0 else False,
+            num_workers=args.dataloader_num_workers,
+        )
+    else:
+        # DataLoaders creation:
+        batch_sampler_generator = torch.Generator().manual_seed(args.seed)
+        batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size)
+        train_dataloader = torch.utils.data.DataLoader(
+            train_dataset,
+            batch_sampler=batch_sampler, 
+            persistent_workers=True if args.dataloader_num_workers != 0 else False,
+            num_workers=args.dataloader_num_workers,
+        )
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+
+    lr_scheduler = get_scheduler(
+        args.lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+        num_training_steps=args.max_train_steps * accelerator.num_processes,
+    )
+
+    # Prepare everything with our `accelerator`.
+    transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+        transformer3d, optimizer, train_dataloader, lr_scheduler
+    )
+
+    if args.use_ema:
+        ema_transformer3d.to(accelerator.device)
+
+    # Move text_encode and vae to gpu and cast to weight_dtype
+    vae.to(accelerator.device, dtype=weight_dtype)
+    if not args.enable_text_encoder_in_dataloader:
+        text_encoder.to(accelerator.device)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    # Afterwards we recalculate our number of training epochs
+    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    # We need to initialize the trackers we use, and also store our configuration.
+    # The trackers initializes automatically on the main process.
+    if accelerator.is_main_process:
+        tracker_config = dict(vars(args))
+        tracker_config.pop("validation_prompts")
+        tracker_config.pop("validation_paths")
+        tracker_config.pop("trainable_modules")
+        tracker_config.pop("trainable_modules_low_learning_rate")
+        accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+    # Function for unwrapping if model was compiled with `torch.compile`.
+    def unwrap_model(model):
+        model = accelerator.unwrap_model(model)
+        model = model._orig_mod if is_compiled_module(model) else model
+        return model
+
+    # Train!
+    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Potentially load in the weights and states from a previous save
+    if args.resume_from_checkpoint:
+        if args.resume_from_checkpoint != "latest":
+            path = os.path.basename(args.resume_from_checkpoint)
+        else:
+            # Get the most recent checkpoint
+            dirs = os.listdir(args.output_dir)
+            dirs = [d for d in dirs if d.startswith("checkpoint")]
+            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+            path = dirs[-1] if len(dirs) > 0 else None
+
+        if path is None:
+            accelerator.print(
+                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+            )
+            args.resume_from_checkpoint = None
+            initial_global_step = 0
+        else:
+            global_step = int(path.split("-")[1])
+
+            initial_global_step = global_step
+
+            pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl")
+            if os.path.exists(pkl_path):
+                with open(pkl_path, 'rb') as file:
+                    _, first_epoch = pickle.load(file)
+            else:
+                first_epoch = global_step // num_update_steps_per_epoch
+            print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.")
+
+            accelerator.print(f"Resuming from checkpoint {path}")
+            accelerator.load_state(os.path.join(args.output_dir, path))
+    else:
+        initial_global_step = 0
+
+    progress_bar = tqdm(
+        range(0, args.max_train_steps),
+        initial=initial_global_step,
+        desc="Steps",
+        # Only show the progress bar once on each machine.
+        disable=not accelerator.is_local_main_process,
+    )
+
+    if args.multi_stream:
+        # create extra cuda streams to speedup inpaint vae computation
+        vae_stream_1 = torch.cuda.Stream()
+        vae_stream_2 = torch.cuda.Stream()
+    else:
+        vae_stream_1 = None
+        vae_stream_2 = None
+
+    for epoch in range(first_epoch, args.num_train_epochs):
+        train_loss = 0.0
+        batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch)
+        for step, batch in enumerate(train_dataloader):
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True)
+                for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                    pixel_value = pixel_value[None, ...]
+                    gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}'
+                    save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True)
+
+            with accelerator.accumulate(transformer3d):
+                # Convert images to latent space
+                pixel_values = batch["pixel_values"].to(weight_dtype)
+                control_pixel_values = batch["control_pixel_values"].to(weight_dtype)
+                if args.training_with_video_token_length:
+                    if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                        pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1))
+                        control_pixel_values = torch.tile(control_pixel_values, (4, 1, 1, 1, 1))
+                        if args.enable_text_encoder_in_dataloader:
+                            batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1))
+                            batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1))
+                        else:
+                            batch['text'] = batch['text'] * 4
+                    elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                        pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1))
+                        control_pixel_values = torch.tile(control_pixel_values, (2, 1, 1, 1, 1))
+                        if args.enable_text_encoder_in_dataloader:
+                            batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1))
+                            batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1))
+                        else:
+                            batch['text'] = batch['text'] * 2
+
+                def create_special_list(length):
+                    if length == 1:
+                        return [1.0]
+                    if length >= 2:
+                        last_element = 0.90
+                        remaining_sum = 1.0 - last_element
+                        other_elements_value = remaining_sum / (length - 1)
+                        special_list = [other_elements_value] * (length - 1) + [last_element]
+                        return special_list
+                    
+                if args.keep_all_node_same_token_length:
+                    actual_token_length = index_rng.choice(numbers_list)
+
+                    actual_video_length = (min(
+                            actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames
+                    ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1
+                    actual_video_length = int(max(actual_video_length, 1))
+                else:
+                    actual_video_length = None
+
+                if args.random_frame_crop:
+                    select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))]
+                    select_frames_prob = np.array(create_special_list(len(select_frames)))
+                    
+                    if rng is None:
+                        temp_n_frames = np.random.choice(select_frames, p = select_frames_prob)
+                    else:
+                        temp_n_frames = rng.choice(select_frames, p = select_frames_prob)
+                    if args.keep_all_node_same_token_length:
+                        temp_n_frames = min(actual_video_length, temp_n_frames)
+
+                    pixel_values = pixel_values[:, :temp_n_frames, :, :]
+                    control_pixel_values = control_pixel_values[:, :temp_n_frames, :, :]
+
+                if args.low_vram:
+                    torch.cuda.empty_cache()
+                    vae.to(accelerator.device)
+                    if not args.enable_text_encoder_in_dataloader:
+                        text_encoder.to(accelerator.device)
+
+                with torch.no_grad():
+                    # This way is quicker when batch grows up
+                    def _slice_vae(pixel_values):
+                        pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                        bs = args.vae_mini_batch
+                        new_pixel_values = []
+                        for i in range(0, pixel_values.shape[0], bs):
+                            pixel_values_bs = pixel_values[i : i + bs]
+                            pixel_values_bs = vae.encode(pixel_values_bs)[0]
+                            pixel_values_bs = pixel_values_bs.sample()
+                            new_pixel_values.append(pixel_values_bs)
+                            vae._clear_fake_context_parallel_cache()
+                        return torch.cat(new_pixel_values, dim = 0)
+                    if vae_stream_1 is not None:
+                        vae_stream_1.wait_stream(torch.cuda.current_stream())
+                        with torch.cuda.stream(vae_stream_1):
+                            latents = _slice_vae(pixel_values)
+                    else:
+                        latents = _slice_vae(pixel_values)
+                    latents = latents * vae.config.scaling_factor
+                        
+                    control_latents = _slice_vae(control_pixel_values)
+                    control_latents = control_latents * vae.config.scaling_factor
+                    control_latents = rearrange(control_latents, "b c f h w -> b f c h w")
+
+                    latents = rearrange(latents, "b c f h w -> b f c h w")
+                        
+                # wait for latents = vae.encode(pixel_values) to complete
+                if vae_stream_1 is not None:
+                    torch.cuda.current_stream().wait_stream(vae_stream_1)
+
+                if args.low_vram:
+                    vae.to('cpu')
+                    torch.cuda.empty_cache()
+                    if not args.enable_text_encoder_in_dataloader:
+                        text_encoder.to(accelerator.device)
+
+                if args.enable_text_encoder_in_dataloader:
+                    prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device)
+                else:
+                    with torch.no_grad():
+                        prompt_ids = tokenizer(
+                            batch['text'], 
+                            max_length=args.tokenizer_max_length, 
+                            padding="max_length", 
+                            add_special_tokens=True, 
+                            truncation=True, 
+                            return_tensors="pt"
+                        )
+                        prompt_embeds = text_encoder(
+                            prompt_ids.input_ids.to(latents.device),
+                            return_dict=False
+                        )[0]
+
+                if args.low_vram and not args.enable_text_encoder_in_dataloader:
+                    text_encoder.to('cpu')
+                    torch.cuda.empty_cache()
+
+                bsz = latents.shape[0]
+                noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype)
+                # Sample a random timestep for each image
+                # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng)
+                timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng)
+                timesteps = timesteps.long()
+
+                # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+                def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+                    tw = tgt_width
+                    th = tgt_height
+                    h, w = src
+                    r = h / w
+                    if r > (th / tw):
+                        resize_height = th
+                        resize_width = int(round(th / h * w))
+                    else:
+                        resize_width = tw
+                        resize_height = int(round(tw / w * h))
+
+                    crop_top = int(round((th - resize_height) / 2.0))
+                    crop_left = int(round((tw - resize_width) / 2.0))
+
+                    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+                def _prepare_rotary_positional_embeddings(
+                    height: int,
+                    width: int,
+                    num_frames: int,
+                    device: torch.device
+                ):
+                    vae_scale_factor_spatial = (
+                        2 ** (len(vae.config.block_out_channels) - 1) if vae is not None else 8
+                    )
+                    grid_height = height // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    grid_width = width // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    base_size_width = 720 // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    base_size_height = 480 // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+
+                    grid_crops_coords = get_resize_crop_region_for_grid(
+                        (grid_height, grid_width), base_size_width, base_size_height
+                    )
+                    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                        embed_dim=unwrap_model(transformer3d).config.attention_head_dim,
+                        crops_coords=grid_crops_coords,
+                        grid_size=(grid_height, grid_width),
+                        temporal_size=num_frames,
+                        use_real=True,
+                    )
+                    freqs_cos = freqs_cos.to(device=device)
+                    freqs_sin = freqs_sin.to(device=device)
+                    return freqs_cos, freqs_sin
+
+                height, width = batch["pixel_values"].size()[-2], batch["pixel_values"].size()[-1]
+                # 7. Create rotary embeds if required
+                image_rotary_emb = (
+                    _prepare_rotary_positional_embeddings(height, width, latents.size(1), latents.device)
+                    if unwrap_model(transformer3d).config.use_rotary_positional_embeddings
+                    else None
+                )
+                prompt_embeds = prompt_embeds.to(device=latents.device)
+
+                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+                if noise_scheduler.config.prediction_type == "epsilon":
+                    target = noise
+                elif noise_scheduler.config.prediction_type == "v_prediction":
+                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                else:
+                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+                # predict the noise residual
+                noise_pred = transformer3d(
+                    hidden_states=noisy_latents,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timesteps,
+                    image_rotary_emb=image_rotary_emb,
+                    return_dict=False,
+                    control_latents=control_latents,
+                )[0]
+                print(noise_pred.size(), noisy_latents.size(), latents.size(), pixel_values.size())
+                loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+                if args.motion_sub_loss and noise_pred.size()[1] > 2:
+                    gt_sub_noise = noise_pred[:, 1:, :].float() - noise_pred[:, :-1, :].float()
+                    pre_sub_noise = target[:, 1:, :].float() - target[:, :-1, :].float()
+                    sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean")
+                    loss = loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio
+
+                # Gather the losses across all processes for logging (if we use distributed training).
+                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+                train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+                # Backpropagate
+                accelerator.backward(loss)
+                if accelerator.sync_gradients:
+                    if not args.use_deepspeed:
+                        trainable_params_grads = [p.grad for p in trainable_params if p.grad is not None]
+                        trainable_params_total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in trainable_params_grads]), 2)
+                        max_grad_norm = linear_decay(args.max_grad_norm * args.initial_grad_norm_ratio, args.max_grad_norm, args.abnormal_norm_clip_start, global_step)
+                        if trainable_params_total_norm / max_grad_norm > 5 and global_step > args.abnormal_norm_clip_start:
+                            actual_max_grad_norm = max_grad_norm / min((trainable_params_total_norm / max_grad_norm), 10)
+                        else:
+                            actual_max_grad_norm = max_grad_norm
+                    else:
+                        actual_max_grad_norm = args.max_grad_norm
+
+                    if not args.use_deepspeed and args.report_model_info and accelerator.is_main_process:
+                        if trainable_params_total_norm > 1 and global_step > args.abnormal_norm_clip_start:
+                            for name, param in transformer3d.named_parameters():
+                                if param.requires_grad:
+                                    writer.add_scalar(f'gradients/before_clip_norm/{name}', param.grad.norm(), global_step=global_step)
+
+                    norm_sum = accelerator.clip_grad_norm_(trainable_params, actual_max_grad_norm)
+                    if not args.use_deepspeed and args.report_model_info and accelerator.is_main_process:
+                        writer.add_scalar(f'gradients/norm_sum', norm_sum, global_step=global_step)
+                        writer.add_scalar(f'gradients/actual_max_grad_norm', actual_max_grad_norm, global_step=global_step)
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad()
+
+            # Checks if the accelerator has performed an optimization step behind the scenes
+            if accelerator.sync_gradients:
+
+                if args.use_ema:
+                    ema_transformer3d.step(transformer3d.parameters())
+                progress_bar.update(1)
+                global_step += 1
+                accelerator.log({"train_loss": train_loss}, step=global_step)
+                train_loss = 0.0
+
+                if global_step % args.checkpointing_steps == 0:
+                    if args.use_deepspeed or accelerator.is_main_process:
+                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+                        if args.checkpoints_total_limit is not None:
+                            checkpoints = os.listdir(args.output_dir)
+                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+                            if len(checkpoints) >= args.checkpoints_total_limit:
+                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+                                removing_checkpoints = checkpoints[0:num_to_remove]
+
+                                logger.info(
+                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+                                )
+                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+                                for removing_checkpoint in removing_checkpoints:
+                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+                                    shutil.rmtree(removing_checkpoint)
+
+                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+                        accelerator.save_state(save_path)
+                        logger.info(f"Saved state to {save_path}")
+
+                if accelerator.is_main_process:
+                    if args.validation_prompts is not None and global_step % args.validation_steps == 0:
+                        if args.use_ema:
+                            # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+                            ema_transformer3d.store(transformer3d.parameters())
+                            ema_transformer3d.copy_to(transformer3d.parameters())
+                        log_validation(
+                            vae,
+                            text_encoder,
+                            tokenizer,
+                            transformer3d,
+                            args,
+                            accelerator,
+                            weight_dtype,
+                            global_step,
+                        )
+                        if args.use_ema:
+                            # Switch back to the original transformer3d parameters.
+                            ema_transformer3d.restore(transformer3d.parameters())
+
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+
+            if global_step >= args.max_train_steps:
+                break
+
+        if accelerator.is_main_process:
+            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+                if args.use_ema:
+                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+                    ema_transformer3d.store(transformer3d.parameters())
+                    ema_transformer3d.copy_to(transformer3d.parameters())
+                log_validation(
+                    vae,
+                    text_encoder,
+                    tokenizer,
+                    transformer3d,
+                    args,
+                    accelerator,
+                    weight_dtype,
+                    global_step,
+                )
+                if args.use_ema:
+                    # Switch back to the original transformer3d parameters.
+                    ema_transformer3d.restore(transformer3d.parameters())
+
+    # Create the pipeline using the trained modules and save it.
+    accelerator.wait_for_everyone()
+    if accelerator.is_main_process:
+        transformer3d = unwrap_model(transformer3d)
+        if args.use_ema:
+            ema_transformer3d.copy_to(transformer3d.parameters())
+
+        if args.use_deepspeed or accelerator.is_main_process:
+            save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+            accelerator.save_state(save_path)
+            logger.info(f"Saved state to {save_path}")
+
+    accelerator.end_training()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train_control.sh b/scripts/train_control.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1d31d076ef0f5f5eca157eb6b1c121dbe00e403e
--- /dev/null
+++ b/scripts/train_control.sh
@@ -0,0 +1,41 @@
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-Pose"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata_control.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --mixed_precision="bf16" scripts/train_control.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1024 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=4 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=2e-05 \
+  --lr_scheduler="constant_with_warmup" \
+  --lr_warmup_steps=50 \
+  --seed=43 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --use_came \
+  --resume_from_checkpoint="latest" \
+  --trainable_modules "."
\ No newline at end of file
diff --git a/scripts/train_lora.py b/scripts/train_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7bca0789855aacbbf3ed3c44386569a312600b7
--- /dev/null
+++ b/scripts/train_lora.py
@@ -0,0 +1,1629 @@
+"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
+"""
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import logging
+import math
+import os
+import pickle
+import shutil
+import sys
+
+import accelerate
+import diffusers
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+from einops import rearrange
+from huggingface_hub import create_repo, upload_folder
+from omegaconf import OmegaConf
+from packaging import version
+from PIL import Image
+from torch.utils.data import RandomSampler
+from torch.utils.tensorboard import SummaryWriter
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
+                          CLIPVisionModelWithProjection, MT5Tokenizer,
+                          T5EncoderModel, T5Tokenizer)
+from transformers.utils import ContextManagers
+
+import datasets
+
+current_file_path = os.path.abspath(__file__)
+project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path))]
+for project_root in project_roots:
+    sys.path.insert(0, project_root) if project_root not in sys.path else None
+from cogvideox.data.bucket_sampler import (ASPECT_RATIO_512,
+                                           ASPECT_RATIO_RANDOM_CROP_512,
+                                           ASPECT_RATIO_RANDOM_CROP_PROB,
+                                           AspectRatioBatchImageSampler,
+                                           AspectRatioBatchImageVideoSampler,
+                                           AspectRatioBatchSampler,
+                                           RandomSampler, get_closest_ratio)
+from cogvideox.data.dataset_image import CC15M
+from cogvideox.data.dataset_image_video import (ImageVideoDataset,
+                                                ImageVideoSampler,
+                                                get_random_mask)
+from cogvideox.data.dataset_video import VideoDataset, WebVid10M
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import \
+    CogVideoX_Fun_Pipeline_Inpaint, add_noise_to_reference_video
+from cogvideox.utils.lora_utils import create_network, merge_lora, unmerge_lora
+from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
+
+if is_wandb_available():
+    import wandb
+
+
+def get_random_downsample_ratio(sample_size, image_ratio=[],
+                                all_choices=False, rng=None):
+    def _create_special_list(length):
+        if length == 1:
+            return [1.0]
+        if length >= 2:
+            first_element = 0.75
+            remaining_sum = 1.0 - first_element
+            other_elements_value = remaining_sum / (length - 1)
+            special_list = [first_element] + [other_elements_value] * (length - 1)
+            return special_list
+            
+    if sample_size >= 1536:
+        number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio 
+    elif sample_size >= 1024:
+        number_list = [1, 1.25, 1.5, 2] + image_ratio
+    elif sample_size >= 768:
+        number_list = [1, 1.25, 1.5] + image_ratio
+    elif sample_size >= 512:
+        number_list = [1] + image_ratio
+    else:
+        number_list = [1]
+
+    if all_choices:
+        return number_list
+
+    number_list_prob = np.array(_create_special_list(len(number_list)))
+    if rng is None:
+        return np.random.choice(number_list, p = number_list_prob)
+    else:
+        return rng.choice(number_list, p = number_list_prob)
+
+def resize_mask(mask, latent, process_first_frame_only=True):
+    latent_size = latent.size()
+    batch_size, channels, num_frames, height, width = mask.shape
+
+    if process_first_frame_only:
+        target_size = list(latent_size[2:])
+        target_size[0] = 1
+        first_frame_resized = F.interpolate(
+            mask[:, :, 0:1, :, :],
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+        
+        target_size = list(latent_size[2:])
+        target_size[0] = target_size[0] - 1
+        if target_size[0] != 0:
+            remaining_frames_resized = F.interpolate(
+                mask[:, :, 1:, :, :],
+                size=target_size,
+                mode='trilinear',
+                align_corners=False
+            )
+            resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
+        else:
+            resized_mask = first_frame_resized
+    else:
+        target_size = list(latent_size[2:])
+        resized_mask = F.interpolate(
+            mask,
+            size=target_size,
+            mode='trilinear',
+            align_corners=False
+        )
+    return resized_mask
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+def log_validation(vae, text_encoder, tokenizer, transformer3d, network, args, accelerator, weight_dtype, global_step):
+    try:
+        logger.info("Running validation... ")
+
+        transformer3d_val = CogVideoXTransformer3DModel.from_pretrained_2d(
+            args.pretrained_model_name_or_path, subfolder="transformer",
+        ).to(weight_dtype)
+        transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict())
+        scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+        
+        if args.train_mode != "normal":
+            pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+                args.pretrained_model_name_or_path, 
+                vae=accelerator.unwrap_model(vae).to(weight_dtype), 
+                text_encoder=accelerator.unwrap_model(text_encoder),
+                tokenizer=tokenizer,
+                transformer=transformer3d_val,
+                scheduler=scheduler,
+                torch_dtype=weight_dtype,
+            )
+        else:
+            pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+                args.pretrained_model_name_or_path, 
+                vae=accelerator.unwrap_model(vae).to(weight_dtype), 
+                text_encoder=accelerator.unwrap_model(text_encoder),
+                tokenizer=tokenizer,
+                transformer=transformer3d_val,
+                scheduler=scheduler,
+                torch_dtype=weight_dtype
+            )
+
+        pipeline = pipeline.to(accelerator.device)
+        pipeline = merge_lora(
+            pipeline, None, 1, accelerator.device, state_dict=accelerator.unwrap_model(network).state_dict(), transformer_only=True
+        )
+
+        if args.seed is None:
+            generator = None
+        else:
+            generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+        for i in range(len(args.validation_prompts)):
+            with torch.no_grad():
+                if args.train_mode != "normal":
+                    with torch.autocast("cuda", dtype=weight_dtype):
+                        video_length = int((args.video_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1
+                        input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size])
+                        sample = pipeline(
+                            args.validation_prompts[i],
+                            num_frames = video_length,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            guidance_scale = 7,
+                            generator   = generator,
+
+                            video        = input_video,
+                            mask_video   = input_video_mask,
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif"))
+
+                        video_length = 1
+                        input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=[args.video_sample_size, args.video_sample_size])
+                        sample = pipeline(
+                            args.validation_prompts[i],
+                            num_frames = video_length,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            generator   = generator,
+
+                            video        = input_video,
+                            mask_video   = input_video_mask,
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif"))
+                else:
+                    with torch.autocast("cuda", dtype=weight_dtype):
+                        sample = pipeline(
+                            args.validation_prompts[i], 
+                            num_frames = args.video_sample_n_frames,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            generator   = generator
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-{i}.gif"))
+
+                        sample = pipeline(
+                            args.validation_prompts[i], 
+                            num_frames = 1,
+                            negative_prompt = "bad detailed",
+                            height      = args.video_sample_size,
+                            width       = args.video_sample_size,
+                            generator   = generator
+                        ).videos
+                        os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True)
+                        save_videos_grid(sample, os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif"))
+
+        del pipeline
+        del transformer3d_val
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+    except Exception as e:
+        gc.collect()
+        torch.cuda.empty_cache()
+        torch.cuda.ipc_collect()
+        print(f"Eval error with info {e}")
+        return None
+
+def linear_decay(initial_value, final_value, total_steps, current_step):
+    if current_step >= total_steps:
+        return final_value
+    current_step = max(0, current_step)
+    step_size = (final_value - initial_value) / total_steps
+    current_value = initial_value + step_size * current_step
+    return current_value
+
+def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None):
+    u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator)
+    t = 1 / (1 + torch.exp(-u)) * (high - low) + low
+    return torch.clip(t.to(torch.int32), low, high - 1)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Simple example of a training script.")
+    parser.add_argument(
+        "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+    )
+    parser.add_argument(
+        "--pretrained_model_name_or_path",
+        type=str,
+        default=None,
+        required=True,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--revision",
+        type=str,
+        default=None,
+        required=False,
+        help="Revision of pretrained model identifier from huggingface.co/models.",
+    )
+    parser.add_argument(
+        "--variant",
+        type=str,
+        default=None,
+        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+    )
+    parser.add_argument(
+        "--train_data_dir",
+        type=str,
+        default=None,
+        help=(
+            "A folder containing the training data. "
+        ),
+    )
+    parser.add_argument(
+        "--train_data_meta",
+        type=str,
+        default=None,
+        help=(
+            "A csv containing the training data. "
+        ),
+    )
+    parser.add_argument(
+        "--max_train_samples",
+        type=int,
+        default=None,
+        help=(
+            "For debugging purposes or quicker training, truncate the number of training examples to this "
+            "value if set."
+        ),
+    )
+    parser.add_argument(
+        "--validation_prompts",
+        type=str,
+        default=None,
+        nargs="+",
+        help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="sd-model-finetuned",
+        help="The output directory where the model predictions and checkpoints will be written.",
+    )
+    parser.add_argument(
+        "--cache_dir",
+        type=str,
+        default=None,
+        help="The directory where the downloaded models and datasets will be stored.",
+    )
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument(
+        "--random_flip",
+        action="store_true",
+        help="whether to randomly flip images horizontally",
+    )
+    parser.add_argument(
+        "--use_came",
+        action="store_true",
+        help="whether to use came",
+    )
+    parser.add_argument(
+        "--multi_stream",
+        action="store_true",
+        help="whether to use cuda multi-stream",
+    )
+    parser.add_argument(
+        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+    )
+    parser.add_argument(
+        "--vae_mini_batch", type=int, default=32, help="mini batch size for vae."
+    )
+    parser.add_argument("--num_train_epochs", type=int, default=100)
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--gradient_checkpointing",
+        action="store_true",
+        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=1e-4,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument(
+        "--scale_lr",
+        action="store_true",
+        default=False,
+        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+    )
+    parser.add_argument(
+        "--lr_scheduler",
+        type=str,
+        default="constant",
+        help=(
+            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+            ' "constant", "constant_with_warmup"]'
+        ),
+    )
+    parser.add_argument(
+        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument(
+        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+    )
+    parser.add_argument(
+        "--allow_tf32",
+        action="store_true",
+        help=(
+            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+        ),
+    )
+    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+    parser.add_argument(
+        "--non_ema_revision",
+        type=str,
+        default=None,
+        required=False,
+        help=(
+            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+            " remote repository specified with --pretrained_model_name_or_path."
+        ),
+    )
+    parser.add_argument(
+        "--dataloader_num_workers",
+        type=int,
+        default=0,
+        help=(
+            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+        ),
+    )
+    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+    parser.add_argument(
+        "--prediction_type",
+        type=str,
+        default=None,
+        help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+    )
+    parser.add_argument(
+        "--hub_model_id",
+        type=str,
+        default=None,
+        help="The name of the repository to keep in sync with the local `output_dir`.",
+    )
+    parser.add_argument(
+        "--logging_dir",
+        type=str,
+        default="logs",
+        help=(
+            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+        ),
+    )
+    parser.add_argument(
+        "--mixed_precision",
+        type=str,
+        default=None,
+        choices=["no", "fp16", "bf16"],
+        help=(
+            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
+            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+        ),
+    )
+    parser.add_argument(
+        "--report_to",
+        type=str,
+        default="tensorboard",
+        help=(
+            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+        ),
+    )
+    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+    parser.add_argument(
+        "--checkpointing_steps",
+        type=int,
+        default=500,
+        help=(
+            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+            " training using `--resume_from_checkpoint`."
+        ),
+    )
+    parser.add_argument(
+        "--checkpoints_total_limit",
+        type=int,
+        default=None,
+        help=("Max number of checkpoints to store."),
+    )
+    parser.add_argument(
+        "--resume_from_checkpoint",
+        type=str,
+        default=None,
+        help=(
+            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+        ),
+    )
+    parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+    parser.add_argument(
+        "--validation_epochs",
+        type=int,
+        default=5,
+        help="Run validation every X epochs.",
+    )
+    parser.add_argument(
+        "--validation_steps",
+        type=int,
+        default=2000,
+        help="Run validation every X steps.",
+    )
+    parser.add_argument(
+        "--tracker_project_name",
+        type=str,
+        default="text2image-fine-tune",
+        help=(
+            "The `project_name` argument passed to Accelerator.init_trackers for"
+            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+        ),
+    )
+    
+    parser.add_argument(
+        "--rank",
+        type=int,
+        default=128,
+        help=("The dimension of the LoRA update matrices."),
+    )
+    parser.add_argument(
+        "--network_alpha",
+        type=int,
+        default=64,
+        help=("The dimension of the LoRA update matrices."),
+    )
+    parser.add_argument(
+        "--train_text_encoder",
+        action="store_true",
+        help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+    )
+    parser.add_argument(
+        "--snr_loss", action="store_true", help="Whether or not to use snr_loss."
+    )
+    parser.add_argument(
+        "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader."
+    )
+    parser.add_argument(
+        "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets."
+    )
+    parser.add_argument(
+        "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets."
+    )
+    parser.add_argument(
+        "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets."
+    )
+    parser.add_argument(
+        "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets."
+    )
+    parser.add_argument(
+        "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.",
+    )
+    parser.add_argument(
+        "--noise_share_in_frames", action="store_true", help="Whether enable noise share in frames."
+    )
+    parser.add_argument(
+        "--noise_share_in_frames_ratio", type=float, default=0.5, help="Noise share ratio.",
+    )
+    parser.add_argument(
+        "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss."
+    )
+    parser.add_argument(
+        "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss."
+    )
+    parser.add_argument(
+        "--keep_all_node_same_token_length",
+        action="store_true", 
+        help="Reference of the length token.",
+    )
+    parser.add_argument(
+        "--train_sampling_steps",
+        type=int,
+        default=1000,
+        help="Run train_sampling_steps.",
+    )
+    parser.add_argument(
+        "--token_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the token.",
+    )
+    parser.add_argument(
+        "--video_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the video.",
+    )
+    parser.add_argument(
+        "--image_sample_size",
+        type=int,
+        default=512,
+        help="Sample size of the video.",
+    )
+    parser.add_argument(
+        "--video_sample_stride",
+        type=int,
+        default=4,
+        help="Sample stride of the video.",
+    )
+    parser.add_argument(
+        "--video_sample_n_frames",
+        type=int,
+        default=17,
+        help="Num frame of video.",
+    )
+    parser.add_argument(
+        "--video_repeat",
+        type=int,
+        default=0,
+        help="Num of repeat video.",
+    )
+    parser.add_argument(
+        "--image_repeat_in_forward",
+        type=int,
+        default=0,
+        help="Num of repeat image in forward.",
+    )
+    parser.add_argument(
+        "--transformer_path",
+        type=str,
+        default=None,
+        help=("If you want to load the weight from other transformers, input its path."),
+    )
+    parser.add_argument(
+        "--vae_path",
+        type=str,
+        default=None,
+        help=("If you want to load the weight from other vaes, input its path."),
+    )
+    parser.add_argument("--save_state", action="store_true", help="Whether or not to save state.")
+
+    parser.add_argument(
+        '--tokenizer_max_length', 
+        type=int,
+        default=226,
+        help='Max length of tokenizer'
+    )
+    parser.add_argument(
+        "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed."
+    )
+    parser.add_argument(
+        "--low_vram", action="store_true", help="Whether enable low_vram mode."
+    )
+    parser.add_argument(
+        "--train_mode",
+        type=str,
+        default="normal",
+        help=(
+            'The format of training data. Support `"normal"`'
+            ' (default), `"inpaint"`.'
+        ),
+    )
+
+    args = parser.parse_args()
+    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+    if env_local_rank != -1 and env_local_rank != args.local_rank:
+        args.local_rank = env_local_rank
+
+    # default to using the same revision for the non-ema model if not specified
+    if args.non_ema_revision is None:
+        args.non_ema_revision = args.revision
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.report_to == "wandb" and args.hub_token is not None:
+        raise ValueError(
+            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+            " Please use `huggingface-cli login` to authenticate with the Hub."
+        )
+
+    if args.non_ema_revision is not None:
+        deprecate(
+            "non_ema_revision!=None",
+            "0.15.0",
+            message=(
+                "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+                " use `--variant=non_ema` instead."
+            ),
+        )
+    logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+    accelerator = Accelerator(
+        gradient_accumulation_steps=args.gradient_accumulation_steps,
+        mixed_precision=args.mixed_precision,
+        log_with=args.report_to,
+        project_config=accelerator_project_config,
+    )
+    if accelerator.is_main_process:
+        writer = SummaryWriter(log_dir=logging_dir)
+
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state, main_process_only=False)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_warning()
+        diffusers.utils.logging.set_verbosity_info()
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+        diffusers.utils.logging.set_verbosity_error()
+
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        set_seed(args.seed)
+        rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index))
+        torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index)
+    else:
+        rng = None
+        torch_rng = None
+    index_rng = np.random.default_rng(np.random.PCG64(43))
+    print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}")
+
+    # Handle the repository creation
+    if accelerator.is_main_process:
+        if args.output_dir is not None:
+            os.makedirs(args.output_dir, exist_ok=True)
+
+    # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision
+    # as these weights are only used for inference, keeping weights in full precision is not required.
+    weight_dtype = torch.float32
+    if accelerator.mixed_precision == "fp16":
+        weight_dtype = torch.float16
+        args.mixed_precision = accelerator.mixed_precision
+    elif accelerator.mixed_precision == "bf16":
+        weight_dtype = torch.bfloat16
+        args.mixed_precision = accelerator.mixed_precision
+
+    # Load scheduler, tokenizer and models.
+    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+    tokenizer = T5Tokenizer.from_pretrained(
+        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+    )
+
+    def deepspeed_zero_init_disabled_context_manager():
+        """
+        returns either a context list that includes one that will disable zero.Init or an empty context list
+        """
+        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+        if deepspeed_plugin is None:
+            return []
+
+        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+    # will try to assign the same optimizer with the same weights to all models during
+    # `deepspeed.initialize`, which of course doesn't work.
+    #
+    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+    # frozen models from being partitioned during `zero.Init` which gets called during
+    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+        text_encoder = T5EncoderModel.from_pretrained(
+            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant,
+            torch_dtype=weight_dtype
+        )
+
+        vae = AutoencoderKLCogVideoX.from_pretrained(
+            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+        )
+
+    transformer3d = CogVideoXTransformer3DModel.from_pretrained_2d(
+        args.pretrained_model_name_or_path, subfolder="transformer"
+    )
+
+    # Freeze vae and text_encoder and set transformer3d to trainable
+    vae.requires_grad_(False)
+    text_encoder.requires_grad_(False)
+    transformer3d.requires_grad_(False)
+
+    # Lora will work with this...
+    network = create_network(
+        1.0,
+        args.rank,
+        args.network_alpha,
+        text_encoder,
+        transformer3d,
+        neuron_dropout=None,
+        add_lora_in_attn_temporal=True,
+    )
+    network.apply_to(text_encoder, transformer3d, args.train_text_encoder and not args.training_with_video_token_length, True)
+
+    if args.transformer_path is not None:
+        print(f"From checkpoint: {args.transformer_path}")
+        if args.transformer_path.endswith("safetensors"):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(args.transformer_path)
+        else:
+            state_dict = torch.load(args.transformer_path, map_location="cpu")
+        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+        m, u = transformer3d.load_state_dict(state_dict, strict=False)
+        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+
+    if args.vae_path is not None:
+        print(f"From checkpoint: {args.vae_path}")
+        if args.vae_path.endswith("safetensors"):
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(args.vae_path)
+        else:
+            state_dict = torch.load(args.vae_path, map_location="cpu")
+        state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
+
+        m, u = vae.load_state_dict(state_dict, strict=False)
+        print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+        assert len(u) == 0
+
+    # `accelerate` 0.16.0 will have better support for customized saving
+    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+        def save_model_hook(models, weights, output_dir):
+            if accelerator.is_main_process:
+                safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors")
+                save_model(safetensor_save_path, accelerator.unwrap_model(models[-1]))
+                if not args.use_deepspeed:
+                    for _ in range(len(weights)):
+                        weights.pop()
+
+                with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file:
+                    pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file)
+
+        def load_model_hook(models, input_dir):
+            pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl")
+            if os.path.exists(pkl_path):
+                with open(pkl_path, 'rb') as file:
+                    loaded_number, _ = pickle.load(file)
+                    batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0)
+                print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.")
+
+        accelerator.register_save_state_pre_hook(save_model_hook)
+        accelerator.register_load_state_pre_hook(load_model_hook)
+
+    if args.gradient_checkpointing:
+        transformer3d.enable_gradient_checkpointing()
+
+    # Enable TF32 for faster training on Ampere GPUs,
+    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+    if args.allow_tf32:
+        torch.backends.cuda.matmul.allow_tf32 = True
+
+    if args.scale_lr:
+        args.learning_rate = (
+            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+        )
+
+    # Initialize the optimizer
+    if args.use_8bit_adam:
+        try:
+            import bitsandbytes as bnb
+        except ImportError:
+            raise ImportError(
+                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+            )
+
+        optimizer_cls = bnb.optim.AdamW8bit
+    elif args.use_came:
+        try:
+            from came_pytorch import CAME
+        except:
+            raise ImportError(
+                "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`"
+            )
+
+        optimizer_cls = CAME
+    else:
+        optimizer_cls = torch.optim.AdamW
+
+    logging.info("Add network parameters")
+    trainable_params = list(filter(lambda p: p.requires_grad, network.parameters()))
+    trainable_params_optim = network.prepare_optimizer_params(args.learning_rate / 2, args.learning_rate, args.learning_rate)
+
+    if args.use_came:
+        optimizer = optimizer_cls(
+            trainable_params_optim,
+            lr=args.learning_rate,
+            # weight_decay=args.adam_weight_decay,
+            betas=(0.9, 0.999, 0.9999), 
+            eps=(1e-30, 1e-16)
+        )
+    else:
+        optimizer = optimizer_cls(
+            trainable_params_optim,
+            lr=args.learning_rate,
+            betas=(args.adam_beta1, args.adam_beta2),
+            weight_decay=args.adam_weight_decay,
+            eps=args.adam_epsilon,
+        )
+
+    # Get the training dataset
+    sample_n_frames_bucket_interval = 4
+
+    train_dataset = ImageVideoDataset(
+        args.train_data_meta, args.train_data_dir,
+        video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, 
+        video_repeat=args.video_repeat, 
+        image_sample_size=args.image_sample_size,
+        enable_bucket=args.enable_bucket, enable_inpaint=True if args.train_mode != "normal" else False,
+    )
+    
+    if args.enable_bucket:
+        aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+        aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+
+        batch_sampler_generator = torch.Generator().manual_seed(args.seed)
+        batch_sampler = AspectRatioBatchImageVideoSampler(
+            sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, 
+            batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True,
+            aspect_ratios=aspect_ratio_sample_size,
+        )
+        if args.keep_all_node_same_token_length:
+            if args.token_sample_size > 256:
+                numbers_list = list(range(256, args.token_sample_size + 1, 128))
+
+                if numbers_list[-1] != args.token_sample_size:
+                    numbers_list.append(args.token_sample_size)
+            else:
+                numbers_list = [256]
+            numbers_list = [_number * _number * args.video_sample_n_frames for _number in  numbers_list]
+        else:
+            numbers_list = None
+
+        def get_length_to_frame_num(token_length):
+            if args.image_sample_size > args.video_sample_size:
+                sample_sizes = list(range(256, args.image_sample_size + 1, 128))
+
+                if sample_sizes[-1] != args.image_sample_size:
+                    sample_sizes.append(args.image_sample_size)
+            else:
+                sample_sizes = [256]
+            
+            length_to_frame_num = {
+                sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes
+            }
+
+            return length_to_frame_num
+
+        def collate_fn(examples):
+            target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size
+            length_to_frame_num = get_length_to_frame_num(
+                target_token_length, 
+            )
+
+            # Create new output
+            new_examples                 = {}
+            new_examples["target_token_length"] = target_token_length
+            new_examples["pixel_values"] = []
+            new_examples["text"]         = []
+            if args.train_mode != "normal":
+                new_examples["mask_pixel_values"] = []
+                new_examples["mask"] = []
+
+            # Get ratio
+            pixel_value     = examples[0]["pixel_values"]
+            data_type       = examples[0]["data_type"]
+            f, h, w, c      = np.shape(pixel_value)
+            if data_type == 'image':
+                random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size, image_ratio=[args.image_sample_size / args.video_sample_size], rng=rng)
+
+                aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+                aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+                
+                batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+            else:
+                if args.random_hw_adapt:
+                    if args.training_with_video_token_length:
+                        local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples]))
+                        choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25]
+                        if len(choice_list) == 0:
+                            choice_list = list(length_to_frame_num.keys())
+                        if rng is None:
+                            local_video_sample_size = np.random.choice(choice_list)
+                        else:
+                            local_video_sample_size = rng.choice(choice_list)
+                        batch_video_length = length_to_frame_num[local_video_sample_size]
+                        random_downsample_ratio = args.video_sample_size / local_video_sample_size
+                    else:
+                        random_downsample_ratio = get_random_downsample_ratio(
+                                args.video_sample_size, rng=rng)
+                        batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+                else:
+                    random_downsample_ratio = 1
+                    batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval
+
+                aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+                aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()}
+
+            closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size)
+            closest_size = [int(x / 16) * 16 for x in closest_size]
+            if args.random_ratio_crop:
+                if rng is None:
+                    random_sample_size = aspect_ratio_random_crop_sample_size[
+                        np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB)
+                    ]
+                else:
+                    random_sample_size = aspect_ratio_random_crop_sample_size[
+                        rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB)
+                    ]
+                random_sample_size = [int(x / 16) * 16 for x in random_sample_size]
+
+            for example in examples:
+                if args.random_ratio_crop:
+                    # To 0~1
+                    pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+
+                    # Get adapt hw for resize
+                    b, c, h, w = pixel_values.size()
+                    th, tw = random_sample_size
+                    if th / tw > h / w:
+                        nh = int(th)
+                        nw = int(w / h * nh)
+                    else:
+                        nw = int(tw)
+                        nh = int(h / w * nw)
+                    
+                    transform = transforms.Compose([
+                        transforms.Resize([nh, nw]),
+                        transforms.CenterCrop([int(x) for x in random_sample_size]),
+                        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+                    ])
+                else:
+                    closest_size = list(map(lambda x: int(x), closest_size))
+                    if closest_size[0] / h > closest_size[1] / w:
+                        resize_size = closest_size[0], int(w * closest_size[0] / h)
+                    else:
+                        resize_size = int(h * closest_size[1] / w), closest_size[1]
+                    
+                    pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous()
+                    pixel_values = pixel_values / 255.
+                    transform = transforms.Compose([
+                        transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR),  # Image.BICUBIC
+                        transforms.CenterCrop(closest_size),
+                        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+                    ])
+                new_examples["pixel_values"].append(transform(pixel_values))
+                new_examples["text"].append(example["text"])
+                batch_video_length = int(
+                    min(
+                        batch_video_length,
+                        (len(pixel_values) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1, 
+                    )
+                )
+
+                if batch_video_length == 0:
+                    batch_video_length = 1
+
+                if args.train_mode != "normal":
+                    mask = get_random_mask(new_examples["pixel_values"][-1].size())
+                    mask_pixel_values = new_examples["pixel_values"][-1] * (1 - mask) + torch.ones_like(new_examples["pixel_values"][-1]) * -1 * mask
+                    new_examples["mask_pixel_values"].append(mask_pixel_values)
+                    new_examples["mask"].append(mask)
+
+            new_examples["pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["pixel_values"]])
+            if args.train_mode != "normal":
+                new_examples["mask_pixel_values"] = torch.stack([example[:batch_video_length] for example in new_examples["mask_pixel_values"]])
+                new_examples["mask"] = torch.stack([example[:batch_video_length] for example in new_examples["mask"]])
+
+            if args.enable_text_encoder_in_dataloader:
+                prompt_ids = tokenizer(
+                    new_examples['text'], 
+                    max_length=args.tokenizer_max_length, 
+                    padding="max_length", 
+                    add_special_tokens=True, 
+                    truncation=True, 
+                    return_tensors="pt"
+                )
+                encoder_hidden_states = text_encoder(
+                    prompt_ids.input_ids,
+                    return_dict=False
+                )[0]
+                new_examples['encoder_attention_mask'] = prompt_ids.attention_mask
+                new_examples['encoder_hidden_states'] = encoder_hidden_states
+
+            return new_examples
+        
+        # DataLoaders creation:
+        train_dataloader = torch.utils.data.DataLoader(
+            train_dataset,
+            batch_sampler=batch_sampler,
+            collate_fn=collate_fn,
+            persistent_workers=True if args.dataloader_num_workers != 0 else False,
+            num_workers=args.dataloader_num_workers,
+        )
+    else:
+        # DataLoaders creation:
+        batch_sampler_generator = torch.Generator().manual_seed(args.seed)
+        batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size)
+        train_dataloader = torch.utils.data.DataLoader(
+            train_dataset,
+            batch_sampler=batch_sampler, 
+            persistent_workers=True if args.dataloader_num_workers != 0 else False,
+            num_workers=args.dataloader_num_workers,
+        )
+
+    # Scheduler and math around the number of training steps.
+    overrode_max_train_steps = False
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+        overrode_max_train_steps = True
+
+    lr_scheduler = get_scheduler(
+        args.lr_scheduler,
+        optimizer=optimizer,
+        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+        num_training_steps=args.max_train_steps * accelerator.num_processes,
+    )
+
+    # Prepare everything with our `accelerator`.
+    network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+        network, optimizer, train_dataloader, lr_scheduler
+    )
+
+    # Move text_encode and vae to gpu and cast to weight_dtype
+    vae.to(accelerator.device, dtype=weight_dtype)
+    transformer3d.to(accelerator.device, dtype=weight_dtype)
+    if not args.enable_text_encoder_in_dataloader:
+        text_encoder.to(accelerator.device)
+
+    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if overrode_max_train_steps:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    # Afterwards we recalculate our number of training epochs
+    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    # We need to initialize the trackers we use, and also store our configuration.
+    # The trackers initializes automatically on the main process.
+    if accelerator.is_main_process:
+        tracker_config = dict(vars(args))
+        tracker_config.pop("validation_prompts")
+        accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+    # Function for unwrapping if model was compiled with `torch.compile`.
+    def unwrap_model(model):
+        model = accelerator.unwrap_model(model)
+        model = model._orig_mod if is_compiled_module(model) else model
+        return model
+
+    # Train!
+    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    global_step = 0
+    first_epoch = 0
+
+    # Potentially load in the weights and states from a previous save
+    if args.resume_from_checkpoint:
+        if args.resume_from_checkpoint != "latest":
+            path = os.path.basename(args.resume_from_checkpoint)
+        else:
+            # Get the most recent checkpoint
+            dirs = os.listdir(args.output_dir)
+            dirs = [d for d in dirs if d.startswith("checkpoint")]
+            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+            path = dirs[-1] if len(dirs) > 0 else None
+
+        if path is None:
+            accelerator.print(
+                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+            )
+            args.resume_from_checkpoint = None
+            initial_global_step = 0
+        else:
+            global_step = int(path.split("-")[1])
+
+            initial_global_step = global_step
+
+            pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl")
+            if os.path.exists(pkl_path):
+                with open(pkl_path, 'rb') as file:
+                    _, first_epoch = pickle.load(file)
+            else:
+                first_epoch = global_step // num_update_steps_per_epoch
+            print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.")
+
+            from safetensors.torch import load_file, safe_open
+            state_dict = load_file(os.path.join(os.path.join(args.output_dir, path), "lora_diffusion_pytorch_model.safetensors"))
+            m, u = accelerator.unwrap_model(network).load_state_dict(state_dict, strict=False)
+            print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
+
+            accelerator.print(f"Resuming from checkpoint {path}")
+            accelerator.load_state(os.path.join(args.output_dir, path))
+    else:
+        initial_global_step = 0
+
+    # function for saving/removing
+    def save_model(ckpt_file, unwrapped_nw):
+        os.makedirs(args.output_dir, exist_ok=True)
+        accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
+        unwrapped_nw.save_weights(ckpt_file, weight_dtype, None)
+
+    progress_bar = tqdm(
+        range(0, args.max_train_steps),
+        initial=initial_global_step,
+        desc="Steps",
+        # Only show the progress bar once on each machine.
+        disable=not accelerator.is_local_main_process,
+    )
+
+    if args.multi_stream and args.train_mode != "normal":
+        # create extra cuda streams to speedup inpaint vae computation
+        vae_stream_1 = torch.cuda.Stream()
+        vae_stream_2 = torch.cuda.Stream()
+    else:
+        vae_stream_1 = None
+        vae_stream_2 = None
+
+    for epoch in range(first_epoch, args.num_train_epochs):
+        train_loss = 0.0
+        batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch)
+        for step, batch in enumerate(train_dataloader):
+            # Data batch sanity check
+            if epoch == first_epoch and step == 0:
+                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
+                pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True)
+                for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
+                    pixel_value = pixel_value[None, ...]
+                    gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}'
+                    save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True)
+                if args.train_mode != "normal":
+                    mask_pixel_values, texts = batch['mask_pixel_values'].cpu(), batch['text']
+                    mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w")
+                    for idx, (pixel_value, text) in enumerate(zip(mask_pixel_values, texts)):
+                        pixel_value = pixel_value[None, ...]
+                        save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/mask_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.gif", rescale=True)
+
+            with accelerator.accumulate(transformer3d):
+                # Convert images to latent space
+                pixel_values = batch["pixel_values"].to(weight_dtype)
+                if args.training_with_video_token_length:
+                    if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                        pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1))
+                        if args.enable_text_encoder_in_dataloader:
+                            batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1))
+                            batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1))
+                        else:
+                            batch['text'] = batch['text'] * 4
+                    elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                        pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1))
+                        if args.enable_text_encoder_in_dataloader:
+                            batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1))
+                            batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1))
+                        else:
+                            batch['text'] = batch['text'] * 2
+                
+                if args.train_mode != "normal":
+                    mask_pixel_values = batch["mask_pixel_values"].to(weight_dtype)
+                    mask = batch["mask"].to(weight_dtype)
+                    if args.training_with_video_token_length:
+                        if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                            mask_pixel_values = torch.tile(mask_pixel_values, (4, 1, 1, 1, 1))
+                            mask = torch.tile(mask, (4, 1, 1, 1, 1))
+                        elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]:
+                            mask_pixel_values = torch.tile(mask_pixel_values, (2, 1, 1, 1, 1))
+                            mask = torch.tile(mask, (2, 1, 1, 1, 1))
+
+                def create_special_list(length):
+                    if length == 1:
+                        return [1.0]
+                    if length >= 2:
+                        last_element = 0.90
+                        remaining_sum = 1.0 - last_element
+                        other_elements_value = remaining_sum / (length - 1)
+                        special_list = [other_elements_value] * (length - 1) + [last_element]
+                        return special_list
+                    
+                if args.keep_all_node_same_token_length:
+                    actual_token_length = index_rng.choice(numbers_list)
+
+                    actual_video_length = (min(
+                            actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames
+                    ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1
+                    actual_video_length = int(max(actual_video_length, 1))
+                else:
+                    actual_video_length = None
+
+                if args.random_frame_crop:
+                    select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))]
+                    select_frames_prob = np.array(create_special_list(len(select_frames)))
+                    
+                    if rng is None:
+                        temp_n_frames = np.random.choice(select_frames, p = select_frames_prob)
+                    else:
+                        temp_n_frames = rng.choice(select_frames, p = select_frames_prob)
+                    if args.keep_all_node_same_token_length:
+                        temp_n_frames = min(actual_video_length, temp_n_frames)
+
+                    pixel_values = pixel_values[:, :temp_n_frames, :, :]
+
+                    if args.train_mode != "normal":
+                        mask_pixel_values = mask_pixel_values[:, :temp_n_frames, :, :]
+                        mask = mask[:, :temp_n_frames, :, :]
+
+                if args.train_mode != "normal":
+                    t2v_flag = [(_mask == 1).all() for _mask in mask]
+                    new_t2v_flag = []
+                    for _mask in t2v_flag:
+                        if _mask and np.random.rand() < 0.90:
+                            new_t2v_flag.append(0)
+                        else:
+                            new_t2v_flag.append(1)
+                    t2v_flag = torch.from_numpy(np.array(new_t2v_flag)).to(accelerator.device, dtype=weight_dtype)
+
+                if args.low_vram:
+                    torch.cuda.empty_cache()
+                    vae.to(accelerator.device)
+                    if not args.enable_text_encoder_in_dataloader:
+                        text_encoder.to("cpu")
+
+                with torch.no_grad():
+                    # This way is quicker when batch grows up
+                    def _slice_vae(pixel_values):
+                        pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
+                        bs = args.vae_mini_batch
+                        new_pixel_values = []
+                        for i in range(0, pixel_values.shape[0], bs):
+                            pixel_values_bs = pixel_values[i : i + bs]
+                            pixel_values_bs = vae.encode(pixel_values_bs)[0]
+                            pixel_values_bs = pixel_values_bs.sample()
+                            new_pixel_values.append(pixel_values_bs)
+                            vae._clear_fake_context_parallel_cache()
+                        return torch.cat(new_pixel_values, dim = 0)
+                    if vae_stream_1 is not None:
+                        vae_stream_1.wait_stream(torch.cuda.current_stream())
+                        with torch.cuda.stream(vae_stream_1):
+                            latents = _slice_vae(pixel_values)
+                    else:
+                        latents = _slice_vae(pixel_values)
+                    latents = latents * vae.config.scaling_factor
+
+                    if args.train_mode != "normal":
+                        mask = rearrange(mask, "b f c h w -> b c f h w")
+                        mask = 1 - mask
+                        mask = resize_mask(mask, latents)
+
+                        if unwrap_model(transformer3d).config.add_noise_in_inpaint_model:
+                            mask_pixel_values = add_noise_to_reference_video(mask_pixel_values)
+                        mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w")
+                        bs = args.vae_mini_batch
+                        new_mask_pixel_values = []
+                        for i in range(0, mask_pixel_values.shape[0], bs):
+                            mask_pixel_values_bs = mask_pixel_values[i : i + bs]
+                            mask_pixel_values_bs = vae.encode(mask_pixel_values_bs)[0]
+                            mask_pixel_values_bs = mask_pixel_values_bs.sample()
+                            new_mask_pixel_values.append(mask_pixel_values_bs)
+                            vae._clear_fake_context_parallel_cache()
+                        mask_latents = torch.cat(new_mask_pixel_values, dim = 0)
+
+                        if vae_stream_2 is not None:
+                            torch.cuda.current_stream().wait_stream(vae_stream_2) 
+
+                        inpaint_latents = torch.concat([mask, mask_latents], dim=1)
+                        inpaint_latents = t2v_flag[:, None, None, None, None] * inpaint_latents
+                        inpaint_latents = inpaint_latents * vae.config.scaling_factor
+                        inpaint_latents = rearrange(inpaint_latents, "b c f h w -> b f c h w")
+
+                    latents = rearrange(latents, "b c f h w -> b f c h w")
+                        
+                # wait for latents = vae.encode(pixel_values) to complete
+                if vae_stream_1 is not None:
+                    torch.cuda.current_stream().wait_stream(vae_stream_1)
+
+                if args.low_vram:
+                    vae.to('cpu')
+                    torch.cuda.empty_cache()
+                    if not args.enable_text_encoder_in_dataloader:
+                        text_encoder.to(accelerator.device)
+
+                if args.enable_text_encoder_in_dataloader:
+                    prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device)
+                else:
+                    with torch.no_grad():
+                        prompt_ids = tokenizer(
+                            batch['text'], 
+                            max_length=args.tokenizer_max_length, 
+                            padding="max_length", 
+                            add_special_tokens=True, 
+                            truncation=True, 
+                            return_tensors="pt"
+                        )
+                        prompt_embeds = text_encoder(
+                            prompt_ids.input_ids.to(latents.device),
+                            return_dict=False
+                        )[0]
+
+                if args.low_vram and not args.enable_text_encoder_in_dataloader:
+                    text_encoder.to('cpu')
+                    torch.cuda.empty_cache()
+
+                bsz = latents.shape[0]
+                noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype)
+                # Sample a random timestep for each image
+                # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng)
+                timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng)
+                timesteps = timesteps.long()
+
+                # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+                def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+                    tw = tgt_width
+                    th = tgt_height
+                    h, w = src
+                    r = h / w
+                    if r > (th / tw):
+                        resize_height = th
+                        resize_width = int(round(th / h * w))
+                    else:
+                        resize_width = tw
+                        resize_height = int(round(tw / w * h))
+
+                    crop_top = int(round((th - resize_height) / 2.0))
+                    crop_left = int(round((tw - resize_width) / 2.0))
+
+                    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+                def _prepare_rotary_positional_embeddings(
+                    height: int,
+                    width: int,
+                    num_frames: int,
+                    device: torch.device
+                ):
+                    vae_scale_factor_spatial = (
+                        2 ** (len(vae.config.block_out_channels) - 1) if vae is not None else 8
+                    )
+                    grid_height = height // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    grid_width = width // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    base_size_width = 720 // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+                    base_size_height = 480 // (vae_scale_factor_spatial * unwrap_model(transformer3d).config.patch_size)
+
+                    grid_crops_coords = get_resize_crop_region_for_grid(
+                        (grid_height, grid_width), base_size_width, base_size_height
+                    )
+                    freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+                        embed_dim=unwrap_model(transformer3d).config.attention_head_dim,
+                        crops_coords=grid_crops_coords,
+                        grid_size=(grid_height, grid_width),
+                        temporal_size=num_frames,
+                        use_real=True,
+                    )
+                    freqs_cos = freqs_cos.to(device=device)
+                    freqs_sin = freqs_sin.to(device=device)
+                    return freqs_cos, freqs_sin
+
+                height, width = batch["pixel_values"].size()[-2], batch["pixel_values"].size()[-1]
+                # 7. Create rotary embeds if required
+                image_rotary_emb = (
+                    _prepare_rotary_positional_embeddings(height, width, latents.size(1), latents.device)
+                    if unwrap_model(transformer3d).config.use_rotary_positional_embeddings
+                    else None
+                )
+                prompt_embeds = prompt_embeds.to(device=latents.device)
+
+                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+                if noise_scheduler.config.prediction_type == "epsilon":
+                    target = noise
+                elif noise_scheduler.config.prediction_type == "v_prediction":
+                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
+                else:
+                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+                
+                # predict the noise residual
+                noise_pred = transformer3d(
+                    hidden_states=noisy_latents,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep=timesteps,
+                    image_rotary_emb=image_rotary_emb,
+                    return_dict=False,
+                    inpaint_latents=inpaint_latents if args.train_mode != "normal" else None,
+                )[0]
+                loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+                if args.motion_sub_loss and noise_pred.size()[1] > 2:
+                    gt_sub_noise = noise_pred[:, 1:, :].float() - noise_pred[:, :-1, :].float()
+                    pre_sub_noise = target[:, 1:, :].float() - target[:, :-1, :].float()
+                    sub_loss = F.mse_loss(gt_sub_noise, pre_sub_noise, reduction="mean")
+                    loss = loss * (1 - args.motion_sub_loss_ratio) + sub_loss * args.motion_sub_loss_ratio
+
+                # Gather the losses across all processes for logging (if we use distributed training).
+                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+                train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+                # Backpropagate
+                accelerator.backward(loss)
+                if accelerator.sync_gradients:
+                    accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad()
+
+            # Checks if the accelerator has performed an optimization step behind the scenes
+            if accelerator.sync_gradients:
+                progress_bar.update(1)
+                global_step += 1
+                accelerator.log({"train_loss": train_loss}, step=global_step)
+                train_loss = 0.0
+
+                if global_step % args.checkpointing_steps == 0:
+                    if args.use_deepspeed or accelerator.is_main_process:
+                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+                        if args.checkpoints_total_limit is not None:
+                            checkpoints = os.listdir(args.output_dir)
+                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+                            if len(checkpoints) >= args.checkpoints_total_limit:
+                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+                                removing_checkpoints = checkpoints[0:num_to_remove]
+
+                                logger.info(
+                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+                                )
+                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+                                for removing_checkpoint in removing_checkpoints:
+                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+                                    shutil.rmtree(removing_checkpoint)
+                        if not args.save_state:
+                            safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors")
+                            save_model(safetensor_save_path, accelerator.unwrap_model(network))
+                            logger.info(f"Saved safetensor to {safetensor_save_path}")
+                        else:
+                            accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+                            accelerator.save_state(accelerator_save_path)
+                            logger.info(f"Saved state to {accelerator_save_path}")
+
+                if accelerator.is_main_process:
+                    if args.validation_prompts is not None and global_step % args.validation_steps == 0:
+                        log_validation(
+                            vae,
+                            text_encoder,
+                            tokenizer,
+                            transformer3d,
+                            network,
+                            args,
+                            accelerator,
+                            weight_dtype,
+                            global_step,
+                        )
+
+            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+            progress_bar.set_postfix(**logs)
+
+            if global_step >= args.max_train_steps:
+                break
+
+        if accelerator.is_main_process:
+            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+                log_validation(
+                    vae,
+                    text_encoder,
+                    tokenizer,
+                    transformer3d,
+                    network,
+                    args,
+                    accelerator,
+                    weight_dtype,
+                    global_step,
+                )
+
+    # Create the pipeline using the trained modules and save it.
+    accelerator.wait_for_everyone()
+    if accelerator.is_main_process:
+        safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors")
+        accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+        save_model(safetensor_save_path, accelerator.unwrap_model(network))
+        if args.save_state:
+            accelerator.save_state(accelerator_save_path)
+        logger.info(f"Saved state to {accelerator_save_path}")
+
+    accelerator.end_training()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/train_lora.sh b/scripts/train_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0a8df2f58fad3a71c9a6341d8c3e49db15391c05
--- /dev/null
+++ b/scripts/train_lora.sh
@@ -0,0 +1,38 @@
+export MODEL_NAME="models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
+export DATASET_NAME="datasets/internal_datasets/"
+export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
+export NCCL_IB_DISABLE=1
+export NCCL_P2P_DISABLE=1
+NCCL_DEBUG=INFO
+
+# When train model with multi machines, use "--config_file accelerate.yaml" instead of "--mixed_precision='bf16'".
+accelerate launch --mixed_precision="bf16" scripts/train_lora.py \
+  --pretrained_model_name_or_path=$MODEL_NAME \
+  --train_data_dir=$DATASET_NAME \
+  --train_data_meta=$DATASET_META_NAME \
+  --image_sample_size=1280 \
+  --video_sample_size=256 \
+  --token_sample_size=512 \
+  --video_sample_stride=3 \
+  --video_sample_n_frames=49 \
+  --train_batch_size=1 \
+  --video_repeat=1 \
+  --gradient_accumulation_steps=1 \
+  --dataloader_num_workers=8 \
+  --num_train_epochs=100 \
+  --checkpointing_steps=50 \
+  --learning_rate=1e-04 \
+  --seed=42 \
+  --output_dir="output_dir" \
+  --gradient_checkpointing \
+  --mixed_precision="bf16" \
+  --adam_weight_decay=3e-2 \
+  --adam_epsilon=1e-10 \
+  --vae_mini_batch=1 \
+  --max_grad_norm=0.05 \
+  --random_hw_adapt \
+  --training_with_video_token_length \
+  --random_frame_crop \
+  --enable_bucket \
+  --low_vram \
+  --train_mode="inpaint" 
\ No newline at end of file
diff --git a/worker_runpod.py b/worker_runpod.py
new file mode 100644
index 0000000000000000000000000000000000000000..716ea7756ac6ad4559bcecad319d42630d2b9395
--- /dev/null
+++ b/worker_runpod.py
@@ -0,0 +1,138 @@
+import json
+import os
+import numpy as np
+import torch
+from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
+                       DPMSolverMultistepScheduler,
+                       EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
+                       PNDMScheduler)
+from transformers import T5EncoderModel, T5Tokenizer
+from omegaconf import OmegaConf
+from PIL import Image
+from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
+from cogvideox.models.autoencoder_magvit import AutoencoderKLCogVideoX
+from cogvideox.pipeline.pipeline_cogvideox import CogVideoX_Fun_Pipeline
+from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
+from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
+from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid, ASPECT_RATIO_512, get_closest_ratio, to_pil
+from huggingface_hub import HfApi, HfFolder
+
+# Low GPU memory mode
+low_gpu_memory_mode = False
+
+# Model loading section
+model_id = "/content/model"
+transformer = CogVideoXTransformer3DModel.from_pretrained_2d(
+    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+).to(torch.bfloat16)
+
+vae = AutoencoderKLCogVideoX.from_pretrained(
+    model_id, subfolder="vae"
+).to(torch.bfloat16)
+
+text_encoder = T5EncoderModel.from_pretrained(
+    model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16
+)
+
+sampler_dict = {
+    "Euler": EulerDiscreteScheduler,
+    "Euler A": EulerAncestralDiscreteScheduler,
+    "DPM++": DPMSolverMultistepScheduler,
+    "PNDM": PNDMScheduler,
+    "DDIM_Cog": CogVideoXDDIMScheduler,
+    "DDIM_Origin": DDIMScheduler,
+}
+scheduler = sampler_dict["DPM++"].from_pretrained(model_id, subfolder="scheduler")
+
+# Pipeline setup
+if transformer.config.in_channels != vae.config.latent_channels:
+    pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
+        model_id, vae=vae, text_encoder=text_encoder,
+        transformer=transformer, scheduler=scheduler,
+        torch_dtype=torch.bfloat16
+    )
+else:
+    pipeline = CogVideoX_Fun_Pipeline.from_pretrained(
+        model_id, vae=vae, text_encoder=text_encoder,
+        transformer=transformer, scheduler=scheduler,
+        torch_dtype=torch.bfloat16
+    )
+
+if low_gpu_memory_mode:
+    pipeline.enable_sequential_cpu_offload()
+else:
+    pipeline.enable_model_cpu_offload()
+
+@torch.inference_mode()
+def generate(input):
+    values = input["input"]
+    prompt = values["prompt"]
+    negative_prompt = values.get("negative_prompt", "")
+    guidance_scale = values.get("guidance_scale", 6.0)
+    seed = values.get("seed", 42)
+    num_inference_steps = values.get("num_inference_steps", 50)
+    base_resolution = values.get("base_resolution", 512)
+    
+    video_length = values.get("video_length", 53)
+    fps = values.get("fps", 10)
+    lora_weight = values.get("lora_weight", 1.00)
+    save_path = "samples"
+    partial_video_length = values.get("partial_video_length", None)
+    overlap_video_length = values.get("overlap_video_length", 4)
+    validation_image_start = values.get("validation_image_start", "asset/1.png")
+    validation_image_end = values.get("validation_image_end", None)
+    
+    generator = torch.Generator(device="cuda").manual_seed(seed)
+    aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
+    start_img = Image.open(validation_image_start)
+    original_width, original_height = start_img.size
+    closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
+    height, width = [int(x / 16) * 16 for x in closest_size]
+    sample_size = [height, width]
+    if partial_video_length is not None:
+        # Handle ultra-long video generation if required
+        # ... (existing logic for partial video generation)
+    else:
+        # Standard video generation
+        video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
+        input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, validation_image_end, video_length=video_length, sample_size=sample_size)
+        
+        with torch.no_grad():
+            sample = pipeline(
+                prompt=prompt,
+                num_frames=video_length,
+                negative_prompt=negative_prompt,
+                height=sample_size[0],
+                width=sample_size[1],
+                generator=generator,
+                guidance_scale=guidance_scale,
+                num_inference_steps=num_inference_steps,
+                video=input_video,
+                mask_video=input_video_mask
+            ).videos
+    
+    if not os.path.exists(save_path):
+        os.makedirs(save_path, exist_ok=True)
+
+    index = len([path for path in os.listdir(save_path)]) + 1
+    prefix = str(index).zfill(8)
+    video_path = os.path.join(save_path, f"{prefix}.mp4")
+    save_videos_grid(sample, video_path, fps=fps)
+
+    # Upload final video to Hugging Face repository
+    #hf_api = HfApi()
+    #repo_id = values.get("repo_id", "your-username/your-repo")  # Set your HF repo
+    #hf_api.upload_file(
+    #    path_or_fileobj=video_path,
+    #    path_in_repo=f"{prefix}.mp4",
+    #    repo_id=repo_id,
+    #    repo_type="model"  # or "dataset" if using a dataset repo
+    #)
+
+    # Prepare output
+    #result_url = f"https://huggingface.co/{repo_id}/blob/main/{prefix}.mp4"
+    result_url = ""
+    job_id = values.get("job_id", "default-job-id")  # For RunPod job tracking
+    return {"jobId": job_id, "result": result_url, "status": "DONE"}
+
+runpod.serverless.start({"handler": generate})